From e69f5bf5a313e12f468bff81887da53bcac3f9c7 Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 29 Aug 2024 09:28:58 +0000 Subject: [PATCH 0001/1018] Xccl process group for Pytorch --- CMakeLists.txt | 6 + build_variables.bzl | 4 + caffe2/CMakeLists.txt | 13 + caffe2/core/macros.h.in | 1 + cmake/Dependencies.cmake | 16 + cmake/External/xccl.cmake | 17 + cmake/Modules/FindXCCL.cmake | 68 +++ cmake/Summary.cmake | 6 + setup.py | 4 + test/distributed/test_c10d_common.py | 9 +- test/distributed/test_c10d_xccl.py | 303 +++++++++++++ torch/CMakeLists.txt | 7 + torch/_C/_distributed_c10d.pyi | 9 + torch/csrc/distributed/c10d/Ops.cpp | 20 + torch/csrc/distributed/c10d/ProcessGroup.cpp | 2 + torch/csrc/distributed/c10d/ProcessGroup.hpp | 3 + .../distributed/c10d/ProcessGroupXCCL.cpp | 401 ++++++++++++++++++ .../distributed/c10d/ProcessGroupXCCL.hpp | 308 ++++++++++++++ torch/csrc/distributed/c10d/init.cpp | 22 + torch/distributed/distributed_c10d.py | 48 ++- torch/testing/_internal/common_distributed.py | 11 +- 21 files changed, 1268 insertions(+), 10 deletions(-) create mode 100644 cmake/External/xccl.cmake create mode 100644 cmake/Modules/FindXCCL.cmake create mode 100644 test/distributed/test_c10d_xccl.py create mode 100644 torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp create mode 100644 torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5139c0a478e788..89ef59681bfff4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -275,6 +275,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_XCCL "Use XCCL" ON + "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" @@ -353,6 +355,8 @@ cmake_dependent_option(USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF) cmake_dependent_option(USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF) +cmake_dependent_option(USE_C10D_XCCL "USE C10D XCCL" ON + "USE_DISTRIBUTED;USE_XCCL" OFF) cmake_dependent_option(USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF) cmake_dependent_option( @@ -365,6 +369,8 @@ cmake_dependent_option( USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF) cmake_dependent_option( USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF) +cmake_dependent_option( + USE_C10D_XCCL "USE C10D XCCL" ON "USE_DISTRIBUTED;USE_XCCL" OFF) cmake_dependent_option( USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF) cmake_dependent_option( diff --git a/build_variables.bzl b/build_variables.bzl index e05c94bd83f577..98b721617b609c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -700,6 +700,10 @@ libtorch_cuda_sources = libtorch_cuda_core_sources + libtorch_cuda_distributed_s "torch/csrc/cuda/nccl.cpp", ] +libtorch_xpu_distributed_extra_sources = [ + "torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp", +] + torch_cpp_srcs = [ "torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA "torch/csrc/api/src/data/datasets/mnist.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8ed93cdff0479c..d44a8da210462f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1014,6 +1014,9 @@ elseif(USE_CUDA) endif() if(USE_XPU) + if(USE_XCCL) + append_filelist("libtorch_xpu_distributed_extra_sources" Caffe2_XPU_SRCS) + endif() add_library(torch_xpu ${Caffe2_XPU_SRCS}) torch_compile_options(torch_xpu) # see cmake/public/utils.cmake target_compile_definitions(torch_xpu PRIVATE USE_XPU) @@ -1079,6 +1082,10 @@ if(USE_XPU) include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS}) endif() + if(USE_XCCL) + target_link_libraries(torch_xpu PRIVATE torch::xccl) + target_compile_definitions(torch_xpu PRIVATE USE_XCCL) + endif() endif() if(NOT MSVC AND USE_XNNPACK) @@ -1365,6 +1372,12 @@ if(USE_DISTRIBUTED) target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) endif() endif() + if(USE_C10D_XCCL) + target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) + set_source_files_properties( + ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp + PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_ZE;CCL_ENABLE_SYCL") + endif() if(USE_MPI AND USE_C10D_MPI) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set_source_files_properties( diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 2929f105b31faa..e5398a83cad947 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -45,6 +45,7 @@ {"USE_CUDNN", "${USE_CUDNN}"}, \ {"CUDNN_VERSION", "${CUDNN_VERSION}"}, \ {"USE_NCCL", "${USE_NCCL}"}, \ + {"USE_XCCL", "${USE_XCCL}"}, \ {"USE_MPI", "${USE_MPI}"}, \ {"USE_GFLAGS", "${USE_GFLAGS}"}, \ {"USE_GLOG", "${USE_GLOG}"}, \ diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index ef33a3165340c1..8abea841fcf61c 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1150,6 +1150,22 @@ if(USE_CUDA) include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() +# ---[ XCCL +if(USE_XCCL) + if(NOT USE_XPU) + message(WARNING + "Not using XPU, so disabling USE_XCCL. Suppress this warning with " + "-DUSE_XCCL=OFF.") + caffe2_update_option(USE_XCCL OFF) + elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") + message(WARNING "USE_XCCL is currently only supported under Linux.") + caffe2_update_option(USE_XCCL OFF) + else() + include(${CMAKE_CURRENT_LIST_DIR}/External/xccl.cmake) + list(APPEND Caffe2_XPU_DEPENDENCY_LIBS torch::xccl) + endif() +endif() + if(USE_DISTRIBUTED AND USE_TENSORPIPE) if(MSVC) message(WARNING "Tensorpipe cannot be used on Windows.") diff --git a/cmake/External/xccl.cmake b/cmake/External/xccl.cmake new file mode 100644 index 00000000000000..56205b381b1324 --- /dev/null +++ b/cmake/External/xccl.cmake @@ -0,0 +1,17 @@ +if(NOT __XCCL_INCLUDED) + set(__XCCL_INCLUDED TRUE) + + if(USE_XCCL) + # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. + find_package(XCCL REQUIRED) + if(XCCL_FOUND) + add_library(torch::xccl INTERFACE IMPORTED) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${XCCL_INCLUDE_DIR}) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES + ${XCCL_LIBRARY}) + endif() + endif() +endif() diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake new file mode 100644 index 00000000000000..56b7fc0f7dcf32 --- /dev/null +++ b/cmake/Modules/FindXCCL.cmake @@ -0,0 +1,68 @@ +# This will define the following variables: +# XCCL_FOUND : True if the system has the XCCL library. +# XCCL_INCLUDE_DIR : Include directories needed to use XCCL. +# XCCL_LIBRARY_DIR :The path to the XCCL library. +# XCCL_LIBRARY : XCCL library fullname. + +include(FindPackageHandleStandardArgs) + +set(XCCL_ROOT "") +if(DEFINED ENV{CCL_ROOT}) + set(XCCL_ROOT $ENV{CCL_ROOT}) +endif() + +string(COMPARE EQUAL "${XCCL_ROOT}" "" nosyclfound) +if(nosyclfound) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "XCCL library not set!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +# Find include path from binary. +find_file( + XCCL_INCLUDE_DIR + NAMES include + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find include/oneapi path from include path. +find_file( + XCCL_INCLUDE_ONEAPI_DIR + NAMES oneapi + HINTS ${XCCL_ROOT}/include/ + NO_DEFAULT_PATH +) + +list(APPEND XCCL_INCLUDE_DIR ${XCCL_INCLUDE_ONEAPI_DIR}) + +# Find library directory from binary. +find_file( + XCCL_LIBRARY_DIR + NAMES lib + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find XCCL library fullname. +find_library( + XCCL_LIBRARY + NAMES ccl + HINTS ${XCCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "XCCL library is incomplete!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +find_package_handle_standard_args( + XCCL + FOUND_VAR XCCL_FOUND + REQUIRED_VARS XCCL_INCLUDE_DIR XCCL_LIBRARY_DIR XCCL_LIBRARY + REASON_FAILURE_MESSAGE "${XCCL_REASON_FAILURE}" +) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index d51c451589c2c4..229ff112ab3187 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -153,6 +153,12 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}") endif() message(STATUS " USE_ITT : ${USE_ITT}") + message(STATUS " USE_XCCL : ${USE_XCCL}") + if(${USE_XCCL}) + message(STATUS " USE_C10D_XCCL : ${USE_C10D_XCCL}") + message(STATUS " XCCL include path : ${XCCL_INCLUDE_DIR}") + message(STATUS " XCCL library : ${XCCL_LIBRARY}") + endif() message(STATUS " USE_NCCL : ${USE_NCCL}") if(${USE_NCCL}) message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") diff --git a/setup.py b/setup.py index 92f1e2ddc7bcd3..ad48f4b0108633 100644 --- a/setup.py +++ b/setup.py @@ -645,6 +645,10 @@ def run(self): report("-- Building NCCL library") else: report("-- Not using NCCL") + if cmake_cache_vars["USE_XCCL"]: + report("-- Building XCCL library") + else: + report("-- Not using XCCL") if cmake_cache_vars["USE_DISTRIBUTED"]: if IS_WINDOWS: report("-- Building without distributed package") diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 6a0621f3f49913..3e5538d57e38ae 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -66,8 +66,13 @@ def gpus_for_rank(world_size): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - visible_devices = list(range(torch.cuda.device_count())) - gpus_per_process = torch.cuda.device_count() // world_size + device_count = ( + torch.xpu.device_count() + if torch.xpu.is_available() + else torch.cuda.device_count() + ) + visible_devices = list(range(device_count)) + gpus_per_process = device_count // world_size gpus_for_rank = [] for rank in range(world_size): gpus_for_rank.append( diff --git a/test/distributed/test_c10d_xccl.py b/test/distributed/test_c10d_xccl.py new file mode 100644 index 00000000000000..704cdd414e554b --- /dev/null +++ b/test/distributed/test_c10d_xccl.py @@ -0,0 +1,303 @@ +# Owner(s): ["oncall: distributed"] + +import math +import os +import sys +import time +from datetime import timedelta +from unittest import mock + +import torch +import torch.distributed as c10d + + +if not c10d.is_available() or not c10d.is_xccl_available(): + print("c10d XCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + +import test_c10d_common + +import torch.distributed as dist +import torch.testing._internal.common_utils as common +from torch.testing._internal.common_distributed import ( + init_multigpu_helper, + MultiProcessTestCase, + requires_xccl, +) +from torch.testing._internal.common_utils import ( + retry_on_connect_failures, + run_tests, + skip_but_pass_in_sandcastle_if, + TEST_XPU, + TestCase, +) + + +def simple_reduce_tests(rank, world_size): + tests = [ + ( + c10d.ReduceOp.SUM, + torch.tensor([rank + 1.0]), + torch.tensor([float(world_size * (world_size + 1) / 2)]), + ), + ( + c10d.ReduceOp.PRODUCT, + torch.tensor([rank + 1.0]), + torch.tensor([float(math.factorial(world_size))]), + ), + ( + c10d.ReduceOp.MIN, + torch.tensor([rank + 1.0]), + torch.tensor([1.0]), + ), + ( + c10d.ReduceOp.MAX, + torch.tensor([rank + 1.0]), + torch.tensor([world_size]), + ), + ] + + return tests + + +TEST_MULTIXPU = torch.xpu.device_count() > 1 + + +class RendezvousEnvTest(TestCase): + @retry_on_connect_failures + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_common_errors(self): + vars = { + "WORLD_SIZE": "1", + "RANK": "0", + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(common.find_free_port()), + } + + class Env: + def __init__(self, vars): + self.env_patcher = mock.patch.dict(os.environ, vars, clear=True) + + def __enter__(self): + self.env_patcher.start() + + def __exit__(self, type, value, traceback): + self.env_patcher.stop() + + def without(d, key): + d = d.copy() + d.pop(key) + return d + + def withouts(d, keys): + d = d.copy() + for key in keys: + d.pop(key) + return d + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + with self.assertRaisesRegex(ValueError, "RANK expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", rank=0) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + c10d.init_process_group(backend="xccl", rank=0, world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(vars): + c10d.init_process_group(backend="xccl") + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "MASTER_ADDR")): + self.assertEqual(None, os.environ.get("MASTER_ADDR")) + with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "MASTER_PORT")): + self.assertEqual(None, os.environ.get("MASTER_PORT")) + with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?world_size={1}") + _, _, size = next(gen) + self.assertEqual(size, 1) + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + gen = c10d.rendezvous(f"env://?rank={0}") + _, rank, _ = next(gen) + self.assertEqual(rank, 0) + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}") + _, rank, size = next(gen) + self.assertEqual(rank, 0) + self.assertEqual(size, 1) + + +class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase): + @requires_xccl() + @retry_on_connect_failures + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_default_store_timeout_nccl(self): + self._test_default_store_timeout("xccl") + + +class ProcessGroupXCCLTest(MultiProcessTestCase): + def _create_process_group_xccl( + self, timeout=timedelta(seconds=600), device_id=None + ): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + timeout=timeout, + device_id=device_id, + ) + pg = c10d.distributed_c10d._get_default_group() + return pg + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + return 2 + + @property + def rank_to_GPU(self): + # return rank to GPU map + return init_multigpu_helper(self.world_size, "xccl") + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_close_multi_pg_unordered(self): + pg = self._create_process_group_xccl() + device = self.rank_to_GPU[self.rank][0] + t = torch.rand(10, 10, device=device) + # First allreduce to initialize default PG's communicator. + pg.allreduce(t).wait() + new_pg1 = c10d.new_group([0, 1]) + new_pg2 = c10d.new_group([0, 1]) + if self.rank == 0 or self.rank == 1: + t1 = torch.rand(10, 10, device=device) + t2 = torch.rand(10, 10, device=device) + new_pg1.allreduce(t1).wait() + new_pg2.allreduce(t2).wait() + if self.rank == 0: + dist.destroy_process_group(new_pg2) + # force destruction of pg2 first + del new_pg2 + dist.destroy_process_group(new_pg1) + del new_pg1 + if self.rank == 1: + c10d.destroy_process_group(new_pg1) + # force destruction of pg1 first + del new_pg1 + dist.destroy_process_group(new_pg2) + del new_pg2 + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_file_store_check(self): + # self.file_name is created using "delete=False" + # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + pg = dist.distributed_c10d._get_default_group() + self.assertEqual(pg.rank(), self.rank) + self.assertEqual(pg.size(), self.world_size) + # give enough time for check() to be executed multiple times + time.sleep(2) + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs") + def test_set_process_group_desc(self): + device = torch.device(f"xpu:{self.rank}") + pg_default = self._create_process_group_xccl(device_id=device) + self.assertEqual(pg_default.group_desc, "default_pg") + pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") + self.assertEqual(pg_1.group_desc, "test_purpose") + pg_2 = c10d.new_group([0, 1]) + self.assertEqual(pg_2.group_desc, "undefined") + + def _test_allreduce_basics(self, fn): + pg = self._create_process_group_xccl() + device = torch.device("xpu:" + str(self.rank)) + # Single input tests + tests = simple_reduce_tests(self.rank, self.world_size) + for op, input, expected in tests: + opts = c10d.AllreduceOptions() + opts.reduceOp = op + tensor = fn(input.to(device)) + fut = pg.allreduce([tensor], opts).get_future() + fut.wait() + result = fut.value() + self.assertEqual(expected, result[0], exact_dtype=False) + + x = fn(torch.tensor([self.rank + 1.0], device=device)) + fut = pg.allreduce(x).get_future() + fut.wait() + result = fut.value() + self.assertEqual( + torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), + result[0], + ) + + @requires_xccl() + def test_allreduce_basics(self): + self._test_allreduce_basics(lambda t: t.clone()) + + +if __name__ == "__main__": + assert ( + not torch.xpu._initialized + ), "test_distributed must not have initialized XPU context on main process" + + run_tests() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index bb949a081c95e9..9a91b26d54cfb4 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -282,6 +282,9 @@ if(USE_DISTRIBUTED) if(USE_NCCL) list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) endif() + if(USE_XCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::xccl) + endif() # Same for MPI. if(USE_MPI) list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) @@ -345,6 +348,10 @@ if(BUILD_LIBTORCHLESS) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() + if(USE_XPU AND USE_C10D_XCCL) + target_compile_definitions(torch_python PRIVATE USE_C10D_XCCL) + endif() + if(USE_DISTRIBUTED) target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED) endif() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 94e8578bbfff62..6033d969925972 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -309,6 +309,7 @@ class ProcessGroup: UNDEFINED = ... GLOO = ... NCCL = ... + XCCL = ... UCC = ... MPI = ... CUSTOM = ... @@ -697,3 +698,11 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... + +class ProcessGroupXCCL(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + ): ... diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index ae822ad3975049..699c54236f6412 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -79,6 +79,7 @@ namespace { } IMPL_SEND(CPU) +IMPL_SEND(XPU) IMPL_SEND(CUDA) IMPL_SEND(PrivateUse1) @@ -94,6 +95,7 @@ IMPL_SEND(PrivateUse1) } IMPL_RECV(CPU) +IMPL_RECV(XPU) IMPL_RECV(CUDA) IMPL_RECV(PrivateUse1) @@ -108,6 +110,7 @@ IMPL_RECV(PrivateUse1) } IMPL_RECV_ANY_SOURCE(CPU) +IMPL_RECV_ANY_SOURCE(XPU) IMPL_RECV_ANY_SOURCE(CUDA) IMPL_RECV_ANY_SOURCE(PrivateUse1) @@ -131,6 +134,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) } IMPL_REDUCE(CPU) +IMPL_REDUCE(XPU) IMPL_REDUCE(CUDA) IMPL_REDUCE(PrivateUse1) @@ -156,6 +160,7 @@ IMPL_REDUCE(PrivateUse1) } IMPL_BROADCAST(CPU) +IMPL_BROADCAST(XPU) IMPL_BROADCAST(CUDA) IMPL_BROADCAST(PrivateUse1) @@ -181,6 +186,7 @@ IMPL_BROADCAST(PrivateUse1) IMPL_ALLREDUCE(CPU) IMPL_ALLREDUCE(CUDA) +IMPL_ALLREDUCE(XPU) IMPL_ALLREDUCE(PrivateUse1) #define IMPL_ALLREDUCE_COALESCED(DEV) \ @@ -198,6 +204,7 @@ IMPL_ALLREDUCE(PrivateUse1) } IMPL_ALLREDUCE_COALESCED(CPU) +IMPL_ALLREDUCE_COALESCED(XPU) IMPL_ALLREDUCE_COALESCED(CUDA) IMPL_ALLREDUCE_COALESCED(PrivateUse1) @@ -222,6 +229,7 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1) // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast) IMPL_ALLGATHER(CPU) +IMPL_ALLGATHER(XPU) IMPL_ALLGATHER(CUDA) IMPL_ALLGATHER(PrivateUse1) @@ -242,6 +250,7 @@ IMPL_ALLGATHER(PrivateUse1) } IMPL__ALLGATHER_BASE(CPU) +IMPL__ALLGATHER_BASE(XPU) IMPL__ALLGATHER_BASE(CUDA) IMPL__ALLGATHER_BASE(PrivateUse1) @@ -258,6 +267,7 @@ IMPL__ALLGATHER_BASE(PrivateUse1) } IMPL_ALLGATHER_COALESCED(CPU) +IMPL_ALLGATHER_COALESCED(XPU) IMPL_ALLGATHER_COALESCED(CUDA) IMPL_ALLGATHER_COALESCED(PrivateUse1) @@ -273,6 +283,7 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1) } IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU) +IMPL_ALLGATHER_INTO_TENSOR_COALESCED(XPU) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CUDA) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) @@ -296,6 +307,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) } IMPL_REDUCE_SCATTER(CPU) +IMPL_REDUCE_SCATTER(XPU) IMPL_REDUCE_SCATTER(CUDA) IMPL_REDUCE_SCATTER(PrivateUse1) @@ -320,6 +332,7 @@ IMPL_REDUCE_SCATTER(PrivateUse1) } IMPL__REDUCE_SCATTER_BASE(CPU) +IMPL__REDUCE_SCATTER_BASE(XPU) IMPL__REDUCE_SCATTER_BASE(CUDA) IMPL__REDUCE_SCATTER_BASE(PrivateUse1) @@ -341,6 +354,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) } IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU) +IMPL_REDUCE_SCATTER_TENSOR_COALESCED(XPU) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CUDA) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) @@ -360,6 +374,7 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) } IMPL_GATHER(CPU) +IMPL_GATHER(XPU) IMPL_GATHER(CUDA) IMPL_GATHER(PrivateUse1) @@ -382,6 +397,7 @@ IMPL_GATHER(PrivateUse1) } IMPL_SCATTER(CPU) +IMPL_SCATTER(XPU) IMPL_SCATTER(CUDA) IMPL_SCATTER(PrivateUse1) @@ -403,6 +419,7 @@ IMPL_SCATTER(PrivateUse1) } IMPL_ALLTOALL(CPU) +IMPL_ALLTOALL(XPU) IMPL_ALLTOALL(CUDA) IMPL_ALLTOALL(PrivateUse1) @@ -424,6 +441,7 @@ IMPL_ALLTOALL(PrivateUse1) } IMPL_ALLTOALL_BASE(CPU) +IMPL_ALLTOALL_BASE(XPU) IMPL_ALLTOALL_BASE(CUDA) IMPL_ALLTOALL_BASE(PrivateUse1) @@ -439,6 +457,7 @@ IMPL_ALLTOALL_BASE(PrivateUse1) } IMPL_BARRIER(CPU) +IMPL_BARRIER(XPU) IMPL_BARRIER(CUDA) IMPL_BARRIER(PrivateUse1) // NOLINTEND(cppcoreguidelines-pro-type-const-cast) @@ -491,6 +510,7 @@ namespace { #define REGISTER_C10D_OP(FUNC) \ REGISTER_C10D_OP1(FUNC, CPU) \ REGISTER_C10D_OP1(FUNC, CUDA) \ + REGISTER_C10D_OP1(FUNC, XPU) \ REGISTER_C10D_OP1(FUNC, PrivateUse1) // Now we start to register ops with the three device keys diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 75635bc68aed4f..70356b3bf382ce 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -21,6 +21,8 @@ static ProcessGroup::BackendType strToBackendType(std::string_view backend) { return ProcessGroup::BackendType::GLOO; } else if (backend == "nccl") { return ProcessGroup::BackendType::NCCL; + } else if (backend == "xccl") { + return ProcessGroup::BackendType::XCCL; } else if (backend == "ucc") { return ProcessGroup::BackendType::UCC; } else if (backend == "mpi") { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index acf8c9c354a76b..73fc2bda701327 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -70,6 +70,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { UCC = 3, MPI = 4, CUSTOM = 5, + XCCL = 6, }; // Not used, set for backwards compatibility and only used for TypeDef in @@ -489,6 +490,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // TODO: HACK for backend name to get sequence number for that backend. if (backendType == ProcessGroup::BackendType::GLOO || backendType == ProcessGroup::BackendType::NCCL || + backendType == ProcessGroup::BackendType::XCCL || backendType == ProcessGroup::BackendType::UCC) { getDefaultBackend()->setSequenceNumberForGroup(); } else { @@ -510,6 +512,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // TODO: HACK for backend name to get sequence number for that backend. if (backendType == ProcessGroup::BackendType::GLOO || backendType == ProcessGroup::BackendType::NCCL || + backendType == ProcessGroup::BackendType::XCCL || backendType == ProcessGroup::BackendType::UCC) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp new file mode 100644 index 00000000000000..5aeeb62bee1ece --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -0,0 +1,401 @@ +#include +#include +#include +#include + +#ifdef USE_C10D_XCCL +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10d { + +namespace { +std::map xcclOps = { + {ReduceOp::MIN, ccl::reduction::min}, + {ReduceOp::MAX, ccl::reduction::max}, + {ReduceOp::SUM, ccl::reduction::sum}, + {ReduceOp::PRODUCT, ccl::reduction::prod}, +}; + +std::map xcclDatatypes = { + {at::kByte, ccl::datatype::uint8}, + {at::kChar, ccl::datatype::int8}, + {at::kInt, ccl::datatype::int32}, + {at::kLong, ccl::datatype::int64}, + {at::kHalf, ccl::datatype::float16}, + {at::kFloat, ccl::datatype::float32}, + {at::kDouble, ccl::datatype::float64}, + {at::kBFloat16, ccl::datatype::bfloat16}, + {at::kBool, ccl::datatype::uint8}, +}; + +XCCL_KVS kvs; +std::mutex kvs_mutex; + +XCCL_KVS get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "xccl_kvs"; + + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + + return kvs; +} + +void check_xpu_single_tensor(const at::Tensor& tensor) { + if (!tensor.is_xpu() || tensor.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); + } + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } +} + +ccl::datatype getXcclDataType(at::ScalarType type) { + auto it = xcclDatatypes.find(type); + TORCH_CHECK_WITH( + TypeError, + it != xcclDatatypes.end(), + "Input tensor data type is not supported for XCCL process group: ", + type); + return it->second; +} + +ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { + try { + if (input.scalar_type() == at::kBool) { + if (reduceOp == ReduceOp::SUM) { + // For bool tensors, map sum to max, which both represent a bitwise or. + // This is to prevent overflow issues with sum, since we use uint8 to + // represent a bool (see xcclDatatypes mapping align with cuda). + return ccl::reduction::max; + } + } + return xcclOps.at(reduceOp); + } catch (const std::out_of_range&) { + switch (reduceOp) { + case ReduceOp::AVG: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp AVG with XCCL"); + break; + case ReduceOp::BAND: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with XCCL"); + break; + case ReduceOp::BOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with XCCL"); + break; + case ReduceOp::BXOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with XCCL"); + break; + default: + C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); + break; + } + } +} + +} // namespace + +static std::mutex xcclCommDevIdxMapMutex; +static std::unordered_map, int> xcclCommDevIdxMap; +constexpr int64_t kSynchronizeBusyWaitMillis = 10; + +ProcessGroupXCCL::WorkXCCL::WorkXCCL( + at::Device& device, + int rank, + OpType opType, + const std::optional>& inputs) + : Work(rank, opType, "profilingTitle", inputs), + device_(device), + workStartTime_(std::chrono::steady_clock::now()) { + unsigned char enable_timing = 0; + xcclEndEvent_ = std::make_shared(enable_timing); +} + +ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) + : Work(w.rank_, w.opType_), + device_(w.device_), + xcclEndEvent_(w.xcclEndEvent_), + blockingWait_(w.blockingWait_), + workStartTime_(w.workStartTime_) {} + +ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; + +bool ProcessGroupXCCL::WorkXCCL::checkTimeout( + std::optional timeout) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); + + auto workTimeout = timeout ? *timeout : opTimeout; + + if (timeElapsed < workTimeout) + return false; + return true; +} + +bool ProcessGroupXCCL::WorkXCCL::isCompleted() { + if (xcclEndEvent_ && xcclEndEvent_->query()) { + return true; + } + return false; +} + +void ProcessGroupXCCL::WorkXCCL::synchronize() { + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeStream() { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + // Block the current stream on the XCCL stream + xcclEndEvent_->block(currentStream); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { + synchronizeStream(); + + if (blockingWait_) { + while (!isCompleted()) { + bool timedOut = checkTimeout( + timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); + if (timedOut) { + break; + } + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + } +} + +bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { + synchronizeInternal(timeout); + return true; +} + +ProcessGroupXCCL::ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size) + : Backend(rank, size), store_(store) { + blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + init(); + + // Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI + // launchers. + if (!with_mpirun()) { + int local_rank = getXCCLEnvVar("LOCAL_RANK"); + int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE"); + if (local_rank == -1 || local_world_size == -1) { + local_rank = rank; + local_world_size = size; + } + setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none"); + setXCCLEnvVar("CCL_LOCAL_RANK", local_rank); + setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size); + } +} + +ProcessGroupXCCL::~ProcessGroupXCCL() = default; + +c10::intrusive_ptr ProcessGroupXCCL::initWork( + at::Device& device, + int rank, + OpType opType, + const std::vector& inputs, + const std::vector& outputs) { + auto r = c10::make_intrusive( + device, rank, opType, std::optional>(inputs)); + return r; +} + +std::shared_ptr ProcessGroupXCCL::getXCCLComm( + const std::string& deviceKey, + at::Device& device) { + if (deviceKey.empty()) { + C10_THROW_ERROR( + DistBackendError, + "Not able to create/get the XCCL Communicator since " + "the devices are empty "); + } + + { + std::lock_guard lock(mutex_); + if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { + return devXCCLCommMap_[deviceKey]; + } + } + + std::shared_ptr XCCLComm; + + XCCL_KVS kvs = get_kvs(rank_, *store_); + + int numRanks, rank; + numRanks = getSize(); + rank = getRank(); + + c10::impl::VirtualGuardImpl impl(device.type()); + c10::Stream stream = impl.getStream(device); + sycl::queue& q = c10::xpu::XPUStream(stream).queue(); + + auto ctx = ccl::create_context(q.get_context()); + ccl::vector_class> devs_rank; + devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs); + XCCLComm = std::make_shared(std::move(comms[0])); + + { + std::lock_guard lock(mutex_); + inInitializationCommMap_.emplace(deviceKey, XCCLComm); + } + + xcclStreams_.emplace(deviceKey, std::move(stream)); + + auto it = inInitializationCommMap_.find(deviceKey); + if (it != inInitializationCommMap_.end()) { + devXCCLCommMap_.emplace(deviceKey, std::move(it->second)); + inInitializationCommMap_.erase(deviceKey); + + xcclCommDevIdxMapMutex.lock(); + xcclCommDevIdxMap.emplace(XCCLComm, device.index()); + xcclCommDevIdxMapMutex.unlock(); + } + + it = devXCCLCommMap_.find(deviceKey); + TORCH_INTERNAL_ASSERT( + it != devXCCLCommMap_.end(), "Communicators not populated in cache!"); + + return it->second; +} + +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType) { + using traits = function_traits; + using attr_t = typename traits::template arg<2>::type; + attr_t attr = ccl::create_operation_attr(); + + auto device = input.device(); + const auto key = std::to_string(device.index()); + auto comm = getXCCLComm(key, device); + + auto stream = xcclStreams_.at(key); + std::vector outputs{output}; + + c10::intrusive_ptr work; + + work = initWork(device, rank_, opType); + + work->outputs_ = + std::make_shared>(std::move(outputs)); + c10::xpu::XPUCachingAllocator::recordStream( + input.storage().data_ptr(), stream); + + auto ccl_stream = ccl::create_stream(stream.queue()); + + fn(input, output, attr, *comm, ccl_stream); + + work->xcclEndEvent_->record(stream); + + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + work->blockingWait_ = blockingWait_; + + return work; +} + +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + opType); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK( + tensors.size() == 1, "Expecting one tensor only but got multiple"); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allreduce_attr attr, + xcclComm_t& comm, + ccl::stream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + stream, + attr); + return ret_evt; + }, + OpType::ALLREDUCE); +} + +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp new file mode 100644 index 00000000000000..14a9f398a8cbe7 --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -0,0 +1,308 @@ +#pragma once + +#if defined(__linux__) +#include +#include +#include +#include +#endif + +#ifdef USE_C10D_XCCL +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace c10d { + +namespace { +int getXCCLEnvVar(std::string envVarName) { + char* stringValue = std::getenv(envVarName.c_str()); + if (stringValue != nullptr) { + try { + int val = std::stoi(stringValue); + return val; + } catch (std::exception& e) { + TORCH_CHECK( + false, + "Invalid value for environment variable: " + std::string(envVarName)); + } + } else { + return -1; + } +} + +template +void setXCCLEnvVar(const std::string& envVarName, T val) { + if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); + } else if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), val.c_str(), 1); + } +} + +bool with_mpirun() { + return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || + getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) + ? true + : false; +} +} // namespace + +static std::vector TORCH_XCCL_BLOCKING_WAIT = { + "TORCH_XCCL_BLOCKING_WAIT", + "XCCL_BLOCKING_WAIT"}; + +using xcclComm_t = ccl::communicator; +using XCCL_KVS = ccl::shared_ptr_class; +constexpr const char* XCCL_BACKEND_NAME = "xccl"; + +class TORCH_API ProcessGroupXCCL : public Backend { + public: + class WorkXCCL : public Work { + public: + WorkXCCL( + at::Device& device, + int rank, + OpType opType, + const std::optional>& inputs = std::nullopt); + WorkXCCL(const WorkXCCL& w); + ~WorkXCCL() override; + + bool isCompleted() override; + + bool isSuccess() const override { + TORCH_CHECK( + false, "ProcessGroupXCCL::WorkXCCL::isSuccess not implemented"); + } + + void abort() override { + TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); + } + + void synchronize() override; + + void synchronizeStream(); + + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; + + c10::intrusive_ptr getFuture() override { + return future_; + } + + std::vector result() override { + TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); + } + + bool checkTimeout( + std::optional timeout = std::nullopt); + + protected: + at::Device device_; + std::shared_ptr xcclEndEvent_; + bool blockingWait_ = false; + std::chrono::time_point workStartTime_; + + private: + void synchronizeInternal(std::chrono::milliseconds timeout); + std::shared_ptr> outputs_; + c10::intrusive_ptr future_; + friend class ProcessGroupXCCL; + }; + + ProcessGroupXCCL(const c10::intrusive_ptr& store, int rank, int size); + + C10_DEPRECATED ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + const std::string& groupName) + : ProcessGroupXCCL(store, rank, size) {} + + ~ProcessGroupXCCL() override; + + const std::string getBackendName() const override { + return std::string(XCCL_BACKEND_NAME); + } + + std::shared_ptr getXCCLComm( + const std::string& deviceKey, + at::Device& device); + + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + const std::vector& inputs = {}, + const std::vector& outputs = {}); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType); + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::allreduce_coalesced not implemented"); + } + + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::reduce not implemented"); + } + + c10::intrusive_ptr broadcast( + std::vector& tensors, + const BroadcastOptions& opts = BroadcastOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::broadcast not implemented"); + } + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::allgather not implemented"); + } + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::_allgather_base not implemented"); + } + + c10::intrusive_ptr allgather_coalesced( + std::vector>& outputTensorLists, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::allgather_coalesced not implemented"); + } + + c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts = AllgatherOptions()) override { + TORCH_CHECK( + false, + "ProcessGroupXCCL::allgather_into_tensor_coalesced not implemented"); + } + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::reduce_scatter not implemented"); + } + + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override { + TORCH_CHECK( + false, "ProcessGroupXCCL::_reduce_scatter_base not implemented"); + } + + c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override { + TORCH_CHECK( + false, + "ProcessGroupXCCL::reduce_scatter_tensor_coalesced not implemented"); + } + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::barrier not implemented"); + } + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::alltoall_base not implemented"); + } + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::alltoall not implemented"); + } + + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override { + TORCH_CHECK(false, "ProcessGroupXCCL::send not implemented"); + } + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override { + TORCH_CHECK(false, "ProcessGroupXCCL::recv not implemented"); + } + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::gather not implemented"); + } + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::scatter not implemented"); + } + + protected: + std::unordered_map xcclStreams_; + std::unordered_map> + inInitializationCommMap_; + std::unordered_map> devXCCLCommMap_; + c10::intrusive_ptr store_; + std::mutex mutex_; + bool blockingWait_ = false; +}; +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index c8f9dff37f06e2..e3ed6d6bd4bcb4 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -37,6 +37,10 @@ #include #endif +#ifdef USE_C10D_XCCL +#include +#endif + #include #include #include @@ -2232,6 +2236,7 @@ The hook must have the following signature: .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED) .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO) .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL) + .value("XCCL", ::c10d::ProcessGroup::BackendType::XCCL) .value("UCC", ::c10d::ProcessGroup::BackendType::UCC) .value("MPI", ::c10d::ProcessGroup::BackendType::MPI) .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) @@ -2877,6 +2882,23 @@ Example:: py::call_guard()); #endif +#ifdef USE_C10D_XCCL + auto processGroupXCCL = + intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>( + module, "ProcessGroupXCCL", backend) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size) { + return c10::make_intrusive<::c10d::ProcessGroupXCCL>( + store, rank, size); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::call_guard()); +#endif + py::enum_<::c10d::OpType>(module, "OpType") .value("BROADCAST", ::c10d::OpType::BROADCAST) .value("ALLREDUCE", ::c10d::OpType::ALLREDUCE) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 45e096985143a3..9fa3224873c9fc 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -87,6 +87,7 @@ "is_nccl_available", "is_torchelastic_launched", "is_ucc_available", + "is_xccl_available", "isend", "monitored_barrier", "new_group", @@ -130,6 +131,7 @@ _NCCL_AVAILABLE = True _GLOO_AVAILABLE = True _UCC_AVAILABLE = True +_XCCL_AVAILABLE = True _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -193,6 +195,14 @@ def _export_c_types() -> None: except ImportError: _UCC_AVAILABLE = False +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupXCCL"] +except ImportError: + _XCCL_AVAILABLE = False + logger = logging.getLogger(__name__) PG_WRAPPER_STORE_PREFIX = "pg_wrapper" @@ -222,7 +232,7 @@ class Backend(str): """ An enum-like class for backends. - Available backends: GLOO, NCCL, UCC, MPI, and other registered backends. + Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. The values of this class are lowercase strings, e.g., ``"gloo"``. They can be accessed as attributes, e.g., ``Backend.NCCL``. @@ -242,21 +252,24 @@ class Backend(str): NCCL = "nccl" UCC = "ucc" MPI = "mpi" + XCCL = "xccl" _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) _plugins: Dict[str, _BackendPlugin] = {} - backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { "cpu": GLOO, "cuda": NCCL, + "xpu": XCCL, } backend_capability: Dict[str, List[str]] = { GLOO: ["cpu", "cuda"], NCCL: ["cuda"], + XCCL: ["xpu"], UCC: ["cpu", "cuda"], MPI: ["cpu", "cuda"], } @@ -265,6 +278,7 @@ class Backend(str): UNDEFINED: ProcessGroup.BackendType.UNDEFINED, GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, + XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, } @@ -1098,6 +1112,11 @@ def is_ucc_available() -> bool: return _UCC_AVAILABLE +def is_xccl_available() -> bool: + """Check if the XCCL backend is available.""" + return _XCCL_AVAILABLE + + def is_backend_available(backend: str) -> bool: """ Check backend availability. @@ -1350,6 +1369,10 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> backends.add(backend) # type: ignore[arg-type] elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): backends.add(backend) # type: ignore[arg-type] + if torch.device("xpu") in devices and is_xccl_available(): + backend = group._get_backend(torch.device("xpu")) + if isinstance(backend, ProcessGroupXCCL): + backends.add(backend) # type: ignore[arg-type] if len(backends) == 0: warnings.warn("Set timeout is now only supported for either nccl or gloo.") for backend in backends: @@ -1385,7 +1408,7 @@ def init_process_group( Args: backend (str or Backend, optional): The backend to use. Depending on - build-time configurations, valid values include ``mpi``, ``gloo``, + build-time configurations, valid values include ``mpi``, ``gloo``, ``xccl``, ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo`` and ``nccl`` backend will be created, see notes below for how multiple backends are managed. This field can be given as a lowercase string @@ -1651,10 +1674,13 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + if device_id is not None and ( + device_id.index is None + or (device_id.type != "cuda" and device_id.type != "xpu") + ): raise ValueError( "init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu" + "id, e.g. cuda:0, xpu, not just cuda or xpu or cpu" ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value @@ -1762,7 +1788,6 @@ def _new_process_group_helper( pg_options = ProcessGroupNCCL.Options() pg_options.is_high_priority_stream = False pg_options._timeout = timeout - if split_from: pg_options.split_from = split_from pg_options.split_color = _process_group_color(global_ranks_in_group) @@ -1781,6 +1806,17 @@ def _new_process_group_helper( backend_prefix_store, group_rank, group_size, timeout=timeout ) backend_type = ProcessGroup.BackendType.UCC + elif backend_str == Backend.XCCL: + if not is_xccl_available(): + raise RuntimeError("Distributed package doesn't have XCCL built in") + if pg_options is not None: + assert isinstance( + pg_options, ProcessGroupXCCL.Options + ), "Expected pg_options argument to be of type ProcessGroupXCCL.Options" + backend_class = ProcessGroupXCCL( + backend_prefix_store, group_rank, group_size + ) + backend_type = ProcessGroup.BackendType.XCCL else: assert ( backend_str.upper() in Backend._plugins diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index d59102232f7db7..26bdcce6103120 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -180,7 +180,8 @@ def skip_if_lt_x_gpu(x): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if torch.cuda.is_available() and torch.cuda.device_count() >= x: + if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \ + (torch.xpu.is_available() and torch.xpu.device_count() >= x): return func(*args, **kwargs) sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) @@ -320,6 +321,12 @@ def requires_nccl(): "c10d was not compiled with the NCCL backend", ) +def requires_xccl(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_xccl_available(), + "c10d was not compiled with the XCCL backend", + ) + def requires_ucc(): return skip_but_pass_in_sandcastle_if( not c10d.is_ucc_available(), @@ -463,7 +470,7 @@ def init_multigpu_helper(world_size: int, backend: str): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - nGPUs = torch.cuda.device_count() + nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count() visible_devices = range(nGPUs) # If rank is less than or equal to number of available GPU's From 7af06345000f08d148fcbd3c4d4c7c0377510200 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Sun, 25 Aug 2024 22:50:11 -0700 Subject: [PATCH 0002/1018] [FR] Fix the bug in FR script (e.g., checking all ranks dump check) (#134383) We somehow convert the rank to string which makes the ranks check fail. This fix now convert them all to int. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134383 Approved by: https://github.com/c-p-i-o --- .../distributed/flight_recorder/test_fr_analysis.py | 4 ++-- tools/flight_recorder/components/builder.py | 2 +- tools/flight_recorder/components/types.py | 13 +++++++------ tools/flight_recorder/components/utils.py | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/test/distributed/flight_recorder/test_fr_analysis.py b/test/distributed/flight_recorder/test_fr_analysis.py index fd0c6df0cc9786..2e389de7694ab9 100644 --- a/test/distributed/flight_recorder/test_fr_analysis.py +++ b/test/distributed/flight_recorder/test_fr_analysis.py @@ -53,10 +53,10 @@ def test_match_one_event(self): ) e3 = create_one_event( - "alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 + "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) e4 = create_one_event( - "alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 + "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index a0307c3b502870..b3cb9f792469a3 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -244,7 +244,7 @@ def build_collectives( nccl_calls.extend(reversed(reversed_calls)) else: has_undecided_case = False - errors = Set() + errors = set() for o in expected_ranks.intersection(set(other_ranks)): for i, e in enumerate(all_entries[o]): # type: ignore[index] # step over ops from other PGs diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index e55c2370f30cd0..4a33f3580e4c16 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -138,8 +138,7 @@ class Database(NamedTuple): "_reduce_scatter_base", "gather", "scatter", - "alltoall_base", - "alltoall", + "all_to_all", } P2P = { @@ -158,7 +157,7 @@ class MatchState(Enum): - COLLECTIVE_STATE_MISMATCH: The states of the collective not same, such as one finished while another just started or scheduled. - UNDECIDED: - The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base. + The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for all_to_all. """ FULLY_MATCHED = 1 @@ -171,6 +170,8 @@ class MatchState(Enum): def check_size_evenly_broadcasting( list1: List[Any], list2: List[Any], size: int ) -> bool: + if len(list1) != len(list2): + return False ratio = None for a, b in zip(list1, list2): current_ratio = int(a) / int(b) @@ -283,7 +284,7 @@ def match(self, other: "Op") -> MatchState: elif self.type in COLLECTIVES: if self.type != other.type: return MatchState.COLLECTIVE_TYPE_MISMATCH - if self.type in ["alltoall", "alltoall_base"]: + if self.type == "all_to_all": return MatchState.UNDECIDED if self.type != "scatter" and self.input_sizes != other.input_sizes: return MatchState.SIZE_OR_SYNTAX_MISMATCH @@ -297,14 +298,14 @@ def match(self, other: "Op") -> MatchState: "all_gather", "all_gather_base", ] and not check_size_evenly_broadcasting( - other.output_sizes, self.input_sizes, self.pg_size + other.output_sizes[0], self.input_sizes[0], self.pg_size ): return MatchState.SIZE_OR_SYNTAX_MISMATCH if self.type in [ "reduce_scatter", "_reduce_scatter_base", ] and not check_size_evenly_broadcasting( - other.input_sizes, self.output_sizes, self.pg_size + other.input_sizes[0], self.output_sizes[0], self.pg_size ): return MatchState.SIZE_OR_SYNTAX_MISMATCH # TODO: need to add more checks for gather and scatter. diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index ef0e9a9f1380f5..2f9d382261b913 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -202,8 +202,8 @@ def check_no_missing_dump_files( ) -> None: all_ranks = set() for membership in memberships: - all_ranks.add(str(membership.global_rank)) - dumps_ranks = set(entries.keys()) + all_ranks.add(int(membership.global_rank)) + dumps_ranks = {int(key) for key in entries.keys()} assert ( dumps_ranks == all_ranks ), f"Missing dump files from ranks {all_ranks - dumps_ranks}" From 4329a7b029eebc8bd80e21a7342e3585d4f607d4 Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Mon, 26 Aug 2024 10:36:51 +0000 Subject: [PATCH 0003/1018] [inductor] support vec for atomic add (#131314) Depends on https://github.com/pytorch/pytorch/pull/130827 to have correct `index_expr` dtype Support vec for atomic add by scalar implementation. TestPlan: ``` python test/inductor/test_cpu_repro.py -k test_scatter_using_atomic_add_vec ``` Generated code for `test_scatter_using_atomic_add_vec` ``` cpp_fused_scatter_0 = async_compile.cpp_pybinding(['const float*', 'const int64_t*', 'const float*', 'float*'], ''' #include "/tmp/torchinductor_root/nn/cnnpkaxivwaa5rzng6qsyc4ao42vschogi3yk33ukwv3emlvxeqq.h" extern "C" void kernel(const float* in_ptr0, const int64_t* in_ptr1, const float* in_ptr2, float* out_ptr0) { { for(long x0=static_cast(0L); x0(16L); x0+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x0), 16); tmp0.store(out_ptr0 + static_cast(x0)); } #pragma omp simd simdlen(8) for(long x0=static_cast(16L); x0(25L); x0+=static_cast(1L)) { auto tmp0 = in_ptr0[static_cast(x0)]; out_ptr0[static_cast(x0)] = tmp0; } } { for(long x0=static_cast(0L); x0(16L); x0+=static_cast(16L)) { auto tmp0 = at::vec::VectorizedN::loadu(in_ptr1 + static_cast(x0), 16); auto tmp12 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x0), 16); auto tmp1 = 25L; auto tmp2 = c10::convert(tmp1); auto tmp3 = at::vec::VectorizedN(tmp2); auto tmp4 = tmp0 + tmp3; auto tmp5 = static_cast(0); auto tmp6 = at::vec::VectorizedN(tmp5); auto tmp7 = at::vec::VecMask(tmp0 < tmp6); auto tmp8 = decltype(tmp4)::blendv(tmp0, tmp4, tmp7.template cast()); auto tmp9 = [&] { __at_align__ std::array tmpbuf; tmp8.store(tmpbuf.data()); return tmpbuf; } () ; auto tmp10 = [&] { __at_align__ std::array tmpbuf; #pragma GCC unroll 16 for (long x0_inner = 0; x0_inner < 16; x0_inner++) { tmpbuf[x0_inner] = static_cast(tmp9[x0_inner]); } return at::vec::VectorizedN::loadu(tmpbuf.data(), 16); } () ; TORCH_CHECK((at::vec::VecMask((at::vec::VectorizedN(0) <= tmp10) & (tmp10 < at::vec::VectorizedN(25L)))).all_masked(), "index out of bounds: 0 <= tmp10 < 25L"); atomic_add_vec(out_ptr0, tmp8, tmp12); } #pragma omp simd simdlen(8) for(long x0=static_cast(16L); x0(20L); x0+=static_cast(1L)) { auto tmp0 = in_ptr1[static_cast(x0)]; auto tmp9 = in_ptr2[static_cast(x0)]; auto tmp1 = 25L; auto tmp2 = c10::convert(tmp1); auto tmp3 = decltype(tmp0)(tmp0 + tmp2); auto tmp4 = tmp0 < 0; auto tmp5 = tmp4 ? tmp3 : tmp0; auto tmp6 = tmp5; auto tmp7 = c10::convert(tmp6); TORCH_CHECK((0 <= tmp7) & (tmp7 < 25L), "index out of bounds: 0 <= tmp7 < 25L"); atomic_add(&out_ptr0[static_cast(tmp5)], static_cast(tmp9)); } } } ''') ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131314 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel --- test/inductor/test_cpu_repro.py | 26 +++++++++++++++++ torch/_inductor/codegen/cpp.py | 42 +++++++++++++++++++++------- torch/_inductor/codegen/cpp_prefix.h | 15 ++++++++++ 3 files changed, 73 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 8c2691aaac4479..1ce3a731c788d3 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1972,6 +1972,32 @@ def _internal_check( with config.patch({"cpp.dynamic_threads": True}), set_num_threads(1): _internal_check(fn, inps, "aten.scatter_reduce_") + @patch("torch.cuda.is_available", lambda: False) + @requires_vectorization + @torch._inductor.config.patch({"cpp.fallback_scatter_reduce_sum": False}) + def test_scatter_using_atomic_add_vec(self): + def fn(a, dim, index, b): + return aten.scatter(a, dim, index, b, reduce="add") + + inps = ( + torch.zeros(1, 1, 25), + 2, + torch.tensor([[[3, 5, 7, 9] * 5]]), + torch.ones(1, 1, 25), + ) + torch._dynamo.reset() + metrics.reset() + self.common(fn, inps) + assert metrics.generated_cpp_vec_kernel_count == 2 + + with set_num_threads(1), config.patch( + {"fx_graph_cache": False, "fx_graph_remote_cache": False} + ): + torch._dynamo.reset() + metrics.reset() + self.common(fn, inps) + assert metrics.generated_cpp_vec_kernel_count == 2 + @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @requires_vectorization @patch("torch.cuda.is_available", lambda: False) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 89b4666b67fb15..08d7f3cc4b6fc7 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2157,6 +2157,7 @@ def _load_or_store_non_contiguous( dtype: torch.dtype, buffer: Optional[IndentedBuffer] = None, store_value: Optional[Union[str, CppCSEVariable]] = None, + accu_store: bool = False, ) -> Optional[CppCSEVariable]: """ Load or store a vector in a non-contiguous way. The vector is initialized from an array that is @@ -2171,10 +2172,12 @@ def _load_or_store_non_contiguous( :param dtype: data type of `var` or `index` if `var` is None. :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. :param store_value: the value to store. If None, we load the vector. + :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided :return: a CppCSEVariable that represents the loaded vector or None if it is a store. """ assert not store_value or var is not None, "store var must be provided" - + if accu_store: + assert store_value if buffer is None: buffer = self.loads @@ -2261,7 +2264,8 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: code.writeline(f"if ({load_mask})") stack.enter_context(code.indent()) if store_value: - code.writeline(f"{rhs} = tmpbuf[{itervar_inner}];") + conjunction = "+=" if accu_store else "=" + code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];") else: code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") if not store_value: @@ -2304,6 +2308,7 @@ def _get_store_line( var: str, index: sympy.Expr, dtype: torch.dtype, + accu_store: bool = False, ): """ Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles @@ -2328,21 +2333,42 @@ def _get_store_line( code.writeline(f"{value}.store({var_expr}, {self.num_elems});") else: self._load_or_store_non_contiguous( - var, index, dtype, buffer=code, store_value=value + var, index, dtype, buffer=code, store_value=value, accu_store=accu_store ) return code def store(self, name, index, value, mode=None): assert "buf" in name - assert mode is None assert isinstance(value, CppCSEVariable), value if not value.is_vec: # this happens when we store a scalar into a vectorized buffer like "fill" value = self.broadcast(value) var = self.args.output(name) index = self.rename_indexing(index) - code = self._get_store_line(value, var, index, V.graph.get_dtype(name)) - self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + dtype = V.graph.get_dtype(name) + if mode is None: + code = self._get_store_line(value, var, index, dtype) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + code = self._get_store_line( + f"{value}", + var, + index, + dtype, + accu_store=True, + ) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + else: + n_src = self._get_num_vectors(dtype) + n_idx = self._get_num_vectors(torch.int64) + cdtype = DTYPE_TO_CPP[dtype] + index = ops.index_expr(index, torch.int64).value + assert index.is_vec + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" + self.stores.writeline(DeferredLine(name, line)) + else: + raise NotImplementedError(f"store mode={mode}") def reduction(self, dtype, src_dtype, reduction_type, value): assert reduction_type in VECTORIZABLE_RTYPES @@ -3081,10 +3107,6 @@ def store(self, name, index, value, mode=None): assert "buf" in name index = self.rename_indexing(index) - if mode: - self.disable_vec(f"store mode: {mode}") - return self.simd_vec - return self.simd_vec def reduction(self, dtype, src_dtype, reduction_type, value): diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index e1937d26dac181..d3cf6270a3ec44 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -628,6 +628,21 @@ atomic_add(volatile T *addr, T offset) { atomic_addr->fetch_add(offset, std::memory_order_relaxed); } +#if INDUCTOR_USE_VECTOR_TYPES() +template +void atomic_add_vec(T *addr, at::vec::VectorizedN index, at::vec::VectorizedN offset) { + constexpr int len = at::vec::VectorizedN::size(); + static_assert(len <= at::vec::VectorizedN::size()); + __at_align__ std::array tmpbuf; + __at_align__ std::array tmpidx; + offset.store(tmpbuf.data()); + index.store(tmpidx.data()); + for (int i = 0; i < len; i++){ + atomic_add(addr + tmpidx[i], tmpbuf[i]); + } +} +#endif + void mm_get_thread_blocking( int num_threads, int64_t M, From b7cdbd65779d96c5e6a4f19d4573f0ada4c25412 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Mon, 26 Aug 2024 20:53:15 +0800 Subject: [PATCH 0004/1018] [CD] fix xpu nightly wheel test env (#134395) --- .circleci/scripts/binary_linux_test.sh | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 5d92c9099bff99..c30814982ef9e6 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -116,12 +116,6 @@ if [[ "$PACKAGE_TYPE" == libtorch ]]; then cd /tmp/libtorch fi -if [[ "$GPU_ARCH_TYPE" == xpu ]]; then - # Refer https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpu/2-5.html - source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh - source /opt/intel/oneapi/pti/latest/env/vars.sh -fi - # Test the package /builder/check_binary.sh From 02a6e5638aa667b56179eba28c41ebb849934b35 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Mon, 26 Aug 2024 13:40:17 +0000 Subject: [PATCH 0005/1018] Revert "[CD] fix xpu nightly wheel test env (#134395)" (#134461) This reverts commit 96738c9d756fbd64e6f2eba67f711d3e18f1630c. Merged without pytorchmergebot command by mistake Pull Request resolved: https://github.com/pytorch/pytorch/pull/134461 Approved by: https://github.com/jeanschmidt --- .circleci/scripts/binary_linux_test.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index c30814982ef9e6..5d92c9099bff99 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -116,6 +116,12 @@ if [[ "$PACKAGE_TYPE" == libtorch ]]; then cd /tmp/libtorch fi +if [[ "$GPU_ARCH_TYPE" == xpu ]]; then + # Refer https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpu/2-5.html + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh + source /opt/intel/oneapi/pti/latest/env/vars.sh +fi + # Test the package /builder/check_binary.sh From 127228a8799ce1b33d7c5229a0ac19c741a01e3b Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Fri, 23 Aug 2024 18:57:15 +0000 Subject: [PATCH 0006/1018] Remove unnecessary test skip (#134250) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134250 Approved by: https://github.com/amjames, https://github.com/janeyx99 --- torch/testing/_internal/common_methods_invocations.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 7081637d4a3e81..59f7f0e9462038 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10113,14 +10113,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_inplace", dtypes=integral_types_and(torch.bool), ), - # FIXME: fails check - # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/TensorIterator.cpp#L508-L510 - DecorateInfo( - unittest.skip("Skipped!"), - "TestForeach", - "test_parity", - dtypes=(torch.bool,), - ), ), ), ForeachFuncInfo( From 2b22928d813b7fe9b938c739ec8e6d9d9bd4861e Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Mon, 26 Aug 2024 14:56:00 +0000 Subject: [PATCH 0007/1018] Migrate Windows bin jobs to runner determinator (#134231) Update Windows binary workflows to use the runner determinator script. Closes: pytorch/ci-infra#262 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134231 Approved by: https://github.com/ZainRizvi --- .../windows_binary_build_workflow.yml.j2 | 22 ++- ...generated-windows-binary-conda-nightly.yml | 152 ++++++++++++------ ...ted-windows-binary-libtorch-debug-main.yml | 17 +- ...-windows-binary-libtorch-debug-nightly.yml | 44 +++-- ...d-windows-binary-libtorch-release-main.yml | 17 +- ...indows-binary-libtorch-release-nightly.yml | 44 +++-- ...generated-windows-binary-wheel-nightly.yml | 152 ++++++++++++------ 7 files changed, 317 insertions(+), 131 deletions(-) diff --git a/.github/templates/windows_binary_build_workflow.yml.j2 b/.github/templates/windows_binary_build_workflow.yml.j2 index d5aca578b9024a..06646f449fc64c 100644 --- a/.github/templates/windows_binary_build_workflow.yml.j2 +++ b/.github/templates/windows_binary_build_workflow.yml.j2 @@ -53,10 +53,20 @@ env: !{{ common.concurrency(build_environment) }} jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + {%- for config in build_configs %} !{{ config["build_name"] }}-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: !{{ common.timeout_minutes }} !{{ upload.binary_env(config, True) }} {%- if config.pytorch_extra_install_requirements is defined and config.pytorch_extra_install_requirements|d('')|length > 0 %} @@ -85,15 +95,17 @@ jobs: !{{ common.wait_and_kill_ssh_windows('pytorch') }} !{{ config["build_name"] }}-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: !{{ config["build_name"] }}-build + needs: + - !{{ config["build_name"] }}-build + - get-label-type {%- if config["gpu_arch_type"] == "cuda" %} {%- if branches == "nightly" %} - runs-on: windows.8xlarge.nvidia.gpu + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" {%- else %} - runs-on: windows.8xlarge.nvidia.gpu.nonephemeral + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu.nonephemeral" {%- endif %} {%- else %} - runs-on: windows.4xlarge.nonephemeral + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" {%- endif %} timeout-minutes: !{{ common.timeout_minutes }} !{{ upload.binary_env(config, True) }} diff --git a/.github/workflows/generated-windows-binary-conda-nightly.yml b/.github/workflows/generated-windows-binary-conda-nightly.yml index da0dd504c75e5f..dac8a32e14a648 100644 --- a/.github/workflows/generated-windows-binary-conda-nightly.yml +++ b/.github/workflows/generated-windows-binary-conda-nightly.yml @@ -32,9 +32,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} conda-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -145,8 +154,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_9-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_9-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - conda-py3_9-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -276,7 +287,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -388,8 +400,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_9-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_9-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_9-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -521,7 +535,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -633,8 +648,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_9-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_9-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_9-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -766,7 +783,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -878,8 +896,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_9-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_9-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_9-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1011,7 +1031,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1122,8 +1143,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_10-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_10-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - conda-py3_10-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1253,7 +1276,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1365,8 +1389,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_10-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_10-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_10-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1498,7 +1524,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1610,8 +1637,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_10-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_10-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_10-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1743,7 +1772,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1855,8 +1885,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_10-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_10-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_10-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1988,7 +2020,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2099,8 +2132,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_11-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_11-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - conda-py3_11-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2230,7 +2265,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2342,8 +2378,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_11-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_11-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_11-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2475,7 +2513,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2587,8 +2626,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_11-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_11-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_11-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2720,7 +2761,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2832,8 +2874,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_11-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_11-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_11-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2965,7 +3009,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3076,8 +3121,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_12-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_12-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - conda-py3_12-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3207,7 +3254,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3319,8 +3367,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_12-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_12-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_12-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3452,7 +3502,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3564,8 +3615,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_12-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_12-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_12-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3697,7 +3750,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3809,8 +3863,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_12-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_12-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_12-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml index 8ac413be0d65ea..6debd8b9a2dfcb 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml @@ -25,9 +25,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} libtorch-cpu-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -142,8 +151,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cpu-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-debug-build - runs-on: windows.4xlarge.nonephemeral + needs: + - libtorch-cpu-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index 60ba59556926f2..8f1576d39565ad 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -32,9 +32,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} libtorch-cpu-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -149,8 +158,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cpu-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-debug-build - runs-on: windows.4xlarge.nonephemeral + needs: + - libtorch-cpu-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -288,7 +299,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda11_8-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -404,8 +416,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda11_8-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda11_8-shared-with-deps-debug-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda11_8-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -545,7 +559,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda12_1-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -661,8 +676,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda12_1-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda12_1-shared-with-deps-debug-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda12_1-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -802,7 +819,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda12_4-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -918,8 +936,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda12_4-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda12_4-shared-with-deps-debug-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda12_4-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-libtorch-release-main.yml b/.github/workflows/generated-windows-binary-libtorch-release-main.yml index ab00cdc8919ea9..b4ae0bc0130ae0 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-main.yml @@ -25,9 +25,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} libtorch-cpu-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -142,8 +151,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cpu-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-release-build - runs-on: windows.4xlarge.nonephemeral + needs: + - libtorch-cpu-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index 842de97a1fbe99..c2d1b7009fff8d 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -32,9 +32,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} libtorch-cpu-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -149,8 +158,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cpu-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-release-build - runs-on: windows.4xlarge.nonephemeral + needs: + - libtorch-cpu-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -288,7 +299,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda11_8-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -404,8 +416,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda11_8-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda11_8-shared-with-deps-release-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda11_8-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -545,7 +559,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda12_1-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -661,8 +676,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda12_1-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda12_1-shared-with-deps-release-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda12_1-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -802,7 +819,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda12_4-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -918,8 +936,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda12_4-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda12_4-shared-with-deps-release-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda12_4-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 98bdf0c6776838..7088e10b5c197b 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -32,9 +32,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} wheel-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -146,8 +155,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_9-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_9-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - wheel-py3_9-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -277,7 +288,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_9-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -390,8 +402,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_9-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_9-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_9-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -523,7 +537,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -636,8 +651,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_9-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_9-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_9-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -769,7 +786,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -882,8 +900,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_9-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_9-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_9-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1015,7 +1035,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1127,8 +1148,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_10-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_10-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - wheel-py3_10-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1258,7 +1281,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_10-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1371,8 +1395,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_10-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_10-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_10-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1504,7 +1530,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_10-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1617,8 +1644,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_10-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_10-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_10-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1750,7 +1779,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1863,8 +1893,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_10-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_10-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_10-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1996,7 +2028,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2108,8 +2141,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_11-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_11-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - wheel-py3_11-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2239,7 +2274,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_11-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2352,8 +2388,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_11-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_11-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_11-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2485,7 +2523,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2598,8 +2637,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_11-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_11-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_11-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2731,7 +2772,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2844,8 +2886,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_11-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_11-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_11-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2977,7 +3021,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3089,8 +3134,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_12-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_12-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - wheel-py3_12-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3220,7 +3267,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3333,8 +3381,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_12-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_12-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_12-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3466,7 +3516,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3579,8 +3630,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_12-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_12-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_12-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3712,7 +3765,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3825,8 +3879,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_12-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_12-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_12-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch From 5f79ffd4eb195a13958245cf437b328a3aa9727c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 15:16:17 +0000 Subject: [PATCH 0008/1018] Revert "[dynamo] simplify implementation for `os.fspath` (#133801)" This reverts commit c5f6b72041144c00e240bcfdc783a5597c3d8928. Reverted https://github.com/pytorch/pytorch/pull/133801 on behalf of https://github.com/ZainRizvi due to Sorry, but this breaks internal tests because of using functools ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2310445169)) --- torch/_dynamo/polyfills/__init__.py | 12 ++++++++++ torch/_dynamo/polyfills/loader.py | 1 - torch/_dynamo/polyfills/os.py | 36 ----------------------------- torch/_dynamo/variables/misc.py | 14 ++++++++++- 4 files changed, 25 insertions(+), 38 deletions(-) delete mode 100644 torch/_dynamo/polyfills/os.py diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 9b465726754bbe..5e51b08ad40fd8 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -154,3 +154,15 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): if isinstance(obj, cls): obj.__init__(*args, **kwargs) return obj + + +def fspath(path): + # Python equivalent of os.fspath + if isinstance(path, (str, bytes)): + return path + elif hasattr(path, "__fspath__"): + return path.__fspath__() + else: + raise TypeError( + f"expected str, bytes or os.PathLike object, not {type(path).__name__}" + ) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index cce1ba183ec695..a1c97ff0cc3dc1 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -15,7 +15,6 @@ "builtins", "functools", "itertools", - "os", ) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) diff --git a/torch/_dynamo/polyfills/os.py b/torch/_dynamo/polyfills/os.py deleted file mode 100644 index 013cc81ac30e6f..00000000000000 --- a/torch/_dynamo/polyfills/os.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Python polyfills for os -""" - -from __future__ import annotations - -import os -from typing import AnyStr - -from ..decorators import substitute_in_graph - - -__all__ = ["fspath"] - - -# Copied from os.py in the standard library -@substitute_in_graph(os.fspath) -def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: - if isinstance(path, (str, bytes)): - return path - - path_type = type(path) - try: - path_repr = path_type.__fspath__(path) # type: ignore[arg-type] - except AttributeError: - if hasattr(path_type, "__fspath__"): - raise - raise TypeError( - f"expected str, bytes or os.PathLike object, not {path_type.__name__}", - ) from None - if isinstance(path_repr, (str, bytes)): - return path_repr # type: ignore[return-value] - raise TypeError( - f"expected {path_type.__name__}.__fspath__() to return str or bytes, " - f"not {type(path_repr).__name__}", - ) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 0e3ae4a77a3632..7dfbec067914c4 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -4,6 +4,7 @@ import functools import inspect import itertools +import os import random import re import sys @@ -14,7 +15,7 @@ import torch._numpy as tnp import torch.utils._pytree as pytree -from .. import config, variables +from .. import config, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented @@ -29,6 +30,7 @@ ) from ..utils import ( check_unspec_or_constant_args, + hashable, identity, is_tensor_base_attr_getter, proxy_args_kwargs, @@ -42,6 +44,10 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator +POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS = { + os.fspath: polyfills.fspath, +} + class SuperVariable(VariableTracker): _nonvar_fields = { @@ -1107,6 +1113,12 @@ def var_getattr(self, tx: "InstructionTranslator", name): else: attr_value = self.value.__dict__[name] + if hashable(attr_value): + if polyfill_fn := POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS.get( + attr_value, None + ): + return variables.UserFunctionVariable(polyfill_fn) + if self.source: new_source = AttrSource(self.source, name) return VariableBuilder(tx, new_source)(attr_value) From 99a1d3e94ac18780228a3f4306261ed208519d3e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 15:16:17 +0000 Subject: [PATCH 0009/1018] Revert "[dynamo][itertools] support `itertools.tee` (#133771)" This reverts commit 0e49b2f18e78386c8ed9ce540a8017411c7ab0cd. Reverted https://github.com/pytorch/pytorch/pull/133771 on behalf of https://github.com/ZainRizvi due to Sorry, but this breaks internal tests because of using functools ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2310445169)) --- test/dynamo/test_misc.py | 16 ------------- torch/_dynamo/polyfills/itertools.py | 34 ---------------------------- torch/_dynamo/polyfills/loader.py | 1 - 3 files changed, 51 deletions(-) delete mode 100644 torch/_dynamo/polyfills/itertools.py diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index dc9e99f770111f..9689e2c12e05a4 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10045,22 +10045,6 @@ def fn(l): self.assertEqual(eager, compiled) self.assertEqual(len(counters["graph_break"]), 0) - def test_itertools_tee(self): - counters.clear() - - def fn(l): - a, b = itertools.tee(l) - return list(a), list(b) - - l = [1, 2, 2, 3, 4, 4, 4, 1, 2] - eager = fn(l) - - compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) - compiled = compiled_fn(l) - - self.assertEqual(eager, compiled) - self.assertEqual(len(counters["graph_break"]), 0) - def test_list_iterator_contains(self): def fn(x): it = iter(["my_weight", "not_my_weight"]) diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py deleted file mode 100644 index 7207acd6504b51..00000000000000 --- a/torch/_dynamo/polyfills/itertools.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Python polyfills for itertools -""" - -import itertools -from typing import Iterable, Iterator, Tuple, TypeVar - -from ..decorators import substitute_in_graph - - -__all__ = ["tee"] - - -_T = TypeVar("_T") - - -# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee -@substitute_in_graph(itertools.tee) -def tee(iterable: Iterable[_T], n: int = 2, /) -> Tuple[Iterator[_T], ...]: - iterator = iter(iterable) - shared_link = [None, None] - - def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def] - try: - while True: - if link[1] is None: - link[0] = next(iterator) - link[1] = [None, None] - value, link = link - yield value - except StopIteration: - return - - return tuple(_tee(shared_link) for _ in range(n)) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index a1c97ff0cc3dc1..516ce1b9c94b01 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -14,7 +14,6 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( "builtins", "functools", - "itertools", ) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) From 1190ec080ebc38c40b747a5a1b00b27c2535e1f5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 15:16:17 +0000 Subject: [PATCH 0010/1018] Revert "[dynamo] simplify implementation for `builtins.sum` (#133779)" This reverts commit 8d90392fb02ce5e6854e6b4dbcdc4a7bbd55f8e2. Reverted https://github.com/pytorch/pytorch/pull/133779 on behalf of https://github.com/ZainRizvi due to Sorry, but this breaks internal tests because of using functools ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2310445169)) --- torch/_dynamo/polyfills/builtins.py | 13 +---------- torch/_dynamo/variables/builtin.py | 34 +++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index d9c3c2644c31ba..d04ea11fe3d7f5 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -3,9 +3,7 @@ """ import builtins -import functools -import operator -from typing import Iterable, TypeVar +from typing import Iterable from ..decorators import substitute_in_graph @@ -13,13 +11,9 @@ __all__ = [ "all", "any", - "sum", ] -_T = TypeVar("_T") - - @substitute_in_graph(builtins.all) def all(iterable: Iterable[object], /) -> bool: for elem in iterable: @@ -34,8 +28,3 @@ def any(iterable: Iterable[object], /) -> bool: if elem: return True return False - - -@substitute_in_graph(builtins.sum) # type: ignore[arg-type] -def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] - return functools.reduce(operator.add, iterable, start) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index ca972b282e387b..b4fbbbc794df1d 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1585,6 +1585,40 @@ def call_filter(self, tx: "InstructionTranslator", fn, seq): except NotImplementedError: return + def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): + # Special case for sum on tuple of floats and ints + if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( + isinstance(x, variables.ConstantVariable) + and isinstance(x.value, (int, float)) + for x in seq.items + ): + if start is self._SENTINEL: + return variables.ConstantVariable.create( + sum(x.value for x in seq.items), + ) + if isinstance(start, variables.ConstantVariable) and isinstance( + start.value, (int, float) + ): + return variables.ConstantVariable.create( + sum((x.value for x in seq.items), start=start.value), + ) + if seq.has_unpack_var_sequence(tx): + if start is self._SENTINEL: + start = variables.ConstantVariable.create(0) + items = seq.unpack_var_sequence(tx) + + from .builder import SourcelessBuilder + + return SourcelessBuilder.create(tx, functools.reduce).call_function( + tx, + [ + BuiltinVariable(operator.add), + variables.TupleVariable(items), + start, + ], + {}, + ) + def call_getattr( self, tx: "InstructionTranslator", From a3012e7d8b521b37f6f6b38d171c9bb80fba3d4e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 15:16:17 +0000 Subject: [PATCH 0011/1018] Revert "[dynamo] simplify implementation for `functools.reduce` (#133778)" This reverts commit 6c0b15e3828b8e2a0bd726a3e5d4e98c8ced5efe. Reverted https://github.com/pytorch/pytorch/pull/133778 on behalf of https://github.com/ZainRizvi due to Sorry, but this breaks internal tests because of using functools ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2310445169)) --- torch/_dynamo/polyfills/functools.py | 46 ---------------------------- torch/_dynamo/polyfills/loader.py | 5 +-- torch/_dynamo/trace_rules.py | 1 + torch/_dynamo/variables/builtin.py | 18 ++++++++--- 4 files changed, 16 insertions(+), 54 deletions(-) delete mode 100644 torch/_dynamo/polyfills/functools.py diff --git a/torch/_dynamo/polyfills/functools.py b/torch/_dynamo/polyfills/functools.py deleted file mode 100644 index 329725b41e354a..00000000000000 --- a/torch/_dynamo/polyfills/functools.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Python polyfills for functools -""" - -import functools -from typing import Callable, Iterable, TypeVar - -from ..decorators import substitute_in_graph - - -__all__ = ["reduce"] - - -_T = TypeVar("_T") -_U = TypeVar("_U") - - -class _INITIAL_MISSING: - pass - - -# Reference: https://docs.python.org/3/library/functools.html#functools.reduce -@substitute_in_graph(functools.reduce) -def reduce( - function: Callable[[_U, _T], _U], - iterable: Iterable[_T], - initial: _U = _INITIAL_MISSING, # type: ignore[assignment] - /, -) -> _U: - it = iter(iterable) - - value: _U - if initial is _INITIAL_MISSING: - try: - value = next(it) # type: ignore[assignment] - except StopIteration: - raise TypeError( - "reduce() of empty iterable with no initial value", - ) from None - else: - value = initial - - for element in it: - value = function(value, element) - - return value diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 516ce1b9c94b01..0a7e4aa628ada1 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -11,10 +11,7 @@ from types import ModuleType -POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( - "builtins", - "functools", -) +POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ("builtins",) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) for submodule in POLYFILLED_MODULE_NAMES diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index ccb29ab13b1479..fe026ea5037954 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2999,6 +2999,7 @@ def _builtin_function_ids() -> Dict[int, str]: rv.update( { id(cast): "typing.cast", + id(functools.reduce): "functools.reduce", id(copy.deepcopy): "copy.deepcopy", } ) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index b4fbbbc794df1d..86900dee30bc56 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1606,10 +1606,7 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): if start is self._SENTINEL: start = variables.ConstantVariable.create(0) items = seq.unpack_var_sequence(tx) - - from .builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, functools.reduce).call_function( + return BuiltinVariable(functools.reduce).call_function( tx, [ BuiltinVariable(operator.add), @@ -1619,6 +1616,19 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): {}, ) + def call_reduce( + self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL + ): + if iterable.has_unpack_var_sequence(tx): + items = iterable.unpack_var_sequence(tx) + if initial is self._SENTINEL: + value, items = items[0], items[1:] + else: + value = initial + for element in items: + value = function.call_function(tx, [value, element], {}) + return value + def call_getattr( self, tx: "InstructionTranslator", From cab2d973bca6fcac548138c1e2dbb2efb3bf2681 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 15:19:26 +0000 Subject: [PATCH 0012/1018] Revert "c10d/logging: add C10D_LOCK_GUARD (#134131)" This reverts commit 4c28a0eb0ba437c1b7db559f63f8bec17bd48f69. Reverted https://github.com/pytorch/pytorch/pull/134131 on behalf of https://github.com/ZainRizvi due to Sorry but this causes formatting errors internally which make it fail to build. See D61759282 ([comment](https://github.com/pytorch/pytorch/pull/134131#issuecomment-2310455878)) --- caffe2/CMakeLists.txt | 2 - test/cpp/c10d/CMakeLists.txt | 2 - test/cpp/c10d/LoggingTest.cpp | 50 ------------------- test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 4 +- torch/csrc/distributed/c10d/NCCLUtils.cpp | 10 ++-- torch/csrc/distributed/c10d/NCCLUtils.hpp | 21 ++++---- .../distributed/c10d/ProcessGroupGloo.cpp | 2 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 50 +++++++++---------- .../distributed/c10d/ProcessGroupNCCL.hpp | 18 +++---- torch/csrc/distributed/c10d/Work.cpp | 13 +++-- torch/csrc/distributed/c10d/Work.hpp | 4 +- torch/csrc/distributed/c10d/logging.cpp | 16 ------ torch/csrc/distributed/c10d/logging.h | 17 ------- 13 files changed, 59 insertions(+), 150 deletions(-) delete mode 100644 test/cpp/c10d/LoggingTest.cpp diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index d44a8da210462f..5ee9968ba484fa 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -927,7 +927,6 @@ elseif(USE_CUDA) set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_cuda) # see cmake/public/utils.cmake target_compile_definitions(torch_cuda PRIVATE USE_CUDA) - target_link_libraries(torch_cuda PRIVATE fmt::fmt-header-only) if(USE_CUFILE) target_link_libraries(torch_cuda PRIVATE torch::cufile) @@ -1334,7 +1333,6 @@ if(USE_ROCM) ${ROCM_SOURCE_DIR}/rocblas/include ${ROCM_SOURCE_DIR}/hipsparse/include ) - target_link_libraries(torch_hip PRIVATE fmt::fmt-header-only) if(USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) endif() diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index c64adb4ad0521e..0874852517e33b 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -10,14 +10,12 @@ function(c10d_add_test test_src) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) target_link_libraries(${test_name} ${ARGN}) - target_link_libraries(${test_name} fmt::fmt-header-only) if(NOT WIN32) target_link_libraries(${test_name} pthread) endif() add_test(NAME ${test_name} COMMAND $) endfunction() -c10d_add_test(LoggingTest.cpp torch_cpu gtest_main) c10d_add_test(BackoffTest.cpp torch_cpu gtest_main) c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main) c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main) diff --git a/test/cpp/c10d/LoggingTest.cpp b/test/cpp/c10d/LoggingTest.cpp deleted file mode 100644 index 56e4ad5d3f1df4..00000000000000 --- a/test/cpp/c10d/LoggingTest.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include - -#include -#include - -#include - -TEST(LockGuard, basic) { - std::timed_mutex mutex; - - { - C10D_LOCK_GUARD(lock, mutex); - - // already locked - ASSERT_FALSE(mutex.try_lock()); - } - - ASSERT_TRUE(mutex.try_lock()); - mutex.unlock(); -} - -TEST(LockGuard, logging) { - std::timed_mutex mutex; - - mutex.lock(); - - auto loggingThread = std::async(std::launch::async, [&]() { - std::unique_lock name{mutex, std::defer_lock}; - ::c10d::detail::lockWithLogging( - name, std::chrono::milliseconds(10), "my lock", __FILE__, __LINE__); - }); - - auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10); - while (true) { - ASSERT_LT(std::chrono::system_clock::now(), deadline); - - testing::internal::CaptureStderr(); - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - std::string output = testing::internal::GetCapturedStderr(); - - if (output.find("my lock: waiting for lock for 10ms") != - std::string::npos) { - break; - } - } - - mutex.unlock(); - - loggingThread.get(); -} diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 54a929ab982514..d416847f7911a5 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -180,7 +180,7 @@ class ProcessGroupNCCLNoHeartbeatCaught : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts), hasMonitorThreadCaughtError_(false) {} - std::timed_mutex& getWatchdogMutex() { + std::mutex& getWatchdogMutex() { return workMetaListMutex_; } @@ -413,7 +413,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { work = pg.allreduce(tensors_); { // Now run all reduce with errors. - std::lock_guard lock(pg.getWatchdogMutex()); + std::lock_guard lock(pg.getWatchdogMutex()); LOG(INFO) << "Lock watchdog thread."; // Wait long enough before monitor thread throws exceptions. std::this_thread::sleep_for( diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 46a2426fd5f4c0..16cb91015ae63d 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -21,7 +21,7 @@ constexpr int64_t kCommInitBusyWaitMillis = 10; namespace c10d { ncclComm_t NCCLComm::getNcclComm() { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (aborted_) { auto commFailureMsg = commFailureReason_ != std::nullopt ? c10::str(" Original reason for failure was: ", *commFailureReason_) @@ -391,7 +391,7 @@ std::optional NCCLTraceBuffer::record( } auto traceback = torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); auto te = Entry{ id_, @@ -448,7 +448,7 @@ void NCCLTraceBuffer::record_pg_ranks( if (!enabled_) { return; } - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); pg_name_to_ranks_[pg_name] = ranks; } @@ -468,7 +468,7 @@ void NCCLTraceBuffer::update_state(Entry& r) { } std::vector NCCLTraceBuffer::dump_entries() { - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); std::vector result; result.reserve(entries_.size()); result.insert(result.end(), entries_.begin() + next_, entries_.end()); @@ -493,7 +493,7 @@ void NCCLTraceBuffer::retire_id( Event* endEvent = nullptr; std::optional duration = std::nullopt; - C10D_LOCK_GUARD(guard, mutex_); + std::unique_lock guard(mutex_); Entry* entry = &entries_.at(*id % max_entries_); if (entry->id_ == *id) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index dabd5a7d21429d..4b0a197dbd536a 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ @@ -271,7 +270,7 @@ class NCCLComm { ~NCCLComm() noexcept { // Add lock in this destructor, as aborted_ needs to be read after memory // barrier here. - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (ncclComm_ && initialized_ && !aborted_) { #ifdef ENABLE_NCCL_ERROR_CHECKING // Use ncclCommAbort instead of ncclCommDestroy here since @@ -363,7 +362,7 @@ class NCCLComm { NCCLComm(NCCLComm&& other) { // Using other's lock, as it reads other's states // Can not use this.mutex_, as this object is being constructed. - C10D_LOCK_GUARD(lock, other.mutex_); + std::unique_lock lock(other.mutex_); std::swap(ncclComm_, other.ncclComm_); std::swap(aborted_, other.aborted_); std::swap(ncclAsyncErr_, other.ncclAsyncErr_); @@ -373,13 +372,13 @@ class NCCLComm { ncclComm_t getNcclComm(); std::optional getNcclCommFailureReason() const { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); return commFailureReason_; } void ncclCommAbort( std::optional commFailureReason = std::nullopt) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_ && !initialized_) { // Should not abort twice. @@ -427,7 +426,7 @@ class NCCLComm { } bool isAborted() const { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); return aborted_; } @@ -436,7 +435,7 @@ class NCCLComm { } ncclResult_t checkForNcclError() { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (ncclAsyncErr_ != ncclSuccess) { return ncclAsyncErr_; @@ -451,7 +450,7 @@ class NCCLComm { } ncclResult_t registerSegment(void* ptr, size_t size) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always @@ -482,7 +481,7 @@ class NCCLComm { } ncclResult_t deregisterSegment(void* ptr) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER TORCH_CHECK( registeredSegmentHandles_.count(ptr) == 1, @@ -519,7 +518,7 @@ class NCCLComm { bool aborted_; uint64_t ncclCommSplitCounter_{0}; ncclResult_t ncclAsyncErr_; - mutable std::timed_mutex mutex_; + mutable std::mutex mutex_; // Rank that this communicator corresponds to. int rank_; // Optional reason for communicator failure, provided by ProcessGroupNCCL for @@ -638,7 +637,7 @@ struct NCCLTraceBuffer { bool enabled_ = false; bool capture_cpp_stack_ = false; - std::timed_mutex mutex_; + std::mutex mutex_; std::vector entries_; size_t max_entries_ = 0; size_t next_ = 0; diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 96ea93f7413b6c..51fa248ec403b9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -602,7 +602,7 @@ uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const { } int ProcessGroupGloo::RecvWork::sourceRank() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); return srcRank_; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 1e19fc44b4e620..3b308a360fb733 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -30,7 +30,6 @@ #include #include #include -#include #include #include @@ -302,7 +301,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { // hooks are called outside the scope of any PG, thus we need traverse // communicators in all PGs. static std::unordered_map, int> ncclCommDevIdxMap; -static std::timed_mutex ncclCommDevIdxMapMutex; +static std::mutex ncclCommDevIdxMapMutex; static bool allocatorHooksAttached = false; std::atomic ProcessGroupNCCL::shouldDump_(false); @@ -315,7 +314,7 @@ void cacheAllocatorRegisterHook( return; } - C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex); + std::lock_guard lock(ncclCommDevIdxMapMutex); for (auto& it : ncclCommDevIdxMap) { auto& ncclComm = it.first; auto& devIdx = it.second; @@ -333,7 +332,7 @@ void cacheAllocatorDeregisterHook( return; } - C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex); + std::lock_guard lock(ncclCommDevIdxMapMutex); for (auto& it : ncclCommDevIdxMap) { auto& ncclComm = it.first; auto& devIdx = it.second; @@ -552,7 +551,7 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { } auto exception_ptr = checkForNCCLErrors(); - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); exception_ = exception_ptr; if (exception_) { LOG(ERROR) << logPrefix() << "Collective " << *this @@ -568,7 +567,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { void ProcessGroupNCCL::WorkNCCL::setException( std::exception_ptr exception_ptr) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); exception_ = exception_ptr; } @@ -777,12 +776,12 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( bool timing) { auto deleter = [this, timing](at::cuda::CUDAEvent* event) { - C10D_LOCK_GUARD(lock, this->cacheMutex_); + std::lock_guard lock(this->cacheMutex_); this->eventsArray_[timing ? 1 : 0].push_back(event); }; at::cuda::CUDAEvent* event = nullptr; { - C10D_LOCK_GUARD(lock, cacheMutex_); + std::lock_guard lock(cacheMutex_); auto events = eventsArray_[timing ? 1 : 0]; if (!events.empty()) { event = events.back(); @@ -1087,9 +1086,8 @@ void ProcessGroupNCCL::waitForPendingWorks() { while (true) { { std::lock(workMetaListMutex_, completedWorkListMutex_); - std::lock_guard lockWork( - workMetaListMutex_, std::adopt_lock); - std::lock_guard lockHook( + std::lock_guard lockWork(workMetaListMutex_, std::adopt_lock); + std::lock_guard lockHook( completedWorkListMutex_, std::adopt_lock); if (workMetaList_.empty() && completedWorkList_.empty()) { @@ -1205,7 +1203,7 @@ bool ProcessGroupNCCL::abort(std::optional abortReason) { } ncclCommDevIdxMapMutex.unlock(); - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); abortCommsFromMap(devNCCLCommMap_, abortReason); abortCommsFromMap(inInitializationCommMap_, abortReason); return true; @@ -1274,8 +1272,8 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { // Serialize all calls to this function to avoid corrupting data, but allow // multiple calls in one runtime. User is responsible for preserving the // output file from an earlier call before a later call overwrites it. - static std::timed_mutex writeDebugInfoMutex; - C10D_LOCK_GUARD(lock, writeDebugInfoMutex); + static std::mutex writeDebugInfoMutex; + std::lock_guard lock(writeDebugInfoMutex); LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; if (ncclTraceBufferSize_ > 0) { // We dump nccl trace into local disk by default and users can register @@ -1354,7 +1352,7 @@ void ProcessGroupNCCL::heartbeatMonitor() { // This won't have any lock since this lock is only used here. // Please be aware that mutex `monitorMutex_` should not be used // somewhere else to avoid the deadlock. - C10D_LOCK_GUARD(lock, monitorMutex_); + std::unique_lock lock(monitorMutex_); if (monitorWakeUpCV_.wait_for( lock, std::chrono::milliseconds(monitorPollInterval), [&] { return terminateHeartbeatMonitorThread_.load(); @@ -1679,7 +1677,7 @@ const std::vector& ProcessGroupNCCL::groupRanks() const { void ProcessGroupNCCL::addEphemeralTimeout( const std::chrono::milliseconds& timeout) { - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); ephemeralTimeoutActive_ += timeout; } @@ -1702,7 +1700,7 @@ void ProcessGroupNCCL::watchdogHandler() { std::list completedWorkList; while (!done || !terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, workMetaListMutex_); + std::unique_lock lock(workMetaListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. workMetaListCV_.wait_for( @@ -1874,7 +1872,7 @@ void ProcessGroupNCCL::watchdogHandler() { if (work.isCompleted()) { { // Reset the timeout and first work if the work is completed. - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); if (work.ownedEphermeralTimeout_.count() > 0) { ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; @@ -1889,7 +1887,7 @@ void ProcessGroupNCCL::watchdogHandler() { // Move Work object to completedWorkList_ to be consumed by the hook // thread { - C10D_LOCK_GUARD(lock, completedWorkListMutex_); + const std::lock_guard lock(completedWorkListMutex_); completedWorkList_.splice( completedWorkList_.end(), workMetaList_, it++); } @@ -1917,7 +1915,7 @@ void ProcessGroupNCCL::runHookLoop() { bool done = false; while (!done || !terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, completedWorkListMutex_); + std::unique_lock lock(completedWorkListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. completedWorkListCV_.wait_for( @@ -2090,7 +2088,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( } void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { TORCH_INTERNAL_ASSERT( false, @@ -2138,7 +2136,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( usedDeviceIdxs_.insert(device.index()); { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { // Reuse the cached communicator if there is one. return devNCCLCommMap_[deviceKey]; @@ -2214,7 +2212,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( options_->split_color != 0, "Must specify a non-zero color when splitting"); // Find a valid, healthy communicator to split from if possible. - C10D_LOCK_GUARD(lock, options_->split_from->mutex_); + std::lock_guard lock(options_->split_from->mutex_); auto& other_comms = options_->split_from->devNCCLCommMap_; auto dit = other_comms.find(getKeyFromDevice(device)); if (dit != other_comms.end()) { @@ -2268,7 +2266,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( options_->is_high_priority_stream || force_high); { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); inInitializationCommMap_.emplace(deviceKey, ncclComm); } @@ -2518,7 +2516,7 @@ void ProcessGroupNCCL::assignTimeoutToWork( const c10::intrusive_ptr& work, const c10::intrusive_ptr& option) { std::chrono::milliseconds timeout = option->timeout; - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); if (ephemeralTimeoutActive_.count() > 0) { timeout += ephemeralTimeoutActive_; } @@ -2531,7 +2529,7 @@ void ProcessGroupNCCL::assignTimeoutToWork( void ProcessGroupNCCL::workEnqueue( c10::intrusive_ptr work) { if (!terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, workMetaListMutex_); + std::lock_guard lock(workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. // View tensors' destruction invokes autograd_meta, which // needs to be destructed in user thread. Otherwise will diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 2ba68cda9cc3ea..28615055a72bb4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -449,7 +449,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { static CUDAEventCache& get(); private: - std::timed_mutex cacheMutex_; + std::mutex cacheMutex_; // NOTE: We intentionaly store raw pointers so that // we do not attempt to destroy the event objects on process exit, // because cuda may be gone. @@ -918,7 +918,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // ephemeralTimeoutActive_/ephemeralTimeoutInflight_. // TODO(fduwjj): We need to have an audit on all mutexes we are adding here. // And consolidate them if possible. - std::timed_mutex mtxTimeoutExtension_; + std::mutex mtxTimeoutExtension_; // The ephemeral timeout added on top of existing timeout for works issued // before first work finishes. @@ -978,7 +978,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { inInitializationCommMap_; // Mutex to guard maps like devNCCLCommMap_. - std::timed_mutex mutex_; + std::mutex mutex_; // Heartbeat of watchdog thread. std::atomic_uint64_t heartbeat_; @@ -1039,18 +1039,18 @@ class TORCH_API ProcessGroupNCCL : public Backend { static std::atomic shouldDump_; // Mutex to Guard workMetaList_ - std::timed_mutex workMetaListMutex_; + std::mutex workMetaListMutex_; // Mutex to Guard monitorWakeUpCV_ - std::timed_mutex monitorMutex_; + std::mutex monitorMutex_; bool writeDebugInfo_ = false; // Condition Variable for watchdog thread sleep - std::condition_variable_any workMetaListCV_; + std::condition_variable workMetaListCV_; // Condition Variable for monitor thread to wake up early - std::condition_variable_any monitorWakeUpCV_; + std::condition_variable monitorWakeUpCV_; // Vector to Store WorkNCCL pointers std::list workMetaList_; @@ -1058,10 +1058,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::chrono::time_point lastWorkListUpdateTime_; // Mutex to Guard workMetaList_ - std::timed_mutex completedWorkListMutex_; + std::mutex completedWorkListMutex_; // Condition Variable for watchdog thread sleep - std::condition_variable_any completedWorkListCV_; + std::condition_variable completedWorkListCV_; std::list completedWorkList_; diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index 6a45392bc86231..8beb8f29362080 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,7 +1,6 @@ #include #include -#include #include namespace c10d { @@ -46,17 +45,17 @@ OpType Work::retrieveOpType() const { Work::~Work() = default; bool Work::isCompleted() { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return completed_; } bool Work::isSuccess() const { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return !exception_; } std::exception_ptr Work::exception() const { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return exception_; } @@ -74,7 +73,7 @@ std::vector Work::result() { void Work::synchronize() {} bool Work::wait(std::chrono::milliseconds timeout) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (timeout == kNoTimeout) { // This waits without a timeout. cv_.wait(lock, [&] { return completed_; }); @@ -104,7 +103,7 @@ c10::intrusive_ptr Work::getFuture() { } void Work::finish(std::exception_ptr exception) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { @@ -116,7 +115,7 @@ void Work::finish(std::exception_ptr exception) { } void Work::finishAndThrow(std::exception_ptr exception) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index ea5853aefb4049..c10e5007b9f544 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -126,8 +126,8 @@ class TORCH_API Work : public torch::CustomClassHolder { // provided by the user. void finishAndThrow(std::exception_ptr exception); - mutable std::timed_mutex mutex_; - std::condition_variable_any cv_; + mutable std::mutex mutex_; + std::condition_variable cv_; bool completed_ = false; std::exception_ptr exception_; diff --git a/torch/csrc/distributed/c10d/logging.cpp b/torch/csrc/distributed/c10d/logging.cpp index 68087312fd68cf..5d05b5a3a5a88d 100644 --- a/torch/csrc/distributed/c10d/logging.cpp +++ b/torch/csrc/distributed/c10d/logging.cpp @@ -34,20 +34,4 @@ bool isLogLevelEnabled(LogLevel level) noexcept { return false; } -void lockWithLogging( - std::unique_lock& lock, - std::chrono::milliseconds log_interval, - c10::string_view desc, - c10::string_view file, - int line) { - while (!lock.try_lock_for(log_interval)) { - C10D_WARNING( - "{}:{} {}: waiting for lock for {}ms", - file, - line, - desc, - log_interval.count()); - } -} - } // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/logging.h b/torch/csrc/distributed/c10d/logging.h index 38679a3ac7d42e..b4df02e773807e 100644 --- a/torch/csrc/distributed/c10d/logging.h +++ b/torch/csrc/distributed/c10d/logging.h @@ -6,7 +6,6 @@ #pragma once -#include #include #include @@ -25,16 +24,6 @@ std::string formatLogMessage(fmt::string_view fmt, T&&... args) { return fmt::vformat(fmt, fmt::make_format_args(args...)); } -// logWithLogging is a wrapper around std::unique_lock -// that automatically logs if the lock cannot be acquired within a given -// timeout. -TORCH_API void lockWithLogging( - std::unique_lock& lock, - std::chrono::milliseconds log_interval, - c10::string_view desc, - c10::string_view file, - int line); - } // namespace detail } // namespace c10d @@ -60,9 +49,3 @@ TORCH_API void lockWithLogging( #define C10D_TRACE(...) \ LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Trace)) \ << "[c10d - trace] " << c10d::detail::formatLogMessage(__VA_ARGS__) - -// TODO: use std::source_location() when we can use C++20 -#define C10D_LOCK_GUARD(name, mutex) \ - std::unique_lock name{mutex, std::defer_lock}; \ - ::c10d::detail::lockWithLogging( \ - name, std::chrono::seconds(30), #mutex, __FILE__, __LINE__) From b3c6f7b4a858c6dba81920fe839a64ee9792ee61 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 15:31:23 +0000 Subject: [PATCH 0013/1018] Revert "[inductor]Let output or input_as_strided match exact strides (#130956)" This reverts commit a63efee5cd422db0aabe5d02d2fe35fef9be7978. Reverted https://github.com/pytorch/pytorch/pull/130956 on behalf of https://github.com/ZainRizvi due to sorry but this seems to cause internal tests to fail. Please see D61771533 for details ([comment](https://github.com/pytorch/pytorch/pull/130956#issuecomment-2310490049)) --- test/inductor/test_torchinductor.py | 18 --- torch/_inductor/graph.py | 63 ++++------- torch/_inductor/ir.py | 164 ++++++---------------------- 3 files changed, 51 insertions(+), 194 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bc22dc3852a121..4db4fd0a2f8dfa 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -7257,24 +7257,6 @@ def fn_channels_last(x): [torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)], ) - def test_exact_stride(self): - full = torch.randn((16, 16), device=self.device) - view = torch.as_strided(full, (16, 8), full.stride()) - - def fn(x): - result = x + x - result_strided = torch.empty_strided( - x.size(), x.stride(), device=self.device - ) - result_strided[:] = result - return result_strided - - self.common(fn, [view]) - reference_out = fn(view) - compiled_fn = torch.compile(fn) - actual_out = compiled_fn(view) - self.assertEqual(reference_out.stride(), actual_out.stride()) - def test_like_channels_last(self): def foo(): randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 514c5b2f9b2b97..47bcb068486532 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -186,9 +186,7 @@ def getattr_recursive( return attr_itr -def mark_nodes_dislike_padding( - g: Graph, user_visible_outputs: Optional[Dict[str, None]] -) -> None: +def mark_nodes_dislike_padding(g: Graph) -> None: """ Nodes like convolution/convolution_backward want its input to be dense. If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction. @@ -235,9 +233,7 @@ def _get_overload_packet( op = _get_overload_packet(cur) if not op: continue - if op in ops_dislike_padding or ( - user_visible_outputs and cur.name in user_visible_outputs - ): + if op in ops_dislike_padding: cur.meta["dislike_padding"] = True if cur.meta.get("dislike_padding", False): @@ -419,11 +415,11 @@ def __init__( self.nodes_prefer_channels_last = ( self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet() ) + mark_nodes_dislike_padding(gm.graph) self._warned_fallback = {"aten.convolution_backward"} self.user_visible_outputs = ( user_visible_outputs if user_visible_outputs is not None else {} ) - mark_nodes_dislike_padding(gm.graph, user_visible_outputs) self.cache_key: str = "" # This is the cache key for the compiled artifact self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored self.cache_linemap: List[ @@ -1356,48 +1352,27 @@ def debug(msg: str) -> None: n.meta["val"], torch.Tensor ): strides = n.meta["val"].stride() - if len(strides): - allow_padding = ( - n.name not in self.user_visible_outputs - and not is_input_for_as_strided - ) - dense = torch._prims_common.is_non_overlapping_and_dense( - n.meta["val"] - ) - unbacked_symbols_in_strides = ( - len(free_unbacked_symbols(strides)) > 0 - ) + dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]) + unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0 + # requiring a stride order for a non-dense output wouldn't + # recreate the same strides, and would fail with view, defer for now. + if not unbacked_symbols_in_strides and dense and len(strides): + stride_order = ir.get_stride_order(strides) if ( - not unbacked_symbols_in_strides - and dense - and len(result.get_size()) == 4 + len(result.get_size()) == 4 and n in self.nodes_prefer_channels_last and n.name not in self.user_visible_outputs and not is_input_for_as_strided ): - strides = ir.FlexibleLayout.stride_ordered_for_memory_format( - result.get_size(), torch.channels_last - ) - if not unbacked_symbols_in_strides and len(strides): - # To avoid converting possible view ops to a copy kernel, we use the previous - # require_exact_strides to handle views. But ultimately it's better to require - # the right strides at the tensor definition. - if n.meta["val"]._is_view() or isinstance( - result.data, ir.BaseView - ): - result = ir.ExternKernel.require_stride_order( - result, - ir.get_stride_order(strides), - allow_padding=allow_padding, - ) - else: - strides = [ - s.node.expr if isinstance(s, torch.SymInt) else s - for s in strides - ] - result = ir.ExternKernel.require_exact_strides( - result, strides, allow_padding=allow_padding - ) + stride_order = ir.NHWC_STRIDE_ORDER + + allow_padding = ( + n.name not in self.user_visible_outputs + and not is_input_for_as_strided + ) + result = ir.ExternKernel.require_stride_order( + result, stride_order, allow_padding=allow_padding + ) # Realize if (1) any user need inputs realized, or (2) there is # already too many reads and rematerializing can be bad. diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8f5cb5d3b06290..a5a50667b38009 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -745,22 +745,6 @@ def welford_combine_fn( raise NotImplementedError(f"unknown reduction_type={reduction_type}") -def significant_strides_equal( - strides1: Sequence[_IntLike], strides2: Sequence[_IntLike], size: Sequence[_IntLike] -) -> bool: - """ - Returns true if the strides are equal, ignoring dimensions of size 1 . - """ - non_1_indices = [ - i - for i, dim in enumerate(size) - if V.graph.sizevars.size_hint(dim, fallback=2) != 1 - ] - strides1 = [V.graph.sizevars.size_hint(strides1[i]) for i in non_1_indices] - strides2 = [V.graph.sizevars.size_hint(strides2[i]) for i in non_1_indices] - return strides1 == strides2 - - @dataclasses.dataclass class Reduction(Loops): reduction_ranges: List[Expr] @@ -2101,7 +2085,6 @@ def as_storage_and_layout( want_contiguous: bool = False, stride_order: Optional[Sequence[Union[int, Integer]]] = None, allow_padding: bool = False, - exact_strides: Optional[Sequence[Union[int, Integer]]] = None, ) -> Tuple[StorageBox, Layout]: """ Try to simplify x into a StorageBox and a Layout. @@ -2116,7 +2099,6 @@ def as_storage_and_layout( want_contiguous=want_contiguous, stride_order=stride_order, allow_padding=allow_padding, - exact_strides=exact_strides, ) if isinstance(x, StorageBox) and isinstance(x.data, Buffer): if freeze: @@ -2127,10 +2109,6 @@ def as_storage_and_layout( x.data.freeze_layout_with_stride_order( stride_order, allow_padding=allow_padding ) - elif exact_strides is not None: - x.data.freeze_layout_with_exact_strides( - exact_strides, allow_padding=allow_padding - ) else: x.data.decide_layout() return x, x.data.layout @@ -3224,19 +3202,6 @@ def as_stride_order(self, order, allow_padding=False): self.offset, ) - def as_exact_strides(self, exact_strides, allow_padding=False): - new_stride = exact_strides - if self.should_pad_strides() and allow_padding: - new_stride = self._pad_strides(new_stride, self.size, self.dtype) - - return FixedLayout( - self.device, - self.dtype, - self.size, - new_stride, - self.offset, - ) - def as_fill_order(self, order): new_stride = self.fill_ordered(self.size, order) if self.should_pad_strides(): @@ -3463,12 +3428,6 @@ def freeze_layout_with_same_order(self, stride): assert isinstance(self.layout, FlexibleLayout) self.layout = self.layout.as_same_order(stride) - def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False): - assert isinstance(self.layout, FlexibleLayout) - self.layout = self.layout.as_exact_strides( - exact_strides, allow_padding=allow_padding - ) - def is_zero_elements(self): return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] @@ -4721,93 +4680,53 @@ def require_stride1(cls, x): return cls.copy_input(x) @classmethod - def require_strides( - cls, - x, - order: Optional[Sequence[int]] = None, - exact_strides: Optional[Sequence[_IntLike]] = None, - allow_padding=False, - ): - assert order is not None or exact_strides is not None + def require_stride_order(cls, x, order, allow_padding=False): if x.get_numel() == 0: # Layout doesn't matter return x - # require x to have the layout + + # require x to have the layout as strided_ordered as order if is_storage_and_layout(x): while isinstance(x.get_layout(), NonOwningLayout): x = x.get_layout().view if isinstance(x.get_layout(), FlexibleLayout): - if order: - # If the the FlexibleLayout already has the size and stride in the required order, - # freeze it to a FixedLayout by using its current size and stride. - # The behavior of using its current size and stride or the given order can be different - # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: - # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), - # the current size and stride already satisfies this order. - # However by freezing it to the required order, the layout will be changed to: - # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. - - # fix flexiblelayout to be FixedLayout with stride_order - as_storage_and_layout( - x, - freeze=True, - want_contiguous=False, - stride_order=get_stride_order( - V.graph.sizevars.size_hints(x.get_layout().stride) - ) - if is_stride_order_storage_and_layout(x, order) - else order, - allow_padding=allow_padding, - ) - return x - else: - # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides. - as_storage_and_layout( - x, - freeze=True, - want_contiguous=False, - stride_order=None, - allow_padding=allow_padding, - exact_strides=exact_strides, - ) - return x - elif isinstance(x.get_layout(), FixedLayout) and ( - (order and x.get_layout().is_stride_ordered(order)) - or ( - exact_strides - and significant_strides_equal( - exact_strides, x.get_layout().stride, x.get_size() + # If the the FlexibleLayout already has the size and stride in the required order, + # freeze it to a FixedLayout by using its current size and stride. + # The behavior of using its current size and stride or the given order can be different + # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: + # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), + # the current size and stride already satisfies this order. + # However by freezing it to the required order, the layout will be changed to: + # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. + + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=get_stride_order( + V.graph.sizevars.size_hints(x.get_layout().stride) ) + if is_stride_order_storage_and_layout(x, order) + else order, + allow_padding=allow_padding, ) - ): + return x + elif isinstance( + x.get_layout(), FixedLayout + ) and x.get_layout().is_stride_ordered(order): return x elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE): if isinstance(x.get_layout().real_layout(), FlexibleLayout): raise AssertionError( "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" ) - elif isinstance(x.get_layout().real_layout(), FixedLayout) and ( - (order and x.get_layout().real_layout().is_stride_ordered(order)) - or ( - exact_strides - and significant_strides_equal( - exact_strides, - x.get_layout().real_layout().stride, - x.get_size(), - ) - ) - ): + elif isinstance( + x.get_layout().real_layout(), FixedLayout + ) and x.get_layout().real_layout().is_stride_ordered(order): return x # TODO - Storage to InputBuffer - if isinstance(x, InputBuffer) and ( - (order and x.get_layout().is_stride_ordered(order)) - or ( - exact_strides - and significant_strides_equal( - exact_strides, x.get_layout().stride, x.get_size() - ) - ) - ): + if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): return x if ( isinstance(x, TensorBox) @@ -4818,14 +4737,7 @@ def require_strides( ): try: x.data = cls.convert_to_reinterpret_view(x.data) - if order: - return cls.require_stride_order( - x, order, allow_padding=allow_padding - ) - elif exact_strides: - return cls.require_exact_strides( - x, exact_strides, allow_padding=allow_padding - ) + return cls.require_stride_order(x, order, allow_padding=allow_padding) except NotImplementedError: pass # Although this is a clone, inductor is good about fusing clones into previous @@ -4837,22 +4749,10 @@ def require_strides( want_contiguous=False, stride_order=order, allow_padding=allow_padding, - exact_strides=exact_strides, ) - if order: - assert is_stride_order_storage_and_layout(x, order) + assert is_stride_order_storage_and_layout(x, order) return x - @classmethod - def require_exact_strides(cls, x, exact_strides, allow_padding=False): - return cls.require_strides( - x, exact_strides=exact_strides, allow_padding=allow_padding - ) - - @classmethod - def require_stride_order(cls, x, order, allow_padding=False): - return cls.require_strides(x, order=order, allow_padding=allow_padding) - @classmethod def require_channels_last(cls, x): return cls.require_stride_order(x, NHWC_STRIDE_ORDER) From 886dc258da84cf0cd4ea17b6e0d30c85a03e2411 Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 26 Aug 2024 15:32:43 +0000 Subject: [PATCH 0014/1018] Add linux arm64 ephemeral runners (#134469) Should be landed with: https://github.com/pytorch/test-infra/pull/5593 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134469 Approved by: https://github.com/jeanschmidt, https://github.com/clee2000 --- .github/lf-canary-scale-config.yml | 24 ++++++++++++++++++++++++ .github/lf-scale-config.yml | 24 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/.github/lf-canary-scale-config.yml b/.github/lf-canary-scale-config.yml index aaa0e21c92ef12..b7f7c7d9cd3efd 100644 --- a/.github/lf-canary-scale-config.yml +++ b/.github/lf-canary-scale-config.yml @@ -289,6 +289,30 @@ runner_types: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 am2: ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 + lf.c.linux.arm64.2xlarge.ephemeral: + disk_size: 256 + instance_type: t4g.2xlarge + is_ephemeral: true + max_available: 200 + os: linux + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + variants: + amz2023: + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + am2: + ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 + lf.c.linux.arm64.m7g.4xlarge.ephemeral: + disk_size: 256 + instance_type: m7g.4xlarge + is_ephemeral: true + max_available: 200 + os: linux + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + variants: + amz2023: + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + am2: + ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.c.linux.arm64.m7g.metal: disk_size: 256 instance_type: m7g.metal diff --git a/.github/lf-scale-config.yml b/.github/lf-scale-config.yml index 338dc4c534bb22..a0b6e30375be02 100644 --- a/.github/lf-scale-config.yml +++ b/.github/lf-scale-config.yml @@ -289,6 +289,30 @@ runner_types: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 am2: ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 + lf.linux.arm64.2xlarge.ephemeral: + disk_size: 256 + instance_type: t4g.2xlarge + is_ephemeral: true + max_available: 200 + os: linux + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + variants: + amz2023: + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + am2: + ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 + lf.linux.arm64.m7g.4xlarge.ephemeral: + disk_size: 256 + instance_type: m7g.4xlarge + is_ephemeral: true + max_available: 200 + os: linux + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + variants: + amz2023: + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + am2: + ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.linux.arm64.m7g.metal: disk_size: 256 instance_type: m7g.metal From 23054fb3995cf1b1ffaa880065f3d0d08a05d715 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Mon, 26 Aug 2024 15:35:48 +0000 Subject: [PATCH 0015/1018] [CD] fix xpu nightly wheel test env (#134395) (#134464) Due to the https://github.com/pytorch/builder/pull/1972 landed, it will source xpu env duplicated in nightly wheel test. Works for https://github.com/pytorch/pytorch/issues/114850 Realnd of #134395 to be landed with pytorchmergebot Pull Request resolved: https://github.com/pytorch/pytorch/pull/134464 Approved by: https://github.com/jeanschmidt Co-authored-by: Wang, Chuanqi --- .circleci/scripts/binary_linux_test.sh | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 5d92c9099bff99..c30814982ef9e6 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -116,12 +116,6 @@ if [[ "$PACKAGE_TYPE" == libtorch ]]; then cd /tmp/libtorch fi -if [[ "$GPU_ARCH_TYPE" == xpu ]]; then - # Refer https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpu/2-5.html - source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh - source /opt/intel/oneapi/pti/latest/env/vars.sh -fi - # Test the package /builder/check_binary.sh From fa9aaf9d8f48ce00be6471fc06c6d91a8675d615 Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 26 Aug 2024 16:33:19 +0000 Subject: [PATCH 0016/1018] Use ephemeral runners for windows nightly builds (#134463) This is definition of windows.4xlarge: ``` windows.4xlarge: disk_size: 256 instance_type: c5d.4xlarge is_ephemeral: true max_available: 420 os: windows ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134463 Approved by: https://github.com/jeanschmidt --- .github/actionlint.yaml | 1 + .../windows_binary_build_workflow.yml.j2 | 4 +++ ...generated-windows-binary-conda-nightly.yml | 32 +++++++++---------- ...-windows-binary-libtorch-debug-nightly.yml | 8 ++--- ...indows-binary-libtorch-release-nightly.yml | 8 ++--- ...generated-windows-binary-wheel-nightly.yml | 32 +++++++++---------- 6 files changed, 45 insertions(+), 40 deletions(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 3408ca3b1c3b88..919c4b4db7b7ed 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -60,6 +60,7 @@ self-hosted-runner: # Organization wide AWS Windows runners - windows.g4dn.xlarge - windows.g4dn.xlarge.nonephemeral + - windows.4xlarge - windows.4xlarge.nonephemeral - windows.8xlarge.nvidia.gpu - windows.8xlarge.nvidia.gpu.nonephemeral diff --git a/.github/templates/windows_binary_build_workflow.yml.j2 b/.github/templates/windows_binary_build_workflow.yml.j2 index 06646f449fc64c..3b087fb4f6ead6 100644 --- a/.github/templates/windows_binary_build_workflow.yml.j2 +++ b/.github/templates/windows_binary_build_workflow.yml.j2 @@ -66,7 +66,11 @@ jobs: !{{ config["build_name"] }}-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type + {%- if branches == "nightly" %} + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + {%- else %} runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + {%- endif %} timeout-minutes: !{{ common.timeout_minutes }} !{{ upload.binary_env(config, True) }} {%- if config.pytorch_extra_install_requirements is defined and config.pytorch_extra_install_requirements|d('')|length > 0 %} diff --git a/.github/workflows/generated-windows-binary-conda-nightly.yml b/.github/workflows/generated-windows-binary-conda-nightly.yml index dac8a32e14a648..8cc1b90e8e8e53 100644 --- a/.github/workflows/generated-windows-binary-conda-nightly.yml +++ b/.github/workflows/generated-windows-binary-conda-nightly.yml @@ -43,7 +43,7 @@ jobs: conda-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -288,7 +288,7 @@ jobs: conda-py3_9-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -536,7 +536,7 @@ jobs: conda-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -784,7 +784,7 @@ jobs: conda-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1032,7 +1032,7 @@ jobs: conda-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1277,7 +1277,7 @@ jobs: conda-py3_10-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1525,7 +1525,7 @@ jobs: conda-py3_10-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1773,7 +1773,7 @@ jobs: conda-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2021,7 +2021,7 @@ jobs: conda-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2266,7 +2266,7 @@ jobs: conda-py3_11-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2514,7 +2514,7 @@ jobs: conda-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2762,7 +2762,7 @@ jobs: conda-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3010,7 +3010,7 @@ jobs: conda-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3255,7 +3255,7 @@ jobs: conda-py3_12-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3503,7 +3503,7 @@ jobs: conda-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3751,7 +3751,7 @@ jobs: conda-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index 8f1576d39565ad..e59a82e63cb24a 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -43,7 +43,7 @@ jobs: libtorch-cpu-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -300,7 +300,7 @@ jobs: libtorch-cuda11_8-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -560,7 +560,7 @@ jobs: libtorch-cuda12_1-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -820,7 +820,7 @@ jobs: libtorch-cuda12_4-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index c2d1b7009fff8d..b41d421322e286 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -43,7 +43,7 @@ jobs: libtorch-cpu-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -300,7 +300,7 @@ jobs: libtorch-cuda11_8-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -560,7 +560,7 @@ jobs: libtorch-cuda12_1-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -820,7 +820,7 @@ jobs: libtorch-cuda12_4-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 7088e10b5c197b..287558aa22ae23 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -43,7 +43,7 @@ jobs: wheel-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -289,7 +289,7 @@ jobs: wheel-py3_9-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -538,7 +538,7 @@ jobs: wheel-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -787,7 +787,7 @@ jobs: wheel-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1036,7 +1036,7 @@ jobs: wheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1282,7 +1282,7 @@ jobs: wheel-py3_10-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1531,7 +1531,7 @@ jobs: wheel-py3_10-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1780,7 +1780,7 @@ jobs: wheel-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2029,7 +2029,7 @@ jobs: wheel-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2275,7 +2275,7 @@ jobs: wheel-py3_11-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2524,7 +2524,7 @@ jobs: wheel-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2773,7 +2773,7 @@ jobs: wheel-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3022,7 +3022,7 @@ jobs: wheel-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3268,7 +3268,7 @@ jobs: wheel-py3_12-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3517,7 +3517,7 @@ jobs: wheel-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3766,7 +3766,7 @@ jobs: wheel-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch From 69d79c12c157298a6e33c1fa106c1a2e2c6fd668 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 16:57:53 +0000 Subject: [PATCH 0017/1018] Revert "[dynamo][guards] De-dupe DUPLICATE_INPUT guard (#134354)" This reverts commit cdb9df5efe78142b7a612ae9c938ddf8a8850d10. Reverted https://github.com/pytorch/pytorch/pull/134354 on behalf of https://github.com/ZainRizvi due to Fails internal tests ([comment](https://github.com/pytorch/pytorch/pull/134272#issuecomment-2310649115)) --- torch/_dynamo/guards.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index d9cc7eb6d44d16..d61ec7f188c98e 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -575,8 +575,6 @@ def __init__( str, torch._C._dynamo.guards.GuardManager ] = {} - self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set() - def guard_on_dict_keys_and_ignore_order(self, example_value, guard): dict_mgr = self.get_guard_manager(guard) if isinstance(dict_mgr, DictGuardManager): @@ -1665,13 +1663,6 @@ def DUPLICATE_INPUT(self, guard, source_b): self._set_guard_export_info(guard, code) if config.enable_cpp_guard_manager: - # Check that the guard has not been inserted already - key = (ref_a, ref_b) - if key in self._cached_duplicate_input_guards: - return - self._cached_duplicate_input_guards.add((ref_a, ref_b)) - self._cached_duplicate_input_guards.add((ref_b, ref_a)) - install_object_aliasing_guard( self.get_guard_manager(guard), self.get_guard_manager_from_source(source_b), From 6d22c8314fe3504d33ad0ba0fbc78549f85ae890 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 16:57:53 +0000 Subject: [PATCH 0018/1018] Revert "[dynamo] Cache _dynamo.disable results (#134272)" This reverts commit a699bd11551e9755bb9238c6b82c369880789397. Reverted https://github.com/pytorch/pytorch/pull/134272 on behalf of https://github.com/ZainRizvi due to Fails internal tests ([comment](https://github.com/pytorch/pytorch/pull/134272#issuecomment-2310649115)) --- test/dynamo/test_decorators.py | 14 -------------- torch/_dynamo/__init__.py | 1 - torch/_dynamo/convert_frame.py | 1 - torch/_dynamo/decorators.py | 9 +-------- 4 files changed, 1 insertion(+), 24 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 463634e3216358..b915e0633f7c11 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -184,20 +184,6 @@ def hook(module, args): all(node.target is not torch.sigmoid for node in gm1.graph.nodes) ) - def test_disable_no_recompile(self): - def gn(x): - return torch.cos(x) - - @torch.compile(backend="eager") - def fn(x): - x = torch.sin(x) - x = torch._dynamo.disable(gn, recursive=True)(x) - return torch.sin(x) - - with torch._dynamo.config.patch(error_on_recompile=True): - for _ in range(5): - fn(torch.randn(4)) - def test_allow_in_graph(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 00286a32750d1e..7f58ba7f7bf7f1 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -107,4 +107,3 @@ def reset_code_caches() -> None: if code: reset_code(code) code_context.clear() - convert_frame.disabled_codes.clear() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 2f0da2be866db1..82ddaca2fa2ebb 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -160,7 +160,6 @@ def clear(self) -> None: input_codes = Tracker() output_codes = Tracker() -disabled_codes: Dict[int, Callable[..., Any]] = {} initial_global_state: Optional[GlobalStateGuard] = None diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 6b557d278a65cf..c83d8d718a62a4 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -10,7 +10,6 @@ from . import trace_rules, variables from .comptime import comptime -from .convert_frame import disabled_codes from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage from .external_utils import is_compiling @@ -56,15 +55,9 @@ def disable(fn=None, recursive=True): """ if recursive: if fn is not None: - id_fn = id(fn) - if cached_fn := disabled_codes.get(id_fn): - return cached_fn - fn = innermost_fn(fn) assert callable(fn) - out = DisableContext()(fn) - disabled_codes[id_fn] = out - return out + return DisableContext()(fn) return DisableContext() else: return skip(fn) From 374d088c77945fa3a2cea2a9f376c4fa1ac6d7c2 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 26 Aug 2024 16:58:04 +0000 Subject: [PATCH 0019/1018] Remove color in CI (#133517) Remove color by default to make CI logs easier to read Example of color image Pull Request resolved: https://github.com/pytorch/pytorch/pull/133517 Approved by: https://github.com/ZainRizvi --- pytest.ini | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytest.ini b/pytest.ini index d3b9f3a9229843..e2ab2ebd0cc164 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,8 +4,6 @@ addopts = -rEfX # Make tracebacks shorter --tb=native - # Color the output - --color=yes # capture only Python print and C++ py::print, but not C output (low-level Python errors) --capture=sys # don't suppress warnings, but don't shove them all to the end either From 0218ce9bca2b9ae2b2e1d77d06f59d23d1da65c7 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 26 Aug 2024 17:24:25 +0000 Subject: [PATCH 0020/1018] Memory optimization for DSD for TorchTune LoRA (#134025) Optimize memory cost at [PR#129635](https://github.com/pytorch/pytorch/pull/129635) There are 2 main part of the optimization here: 1. optimize the tensor distributing part, postpone the full_tensor generation, which avoids the memory overlap, saves around 50% peak memory at 2 param test case. 2. apply `assign=True` for the `load_state_dict`, saves memory cost at the state dict loading by assigning the input param, around 50% peak memory at loading part. Future work: Memory optimization to the opt will be conducted in the next PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/134025 Approved by: https://github.com/fegin Co-authored-by: Rachel Guo --- torch/distributed/_state_dict_utils.py | 11 ++++++----- torch/distributed/checkpoint/state_dict.py | 5 +++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 340cb43e295e8d..b2694f69d720d0 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -627,16 +627,17 @@ def _distribute_state_dict( local_state_dict[key] = value.cpu() else: assert isinstance(value, torch.Tensor) - full_tensor = value.detach().to(device) local_state = local_state_dict.get(key, None) if local_state is None: continue elif isinstance(local_state, DTensor): - local_state_dict[key] = (local_state, full_tensor) + local_state_dict[key] = distribute_tensor( + value.detach().to(device), + local_state.device_mesh, + local_state.placements, + ) else: - local_state_dict[key] = full_tensor - - _distribute_tensors(local_state_dict, [key], device, pg) + local_state_dict[key] = value.detach().to(device) # These APIs are from torch.distributed.checkpoint. diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index f7a7ea799ef03e..cf20ecbfb213a3 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -552,6 +552,7 @@ def _load_model_state_dict( state_dict[fqn_with_prefix] = state_dict.pop(fqn) local_state_dict[fqn_with_prefix] = value + assign = False if info.broadcast_from_rank0 or info.full_state_dict: device = None for key, value in local_state_dict.items(): @@ -563,7 +564,7 @@ def _load_model_state_dict( assert device is not None if device == torch.device("meta"): device = dist.distributed_c10d._get_pg_default_device() - model.to_empty(device=device) + assign = True if info.broadcast_from_rank0: _broadcast_state_dict( state_dict, local_state_dict, device=device, strict=info.strict @@ -577,7 +578,7 @@ def _load_model_state_dict( return cast( _IncompatibleKeys, _state_dict_fn(model, "load_state_dict")( - state_dict=state_dict, strict=info.strict + state_dict=state_dict, strict=info.strict, assign=assign ), ) From 56ae663c5128883752c919e862d79fa9a9189b39 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 26 Aug 2024 17:31:10 +0000 Subject: [PATCH 0021/1018] Revert "Allow mp.start_processes to create processes in parallel (#133707)" This reverts commit 3546628a2a167ace6060737eeccf8ee8fd87ddc0. Reverted https://github.com/pytorch/pytorch/pull/133707 on behalf of https://github.com/ZainRizvi due to sorry but trunk has been consistently broken since this PR was merged. See: [GH job link](https://github.com/pytorch/pytorch/actions/runs/10529617600/job/29191757055) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/3546628a2a167ace6060737eeccf8ee8fd87ddc0) ([comment](https://github.com/pytorch/pytorch/pull/133707#issuecomment-2310709523)) --- .../elastic/multiprocessing/api_test.py | 285 ++++++++---------- test/test_multiprocessing_spawn.py | 107 +------ torch/multiprocessing/__init__.py | 1 - torch/multiprocessing/spawn.py | 58 +--- 4 files changed, 142 insertions(+), 309 deletions(-) diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 9f55785810ca86..75e903807ff9b2 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -226,65 +226,36 @@ def start_processes_zombie_test( pc.close(e.sigval) -class _StartProcessesTest(TestCase): - def setUp(self): - super().setUp() - self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") - self._start_methods = ["spawn"] - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.test_dir) +# tests incompatible with tsan or asan +if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): - def log_dir(self): - return tempfile.mkdtemp(dir=self.test_dir) + class StartProcessesTest(TestCase): + def setUp(self): + super().setUp() + self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") + self._start_methods = ["spawn"] - def assert_in_file(self, expected: List[str], filename: str) -> None: - expected = [f"{line.rstrip()}\n" for line in expected] - with open(filename) as fp: - actual = fp.readlines() - for line in expected: - self.assertIn(line, actual) + def tearDown(self): + super().tearDown() + shutil.rmtree(self.test_dir) - def assert_pids_noexist(self, pids: Dict[int, int]): - for local_rank, pid in pids.items(): - with self.assertRaises( - OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist" - ): - os.kill(pid, 0) - - def _test_zombie_workflow( - self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals - ) -> None: - mp_queue = mp.get_context("spawn").Queue() - child_nproc = 2 - ctx = mp.spawn( - start_processes_zombie_test, - nprocs=1, - args=(entrypoint, mp_queue, self.log_dir(), child_nproc), - join=False, - ) - total_processes = child_nproc + 1 - pids = [] - for _ in range(total_processes): - pids.append(mp_queue.get(timeout=120)) - parent_pid = pids[0] - child_pids = pids[1:] - - os.kill(parent_pid, signal.SIGTERM) - # Wait to give time for signal handlers to finish work - time.sleep(5) - for child_pid in child_pids: - # Killing parent should kill all children, we expect that each call to - # os.kill would raise OSError - with self.assertRaises(OSError): - os.kill(child_pid, 0) + def log_dir(self): + return tempfile.mkdtemp(dir=self.test_dir) + def assert_in_file(self, expected: List[str], filename: str) -> None: + expected = [f"{line.rstrip()}\n" for line in expected] + with open(filename) as fp: + actual = fp.readlines() + for line in expected: + self.assertIn(line, actual) -# tests incompatible with tsan or asan -if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): + def assert_pids_noexist(self, pids: Dict[int, int]): + for local_rank, pid in pids.items(): + with self.assertRaises( + OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist" + ): + os.kill(pid, 0) - class StartProcessesAsFuncTest(_StartProcessesTest): def test_to_map(self): local_world_size = 2 self.assertEqual( @@ -378,13 +349,26 @@ def test_multiprocess_context_close(self): self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped()) + def test_subprocess_context_close(self): + pc = start_processes( + name="sleep", + entrypoint=bin("zombie_test.py"), + args={0: (1,)}, + envs={0: {}}, + logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), + ) + + pids = pc.pids() + pc.close() + self.assert_pids_noexist(pids) + def test_function_with_tensor(self): for start_method in self._start_methods: pc = start_processes( name="dummy_compute", entrypoint=dummy_compute, - args={0: ()}, - envs={0: {}}, + args={}, + envs={}, logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), start_method=start_method, ) @@ -505,55 +489,10 @@ def test_wait_for_all_child_procs_to_exit(self): mpc._poll() self.assertEqual(4, mock_join.call_count) - @skip_but_pass_in_sandcastle_if( - NO_MULTIPROCESSING_SPAWN, - "Disabled for environments that \ - don't support multiprocessing with spawn start method", - ) - def test_multiprocessing_context_poll_raises_exception(self): - mp_context = MultiprocessContext( - name="test_mp", - entrypoint=echo0, - args={0: (0, 1)}, - envs={0: {}}, - logs_specs=DefaultLogsSpecs( - log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL - ), - start_method="spawn", - ) - mp_context._pc = mock.Mock() - # Using mock since we cannot just set exitcode on process - mock_process = mock.Mock() - mock_process.exitcode = -1 - mp_context._pc.processes = [mock_process] - e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123) - mp_context._pc.join.side_effect = e - with mock.patch.object(mp_context, "close"): - run_result = mp_context._poll() - self.assertEqual(1, len(run_result.failures)) - failure = run_result.failures[0] - self.assertEqual( - "Signal 1 (SIGHUP) received by PID 123", failure.message - ) - - class StartProcessesAsBinaryTest(_StartProcessesTest): ######################################## # start_processes as binary tests ######################################## - def test_subprocess_context_close(self): - pc = start_processes( - name="sleep", - entrypoint=bin("zombie_test.py"), - args={0: (1,)}, - envs={0: {}}, - logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), - ) - - pids = pc.pids() - pc.close() - self.assert_pids_noexist(pids) - def test_binary_exit(self): FAIL = 138 pc = start_processes( @@ -617,11 +556,45 @@ def test_validate_full_rank(self): with self.assertRaises(RuntimeError): _validate_full_rank({}, 10, "") + @skip_but_pass_in_sandcastle_if( + NO_MULTIPROCESSING_SPAWN, + "Disabled for environments that \ + don't support multiprocessing with spawn start method", + ) + def test_multiprocessing_context_poll_raises_exception(self): + mp_context = MultiprocessContext( + name="test_mp", + entrypoint=echo0, + args={0: (0, 1)}, + envs={0: {}}, + logs_specs=DefaultLogsSpecs( + log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL + ), + start_method="spawn", + ) + mp_context._pc = mock.Mock() + # Using mock since we cannot just set exitcode on process + mock_process = mock.Mock() + mock_process.exitcode = -1 + mp_context._pc.processes = [mock_process] + e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123) + mp_context._pc.join.side_effect = e + with mock.patch.object(mp_context, "close"): + run_result = mp_context._poll() + self.assertEqual(1, len(run_result.failures)) + failure = run_result.failures[0] + self.assertEqual( + "Signal 1 (SIGHUP) received by PID 123", failure.message + ) + # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): - class StartProcessesListAsFuncTest(_StartProcessesTest): + class StartProcessesListTest(StartProcessesTest): + ######################################## + # start_processes as binary tests + ######################################## def test_function(self): for start_method, redirs in product( self._start_methods, redirects_oss_test() @@ -661,10 +634,6 @@ def test_function(self): [f"hello stderr from {i}"], results.stderrs[i] ) - class StartProcessesListAsBinaryTest(_StartProcessesTest): - ######################################## - # start_processes as binary tests - ######################################## def test_binary(self): for redirs in redirects_oss_test(): with self.subTest(redirs=redirs): @@ -731,7 +700,7 @@ def test_binary_redirect_and_tee(self): # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI): - class StartProcessesNotCIAsFuncTest(_StartProcessesTest): + class StartProcessesNotCITest(StartProcessesTest): @skip_if_pytest def test_wrap_bad(self): none = "" @@ -764,6 +733,32 @@ def test_wrap_bad(self): self.assert_in_file(["hello stderr from 0"], stderr_log) worker_finished_event_mock.wait.assert_called_once() + def test_binary_signal(self): + pc = start_processes( + name="echo", + entrypoint=bin("echo3.py"), + args={0: ("--segfault", "true", "foo"), 1: ("bar",)}, + envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, + logs_specs=DefaultLogsSpecs( + log_dir=self.log_dir(), + ), + ) + + results = pc.wait(period=0.1) + + self.assert_pids_noexist(pc.pids()) + self.assertTrue(results.is_failed()) + self.assertEqual(1, len(results.failures)) + + failure = results.failures[0] + self.assertNotEqual(signal.SIGSEGV, failure.exitcode) + if TEST_WITH_ASAN or TEST_WITH_TSAN: + # ASAN/TSAN exit code is 1. + self.assertEqual("", failure.signal_name()) + else: + self.assertEqual("SIGSEGV", failure.signal_name()) + self.assertEqual("", failure.error_file_data["message"]) + def test_function_redirect_and_tee(self): for start_method in self._start_methods: with self.subTest(start_method=start_method): @@ -876,60 +871,42 @@ def test_function_exit(self): self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped()) - def test_no_zombie_process_function(self): - signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] - for s in signals: - self._test_zombie_workflow(wait_fn, s) - - class StartProcessesNotCIAsBinaryTest(_StartProcessesTest): - def test_binary_signal(self): - pc = start_processes( - name="echo", - entrypoint=bin("echo3.py"), - args={0: ("--segfault", "true", "foo"), 1: ("bar",)}, - envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, - logs_specs=DefaultLogsSpecs( - log_dir=self.log_dir(), - ), - ) - - results = pc.wait(period=0.1) - - self.assert_pids_noexist(pc.pids()) - self.assertTrue(results.is_failed()) - self.assertEqual(1, len(results.failures)) - - failure = results.failures[0] - self.assertNotEqual(signal.SIGSEGV, failure.exitcode) - if TEST_WITH_ASAN or TEST_WITH_TSAN: - # ASAN/TSAN exit code is 1. - self.assertEqual("", failure.signal_name()) - else: - self.assertEqual("SIGSEGV", failure.signal_name()) - self.assertEqual("", failure.error_file_data["message"]) - def test_no_zombie_process_binary(self): signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] for s in signals: self._test_zombie_workflow(bin("zombie_test.py"), s) - class ForkServerTest( - StartProcessesAsFuncTest, - StartProcessesListAsFuncTest, - StartProcessesNotCIAsFuncTest, - ): - def setUp(self): - super().setUp() - self._start_methods = ["forkserver"] - self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) - os.environ[mp.ENV_VAR_PARALLEL_START] = "1" + def test_no_zombie_process_function(self): + signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] + for s in signals: + self._test_zombie_workflow(wait_fn, s) - def tearDown(self): - super().tearDown() - if self.orig_paralell_env_val is None: - del os.environ[mp.ENV_VAR_PARALLEL_START] - else: - os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val + def _test_zombie_workflow( + self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals + ) -> None: + mp_queue = mp.get_context("spawn").Queue() + child_nproc = 2 + ctx = mp.spawn( + start_processes_zombie_test, + nprocs=1, + args=(entrypoint, mp_queue, self.log_dir(), child_nproc), + join=False, + ) + total_processes = child_nproc + 1 + pids = [] + for _ in range(total_processes): + pids.append(mp_queue.get(timeout=120)) + parent_pid = pids[0] + child_pids = pids[1:] + + os.kill(parent_pid, signal.SIGTERM) + # Wait to give time for signal handlers to finish work + time.sleep(5) + for child_pid in child_pids: + # Killing parent should kill all children, we expect that each call to + # os.kill would raise OSError + with self.assertRaises(OSError): + os.kill(child_pid, 0) if __name__ == "__main__": diff --git a/test/test_multiprocessing_spawn.py b/test/test_multiprocessing_spawn.py index df29e07b2227b2..5160056a87f7d0 100644 --- a/test/test_multiprocessing_spawn.py +++ b/test/test_multiprocessing_spawn.py @@ -8,14 +8,9 @@ import time import unittest +from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN) import torch.multiprocessing as mp -from torch.testing._internal.common_utils import ( - IS_WINDOWS, - NO_MULTIPROCESSING_SPAWN, - run_tests, - TestCase, -) def _test_success_func(i): pass @@ -92,7 +87,7 @@ def _test_nested(i, pids_queue, nested_child_sleep, start_method): # Kill self. This should take down the child processes as well. os.kill(os.getpid(), signal.SIGTERM) -class _TestMultiProcessing(TestCase): +class _TestMultiProcessing: start_method = None def test_success(self): @@ -194,11 +189,10 @@ def _test_nested(self): self.assertLess(time.time() - start, nested_child_sleep / 2) time.sleep(0.1) - @unittest.skipIf( NO_MULTIPROCESSING_SPAWN, "Disabled for environments that don't support the spawn start method") -class SpawnTest(_TestMultiProcessing): +class SpawnTest(TestCase, _TestMultiProcessing): start_method = 'spawn' def test_exception_raises(self): @@ -222,103 +216,10 @@ def _test_process_exited(self): IS_WINDOWS, "Fork is only available on Unix", ) -class ForkTest(_TestMultiProcessing): +class ForkTest(TestCase, _TestMultiProcessing): start_method = 'fork' -@unittest.skipIf( - IS_WINDOWS, - "Fork is only available on Unix", -) -class ForkServerTest(_TestMultiProcessing): - start_method = 'forkserver' - - -class _ParallelTest: - orig_paralell_env_val = None - - def setUp(self): - super().setUp() - self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) - os.environ[mp.ENV_VAR_PARALLEL_START] = "1" - - def tearDown(self): - super().tearDown() - if self.orig_paralell_env_val is None: - del os.environ[mp.ENV_VAR_PARALLEL_START] - else: - os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val - - -@unittest.skipIf( - NO_MULTIPROCESSING_SPAWN, - "Disabled for environments that don't support the spawn start method") -class ParallelSpawnShouldFallbackAndWorkTest(SpawnTest, _ParallelTest): - pass - - -@unittest.skipIf( - IS_WINDOWS, - "Fork is only available on Unix", -) -class ParallelForkShouldFallbackAndWorkTest(ForkTest, _ParallelTest): - pass - - -@unittest.skipIf( - IS_WINDOWS, - "Fork is only available on Unix", -) -class ParallelForkServerShouldWorkTest(ForkServerTest, _ParallelTest): - pass - - -@unittest.skipIf( - IS_WINDOWS, - "Fork is only available on Unix", -) -class ParallelForkServerPerfTest(TestCase): - def test_forkserver_perf(self): - start_method = 'forkserver' - expensive = Expensive() - nprocs = 6 - orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) - - # test the non parallel case - os.environ[mp.ENV_VAR_PARALLEL_START] = "0" - start = time.perf_counter() - mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) - elapsed = time.perf_counter() - start - # the time should be at least 6x the sleep time - self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs) - - # test the parallel case - os.environ[mp.ENV_VAR_PARALLEL_START] = "1" - start = time.perf_counter() - mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) - elapsed = time.perf_counter() - start - - # the time should be at most 1x the sleep time + small overhead - self.assertLess(elapsed, Expensive.SLEEP_SECS + 10) - - if orig_paralell_env_val is None: - del os.environ[mp.ENV_VAR_PARALLEL_START] - else: - os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val - - -class Expensive: - SLEEP_SECS = 10 - # Simulate startup overhead such as large imports - time.sleep(SLEEP_SECS) - - def __init__(self): - self.config: str = "*" * 1000000 - - def my_call(self, *args): - pass - - class ErrorTest(TestCase): def test_errors_pickleable(self): for error in ( diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index 745c180d8c415c..f9ea0ad13bdafa 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -39,7 +39,6 @@ """Add helper function to spawn N processes and wait for completion of any of them. This depends `mp.get_context` which was added in Python 3.4.""" from .spawn import ( - ENV_VAR_PARALLEL_START, ProcessContext, ProcessExitedException, ProcessRaisedException, diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 756e9aa670c9ca..aae2e338f7b0a0 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -9,26 +9,13 @@ import tempfile import time import warnings -from concurrent.futures import as_completed, ThreadPoolExecutor from typing import Optional from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] -ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START" - log = logging.getLogger(__name__) -__all__ = [ - "ProcessContext", - "ProcessException", - "ProcessExitedException", - "ProcessRaisedException", - "spawn", - "SpawnContext", - "start_processes", -] - class ProcessException(Exception): __slots__ = ["error_index", "error_pid"] @@ -218,31 +205,12 @@ def __init__(self, processes, error_files): # Currently we only add this API first, we can consider adding it to documentation as # needed in the future. def start_processes( - fn, - args=(), - nprocs=1, - join=True, - daemon=False, - start_method="spawn", + fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn" ): - # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), - # this func will start processes in parallel if start_method is 'forkserver'. - # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1. - # todo: investigate why spawn does not work with threadpool and raises SIGINT - if ( - start_method == "forkserver" - and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1" - ): - start_parallel = True - else: - # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start - start_parallel = False - mp = multiprocessing.get_context(start_method) - error_files = [None] * nprocs - processes = [None] * nprocs - - def start_process(i): + error_files = [] + processes = [] + for i in range(nprocs): # Each process is assigned a file to write tracebacks to. We # use the file being non-empty to indicate an exception # occurred (vs an expected shutdown). Note: this previously @@ -260,21 +228,9 @@ def start_process(i): daemon=daemon, ) process.start() - return i, process, tf.name - - if not start_parallel: - for i in range(nprocs): - idx, process, tf_name = start_process(i) - error_files[idx] = tf_name - processes[idx] = process - else: - with ThreadPoolExecutor(max_workers=nprocs) as executor: - futures = [executor.submit(start_process, i) for i in range(nprocs)] - for fut in as_completed(futures): - idx, process, tf_name = fut.result() - # idx and process rank needs to be the same. - error_files[idx] = tf_name - processes[idx] = process + error_files.append(tf.name) + processes.append(process) + context = ProcessContext(processes, error_files) if not join: return context From 0886787c30de8578b0f712ca9df22ff8978ce088 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 26 Aug 2024 17:38:53 +0000 Subject: [PATCH 0022/1018] [inductor] calibration inductor windows uts (8/N) (#134424) enable Windows inductor UTs of `test/inductor/test_benchmark_fusion.py` Local test pass: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134424 Approved by: https://github.com/ezyang --- test/inductor/test_benchmark_fusion.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 9011e11613cd01..f6d7daba75d2c9 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -7,12 +7,7 @@ from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck -from torch.testing._internal.common_utils import ( - IS_CI, - IS_WINDOWS, - slowTest, - TEST_WITH_ASAN, -) +from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA @@ -23,22 +18,11 @@ import contextlib import unittest +from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests from torch._inductor import config from torch._inductor.scheduler import Scheduler -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - - -from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests - - class TestCase(InductorTestCase): @classmethod def setUpClass(cls): From d10da3a201b769025b2e005ce1f05d0182af124c Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 26 Aug 2024 17:43:34 +0000 Subject: [PATCH 0023/1018] [inductor] calibration inductor windows uts (10/N) (#134426) enable Windows inductor UT of `test/inductor/test_efficient_conv_bn_eval.py` Local test pass: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134426 Approved by: https://github.com/jansel --- test/inductor/test_efficient_conv_bn_eval.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/test/inductor/test_efficient_conv_bn_eval.py b/test/inductor/test_efficient_conv_bn_eval.py index 55633677661753..c186416827276b 100644 --- a/test/inductor/test_efficient_conv_bn_eval.py +++ b/test/inductor/test_efficient_conv_bn_eval.py @@ -4,7 +4,6 @@ import itertools import os import sys -import unittest import torch from torch import nn @@ -17,18 +16,10 @@ from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN +from torch.testing._internal.common_utils import TEST_WITH_ASAN from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - importlib.import_module("functorch") importlib.import_module("filelock") From d449730926614282eb9323990742be1b6aca4791 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 26 Aug 2024 17:43:50 +0000 Subject: [PATCH 0024/1018] [20/N] Fix clang-tidy warnings in jit (#133399) Follows #133067 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133399 Approved by: https://github.com/Skylion007 --- torch/csrc/jit/codegen/fuser/arg_spec.h | 8 +- torch/csrc/jit/codegen/fuser/codegen.cpp | 12 +-- torch/csrc/jit/codegen/fuser/codegen.h | 8 +- torch/csrc/jit/codegen/fuser/compiler.cpp | 8 +- torch/csrc/jit/codegen/fuser/compiler.h | 8 +- .../jit/codegen/fuser/cuda/fused_kernel.cpp | 14 +-- .../jit/codegen/fuser/cuda/fused_kernel.h | 10 +-- .../jit/codegen/fuser/cuda/resource_strings.h | 10 +-- torch/csrc/jit/codegen/fuser/executor.cpp | 15 +--- torch/csrc/jit/codegen/fuser/executor.h | 8 +- torch/csrc/jit/codegen/fuser/fallback.cpp | 8 +- torch/csrc/jit/codegen/fuser/fallback.h | 8 +- torch/csrc/jit/codegen/fuser/fused_kernel.h | 9 +- torch/csrc/jit/codegen/fuser/interface.cpp | 6 +- torch/csrc/jit/codegen/fuser/interface.h | 6 +- torch/csrc/jit/codegen/fuser/kernel_cache.cpp | 10 +-- torch/csrc/jit/codegen/fuser/kernel_cache.h | 11 +-- torch/csrc/jit/codegen/fuser/kernel_spec.h | 24 +++--- torch/csrc/jit/codegen/fuser/partition_desc.h | 9 +- torch/csrc/jit/codegen/fuser/tensor_desc.h | 14 +-- torch/csrc/jit/codegen/fuser/tensor_info.h | 11 +-- torch/csrc/jit/cuda/cuda.h | 7 +- .../csrc/jit/frontend/script_type_parser.cpp | 13 +-- torch/csrc/jit/frontend/tree_views.h | 6 +- torch/csrc/jit/ir/alias_analysis.cpp | 5 +- torch/csrc/jit/ir/ir.h | 4 +- torch/csrc/jit/jit_log.cpp | 1 - torch/csrc/jit/mobile/import_data.cpp | 3 +- torch/csrc/jit/mobile/promoted_prim_ops.cpp | 12 +-- .../onnx/cast_all_constant_to_floating.cpp | 3 +- torch/csrc/jit/passes/onnx/peephole.cpp | 3 +- .../passes/onnx/unpack_quantized_weights.cpp | 11 +-- .../passes/quantization/insert_observers.cpp | 6 +- torch/csrc/jit/runtime/argument_spec.h | 12 +-- torch/csrc/jit/runtime/autodiff.cpp | 6 +- .../csrc/jit/runtime/interpreter/code_impl.h | 3 +- torch/csrc/jit/runtime/register_prim_ops.cpp | 86 ++++++------------- .../csrc/jit/runtime/register_special_ops.cpp | 25 ++---- torch/csrc/jit/runtime/static/fusion.cpp | 3 +- torch/csrc/jit/runtime/static/impl.cpp | 3 +- torch/csrc/jit/runtime/symbolic_script.cpp | 8 +- .../csrc/jit/tensorexpr/operators/conv2d.cpp | 14 +-- 42 files changed, 133 insertions(+), 318 deletions(-) diff --git a/torch/csrc/jit/codegen/fuser/arg_spec.h b/torch/csrc/jit/codegen/fuser/arg_spec.h index 06a72c231b84a5..7239e0391b8fce 100644 --- a/torch/csrc/jit/codegen/fuser/arg_spec.h +++ b/torch/csrc/jit/codegen/fuser/arg_spec.h @@ -8,9 +8,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Describes the (runtime) arguments to a kernel. // ArgSpecs are also used as keys to lookup instantiated kernels, so @@ -55,6 +53,4 @@ struct TORCH_API ArgSpec { int device_; }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index 940444f4ce7edb..04700b905f41bd 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -18,9 +18,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Template for computing the offset into the tensor to access a value static auto dim_calc = at::jit::CodeTemplate(R"( @@ -538,7 +536,7 @@ std::string generateKernel( // places where the constant None node is used // Note: No need to iterate over reference as n is a pointer for (const auto n : graph.nodes()) { - static_assert(std::is_pointer::value, "n must be a pointer"); + static_assert(std::is_pointer_v, "n must be a pointer"); // Note: FusedConcat nodes work by narrowing the output Tensors before the // kernel runs if (n->kind() == prim::FusedConcat) @@ -680,11 +678,9 @@ std::string generateKernel( } if (debugFuser()) { - std::cerr << "fusion code:" << code_string << std::endl; + std::cerr << "fusion code:" << code_string << '\n'; } return code_string; } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/codegen.h b/torch/csrc/jit/codegen/fuser/codegen.h index e42adc1314320e..cab5ccf0eb5e6f 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.h +++ b/torch/csrc/jit/codegen/fuser/codegen.h @@ -9,9 +9,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Creates a CPU or CUDA kernel for the given graph. // Returns the C++ or CUDA string implementing the kernel. @@ -23,6 +21,4 @@ TORCH_API std::string generateKernel( const std::vector>& outputs, const bool use_cuda); -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index 7e03b576d12184..c28a1b89f07699 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -30,9 +30,7 @@ std::mutex& fusionBackendLock() { } } // namespace -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { static std::unordered_map& getFusionBackends() { @@ -297,6 +295,4 @@ std::shared_ptr compileKernel( spec.hasRandom()); } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/compiler.h b/torch/csrc/jit/codegen/fuser/compiler.h index 3d490f6940ab0d..a998c6d459d370 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.h +++ b/torch/csrc/jit/codegen/fuser/compiler.h @@ -11,9 +11,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Performs device-independent "upfront" compilation of the given fusion_group, // if it has not been registered already. @@ -55,6 +53,4 @@ struct TORCH_API RegisterFusionBackend { } }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp index d17da0cbaa21d4..8ff804c4601522 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp @@ -115,14 +115,12 @@ FusedKernelCUDA::FusedKernelCUDA( // Acquires device and NVRTC properties (for compile arch and occupancy // calculations) prop_ = at::cuda::getCurrentDeviceProperties(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int major, minor; + int major = 0, minor = 0; bool compile_to_sass = false; codegenOutputQuery(prop_, major, minor, compile_to_sass); // Creates the NVRTC program - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - nvrtcProgram program; + nvrtcProgram program{}; AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram( &program, code_.c_str(), nullptr, 0, nullptr, nullptr)); @@ -144,17 +142,14 @@ FusedKernelCUDA::FusedKernelCUDA( "compute_" + #endif std::to_string(major) + std::to_string(minor); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const std::vector args = { "--std=c++17", compute.c_str(), "-default-device"}; #endif const auto result = nvrtc().nvrtcCompileProgram(program, args.size(), args.data()); if (result != NVRTC_SUCCESS) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t logsize; + size_t logsize = 0; AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize)); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector log(logsize); AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data())); std::stringstream cu; @@ -164,8 +159,7 @@ FusedKernelCUDA::FusedKernelCUDA( ResourceGuard holdProgram( [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); }); AT_CUDA_NVRTC_CHECK(result); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t ptx_size; + size_t ptx_size = 0; #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 // compile_to_sass determines whether we are generating SASS or PTX, hence // the different API. diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h index 4d39660532e358..5593ad570bbb21 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h @@ -12,10 +12,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { +namespace torch::jit::fuser::cuda { // query codegen output arch and target TORCH_CUDA_CU_API void codegenOutputQuery( @@ -60,7 +57,4 @@ struct TORCH_CUDA_CU_API FusedKernelCUDA CUfunction function_; }; -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser::cuda diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index e6114f818e3318..ff2ef1f2377ce1 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -3,10 +3,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { +namespace torch::jit::fuser::cuda { /*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input. Correct code for this case is generated, however, nvrtc does @@ -405,7 +402,4 @@ __device__ float __bfloat162float(const __nv_bfloat16 a) { )"; #endif -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser::cuda diff --git a/torch/csrc/jit/codegen/fuser/executor.cpp b/torch/csrc/jit/codegen/fuser/executor.cpp index 411dbe62a2e157..fa104b7cc16bff 100644 --- a/torch/csrc/jit/codegen/fuser/executor.cpp +++ b/torch/csrc/jit/codegen/fuser/executor.cpp @@ -14,15 +14,9 @@ #include #include -#include // TODO: remove, debugging only -#include -#include -#include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Returns the "map size" for this run, which is the common size for all // intermediate tensors. @@ -215,8 +209,7 @@ static void launchFusion( // Computes map_size, numel from the first input at::IntArrayRef map_size; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t numel; + uint32_t numel = 0; std::vector keep_alive_size; if (fusion.chunkDesc()[0].isNoop()) { map_size = inputs[0].sizes(); @@ -409,6 +402,4 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) { return true; } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/executor.h b/torch/csrc/jit/codegen/fuser/executor.h index f454e3d6a1845d..188a25ed8cc65c 100644 --- a/torch/csrc/jit/codegen/fuser/executor.h +++ b/torch/csrc/jit/codegen/fuser/executor.h @@ -7,9 +7,7 @@ #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Runs the fusion associated with the key (see registerFusion() in interface.h) // on the inputs taken from the given Stack. @@ -18,6 +16,4 @@ TORCH_API bool runFusion( Stack& stack, std::string* code_out = nullptr); -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/fallback.cpp b/torch/csrc/jit/codegen/fuser/fallback.cpp index 60a5d72f3c439c..70d90bd94c0ae1 100644 --- a/torch/csrc/jit/codegen/fuser/fallback.cpp +++ b/torch/csrc/jit/codegen/fuser/fallback.cpp @@ -9,9 +9,7 @@ #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { namespace { c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() { @@ -46,6 +44,4 @@ void runFallback(int64_t key, Stack& stack) { InterpreterState{(*maybe_spec)->code()}.run(stack); } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/fallback.h b/torch/csrc/jit/codegen/fuser/fallback.h index 570348ec53640d..af0ff32641ce65 100644 --- a/torch/csrc/jit/codegen/fuser/fallback.h +++ b/torch/csrc/jit/codegen/fuser/fallback.h @@ -4,12 +4,8 @@ #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { void runFallback(int64_t key, Stack& stack); -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/fused_kernel.h b/torch/csrc/jit/codegen/fuser/fused_kernel.h index 29ab3e7ed51c06..de00904a749c8e 100644 --- a/torch/csrc/jit/codegen/fuser/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/fused_kernel.h @@ -9,14 +9,11 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { struct FusedKernel { AT_DISALLOW_COPY_AND_ASSIGN(FusedKernel); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FusedKernel( std::string name, std::string code, @@ -98,6 +95,4 @@ struct FusedKernel { const bool has_random_; }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/interface.cpp b/torch/csrc/jit/codegen/fuser/interface.cpp index 13f4fd57142f45..0d6f7c1facc345 100644 --- a/torch/csrc/jit/codegen/fuser/interface.cpp +++ b/torch/csrc/jit/codegen/fuser/interface.cpp @@ -8,8 +8,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace detail { @@ -105,5 +104,4 @@ size_t nCompiledKernels() { return fuser::nCompiledKernels(); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/codegen/fuser/interface.h b/torch/csrc/jit/codegen/fuser/interface.h index 2c67064fe1359c..977e90191160ce 100644 --- a/torch/csrc/jit/codegen/fuser/interface.h +++ b/torch/csrc/jit/codegen/fuser/interface.h @@ -9,8 +9,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { constexpr int kCPUDevice = -1; @@ -52,5 +51,4 @@ TORCH_API std::string debugGetFusedKernelCode( TORCH_API size_t nCompiledKernels(); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/codegen/fuser/kernel_cache.cpp b/torch/csrc/jit/codegen/fuser/kernel_cache.cpp index 349378af05ea4d..4d82e4db3bf609 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/fuser/kernel_cache.cpp @@ -6,9 +6,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { struct KernelCacheImpl { // Note: std::unordered_map does not invalidate references even if rehashing @@ -76,7 +74,7 @@ std::optional retrieve(const int64_t key) { } // precondition: graph has been normalized via normalizeGraphForCache -std::optional lookupGraph(std::shared_ptr graph) { +std::optional lookupGraph(const std::shared_ptr& graph) { auto& cache = getKernelCache(); std::string repr = graph->toString(false); @@ -87,6 +85,4 @@ std::optional lookupGraph(std::shared_ptr graph) { return nolock_retrieve(cache, it->second); } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/kernel_cache.h b/torch/csrc/jit/codegen/fuser/kernel_cache.h index c4f340225deb7b..370f782453b1cb 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_cache.h +++ b/torch/csrc/jit/codegen/fuser/kernel_cache.h @@ -8,9 +8,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // A thread-safe cache interface. @@ -22,7 +20,8 @@ TORCH_API std::shared_ptr normalizeGraphForCache( TORCH_API int64_t store(std::shared_ptr graph); // Given a graph, find a KernelSpec based on it -TORCH_API std::optional lookupGraph(std::shared_ptr graph); +TORCH_API std::optional lookupGraph( + const std::shared_ptr& graph); // Returns the graph corresponding to the given key (if it exists) TORCH_API std::optional retrieve(const int64_t key); @@ -31,6 +30,4 @@ TORCH_API std::optional retrieve(const int64_t key); // Only used for testing. TORCH_API int64_t debugNumCachedKernelSpecs(); -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/kernel_spec.h b/torch/csrc/jit/codegen/fuser/kernel_spec.h index eacdbc7ec3f336..6b1af19a1a6a61 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_spec.h +++ b/torch/csrc/jit/codegen/fuser/kernel_spec.h @@ -16,9 +16,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Helper struct containing partition information: the number of tensors // created and the dimension the partitioning is performed on. @@ -56,20 +54,19 @@ struct TORCH_API KernelSpec { // Note: assumes the spec is a single block // Note: This is the appropriate place to generalize if you want to add other // passes to upfront compilation that walk the graph. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) KernelSpec(const int64_t _key, const std::shared_ptr& _graph) : key_{_key}, graph_{_graph}, code_{_graph, ""}, nInputs_{_graph->inputs().size()}, - nTensorInputs_{}, + inputBroadcastGroups_{}, inputChunks_{}, - has_random_{false}, + kernels_{} { // No need to iterate over reference since n is pointer for (const auto n : graph_->nodes()) { - static_assert(std::is_pointer::value, "n must be a pointer"); + static_assert(std::is_pointer_v, "n must be a pointer"); if (n->kind() == aten::rand_like) { has_random_ = true; break; @@ -125,8 +122,9 @@ struct TORCH_API KernelSpec { return std::nullopt; return it->second; } - void cacheKernel(const ArgSpec& arg_spec, std::shared_ptr kernel) - const { + void cacheKernel( + const ArgSpec& arg_spec, + const std::shared_ptr& kernel) const { std::lock_guard guard{mutex_}; kernels_.emplace(arg_spec, kernel); } @@ -136,16 +134,14 @@ struct TORCH_API KernelSpec { std::shared_ptr graph_; Code code_; uint64_t nInputs_; - uint64_t nTensorInputs_; + uint64_t nTensorInputs_{}; std::vector> inputBroadcastGroups_; std::vector inputChunks_; - bool has_random_; + bool has_random_{false}; mutable std::mutex mutex_; mutable std:: unordered_map, c10::hash> kernels_; }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/partition_desc.h b/torch/csrc/jit/codegen/fuser/partition_desc.h index 87cf7986e5e91c..ad42609e0ddc54 100644 --- a/torch/csrc/jit/codegen/fuser/partition_desc.h +++ b/torch/csrc/jit/codegen/fuser/partition_desc.h @@ -8,9 +8,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Descriptor for chunk-ing an input tensor into subtensors // OR concat-ing an output tensor from subtensors @@ -22,7 +20,6 @@ struct TORCH_API PartitionDesc { PartitionDesc(const TensorDesc& _desc, size_t _nSubTensors, size_t _dim) : nSubTensors_{_nSubTensors}, dim_{_dim} { AT_ASSERT(nSubTensors_ > 1); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector cont = _desc.contiguity; if (dim_ > 0) { // when we narrow the concatenated output/chunked input @@ -59,6 +56,4 @@ struct TORCH_API PartitionDesc { subTensorDesc_; // descriptor for the subtensor, if it exists }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/tensor_desc.h b/torch/csrc/jit/codegen/fuser/tensor_desc.h index ffc405244a71ef..0c5db65d54ad1a 100644 --- a/torch/csrc/jit/codegen/fuser/tensor_desc.h +++ b/torch/csrc/jit/codegen/fuser/tensor_desc.h @@ -10,20 +10,15 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // type information needed by the compiler for input/outputs // contiguity[i] is true if the dim i is contiguous with dim i + 1. // contiguity.back() == true means strides.back() == 1. struct TORCH_API TensorDesc { - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) at::ScalarType scalar_type; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::vector contiguity; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorDesc(const at::ScalarType& type, const std::vector& contiguity) : scalar_type{type}, contiguity{contiguity} { if (contiguity.empty()) { @@ -35,7 +30,6 @@ struct TORCH_API TensorDesc { } // Delegating constructors - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorDesc( const at::ScalarType& type, const at::IntArrayRef& sizes, @@ -45,7 +39,6 @@ struct TORCH_API TensorDesc { TensorDesc(const at::Tensor& t) : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {} - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorDesc(const c10::TensorTypePtr& type) : TensorDesc( type->scalarType().value(), @@ -66,7 +59,6 @@ struct TORCH_API TensorDesc { const at::IntArrayRef& sizes, const at::IntArrayRef& strides) { AT_ASSERT(sizes.size() == strides.size()); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector cont(sizes.size()); for (size_t i = 0; i < sizes.size(); ++i) { const auto expected_stride = @@ -103,6 +95,4 @@ inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) { return out; } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/tensor_info.h b/torch/csrc/jit/codegen/fuser/tensor_info.h index 9f3fcca1aca2e5..77a0d8bacdf23b 100644 --- a/torch/csrc/jit/codegen/fuser/tensor_info.h +++ b/torch/csrc/jit/codegen/fuser/tensor_info.h @@ -1,11 +1,10 @@ #pragma once #include +#include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Host-side view of TensorInfo // Note dims[0] - we need to dynamically allocate the dims. @@ -18,12 +17,8 @@ struct TORCH_API TensorInfo { } void* data; -#pragma GCC diagnostic ignored "-Wpedantic" // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) uint32_t sizes_strides[0]; -#pragma GCC diagnostic pop }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/cuda/cuda.h b/torch/csrc/jit/cuda/cuda.h index eade3b23d7266c..a6a4d2e31522c1 100644 --- a/torch/csrc/jit/cuda/cuda.h +++ b/torch/csrc/jit/cuda/cuda.h @@ -12,7 +12,6 @@ class CUDAEvent; // c10/cuda/CUDAStream.h. class CUDAStream final : public CustomClassHolder { public: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CUDAStream( std::optional device = std::nullopt, int64_t priority = 0) { @@ -22,7 +21,6 @@ class CUDAStream final : public CustomClassHolder { c10::cuda::getStreamFromPool(static_cast(priority), device_index)); } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CUDAStream(c10::cuda::CUDAStream s) { stream_ = std::make_unique(s); } @@ -69,12 +67,10 @@ class CUDAStream final : public CustomClassHolder { // aten/src/ATen/cuda/CUDAEvent.h. class CUDAEvent final : public CustomClassHolder { public: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CUDAEvent( bool enable_timing = false, bool blocking = false, bool interprocess = false) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int flags = cudaEventDisableTiming; if (enable_timing) { flags = cudaEventDefault; @@ -95,8 +91,7 @@ class CUDAEvent final : public CustomClassHolder { } std::string ipcHandle() { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - cudaIpcEventHandle_t handle; + cudaIpcEventHandle_t handle{}; event_->ipc_handle(&handle); std::string str_handle((const char*)&handle, sizeof(handle)); return str_handle; diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index f9f317d9a43068..4d7904b8707c27 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -38,9 +38,10 @@ TypePtr ScriptTypeParser::subscriptToType( // here. See https://docs.python.org/3/library/typing.html#typing.Tuple auto tup_literal = TupleLiteral(subscript.subscript_exprs()[0]); if (!tup_literal.inputs().empty()) { - throw ErrorReport(tup_literal.range()) + throw( + ErrorReport(tup_literal.range()) << "Tuple literal in Tuple type annotation must not " - << "have any elements!"; + << "have any elements!"); } return TupleType::create({}); } @@ -179,12 +180,12 @@ std::optional> ScriptTypeParser::parseBroadcastList( TypePtr list_ptr = ListType::create(elem_ptr->second); const char* len_c = len.c_str(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - char* end; + char* end = nullptr; size_t len_v = strtoull(len_c, &end, 10); if (end != len_c + len.size()) { - throw ErrorReport(subscript.subscript_exprs().range()) - << "subscript of Broadcastable list must be a positive integer"; + throw( + ErrorReport(subscript.subscript_exprs().range()) + << "subscript of Broadcastable list must be a positive integer"); } return std::pair(list_ptr, len_v); } diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index b4547eb85c0c5e..d7e758e34e38b8 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -920,13 +920,11 @@ struct Const : public Expr { double asFloatingPoint() const { // We can't pass in nullptr as the dummy pointer gets dereferenced for // Android version of strtod_c(). - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - char* dummy; + char* dummy = nullptr; return torch::jit::strtod_c(subtree(0)->stringValue().c_str(), &dummy); } c10::complex asComplex() const { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - char* dummy; + char* dummy = nullptr; auto str = subtree(0)->stringValue(); // Complex numbers (a+bj, where a is non-zero) are parsed as an addition // between float/int a and a complex number "bj". When a is 0, a complex diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 7b45697b0aa6d9..87af6a1a634386 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -1757,12 +1757,9 @@ bool AliasDb::tryMove( // dependencies WorkingSet workingSet(toMove, *this); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int direction; + auto direction = kNextDirection; if (toMove->isAfter(movePoint)) { direction = kPrevDirection; - } else { - direction = kNextDirection; } auto curNode = toMove->next_in_graph[direction]; diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 73bbf7c10a6e38..44087074e89168 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -1743,8 +1743,8 @@ struct OperatorMap { // TODO: return iterator std::vector getAllKeysAndValues() const { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector keys_values; + keys_values.reserve(map.size()); for (auto& symbol_mapping : map) { auto& vec = symbol_mapping.second; for (auto& pair : vec) { @@ -1819,8 +1819,8 @@ struct FunctionSchemaMap { // TODO: return iterator std::vector getAllKeysAndValues() const { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector keys_values; + keys_values.reserve(map.size()); for (auto& symbol_mapping : map) { auto& vec = symbol_mapping.second; for (auto& pair : vec) { diff --git a/torch/csrc/jit/jit_log.cpp b/torch/csrc/jit/jit_log.cpp index ef77bb7b642745..b94fc346f5f188 100644 --- a/torch/csrc/jit/jit_log.cpp +++ b/torch/csrc/jit/jit_log.cpp @@ -31,7 +31,6 @@ class JitLoggingConfig { std::unordered_map files_to_levels; std::ostream* out; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) JitLoggingConfig() : out(&std::cerr) { const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL"); logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level); diff --git a/torch/csrc/jit/mobile/import_data.cpp b/torch/csrc/jit/mobile/import_data.cpp index 8f206230957280..1bd34e4a823ae9 100644 --- a/torch/csrc/jit/mobile/import_data.cpp +++ b/torch/csrc/jit/mobile/import_data.cpp @@ -61,8 +61,7 @@ c10::IValue IValueUnpickler::readArchive( std::stringstream picklename; picklename << archive_name << ".pkl"; at::DataPtr pickle_ptr; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t pickle_size; + size_t pickle_size = 0; std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename.str()); size_t bytes_read = 0; diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.cpp b/torch/csrc/jit/mobile/promoted_prim_ops.cpp index 405136795df71d..857cb304291025 100644 --- a/torch/csrc/jit/mobile/promoted_prim_ops.cpp +++ b/torch/csrc/jit/mobile/promoted_prim_ops.cpp @@ -113,10 +113,8 @@ void layout(Stack& stack) { } void toPrimDType(Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; + bool non_blocking = false; + bool copy = false; pop(stack, non_blocking, copy); std::optional scalarType = pop(stack).toOptional(); @@ -141,10 +139,8 @@ void boolTensor(Stack& stack) { } void toList(Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int elem_ty_val; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int dim_val; + int elem_ty_val = 0; + int dim_val = 0; at::Tensor t; pop(stack, elem_ty_val); diff --git a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp index de719a1eabc539..65aefb9a577613 100644 --- a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp +++ b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp @@ -31,8 +31,7 @@ void CastAllConstantToFloating(Block* block) { auto val_type = TensorType::create(val); if (dtype != at::ScalarType::Double && dtype != at::ScalarType::Float && dtype != at::ScalarType::Half) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int to_type; + int to_type = 0; switch (val.scalar_type()) { case at::ScalarType::Byte: case at::ScalarType::Char: diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 18c31ea656610d..9ccea955aee1b7 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -818,8 +818,7 @@ static void fuseLogSoftmaxNllLoss(Block* b) { if (it->kind() == onnx::NegativeLogLikelihoodLoss) { auto prev = it->input(0)->node(); Node* origNllLossNode = *it; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Node* origLogSoftmaxNode; + Node* origLogSoftmaxNode = nullptr; // Check for patterns especially in cases with autocasting enabled // in which a cast node is inserted before the NegativeLogLikelihoodLoss diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 850070aa0b2fa3..e34e06fbcbfaf9 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -299,10 +299,6 @@ void unpackQuantizedWeightsHelper( torch::List stride_int, padding_int, dilation_int, output_padding_int; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t groups_int; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t transpose_int; if (itr->second.isTuple()) { // Pre-unpacked weights. Comes from Conv/Linear weights which are @@ -415,9 +411,9 @@ void unpackQuantizedWeightsHelper( } idx++; } - groups_int = conv_params_packed[idx].item(); + auto groups_int = conv_params_packed[idx].item(); idx++; - transpose_int = conv_params_packed[idx].item(); + auto transpose_int = conv_params_packed[idx].item(); idx++; TORCH_INTERNAL_ASSERT( idx == conv_params_packed.numel(), @@ -459,11 +455,10 @@ void unpackQuantizedWeightsHelper( for (const auto& d : dilation_ivalue) { dilation_int.emplace_back(d.toTensor()[0].item()); } - groups_int = groups_ivalue.toTensor()[0].item(); + groups = groups_ivalue.toTensor()[0].item(); stride = stride_int; padding = padding_int; dilation = dilation_int; - groups = groups_int; if (expect_output_padding) { auto output_padding_ivalue = diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index f906efacceca7b..9aacd481a55b04 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -1120,8 +1120,7 @@ void InsertObserversHelper::fillBoundaryValueMap( // offset of input for the caller node, since the first // input of CallFunction is the function node and the graph // for CallFunction start with actual input - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t input_offset; + size_t input_offset = 0; if (n->kind() == prim::CallMethod) { auto m_opt = getInvokedModuleOpt(module, n, self); if (!m_opt.has_value()) { @@ -1469,8 +1468,7 @@ InsertObserversHelper::insertObserversFor( if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) { script::Module m; std::shared_ptr g; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t input_offset; + size_t input_offset = 0; bool is_udf_for_subblock = is_user_defined_function; if (n->kind() == prim::CallMethod) { auto m_opt = getInvokedModuleOpt(module, n, self); diff --git a/torch/csrc/jit/runtime/argument_spec.h b/torch/csrc/jit/runtime/argument_spec.h index 7a815e815d8e96..324fc37e080c69 100644 --- a/torch/csrc/jit/runtime/argument_spec.h +++ b/torch/csrc/jit/runtime/argument_spec.h @@ -36,7 +36,6 @@ struct ArgumentInfo { return requires_grad_; } int dim() const { - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return dim_; } at::ScalarType type() const { @@ -104,7 +103,6 @@ struct ArgumentSpec { if (arg.defined_) { arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad(); arg.dim_ = t->dim(); - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) at::Device device = t->device(); arg.dev_type_ = // NOLINTNEXTLINE(bugprone-signed-char-misuse) @@ -117,8 +115,7 @@ struct ArgumentSpec { } void combineHash(const ArgumentInfo& arg) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - ArgumentInfo::plain_data_type arg_data; + ArgumentInfo::plain_data_type arg_data = 0; std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo)); hash_code = c10::hash_combine(hash_code, arg_data); } @@ -242,12 +239,10 @@ static_assert( struct CompleteArgumentInfo; struct CompleteArgumentSpec { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CompleteArgumentSpec(bool with_grad, at::ArrayRef inputs) : hash_code(0), ninputs(inputs.size()) { int32_t all_dims = 0; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - const int32_t num_inputs = inputs.size(); + const auto num_inputs = inputs.size(); for (const auto i : c10::irange(num_inputs)) { if (!inputs[i].isTensor()) continue; @@ -258,7 +253,6 @@ struct CompleteArgumentSpec { data.resize(ninputs + all_dims * 2); // and reinterpret our data array as these structs - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) auto* pods = reinterpret_cast(data.data()); int64_t* next_dim = sizes_strides(); int32_t total_dims = 0; @@ -270,7 +264,6 @@ struct CompleteArgumentSpec { pod.defined = t.defined(); if (pod.defined) { pod.type = static_cast(t.scalar_type()); - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) at::Device device = t.device(); // NOLINTNEXTLINE(bugprone-signed-char-misuse) pod.dev_type = static_cast::type>( @@ -394,7 +387,6 @@ struct CompleteArgumentInfo { int sizes_strides_offset(int j) const { if (j == 0) return 0; - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return 2 * pod(j - 1).total_dims; } const CompleteArgumentInfoPOD& pod(int j) const { diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index 047a35e417fff8..d525fcaecd816b 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -206,8 +206,7 @@ class GradientHelper { } else if (node->kind() == prim::ConstantChunk) { auto* g = node->owningGraph(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Value* input_list; + Value* input_list = nullptr; if (grad_values.size() == 1 && grad_values[0]->type()->isSubtypeOf(*ListType::ofTensors())) { input_list = grad_values[0]; @@ -575,8 +574,7 @@ static void foldSizeIfNotEqual(Node* node) { // insert in front of _grad_sum_to_size WithInsertPoint guard(node); IValue ival{}; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Value* size; + Value* size = nullptr; if (input_size != output_size) { size = node->owningGraph()->insertConstant(*input_size); } else { diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 9762106340b404..8517e6a94b57b6 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -340,8 +340,7 @@ struct CodeImpl { int reg = registerFor(input); bool moved = input->uses().size() == ++use_count_[input]; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - OpCode op; + OpCode op{}; if (input->node()->kind() == prim::Constant) { op = LOADC; } else if (moved) { diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index f6eccede28bab1..ed56e7d4322f61 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -13,18 +13,11 @@ #include #include #include -#include -#include #include -#include #include -#include #include #include #include -#include -#include -#include #include #include @@ -259,8 +252,7 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA( "aten::__range_length(int lo, int hi, int step) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t lo, hi, step; + int64_t lo = 0, hi = 0, step = 0; pop(stack, lo, hi, step); // error handling when step_val = 0 during runtime if (step == 0) { @@ -279,8 +271,7 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA( "aten::__derive_index(int index, int start, int step) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t index, start, step; + int64_t index = 0, start = 0, step = 0; pop(stack, index, start, step); push(stack, start + index * step); }, @@ -336,8 +327,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Bool.int(int a) -> bool"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t i; + int64_t i = 0; pop(stack, i); push(stack, (bool)i); }, @@ -345,8 +335,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Bool.float(float a) -> bool"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double d; + double d = 0; pop(stack, d); push(stack, (bool)d); }, @@ -362,8 +351,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Int.bool(bool a) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool b; + bool b = false; pop(stack, b); push(stack, static_cast(b)); }, @@ -371,8 +359,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Int.float(float a) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double d; + double d = 0; pop(stack, d); push(stack, static_cast(d)); }, @@ -394,8 +381,7 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA("aten::Int.str(str a) -> int"), [](Stack& stack) { auto s = pop(stack).toString(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::string::size_type sz; + std::string::size_type sz = 0; int64_t val = static_cast(std::stoll(s->string(), &sz)); if (sz == s->string().size()) { push(stack, val); @@ -432,8 +418,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Float.int(int a) -> float"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t i; + int64_t i = 0; pop(stack, i); push(stack, (float)i); }, @@ -441,8 +426,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Float.bool(bool a) -> float"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool b; + bool b = false; pop(stack, b); push(stack, (float)b); }, @@ -451,8 +435,7 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA("aten::Float.str(str a) -> float"), [](Stack& stack) { auto s = pop(stack).toString(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::string::size_type sz; + std::string::size_type sz = 0; double b = std::stod(s->string(), &sz); if (sz == s->string().size()) { push(stack, b); @@ -1036,8 +1019,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::pow.int_to_int(int a, int b) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t a, b; + int64_t a = 0, b = 0; pop(stack, a, b); push(stack, powWrapper(a, b)); }, @@ -1270,10 +1252,8 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA( "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; + bool non_blocking = false; + bool copy = false; pop(stack, non_blocking, copy); std::optional scalarType = pop(stack).toOptional(); @@ -1709,14 +1689,12 @@ static const std::vector dict_ops{ }; RegisterOperators reg_dict_ops(createOperators(dict_ops)); -// NOLINTNEXTLINE(clang-diagnostic-unused-function) constexpr c10::AliasAnalysisKind aliasAnalysisFromSchema() { return c10::AliasAnalysisKind::FROM_SCHEMA; } // Convert an python index (which may be negative) into an index usable for a // C++ container -// NOLINTNEXTLINE(clang-diagnostic-unused-function) int64_t normalizeIndex(int64_t idx, int64_t list_size) { if (idx < 0) { // Handle negative indexing @@ -2418,8 +2396,7 @@ static const std::vector opGenArgs1{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::rangelist(int n) -> int[]"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t n; + int64_t n = 0; pop(stack, n); c10::List elems; elems.reserve(n); @@ -2458,10 +2435,8 @@ static const std::vector opGenArgs1{ "aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), [](Stack& stack) { at::Tensor self; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; + bool non_blocking = false; + bool copy = false; pop(stack, self, non_blocking, copy); std::optional device = std::nullopt; std::optional scalarType = std::nullopt; @@ -3077,25 +3052,20 @@ static const std::vector opGenArgs2{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::modf(float a) -> (float, float)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a; + double a = 0; pop(stack, a); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double b, c; - b = modf(a, &c); + double c = 0; + double b = modf(a, &c); push(stack, b, c); }, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::frexp(float a) -> (float, int)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a; + double a = 0; pop(stack, a); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double m; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int e; + double m = 0; + int e = 0; m = std::frexp(a, &e); push(stack, m, e); }, @@ -3103,10 +3073,8 @@ static const std::vector opGenArgs2{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::ldexp(float x, int i) -> float"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t b; + double a = 0; + int64_t b = 0; pop(stack, a, b); push(stack, std::ldexp(a, b)); }, @@ -3388,8 +3356,7 @@ static const std::vector opGenArgs2{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::divmod.int(int x, int y) -> (int, int)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t a, b; + int64_t a = 0, b = 0; lldiv_t divresult = {}; pop(stack, a, b); if (b == 0) { @@ -3411,8 +3378,7 @@ static const std::vector opGenArgs2{ TORCH_SELECTIVE_SCHEMA( "aten::divmod.float(float x, float y) -> (float, float)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a, b; + double a = 0, b = 0; pop(stack, a, b); if (b == 0) { throw std::runtime_error("ZeroDivisionError: float divmod()"); diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 63fdee6de8042c..5fa06b09274519 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -16,7 +16,6 @@ #include #include -#include #include namespace torch::jit { @@ -187,8 +186,7 @@ template void createTensorFromList(Stack& stack) { // torch.tensor has a fourth requires_grad arg but torch.as_tensor not, so // we use the template arg to distinguish between these two cases - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool requires_grad; + bool requires_grad = false; IValue data; IValue dtype; IValue device; @@ -334,10 +332,8 @@ RegisterOperators reg({ [](Stack& stack) { at::Tensor weight; at::Tensor input; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double max_norm; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double norm_type; + double max_norm = 0; + double norm_type = 0; pop(stack, weight, input, max_norm, norm_type); // TODO: remove when script supports setting grad mode @@ -402,13 +398,11 @@ RegisterOperators reg({ torch::NoGradGuard no_grad; at::Tensor tensor; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double b; std::optional generator = pop(stack).toOptional(); + double a = 0; + double b = 0; pop(stack, tensor, a, b); push(stack, tensor.uniform_(a, b, generator)); }, @@ -421,10 +415,8 @@ RegisterOperators reg({ torch::NoGradGuard no_grad; at::Tensor tensor; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double mean; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double std; + double mean = 0; + double std = 0; std::optional generator = pop(stack).toOptional(); @@ -440,8 +432,7 @@ RegisterOperators reg({ torch::NoGradGuard no_grad; at::Tensor tensor; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double val; + double val = 0; pop(stack, tensor, val); push(stack, at::fill_(tensor, val)); }, diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp index c87d3d9b069ca2..f52c24e9f01ff6 100644 --- a/torch/csrc/jit/runtime/static/fusion.cpp +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -271,8 +271,7 @@ void createFusionGroups(Block* block, AliasDb* aliasDb, size_t min_size) { while (any_changed) { any_changed = false; for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool changed; + bool changed = false; std::tie(it, changed) = scanNode(*it, aliasDb); any_changed |= changed; } diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 589691eeae6584..19e2e47498f39a 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -457,8 +457,7 @@ ManagedTensorRanges::ManagedTensorRanges( for (auto* managed_tensor : managed_tensor_values) { auto* lifetime = getLifetime(managed_tensor); DCHECK(lifetime && lifetime->end <= num_nodes); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Node* freeing_node; + Node* freeing_node = nullptr; if (lifetime->end == num_nodes) { freeing_node = block.return_node(); } else { diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 92d901e43a5d21..beecdb770cba37 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1556,12 +1556,12 @@ static void loadModule(const CompilationUnit& module) { Node* forward_tuple = pair.forward->outputs().at(0)->node(); if (forward_tuple->kind() != prim::TupleConstruct) { - throw ErrorReport(forward_tuple->sourceRange()) - << "gradient must return literal a tuple"; + throw( + ErrorReport(forward_tuple->sourceRange()) + << "gradient must return literal a tuple"); } - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Value* context; + Value* context = nullptr; std::tie(pair.backward, context) = extractClosure(forward_tuple->inputs().back()); diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp index bfce006d55177e..2dd46335e09f8e 100644 --- a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp +++ b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp @@ -5,9 +5,7 @@ #include #include -namespace torch { -namespace jit { -namespace tensorexpr { +namespace torch::jit::tensorexpr { namespace { @@ -20,8 +18,8 @@ void assert_dims_constant(const BufHandle& buf) { using InitFunc = std::function&)>; Tensor conv2d_depthwise_static( - BufHandle input, - BufHandle weight, + const BufHandle& input, + const BufHandle& weight, const InitFunc& init_func, int stride, int pad, @@ -78,14 +76,12 @@ Tensor conv2d_depthwise_static( constexpr int kLoopH = 2, kLoopW = 3; if (R == 3 && stride == 2 && pad == 1) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head, tail; auto loops = nest.getLoopStmtsFor(conv); nest.sliceHead(loops[kLoopW], 2, &head, &tail); loops = nest.getLoopStmtsFor(conv); nest.sliceHead(loops[kLoopH], 2, &head, &tail); } else if (R == 3 && stride == 1 && pad == 1) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr main, peeled; auto loops = nest.getAllLoopNestsWritingToBuf(conv.buf()); main = loops[1][kLoopW]; @@ -489,6 +485,4 @@ Tensor computeMkldnnPrepackedConvRun( return Tensor(ResultBuf.node(), s); } -} // namespace tensorexpr -} // namespace jit -} // namespace torch +} // namespace torch::jit::tensorexpr From 38724bc3966127fd5daae1cb8a9d504afb3b6d4b Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 26 Aug 2024 17:43:57 +0000 Subject: [PATCH 0025/1018] [inductor] calibration inductor windows uts (11/N) (#134427) enable Windows inductor UTs of `test/inductor/test_inductor_freezing.py` Local test pass: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134427 Approved by: https://github.com/jansel --- test/inductor/test_inductor_freezing.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 6346751e7e9177..2f4faad8f9f063 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -22,23 +22,8 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.common_utils import ( - IS_CI, - IS_WINDOWS, - TEST_WITH_ASAN, - TEST_WITH_ROCM, -) - - -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests +from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM importlib.import_module("functorch") From f6740a313a5ea1b6bfd42d4fcab42e2cc79e60de Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 26 Aug 2024 17:47:20 +0000 Subject: [PATCH 0026/1018] [CD] Use ephemeral arm64 runners for nightly and docker builds (#134473) Follow up after adding linux arm64 ephemeral instances: https://github.com/pytorch/pytorch/pull/134469 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134473 Approved by: https://github.com/malfet --- .github/actionlint.yaml | 3 +++ .../templates/linux_binary_build_workflow.yml.j2 | 2 +- .github/workflows/build-manywheel-images.yml | 6 +++--- ...ed-linux-aarch64-binary-manywheel-nightly.yml | 16 ++++++++-------- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 919c4b4db7b7ed..f28b57dcf560c3 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -16,7 +16,9 @@ self-hosted-runner: - linux.24xlarge - linux.24xlarge.ephemeral - linux.arm64.2xlarge + - linux.arm64.2xlarge.ephemeral - linux.arm64.m7g.4xlarge + - linux.arm64.m7g.4xlarge.ephemeral - linux.4xlarge.nvidia.gpu - linux.8xlarge.nvidia.gpu - linux.16xlarge.nvidia.gpu @@ -40,6 +42,7 @@ self-hosted-runner: - amz2023.linux.24xlarge - amz2023.linux.arm64.2xlarge - amz2023.linux.arm64.m7g.4xlarge + - amz2023.linux.arm64.m7g.4xlarge.ephemeral - amz2023.linux.4xlarge.nvidia.gpu - amz2023.linux.8xlarge.nvidia.gpu - amz2023.linux.16xlarge.nvidia.gpu diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index 42cfa8eb3b07ca..1cccee3e234741 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -69,7 +69,7 @@ jobs: with:!{{ upload.binary_env_as_input(config) }} {%- if "aarch64" in build_environment %} runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" {%- elif "s390x" in build_environment %} runs_on: linux.s390x diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index a85d0c59335cd1..6ee9395cae675f 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -108,7 +108,7 @@ jobs: .ci/docker/manywheel/build.sh manylinux2_28-builder:cuda${{matrix.cuda_version}} build-docker-cuda-aarch64: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge + runs-on: linux.arm64.2xlarge.ephemeral strategy: matrix: cuda_version: ["12.4"] @@ -236,7 +236,7 @@ jobs: .ci/docker/manywheel/build.sh manylinux2_28-builder:cpu build-docker-cpu-aarch64: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge + runs-on: linux.arm64.2xlarge.ephemeral env: GPU_ARCH_TYPE: cpu-aarch64 steps: @@ -267,7 +267,7 @@ jobs: .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cpu-aarch64 build-docker-cpu-aarch64-2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge + runs-on: linux.arm64.2xlarge.ephemeral env: GPU_ARCH_TYPE: cpu-aarch64-2_28 steps: diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index f0f54b38337668..90923ffeb912c0 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -60,7 +60,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.9" runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -129,7 +129,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.9" runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cuda-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -175,7 +175,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.10" runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -244,7 +244,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.10" runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -290,7 +290,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.11" runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -359,7 +359,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.11" runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -405,7 +405,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.12" runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -474,7 +474,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.12" runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64 build_environment: linux-aarch64-binary-manywheel From 4f0c6221df52806ec30272e738ae7c6e8ab28f19 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 26 Aug 2024 17:48:09 +0000 Subject: [PATCH 0027/1018] [inductor] fix _maybe_subprocess_run not support Windows path (#134365) Windows file path use `\` as delimiter, it is also a escape character. We need translate all path `\` to `/`. which like Linux. Reproduce UTs: ```cmd pytest test\dynamo\test_minifier.py -v -k test_after_dynamo_cpu_accuracy_error ``` Error message: ```cmd ____________________________________________________________________________________________________________ MinifierTests.test_after_dynamo_cpu_accuracy_error _____________________________________________________________________________________________________________ Traceback (most recent call last): File "D:\xu_git\dnnl_cb\pytorch\test\dynamo\test_minifier.py", line 40, in test_after_dynamo_cpu_accuracy_error self._test_after_dynamo( File "D:\xu_git\dnnl_cb\pytorch\test\dynamo\test_minifier.py", line 27, in _test_after_dynamo self._run_full_test(run_code, "dynamo", expected_error, isolate=False) File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\site-packages\torch\_dynamo\test_minifier_common.py", line 235, in _run_full_test self.assertIn(expected_error, test_proc.stderr.decode("utf-8")) File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\unittest\case.py", line 1112, in assertIn self.fail(self._formatMessage(msg, standardMsg)) File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\unittest\case.py", line 675, in fail raise self.failureException(msg) AssertionError: 'AccuracyError' not found in 'Traceback (most recent call last):\n File "C:\\Users\\Xuhan\\.conda\\envs\\win_mkl_static\\lib\\site-packages\\torch\\_dynamo\\test_minifier_common.py", line 114, in _maybe_subprocess_run\n exec(code, {"__name__": "__main__", "__compile_source__": code})\n File "", line 9\n torch._dynamo.config.debug_dir_root = "C:\\Users\\Xuhan\\AppData\\Local\\Temp\\tmpufu9t3pc"\n ^\nSyntaxError: (unicode error) \'unicodeescape\' codec can\'t decode bytes in position 2-3: truncated \\UXXXXXXXX escape\n' To execute this test, run the following from the base repo dir: python test\dynamo\test_minifier.py MinifierTests.test_after_dynamo_cpu_accuracy_error This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 --------------------------------------------------------------------------------------------------------------------------- Captured stdout call ---------------------------------------------------------------------------------------------------------------------------- test stdout: test stderr: Traceback (most recent call last): File "C:\Users\Xuhan\.conda\envs\win_mkl_static\lib\site-packages\torch\_dynamo\test_minifier_common.py", line 114, in _maybe_subprocess_run exec(code, {"__name__": "__main__", "__compile_source__": code}) File "", line 9 torch._dynamo.config.debug_dir_root = "C:\Users\Xuhan\AppData\Local\Temp\tmpufu9t3pc" ^ SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes in position 2-3: truncated \UXXXXXXXX escape --------------------------------------------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------------------------------------------- running test ``` Local test passed: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134365 Approved by: https://github.com/jansel, https://github.com/jgong5 --- torch/_dynamo/test_minifier_common.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index 4736c75785ccb7..b05542d578f43b 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -15,6 +15,7 @@ import torch import torch._dynamo import torch._dynamo.test_case +from torch._dynamo.trace_rules import _as_posix_path from torch.utils._traceback import report_compile_source_on_error @@ -107,8 +108,9 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None): log = logging.getLogger("torch._dynamo") log.addHandler(log_handler) try: - prev_cwd = os.getcwd() + prev_cwd = _as_posix_path(os.getcwd()) if cwd is not None: + cwd = _as_posix_path(cwd) os.chdir(cwd) with patch("sys.argv", args), report_compile_source_on_error(): exec(code, {"__name__": "__main__", "__compile_source__": code}) @@ -135,6 +137,8 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None): stderr.getvalue().encode("utf-8"), ) else: + if cwd is not None: + cwd = _as_posix_path(cwd) return subprocess.run(args, capture_output=True, cwd=cwd, check=False) # Run `code` in a separate python process. @@ -157,7 +161,7 @@ def _run_test_code(self, code, *, isolate): # Runs the minifier launcher script in `repro_dir` def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()): self.assertIsNotNone(repro_dir) - launch_file = os.path.join(repro_dir, "minifier_launcher.py") + launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) with open(launch_file) as f: launch_code = f.read() self.assertTrue(os.path.exists(launch_file)) @@ -176,7 +180,7 @@ def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()): # Runs the repro script in `repro_dir` def _run_repro(self, repro_dir, *, isolate=True): self.assertIsNotNone(repro_dir) - repro_file = os.path.join(repro_dir, "repro.py") + repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) with open(repro_file) as f: repro_code = f.read() self.assertTrue(os.path.exists(repro_file)) @@ -196,11 +200,11 @@ def _gen_test_code(self, run_code, repro_after, repro_level): return f"""\ import torch import torch._dynamo -{torch._dynamo.config.codegen_config()} -{torch._inductor.config.codegen_config()} +{_as_posix_path(torch._dynamo.config.codegen_config())} +{_as_posix_path(torch._inductor.config.codegen_config())} torch._dynamo.config.repro_after = "{repro_after}" torch._dynamo.config.repro_level = {repro_level} -torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}" {run_code} """ From e630a451fbc3748988ed03a1575c4bbfec7dfe45 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 26 Aug 2024 18:19:59 +0000 Subject: [PATCH 0028/1018] [inductor] enable clang for Windows inductor (#134444) Changes: 1. Add Windows clang-cl compiler check. 2. Add openmp config for clang-cl. 3. Preload libomp.dll when use clang. 4. Add compiler flags syntax check for `clang` and `clang++`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134444 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/malfet --- torch/_inductor/cpp_builder.py | 40 +++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 3ade5cf558342e..72348824bd75a1 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -15,6 +15,7 @@ import sys import sysconfig import warnings +from ctypes import cdll from pathlib import Path from typing import Any, List, Optional, Sequence, Tuple, Union @@ -131,6 +132,9 @@ def check_compiler_exist_windows(compiler: str) -> None: ) except FileNotFoundError as exc: raise RuntimeError(f"Compiler: {compiler} is not found.") from exc + except subprocess.SubprocessError: + # Expected that some compiler(clang, clang++) is exist, but they not support `/help` args. + pass def get_cpp_compiler() -> str: @@ -158,6 +162,13 @@ def _is_clang(cpp_compiler: str) -> bool: # Mac OS apple clang maybe named as gcc, need check compiler info. if sys.platform == "darwin": return _is_apple_clang(cpp_compiler) + elif _IS_WINDOWS: + # clang suite have many compilers, and only clang-cl is supported. + if re.search(r"((clang$)|(clang\+\+$))", cpp_compiler): + raise RuntimeError( + "Please use clang-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + return bool(re.search(r"(clang-cl)", cpp_compiler)) return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) @@ -761,6 +772,20 @@ def homebrew_libomp() -> Tuple[bool, str]: return False, "" +@functools.lru_cache(None) +def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: + try: + output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( + "utf8" + ) + omp_path = os.path.join(output.rstrip(), omp_name) + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + except subprocess.SubprocessError: + pass + + def _get_openmp_args( cpp_compiler: str, ) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str]]: @@ -817,11 +842,16 @@ def _get_openmp_args( # if openmp is still not available, we let the compiler to have a try, # and raise error together with instructions at compilation error later elif _IS_WINDOWS: - # /openmp, /openmp:llvm - # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ - # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 - cflags.append("openmp") - cflags.append("openmp:experimental") # MSVC CL + if _is_clang(cpp_compiler): + cflags.append("openmp") + libs.append("libomp") + perload_clang_libomp_win(cpp_compiler, "libomp.dll") + else: + # /openmp, /openmp:llvm + # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ + # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 + cflags.append("openmp") + cflags.append("openmp:experimental") # MSVC CL else: if config.is_fbcode(): include_dir_paths.append(build_paths.openmp()) From 65a2d9c4c3ebd15b299467ffb9082aec4f1be0fe Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 25 Aug 2024 22:22:05 -0700 Subject: [PATCH 0029/1018] [dynamo][super] Improve handling of getattr on super (#134039) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134039 Approved by: https://github.com/yanboliang, https://github.com/jansel --- test/dynamo/test_repros.py | 49 ++++++++++++++++++++++++++++++ torch/_dynamo/variables/misc.py | 53 +++++++++++++++++++++------------ 2 files changed, 83 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 0363ac4b9e2ed9..21f80c4a4fb4be 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4922,6 +4922,55 @@ def fn(obj): compiled_str = str(e) self.assertEqual(orig_str, compiled_str) + def test_super_staticmethod(self): + class Parent: + @staticmethod + def greet(): + return 5 + + class Child(Parent): + @staticmethod + def greet(x): + return x * super(Child, Child).greet() + + child = Child() + + def fn(x): + return child.greet(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.ones(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_super_diamond(self): + class A: + def __init__(self): + super().__init__() + self.a = 5 + + class Nothing: + pass + + class B(Nothing, A): + def __init__(self): + super().__init__() + self.b = 10 + + def run(self, x): + return self.a * self.b * x + + def fn(x): + b = B() + return b.run(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + def test_vc_bumped_in_inference_graph(self): @torch.compile def f(x): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 7dfbec067914c4..b2f8b9ea6cde2f 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -49,6 +49,10 @@ } +class NO_SUCH_SUBOBJ: + pass + + class SuperVariable(VariableTracker): _nonvar_fields = { "specialized", @@ -99,22 +103,28 @@ def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): type_to_use_source = self.objvar.source source = None - if self.objvar.source is not None: - # Walk the mro tuple to find out the actual class where the - # attribute resides. - search_mro = type_to_use.__mro__ - start_index = search_mro.index(search_type) + 1 - for index in range(start_index, len(search_mro)): - if hasattr(search_mro[index], name): + resolved_class = None + resolved_attr = None + search_mro = type_to_use.__mro__ + start_index = search_mro.index(search_type) + 1 + # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812 + # super has its getattro implementation. The key point is that instead of calling getattr, it checks the + # attribute in the class __dict__ + for index in range(start_index, len(search_mro)): + # Dont call getattr, just check the __dict__ of the class + if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ): + if resolved_getattr is not NO_SUCH_SUBOBJ: # Equivalent of something like type(L['self']).__mro__[1].attr_name - source = AttrSource( - GetItemSource(AttrSource(type_to_use_source, "__mro__"), index), - name, - ) - break + if type_to_use_source: + source = AttrSource( + GetItemSource( + AttrSource(type_to_use_source, "__mro__"), index + ), + name, + ) + return resolved_getattr, source - # TODO(jansel): there is a small chance this could trigger user code, prevent that - return getattr(super(search_type, type_to_use), name), source + unimplemented("Unable to resolve super getattr") def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": # Check if getattr is a constant. If not, delay the actual work by @@ -162,12 +172,17 @@ def call_method( return tx.output.side_effects.track_object_new_from_user_defined_class( self.objvar ) - elif name == "__new__" and isinstance(inner_fn, types.FunctionType): - # __new__ is a staticmethod object, but accessing __new__ from the super object, as done in - # _resolved_getattr_and_source, results in a function object. If not specialized here, it will try to add - # the `self` arg and fail bind arg matching later. + elif isinstance(inner_fn, staticmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): return variables.UserFunctionVariable( - inner_fn, source=source + inner_fn.__func__, source=source + ).call_function(tx, args, kwargs) + elif isinstance(inner_fn, classmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): + return variables.UserMethodVariable( + inner_fn.__func__, self.objvar, source=source ).call_function(tx, args, kwargs) elif isinstance(inner_fn, types.FunctionType): return variables.UserFunctionVariable( From ccc071933a9c994a16d3f8de689683ed1384f93f Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 26 Aug 2024 18:21:20 +0000 Subject: [PATCH 0030/1018] [HOO] add hints_wrapper to support passing context hints (#132860) Fixes #126393 The implementation code is based on feedback here (https://github.com/pytorch/pytorch/pull/121639#issuecomment-2223948842). Hints are passed as kwargs of hints_wrapper op. It also supports nested hints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132860 Approved by: https://github.com/ydwu4, https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 134 +++++++++++++++++ test/export/test_export.py | 65 +++++++++ torch/_dynamo/variables/higher_order_ops.py | 83 ++++++++++- torch/_higher_order_ops/__init__.py | 2 + torch/_higher_order_ops/hints_wrap.py | 151 ++++++++++++++++++++ torch/testing/_internal/hop_db.py | 1 + 6 files changed, 435 insertions(+), 1 deletion(-) create mode 100644 torch/_higher_order_ops/hints_wrap.py diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 2ec9b7f5b681ad..4c6cf74421be87 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -23,6 +23,7 @@ normalize_gm, ) from torch._dynamo.utils import counters, ifdynstaticdefault +from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.wrap import wrap from torch.testing._internal.common_utils import ( munge_exc, @@ -2502,6 +2503,139 @@ def fn(pred, pytree_in): ): torch.compile(fn, backend="eager")(pred, pytree_in) + def test_hints_wrapper(self): + def ref_fn(x, y): + x = x + y + x = torch.relu(x) + x = x + y + return torch.abs(x) + + def fn_with_hints(x, y): + x = x + y + + def inner_body_fn(x, y): + x = torch.relu(x) + x = x + y + return x + + def outer_body_fn(x, y): + x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True}) + x = torch.abs(x) + return x + + res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True}) + return res + + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + x = torch.randn(2, 4) + y = torch.ones(4) + + eager_res = fn_with_hints(x, y) + compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) + ref_res = ref_fn(x, y) + self.assertEqual(eager_res, ref_res) + self.assertEqual(compiled_res, ref_res) + self.assertEqual(len(cnt.graphs), 1) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + graph = backend.graphs[0] + self.assertExpectedInline( + normalize_gm(graph.print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"): + l_x_ = L_x_ + l_y_ = L_y_ + + x: "f32[2, 4]" = l_x_ + l_y_; l_x_ = None + + hints_wrapper_body_1 = self.hints_wrapper_body_1 + hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None + res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None + return (res,) + + class hints_wrapper_body_1(torch.nn.Module): + def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"): + hints_wrapper_body_0 = self.hints_wrapper_body_0 + hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None + x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None + + x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None + return (x_2,) + + class hints_wrapper_body_0(torch.nn.Module): + def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"): + x_1: "f32[2, 4]" = torch.relu(x); x = None + + x_2: "f32[2, 4]" = x_1 + l_y_; x_1 = l_y_ = None + return (x_2,) +""", + ) + + def test_hints_wrapper_no_hints(self): + def fn_with_hints(x, y): + def outer_body_fn(x, y): + x = torch.add(x, y) + return x + + res = hints_wrapper(outer_body_fn, (x, y), {}) + return res + + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + x = torch.randn(2, 4) + y = torch.ones(4) + + msg = "hints_wrapper - key hints not provided" + with self.assertRaisesRegex(RuntimeError, msg): + compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) + + def test_hints_wrapper_incorrect_type(self): + def fn_with_hints(x, y): + def outer_body_fn(x, y): + x = torch.add(x, y) + return x + + res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)}) + return res + + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + x = torch.randn(2, 4) + y = torch.ones(4) + + msg = r"hints must be a dict containing int, float, bool or str value," + with self.assertRaisesRegex(RuntimeError, msg): + compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) + + def test_hints_wrapper_pytree_inputs(self): + def fn_with_hints(x, y): + def outer_body_fn(x): + res = torch.add(x[0], x[1]["test"]) + return res + + res = hints_wrapper( + outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True} + ) + return res + + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + x = torch.randn(2, 4) + y = torch.ones(4) + + msg = r"args must be a tuple of tensors, ints, floats, or bools," + with self.assertRaisesRegex(RuntimeError, msg): + fn_with_hints(x, y) + class HigherOrderOpVmapGuardTests(LoggingTestCase): @make_logging_test(recompiles=True) diff --git a/test/export/test_export.py b/test/export/test_export.py index 90422bace4c162..1a491cb8a3c37a 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -20,6 +20,7 @@ from torch import Tensor from torch._decomp import get_decompositions from torch._dynamo.test_case import TestCase +from torch._dynamo.testing import normalize_gm from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse from torch._export.utils import ( get_buffer, @@ -28,6 +29,7 @@ is_param, register_dataclass_as_pytree_node, ) +from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._inductor.compile_fx import split_const_gm from torch._subclasses import FakeTensorMode from torch.export import Dim, export, unflatten @@ -7028,6 +7030,69 @@ def forward(self, x, y): ], ) + @testing.expectedFailureNonStrict + @testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked? + @testing.expectedFailureSerDer # T195866111 + def test_hints_wrapper(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + x = x + y + + def inner_body_fn(x, y): + x = torch.relu(x) + x = x + y + return x + + def outer_body_fn(x, y): + x = hints_wrapper( + inner_body_fn, (x, y), {}, hints={"inner_body": True} + ) + x = torch.abs(x) + return x + + res = hints_wrapper( + outer_body_fn, (x, y), {}, hints={"outer_body": True} + ) + return res + + x = torch.randn(2, 4) + y = torch.ones(4) + + ep = export(M(), (x, y)) + export_res = ep.module()(x, y) + ref_res = M()(x, y) + self.assertEqual(export_res, ref_res) + self.assertExpectedInline( + normalize_gm(ep.graph_module.print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, x: "f32[2, 4]", y: "f32[4]"): + add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y); x = None + + hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0 + hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True}); hints_wrapper_body_graph_0 = add = y = None + getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None + return (getitem,) + + class hints_wrapper_body_graph_0(torch.nn.Module): + def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): + hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0 + hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True}); hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None + getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None + abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem); getitem = None + return (abs_1,) + + class hints_wrapper_body_graph_0(torch.nn.Module): + def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): + relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None + add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None + return (add,) +""", + ) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 7aeb279dfb9704..851c2a05af1ef7 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -24,7 +24,12 @@ from torch.utils import _pytree as pytree from .. import variables -from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported +from ..exc import ( + IncorrectUsage, + UncapturedHigherOrderOpError, + unimplemented, + Unsupported, +) from ..source import AttrSource from ..utils import proxy_args_kwargs from .dicts import ConstDictVariable @@ -578,6 +583,8 @@ def make(value, source=None, **kwargs): return OutDtypeHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "wrap": return WrapHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "hints_wrapper": + return HintsWrapperHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "flex_attention": return FlexAttentionHigherOrderVariable(value, source, **kwargs) elif value.__name__ in ( @@ -1431,6 +1438,80 @@ def call_function( ) +class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + _check_supported_callable_arg(tx, args[0], "body_fn") + + # inputs + if len(args) != 3: + unimplemented( + f"Expected 3 arguments but got {len(args)}.\n" + f"Usage: hints_wrapper(body_fn, args, kwargs, hints).\n" + f"kwargs required to be provided explicitly." + ) + + if not isinstance(args[1], (ListVariable, TupleVariable)): + unimplemented( + f"Expected a tuple but got {args[1].python_type()}", + ) + operands = args[1].unpack_var_sequence(tx) + + if not isinstance(args[2], ConstDictVariable): + unimplemented( + f"Expected a dict but got {args[2].python_type()}", + ) + + if "hints" not in kwargs: + raise IncorrectUsage("hints_wrapper - key hints not provided") + + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], # function + operands, + args[2].as_python_constant(), + "hints_wrapper", + source_target=self.value, + should_flatten_outputs=True, + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = add_subgraph( + tx, + "hints_wrapper_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, + # all the arguments are lifted. + lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + p_args = (body_node, lifted_args, {}) + + p_kwargs = {} + # add hints into p_kwargs + p_kwargs["hints"] = kwargs["hints"].as_python_constant() + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, self.value, p_args, p_kwargs, flat_example_value, treespec + ) + + class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): def call_function( self, diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 5bfca7aa679583..72800cae7fc98c 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -3,6 +3,7 @@ flex_attention, flex_attention_backward, ) +from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.while_loop import while_loop @@ -11,4 +12,5 @@ "while_loop", "flex_attention", "flex_attention_backward", + "hints_wrapper", ] diff --git a/torch/_higher_order_ops/hints_wrap.py b/torch/_higher_order_ops/hints_wrap.py new file mode 100644 index 00000000000000..c211d405614636 --- /dev/null +++ b/torch/_higher_order_ops/hints_wrap.py @@ -0,0 +1,151 @@ +# mypy: allow-untyped-defs +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + autograd_not_implemented, + reenter_make_fx, + unique_graph_id, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + + +# used for wrapping a function/op with context hints +class HintsWrapper(HigherOrderOperator): + def __init__(self): + super().__init__("hints_wrapper") + + def __call__(self, body_fn, args, kwargs, hints): + r""" + Call implementation of hints_wrapper + + Args: + body_fn (Callable): A callable function that is within the scope + that is being traced. + + args (Tuple of torch.Tensor/int/float/bool): A tuple of inputs to + body_fn. + + kwargs (dict): Keyword argument to the body_fn. + + hints (dict): A dict of context hints which could be passed to + backend compiler. + """ + if not isinstance(args, tuple): + raise RuntimeError(f"args must be a tuple, got {type(args)}") + + if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args): + raise RuntimeError( + "args must be a tuple of tensors, ints, floats, or bools, got " + f"{args}" + ) + + if not isinstance(kwargs, dict): + raise RuntimeError(f"kwargs must be a dict, got {type(kwargs)}") + + if len(kwargs) > 0: + raise RuntimeError( + f"kwargs except for hints are not supported, got {kwargs}" + ) + + if not isinstance(hints, dict): + raise RuntimeError(f"hints must be a dict, got {type(hints)}") + + for k, v in hints.items(): + if not isinstance(k, str): + raise RuntimeError(f"hints key must be a str, got {k}.") + + if not isinstance(v, (int, float, bool, str)): + raise RuntimeError( + "hints must be a dict containing int, float, bool or str " + f"value, got value {v} for key {k}." + ) + + return super().__call__(body_fn, args, kwargs, hints) + + +hints_wrapper = HintsWrapper() + + +@hints_wrapper.py_impl(DispatchKey.CompositeExplicitAutograd) +def hints_wrapper_dense(body_fn, args, kwargs, hints): + return body_fn(*args, **kwargs) + + +hints_wrapper.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(hints_wrapper, deferred_error=True) +) + + +@hints_wrapper.py_impl(FakeTensorMode) +def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints): + flat_args = pytree.tree_leaves(args) + with mode: + return body_func(*flat_args, **kwargs) + + +@hints_wrapper.py_functionalize_impl +def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints): + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + unwrapped_hints = ctx.unwrap_tensors(hints) + with ctx.redispatch_to_next(): + functional_body_fn = ctx.functionalize(body_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + if _has_potential_branch_input_mutation( + functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "body_fn of hints_wrapper might be modifying the input!" + ) + if _has_potential_branch_input_alias( + functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "body_fn of hints_wrapper might be aliasing the input!" + ) + outputs = hints_wrapper( + functional_body_fn, + unwrapped_args, + unwrapped_kwargs, + unwrapped_hints, + ) + return ctx.wrap_tensors(outputs) + + +def trace_hints_wrapper(proxy_mode, hints_wrapper, body_fn, args, kwargs, hints): + flat_args = tuple(pytree.tree_leaves(args)) + body_graph = reenter_make_fx(body_fn)(*flat_args, **kwargs) + + _, body_graph_name = unique_graph_id(proxy_mode, prefix="hints_wrapper_body_graph") + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) + + new_args: tuple = (body_graph, flat_args, {}) + # merge hints into kwargs + new_kwargs = {} + new_kwargs["hints"] = hints + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_args) + proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_kwargs) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", hints_wrapper, proxy_args, proxy_kwargs, name="hints_wrapper" + ) + + out = body_fn(*flat_args, **kwargs) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@hints_wrapper.py_impl(ProxyTorchDispatchMode) +def inner(proxy_mode, body_fn, args, kwargs, hints): + if proxy_mode.enable_tracing: + return trace_hints_wrapper( + proxy_mode, hints_wrapper, body_fn, args, kwargs, hints + ) + else: + return hints_wrapper(body_fn, args, kwargs, hints) diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 172c9efb8800ba..fa352cb5a3777e 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -61,6 +61,7 @@ def f2(x, y0, y1): "call_torchbind", "triton_kernel_wrapper_mutation", "triton_kernel_wrapper_functional", + "hints_wrapper", ] torch.library.define( From 3862f831fbdd905259d4f329920d1d7f46309959 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 26 Aug 2024 18:29:40 +0000 Subject: [PATCH 0031/1018] [aoti] remove c_shim_version v1 logic (#134283) Summary: Previously, https://github.com/pytorch/pytorch/pull/132750 and https://github.com/pytorch/pytorch/pull/133105 set c_shim_version to 2 for all cases. So removing c_shim_version logic. Test Plan: ci Differential Revision: D61574695 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134283 Approved by: https://github.com/desertfire --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 43 +++++----------------- torch/_inductor/ir.py | 12 +----- 2 files changed, 12 insertions(+), 43 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 6ced8aefc50f86..be69d0429f731b 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -152,12 +152,9 @@ def write_header(self): ) if config.abi_compatible: - if config.c_shim_version == "1": - self.header.splice("#include ") - else: - self.header.splice( - f"#include " - ) + self.header.splice( + f"#include " + ) self.header.splice( """ #include @@ -1192,20 +1189,8 @@ def get_c_shim_func_name(self, kernel): kernel_suffix = kernel_tokens[-1] if kernel_suffix == "call": kernel_suffix = kernel_tokens[-2] - if config.c_shim_version == "1": - # For sdpa, we need the v2 version since v1 didn't consider optional arg - # FIXME: no need to do this after we switch to the torchgen-ed C shim - if kernel_suffix == "_scaled_dot_product_flash_attention": - shim_fn = "aoti_torch__scaled_dot_product_flash_attention_v2" - elif kernel_suffix == "_scaled_mm": - shim_fn = "aoti_torch__scaled_mm_v2" - elif kernel_suffix.startswith("wrapped_fbgemm"): - assert self.device == "cpu", "Using wrapped_fbgemm out of CPU!" - shim_fn = f"aoti_torch_cpu_{kernel_suffix}" - else: - shim_fn = f"aoti_torch_{kernel_suffix}" - else: - shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}" + + shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}" return shim_fn def generate_c_shim_extern_kernel_call(self, kernel, args): @@ -1335,19 +1320,11 @@ def generate_scatter_fallback( # No stack allocation when there is a fallback op self.allow_stack_allocation = False - # TODO: needs updates to use C shim v2 if config.abi_compatible: # call the ABI shim function instead of the ATen one - if config.c_shim_version == "1": - cpp_kernel_name = ( - "aoti_torch_scatter_reduce_out" - if python_kernel_name.startswith("aten.scatter_reduce") - else "aoti_torch_scatter_out" - ) - else: - cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) - # C shim only contains out-variant instead of inplace-variant - cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" inputs_wrapped = [ f"convert_arrayref_tensor_to_tensor({x})" if isinstance(x, str) @@ -1375,7 +1352,7 @@ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): # No stack allocation when there is a fallback op self.allow_stack_allocation = False - # TODO: needs updates to use C shim v2 + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version if config.abi_compatible: # See the comment in codegen_reinterpret_view about why having something like # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding @@ -2484,7 +2461,7 @@ def val_to_arg_str(self, val, type_=None) -> str: f"{self.c_type_for_prim_type(element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" ) return f"&{var_name}" - elif config.c_shim_version == "2": + else: # type_ is Optional[Tensor] # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim base_handle = self.val_to_arg_str(val, element_type) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a5a50667b38009..7d08b62f51b143 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5914,17 +5914,9 @@ def codegen(self, wrapper): if V.graph.cpp_wrapper: from torchgen.aoti.fallback_ops import inductor_fallback_ops - if ( - config.is_fbcode() - and kernel not in has_c_shim + if config.abi_compatible and str(kernel) not in inductor_fallback_ops: # C shim v2 is torchgen-ed, which should cover all aten ops. - # If you do hit a missed op, please update gen_aoti_c_shim.py. - and config.c_shim_version == "1" - ) or ( - config.abi_compatible - and config.c_shim_version == "2" - and str(kernel) not in inductor_fallback_ops - ): + # If you do hit a missed op, please update fallback_ops.py. log.warning( "%s is missing a c-shim implementation, using proxy executor as fallback", kernel, From 1c1c5be4da4e552a1cb84d9b240ee1ed28c143bc Mon Sep 17 00:00:00 2001 From: eqy Date: Mon, 26 Aug 2024 18:35:55 +0000 Subject: [PATCH 0032/1018] Support larger page sizes with `use_mmap_weights` (#131000) Fixes e.g., `test_large_mmaped_weights_non_abi_compatible_cuda` on machines with 64K page size CC @malfet @tinglvv @nWEIdia Pull Request resolved: https://github.com/pytorch/pytorch/pull/131000 Approved by: https://github.com/malfet --- torch/_inductor/codecache.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 59e0e8ba9daf5f..66ab21078dfe6e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1904,10 +1904,14 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: run_command_and_check(link_cmd) if use_mmap_weights: + import resource + page_size_ = resource.getpagesize() + page_size = max(16384, page_size_) + with open(output_so, "a+b") as f_so: so_size = f_so.tell() # Page align the weights - f_so.write(b" " * (16384 - so_size % 16384)) + f_so.write(b" " * (page_size - so_size % page_size)) f_so.write(serialized_weights) f_so.write(struct.pack("q", magic_number)) From fd6b20fc4cae3a9dfa2ddacc8fdf4846a955c386 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 26 Aug 2024 19:26:08 +0000 Subject: [PATCH 0033/1018] Add mean.dtype_out (#133506) Give it a try and see if CI is happy Pull Request resolved: https://github.com/pytorch/pytorch/pull/133506 Approved by: https://github.com/bdhirsh --- aten/src/ATen/native/ReduceOps.cpp | 9 +++++++++ aten/src/ATen/native/native_functions.yaml | 9 ++++----- .../expect/HasDecompTest.test_aten_core_operators.expect | 1 + torch/testing/_internal/common_methods_invocations.py | 5 ++++- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index c73ed0cce0562b..7a3e34eb6df591 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -1414,6 +1414,15 @@ Tensor& mean_out(const Tensor& self, DimnameList dim, return at::mean_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype); } +Tensor& mean_dtype_out(const Tensor &self, std::optional dtype, Tensor& result) { + TORCH_CHECK( + canCast(self.scalar_type(), result.scalar_type()), + "mean.dtype_out(): input types can't be cast to the desired output type ", + result.scalar_type()); + // at::mean_out should make sure dtype and result.scalar_type() are the same + return at::mean_out(result, self, IntArrayRef{}, false, dtype); +} + // TODO(@heitorschueroff) implement custom kernels for nanmean Tensor& nanmean_out( const Tensor& self, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b8168d1af4347b..e46cdec768cdff 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3920,11 +3920,10 @@ tags: core # For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this. -# FIXME: fix CI jobs and re-enable this -#- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) -# device_check: NoCheck # TensorIterator -# dispatch: -# CompositeExplicitAutograd: mean_dtype_out +- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: mean_dtype_out - func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor structured_delegate: mean.out diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 049ce10c9a123f..bf7ad0a4659cc2 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -350,6 +350,7 @@ aten::maximum aten::maximum.out aten::mean aten::mean.dim +aten::mean.dtype_out aten::mean.out aten::minimum aten::minimum.out diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 59f7f0e9462038..961d4024d61962 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -20929,7 +20929,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # FIXME: mean needs 'dim' parameter when using the 'out' overload. # Adding it with 'generate_args_kwargs' does not work, since these also get passed # onto the reference implementations. - supports_out=False, + supports_out=True, assert_autodiffed=True, assert_jit_shape_analysis=True, promotes_int_to_float=True, @@ -20937,6 +20937,9 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ref=reference_reduction_numpy(np.mean), error_inputs_func=error_inputs_mean, skips=( + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='cuda', dtypes=[torch.float32]), # FIXME: mean does not support passing keepdim without passing dim DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), # FIXME: mean reduces all dimensions when dim=[] From 2f2b7d85b12c980945b68d76f22472560028e0ab Mon Sep 17 00:00:00 2001 From: albanD Date: Mon, 26 Aug 2024 19:39:06 +0000 Subject: [PATCH 0034/1018] Move module_tracker to logging for confused hierarchy (#134467) Fixes https://github.com/pytorch/pytorch/issues/134242 Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134467 Approved by: https://github.com/malfet --- test/test_module_tracker.py | 40 ++++++++++++++++++++++++++++++----- torch/autograd/graph.py | 3 ++- torch/utils/module_tracker.py | 21 +++++++++++------- 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/test/test_module_tracker.py b/test/test_module_tracker.py index abbaaed4491a74..457b9648d73c95 100644 --- a/test/test_module_tracker.py +++ b/test/test_module_tracker.py @@ -3,7 +3,9 @@ from copy import copy import torch +from torch import nn from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo +from torch.utils.checkpoint import checkpoint from torch.utils.module_tracker import ModuleTracker @@ -14,7 +16,7 @@ def test_module_hierarchy(self): seen_fw = [] seen_bw = [] - class Foo(torch.nn.Module): + class Foo(nn.Module): def forward(self, x): x = x["a"].relu_() seen_fw.append((copy(tracker.parents), tracker.is_bw)) @@ -23,12 +25,12 @@ def forward(self, x): ) return {"a": torch.mm(x, x)} - class Mod(torch.nn.Module): + class Mod(nn.Module): def __init__(self) -> None: super().__init__() self.a = Foo() - self.b = torch.nn.ModuleDict({"nest": Foo()}) - self.c = torch.nn.ModuleList([Foo()]) + self.b = nn.ModuleDict({"nest": Foo()}) + self.c = nn.ModuleList([Foo()]) def forward(self, x): x = self.c[0](x) @@ -68,8 +70,36 @@ def forward(self, x): ], ) + def test_confused_hierarchy(self): + class MyMod(nn.Module): + def __init__(self): + super().__init__() + self.inner = nn.Linear(2, 2) + self.ran = False + + def forward(self, inp): + if not self.ran: + self.ran = True + return self(inp) + else: + self.ran = False + return self.inner(inp) + + mod = MyMod() + inp = torch.rand(1, 2, requires_grad=True) + + # Should not fail + with ModuleTracker() as tracker: + res = mod(inp) + res.sum().backward() + + # Should not fail + with ModuleTracker() as tracker: + res = checkpoint(lambda inp: mod(inp), inp) + res.sum().backward() + def test_bw_detection(self): - mod = torch.nn.Linear(2, 2) + mod = nn.Linear(2, 2) with ModuleTracker() as tracker: mod(torch.rand(2, requires_grad=True)).sum().backward() diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index aad31e223b174b..6fbf797d26889f 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -179,7 +179,8 @@ def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node: if isinstance(t, GradientEdge): return t.node if t.requires_grad and t.grad_fn is None: - node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr] + with torch.enable_grad(): + node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr] else: node = t.grad_fn assert node is not None diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 91958127b03be8..63a6b817c42462 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import logging import weakref from typing import Set @@ -11,6 +12,9 @@ from torch.utils._pytree import tree_flatten +logger = logging.getLogger(__name__) + + __all__ = ["ModuleTracker"] @@ -93,9 +97,10 @@ def fn(*args): if is_bw: self._maybe_set_engine_callback() if name in self.parents: - print( - "The module hierarchy tracking seems to be messed up." - "Please file a bug to PyTorch." + logger.info( + "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s", + name, + "backward" if is_bw else "forward", ) self.parents.add(name) @@ -105,11 +110,11 @@ def _get_pop_fn(self, name, is_bw): def fn(*args): if name in self.parents: self.parents.remove(name) - elif not is_bw: - # Due to some input/output not requiring gradients, we cannot enforce - # proper nesting in backward - raise RuntimeError( - "The Module hierarchy tracking is wrong. Report a bug to PyTorch" + else: + logger.info( + "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s", + name, + "backward" if is_bw else "forward", ) return fn From 02a08727a523fc7618da6ac128590453b1daca5f Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 26 Aug 2024 20:13:21 +0000 Subject: [PATCH 0035/1018] Fix lint failures (#134488) Introduced by https://github.com/pytorch/pytorch/pull/131000 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134488 Approved by: https://github.com/Skylion007, https://github.com/msaroufim, https://github.com/albanD, https://github.com/atalman --- torch/_inductor/codecache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 66ab21078dfe6e..53748266b7a9c4 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1905,6 +1905,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: if use_mmap_weights: import resource + page_size_ = resource.getpagesize() page_size = max(16384, page_size_) From bc767c987b8e6abc5fb6f2b08c3cb2c437a74030 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 26 Aug 2024 10:14:49 -0700 Subject: [PATCH 0036/1018] [dynamo][guards] De-dupe DUPLICATE_INPUT guard (#134354) Hard to write a test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134354 Approved by: https://github.com/jansel --- torch/_dynamo/guards.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index d61ec7f188c98e..d9cc7eb6d44d16 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -575,6 +575,8 @@ def __init__( str, torch._C._dynamo.guards.GuardManager ] = {} + self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set() + def guard_on_dict_keys_and_ignore_order(self, example_value, guard): dict_mgr = self.get_guard_manager(guard) if isinstance(dict_mgr, DictGuardManager): @@ -1663,6 +1665,13 @@ def DUPLICATE_INPUT(self, guard, source_b): self._set_guard_export_info(guard, code) if config.enable_cpp_guard_manager: + # Check that the guard has not been inserted already + key = (ref_a, ref_b) + if key in self._cached_duplicate_input_guards: + return + self._cached_duplicate_input_guards.add((ref_a, ref_b)) + self._cached_duplicate_input_guards.add((ref_b, ref_a)) + install_object_aliasing_guard( self.get_guard_manager(guard), self.get_guard_manager_from_source(source_b), From 3356366c056c018082e7daebd04082ab9dccfbaf Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 26 Aug 2024 20:57:38 +0000 Subject: [PATCH 0037/1018] [inductor] calibration inductor windows uts (9/N) (#134425) enable Windows inductor UTs of `test/inductor/test_binary_folding.py` Failed UT depends on https://github.com/pytorch/pytorch/pull/134427 Need to rebase after https://github.com/pytorch/pytorch/pull/134427 merged. ```cmd 2024-08-25T23:32:23.0905727Z Traceback (most recent call last): 2024-08-25T23:32:23.0906516Z File "C:\actions-runner\_work\pytorch\pytorch\test\inductor\test_binary_folding.py", line 18, in 2024-08-25T23:32:23.0908200Z from inductor.test_inductor_freezing import TestCase 2024-08-25T23:32:23.0909883Z File "C:\actions-runner\_work\pytorch\pytorch\test\inductor\test_inductor_freezing.py", line 39, in 2024-08-25T23:32:23.0911128Z raise unittest.SkipTest("requires sympy/functorch/filelock") 2024-08-25T23:32:23.0911801Z unittest.case.SkipTest: requires sympy/functorch/filelock 2024-08-25T23:32:23.0912370Z Got exit code 1 2024-08-25T23:32:23.0913155Z No stepcurrent file found. Either pytest didn't get to run (e.g. import error) or file got deleted (contact dev infra) ``` Local test pass: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134425 Approved by: https://github.com/ezyang, https://github.com/jansel --- test/inductor/test_binary_folding.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/test/inductor/test_binary_folding.py b/test/inductor/test_binary_folding.py index 045b330d02456c..a8b39c1bba055e 100644 --- a/test/inductor/test_binary_folding.py +++ b/test/inductor/test_binary_folding.py @@ -4,7 +4,6 @@ import itertools import os import sys -import unittest import torch from torch import nn @@ -16,20 +15,10 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN -from torch.testing._internal.inductor_utils import skipCUDAIf - - -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - from inductor.test_inductor_freezing import TestCase from inductor.test_torchinductor import check_model, check_model_gpu, copy_tests +from torch.testing._internal.common_utils import TEST_WITH_ASAN +from torch.testing._internal.inductor_utils import skipCUDAIf importlib.import_module("functorch") From 4d925d43066d6ca44462e95caef3f0a46a3c9edf Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Mon, 26 Aug 2024 21:02:29 +0000 Subject: [PATCH 0038/1018] [ARM] Fix infinite recursion in unwind (#134387) Fixes #119905 The `TORCH_SHOW_CPP_STACKTRACES=1` setting on ARM causes infinite recursive unwind because on failure a `StackTraceFetcher` attempts to unwind the failed instruction: https://github.com/pytorch/pytorch/blob/5ad759ca33ba8299cf7e1a6bb1dff7c9a5555e29/torch/csrc/profiler/combined_traceback.cpp#L25 then the unwind itself fails: https://github.com/pytorch/pytorch/blob/5ad759ca33ba8299cf7e1a6bb1dff7c9a5555e29/torch/csrc/profiler/unwind/unwind.cpp#L10-L12 and it causes another attempt to unwind the failure in `unwind()`... In summary, the executed instruction is equivalent to: ```C++ std::vector unwind() { // some instructions ... return unwind(); } ``` This PR replaces `TORCH_CHECK` by `TORCH_WARN_ONCE` as it will not cause an uncontrolled recursion. The only side effect would be an empty back-trace. Huge thanks to @nWEIdia who found the root cause! Pull Request resolved: https://github.com/pytorch/pytorch/pull/134387 Approved by: https://github.com/eqy, https://github.com/nWEIdia, https://github.com/malfet --- test/run_test.py | 9 ++------- torch/csrc/Module.cpp | 2 +- torch/csrc/profiler/unwind/unwind.cpp | 16 ++++++++-------- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index 80a724e129a7a2..adbc947958f0e1 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -573,13 +573,8 @@ def run_test( def try_set_cpp_stack_traces(env, command, set=True): # Print full c++ stack traces during retries - # Don't do it for macos inductor tests as it makes them - # segfault for some reason - if not ( - IS_MACOS and len(command) >= 2 and command[2].startswith(INDUCTOR_TEST_PREFIX) - ): - env = env or {} - env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if set else "0" + env = env or {} + env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if set else "0" return env diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 35dbb0854a8545..3dff5fb65ddef0 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -170,7 +170,7 @@ static PyObject* THPModule_initExtension( PyObject* _unused, PyObject* shm_manager_path) { HANDLE_TH_ERRORS -#if !defined(FBCODE_CAFFE2) +#if !defined(FBCODE_CAFFE2) && !defined(__aarch64__) if (torch::get_cpp_stacktraces_enabled()) { c10::SetStackTraceFetcher([]() -> std::string { auto tb = torch::CapturedTraceback::gather(false, false, true); diff --git a/torch/csrc/profiler/unwind/unwind.cpp b/torch/csrc/profiler/unwind/unwind.cpp index 8ff3910b6cd612..22ddf02d8452ea 100644 --- a/torch/csrc/profiler/unwind/unwind.cpp +++ b/torch/csrc/profiler/unwind/unwind.cpp @@ -7,29 +7,29 @@ !__has_include("ext/stdio_filebuf.h") namespace torch::unwind { std::vector unwind() { - TORCH_CHECK( - false, + TORCH_WARN_ONCE( "record_context_cpp is not support on non-linux non-x86_64 platforms"); + return {}; } std::optional> libraryFor(void* addr) { - TORCH_CHECK( - false, + TORCH_WARN_ONCE( "record_context_cpp is not support on non-linux non-x86_64 platforms"); + return {}; } #ifndef FBCODE_CAFFE2 std::vector symbolize(const std::vector& frames, Mode mode) { - TORCH_CHECK( - false, + TORCH_WARN_ONCE( "record_context_cpp is not support on non-linux non-x86_64 platforms"); + return {}; } #endif Stats stats() { - TORCH_CHECK( - false, + TORCH_WARN_ONCE( "record_context_cpp is not support on non-linux non-x86_64 platforms"); + return {}; } } // namespace torch::unwind From b7316f8c339cf0c2ffbfc5b08f915aa9b0e34d78 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 26 Aug 2024 10:37:09 -0700 Subject: [PATCH 0039/1018] [dynamo] Cache _dynamo.disable results (#134272) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134272 Approved by: https://github.com/yf225, https://github.com/jansel --- test/dynamo/test_decorators.py | 14 ++++++++++++++ torch/_dynamo/__init__.py | 1 + torch/_dynamo/convert_frame.py | 1 + torch/_dynamo/decorators.py | 15 ++++++++++++++- 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index b915e0633f7c11..463634e3216358 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -184,6 +184,20 @@ def hook(module, args): all(node.target is not torch.sigmoid for node in gm1.graph.nodes) ) + def test_disable_no_recompile(self): + def gn(x): + return torch.cos(x) + + @torch.compile(backend="eager") + def fn(x): + x = torch.sin(x) + x = torch._dynamo.disable(gn, recursive=True)(x) + return torch.sin(x) + + with torch._dynamo.config.patch(error_on_recompile=True): + for _ in range(5): + fn(torch.randn(4)) + def test_allow_in_graph(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 7f58ba7f7bf7f1..00286a32750d1e 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -107,3 +107,4 @@ def reset_code_caches() -> None: if code: reset_code(code) code_context.clear() + convert_frame.disabled_codes.clear() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 82ddaca2fa2ebb..2f0da2be866db1 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -160,6 +160,7 @@ def clear(self) -> None: input_codes = Tracker() output_codes = Tracker() +disabled_codes: Dict[int, Callable[..., Any]] = {} initial_global_state: Optional[GlobalStateGuard] = None diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index c83d8d718a62a4..52ffa76cf4a4e7 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -10,6 +10,7 @@ from . import trace_rules, variables from .comptime import comptime +from .convert_frame import disabled_codes from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage from .external_utils import is_compiling @@ -55,9 +56,21 @@ def disable(fn=None, recursive=True): """ if recursive: if fn is not None: + id_fn = id(fn) + if cached_fn := disabled_codes.get(id_fn): + return cached_fn + fn = innermost_fn(fn) assert callable(fn) - return DisableContext()(fn) + out = DisableContext()(fn) + + # Do not cache Fx graph methods. These come from Dynamo itself. Caching such graphs can increase lifetime of + # held objects. Since these are coming from Dynamo itself, there is no benefit of caching. + if not ( + inspect.ismethod(fn) and isinstance(fn.__self__, torch.fx.GraphModule) + ): + disabled_codes[id_fn] = out + return out return DisableContext() else: return skip(fn) From edaf80802cf54326220432be7cc0443926d49679 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 26 Aug 2024 11:31:58 -0400 Subject: [PATCH 0040/1018] Update AC pass use_reentrant message (#134472) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134472 Approved by: https://github.com/albanD --- torch/utils/checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 22d616b83faa47..94a8744e5c47a2 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -433,7 +433,7 @@ def checkpoint( use_reentrant(bool): specify whether to use the activation checkpoint variant that requires reentrant autograd. This parameter should be passed - explicitly. In version 2.4 we will raise an exception if + explicitly. In version 2.5 we will raise an exception if ``use_reentrant`` is not passed. If ``use_reentrant=False``, ``checkpoint`` will use an implementation that does not require reentrant autograd. This allows ``checkpoint`` to support additional @@ -464,7 +464,7 @@ def checkpoint( if use_reentrant is None: warnings.warn( "torch.utils.checkpoint: the use_reentrant parameter should be " - "passed explicitly. In version 2.4 we will raise an exception " + "passed explicitly. In version 2.5 we will raise an exception " "if use_reentrant is not passed. use_reentrant=False is " "recommended, but if you need to preserve the current default " "behavior, you can pass use_reentrant=True. Refer to docs for more " @@ -533,7 +533,7 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar use_reentrant(bool): specify whether to use the activation checkpoint variant that requires reentrant autograd. This parameter should be passed - explicitly. In version 2.4 we will raise an exception if + explicitly. In version 2.5 we will raise an exception if ``use_reentrant`` is not passed. If ``use_reentrant=False``, ``checkpoint`` will use an implementation that does not require reentrant autograd. This allows ``checkpoint`` to support additional @@ -553,7 +553,7 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar warnings.warn( "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " "parameter should be passed explicitly. " - "In version 2.4 we will raise an exception if use_reentrant " + "In version 2.5 we will raise an exception if use_reentrant " "is not passed. use_reentrant=False is " "recommended, but if you need to preserve the current default " "behavior, you can pass use_reentrant=True. Refer to docs for more " From b70d6d128412dc4d46f4128d0425a325c780827b Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 26 Aug 2024 15:45:13 +0000 Subject: [PATCH 0041/1018] Fix test_parity xfail for sigmoid (#134253) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134253 Approved by: https://github.com/amjames, https://github.com/janeyx99 --- aten/src/ATen/native/cuda/ForeachUnaryOp.cu | 2 +- .../_internal/common_methods_invocations.py | 57 ------------------- 2 files changed, 1 insertion(+), 58 deletions(-) diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index d7a118e6a9584c..c4718194955553 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -304,7 +304,7 @@ struct Sign { } }; -OP_CUSTOM_FUNCTOR(floating_half_bfloat16, sigmoid, Sigmoid) +OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, sigmoid, Sigmoid) OP_CUSTOM_FUNCTOR(floating_half_bfloat16, round, Round) OP_CUSTOM_FUNCTOR(floating_half_bfloat16, frac, Trunc) OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, reciprocal, Reciprocal) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 961d4024d61962..86fb406a1e5a6f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10572,75 +10572,18 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_meta_inplace", - device_type="cuda", - dtypes=complex_types(), - ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_meta_outplace", - device_type="cuda", - dtypes=complex_types(), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_symbolic_meta_inplace", - device_type="cuda", - dtypes=complex_types(), - ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_symbolic_meta_outplace", - device_type="cuda", - dtypes=complex_types(), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_meta_inplace", - device_type="cuda", - dtypes=complex_types(), - ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_meta_outplace", - device_type="cuda", - dtypes=complex_types(), - ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_autodiff", - device_type="cuda", - dtypes=(torch.complex128,), - ), ), ), ForeachFuncInfo( From b3d5aa3c75a13c5e43b3eaace6c77be0c4f85c7d Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 26 Aug 2024 15:45:13 +0000 Subject: [PATCH 0042/1018] Add the fast path for bfloat16 lgamma (#134344) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134344 Approved by: https://github.com/amjames, https://github.com/janeyx99 ghstack dependencies: #134253 --- aten/src/ATen/native/cuda/ForeachUnaryOp.cu | 2 +- .../_internal/common_methods_invocations.py | 44 +------------------ 2 files changed, 2 insertions(+), 44 deletions(-) diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index c4718194955553..1a969cfdbdcc43 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -237,7 +237,7 @@ void floating_half_bfloat16_(TensorList tensors) { OP_CUSTOM_FUNCTOR(function, op_name, functor_name); OP(floating_half_bfloat16, erfc, Erfc); -OP(floating_half, lgamma, Lgamma); +OP(floating_half_bfloat16, lgamma, Lgamma); OP(floating_half_bfloat16, trunc, Truncf); OP(floating_half_bfloat16, floor, Floor); OP(floating_half_bfloat16, ceil, Ceil); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 86fb406a1e5a6f..048c4139721aad 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10798,84 +10798,42 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_dispatch_meta_inplace", dtypes=complex_types() + integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_meta_inplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_meta_outplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", dtypes=complex_types() + integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_symbolic_meta_inplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_symbolic_meta_outplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=complex_types() + integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_meta_inplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_meta_outplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", "test_parity", device_type="cuda", - dtypes=complex_types() + (torch.bfloat16,), + dtypes=complex_types(), active_if=lambda kwargs: not kwargs.get("noncontiguous", False), ), DecorateInfo( From bcba5e4f44179a0fc286053b147258cc10f3d087 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 26 Aug 2024 15:45:14 +0000 Subject: [PATCH 0043/1018] TestForeach::test_parity: Remove check for error message text (#134251) Previously, error messages were expected to be string equivalent to error messages thrown by the ref function. This check fails for dozens of torch functions, and doesn't appear to add much value for the end user. This commit removes this check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134251 Approved by: https://github.com/amjames, https://github.com/janeyx99 ghstack dependencies: #134253, #134344 --- test/test_foreach.py | 6 +- .../_internal/common_methods_invocations.py | 104 ------------------ 2 files changed, 1 insertion(+), 109 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index c3d9892ef15023..6f664392289834 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -229,11 +229,7 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): **sample.kwargs, ) except Exception as e: - with ( - self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])) - if not (op.has_no_in_place or not op.supports_out) - else self.assertRaises(type(e)) - ): + with self.assertRaises(type(e)): ref([ref_input, *sample.ref_args], **ref_kwargs) else: expected = ref([ref_input, *sample.ref_args], **ref_kwargs) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 048c4139721aad..d3c7bd33c9d711 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10158,14 +10158,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10218,14 +10210,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10278,14 +10262,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10366,14 +10342,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10453,14 +10421,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10513,14 +10473,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=integral_types_and(torch.bool) + complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10629,14 +10581,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10762,14 +10706,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10828,14 +10764,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10946,14 +10874,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10984,14 +10904,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -11023,14 +10935,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -11062,14 +10966,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", From 88779a75653753771acac54ca3eb5f974524e96c Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Mon, 26 Aug 2024 22:44:12 +0000 Subject: [PATCH 0044/1018] [export] enumerate unsupported sympy.Functions (#134271) There's 2 concepts of unsupported sympy.Functions in symbolic_shapes: 1) unsupported by the export solver, meaning the solver doesn't know how to provide useful fixes for those functions 2) unsupported by the sympy interpreter - meaning we can't reify them into FX nodes because the functions aren't present in PythonReferenceAnalysis This splits the current call into a call for each version, with the Export solver the only user of 1). For 1), we enumerate the functions in _sympy/functions.py, and subtract the functions we know we can support. For 2) there's only 3 functions we've seen pop up in test cases. Differential Revision: D61677956 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134271 Approved by: https://github.com/avikchaudhuri --- torch/fx/experimental/symbolic_shapes.py | 31 +++++++++++++++++++++--- torch/fx/passes/runtime_assert.py | 6 ++--- torch/utils/_sympy/functions.py | 7 ++++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 9c7f7ea80ce399..8098d3b1465cb2 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1267,13 +1267,14 @@ def _is_supported_equivalence(expr): ) return isinstance(expr, sympy.Symbol) -def _has_unsupported_sympy_function(expr) -> bool: +def _has_uninterpretable_sympy_function(expr) -> bool: + """ + Add functions that our sympy interpreter can't reify into FX nodes + """ return expr.has( torch.utils._sympy.functions.ToFloat, torch.utils._sympy.functions.TruncToInt, torch.utils._sympy.functions.CeilToInt, - # add more sympy functions that involve float<->int conversion here - # since our solver does not know what to do with them ) @dataclass(frozen=True) @@ -1675,6 +1676,14 @@ def __init__( # symbols that are marked dynamic self._marked_dynamic = marked_dynamic + # track supported sympy functions and subtract from list of all sympy functions + self._supported_sympy_functions: Set[sympy.Function] = { + Mod, + PythonMod, + FloorDiv, + } + self._enumerate_sympy_functions() + def rewrite_with_congruences(self, s, expr): """ Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. @@ -1741,6 +1750,20 @@ def floor_div_handler(*args): expr = expr.replace(FloorDiv, floor_div_handler) return expr + def _enumerate_sympy_functions(self): + module = torch.utils._sympy.functions + all_functions = set() + for attr in dir(module): + if isinstance(func := getattr(module, attr), sympy.FunctionClass): + all_functions.add(func) + self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions) + + def _has_unsupported_sympy_function(self, expr) -> bool: + """ + Tracks list of sympy.Functions the export solver doesn't know how to handle. + """ + return expr.has(*self._unsupported_sympy_functions) + def add(self, expr) -> bool: """Add an expression to the set of constraints. @@ -1757,7 +1780,7 @@ def add(self, expr) -> bool: # a fix for this issue, we delay raising such failures. See solve(). if orig_reduced == sympy.false: self._inconsistencies.append(f"{orig_expr} is inconsistent!") - if isinstance(expr, sympy.Ne) or _has_unsupported_sympy_function(expr): + if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr): # we're not going to do anything useful with these, so drop them return False free_symbols = expr.free_symbols diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 7307fb8c835703..e2a7c1bee3db1d 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -93,7 +93,7 @@ def insert_deferred_runtime_asserts( import sympy from torch.fx.experimental.symbolic_shapes import ( - _has_unsupported_sympy_function, + _has_uninterpretable_sympy_function, CallMethodKey, cast_symbool_to_symint_guardless, ConvertIntKey, @@ -136,7 +136,7 @@ def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: (val := _get_sym_val(node)) is not None and not isinstance(val, sympy.Number) # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported - and not _has_unsupported_sympy_function(val) + and not _has_uninterpretable_sympy_function(val) and any( isinstance(arg, fx.Node) and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size)) @@ -195,7 +195,7 @@ def add_runtime_asserts(ras): and _is_bound_expr_for_symbol(ra.expr) ) # don't try to reify sympy functions we can't turn into FX nodes - or _has_unsupported_sympy_function(ra.expr) + or _has_uninterpretable_sympy_function(ra.expr) ): continue diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index d54495047e2713..4ed6b4f6b5542a 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -53,13 +53,20 @@ __all__ = [ "FloorDiv", "ModularIndexing", + "Where", + "PythonMod", + "Mod", "CleanDiv", + "CeilToInt", + "FloorToInt", "CeilDiv", "IntTrueDiv", "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", + "TruncToFloat", + "TruncToInt", "RoundToInt", "RoundDecimal", "ToFloat", From c354ac9a85ce75f21f6726b61d8e5a336a2d7421 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 26 Aug 2024 22:51:43 +0000 Subject: [PATCH 0045/1018] Fix docstring for torch.signal.windows.nuttall (#134512) This partially fixes regression introduced by https://github.com/pytorch/pytorch/pull/124771 but also just improves `z_n` rendering, by using MathML In 2.3 it was [rendered](https://pytorch.org/docs/2.3/generated/torch.signal.windows.nuttall.html#torch.signal.windows.nuttall) as image With this change it'll be [rendered](https://docs-preview.pytorch.org/pytorch/pytorch/134512/generated/torch.signal.windows.nuttall.html#torch.signal.windows.nuttall) as ```math z_n = \frac{2 \pi n}{M} ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134512 Approved by: https://github.com/kit1980, https://github.com/aorenste, https://github.com/atalman --- torch/signal/windows/windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 51784b111941ee..6626b6d1f3aaf2 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -750,7 +750,7 @@ def general_hamming(M, .. math:: w_n = 1 - 0.36358 \cos{(z_n)} + 0.48917 \cos{(2z_n)} - 0.13659 \cos{(3z_n)} + 0.01064 \cos{(4z_n)} -where ``z_n = 2 \u03c0 n/ M``. +where :math:`z_n = \frac{2 \pi n}{M}`. """, """ From 8e05fa6b625e77346241502b279f8bc5ead0a42c Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Mon, 26 Aug 2024 23:48:23 +0000 Subject: [PATCH 0046/1018] [MPS] Add `test/test_nn.py` to test suite (#134184) This PR increases test coverage by including the tests in `test/test_nn.py` in the test suite of MPS. Some of the tests are decorated with `@expectedFailureMPS` for various reasons. Either that the op is not implemented, or that the outputs do not align. Those tests that contain differing results should be investigated further to rule out any live bugs. ```bash $ python test/run_test.py --mps --verbose -k TestNN Running test batch 'tests to run' cost 84.76 seconds ``` Ref #133520 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134184 Approved by: https://github.com/albanD, https://github.com/malfet --- test/run_test.py | 2 +- test/test_nn.py | 82 +++++++++++++++---- torch/testing/_internal/common_device_type.py | 4 + 3 files changed, 73 insertions(+), 15 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index adbc947958f0e1..231a1b2b7ca012 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1417,7 +1417,7 @@ def get_selected_tests(options) -> List[str]: options.exclude.extend(CPP_TESTS) if options.mps: - selected_tests = ["test_mps", "test_metal", "test_modules"] + selected_tests = ["test_mps", "test_metal", "test_modules", "test_nn"] else: # Exclude all mps tests otherwise options.exclude.extend(["test_mps", "test_metal"]) diff --git a/test/test_nn.py b/test/test_nn.py index eb4ccd76515309..4bda984d736bc6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -39,10 +39,11 @@ from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input -from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \ +from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \ dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \ - onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types + onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, \ + skipMPSVersionIfLessThan, skipMeta, get_all_device_types from hypothesis import given import torch.testing._internal.hypothesis_utils as hu @@ -7919,6 +7920,7 @@ def _test_module_empty_inputs(self, module, inputs): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off() @bf32_on_and_off() def test_affine_2d_rotate0(self, device): @@ -7959,6 +7961,7 @@ def test_affine_2d_rotate0(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.001) @bf32_on_and_off(0.001) def test_affine_2d_rotate90(self, device): @@ -8008,6 +8011,7 @@ def test_affine_2d_rotate90(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.005) @bf32_on_and_off(0.005) def test_affine_2d_rotate45(self, device): @@ -8085,6 +8089,7 @@ def test_avg_pool_large_tensor2(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.005) @bf32_on_and_off(0.005) def test_affine_2d_rotateRandom(self, device): @@ -8137,6 +8142,7 @@ def test_affine_2d_rotateRandom(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # aten::grid_sampler_3d not implemented https://github.com/pytorch/pytorch/issues/77764 @tf32_on_and_off(0.005) @bf32_on_and_off(0.005) def test_affine_3d_rotateRandom(self, device): @@ -8199,7 +8205,9 @@ def test_batchnorm_large_batch(self, device, dtype): data = torch.rand(880801, 1, 1, 1, device=device, dtype=dtype) out = bn(data).sum().backward() + @skipMPSVersionIfLessThan(14, 0) # macOS 13 does not support bfloat16 @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.complex128) + @dtypesIfMPS(torch.float, torch.bfloat16, torch.complex64) @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) def test_conv_empty_input(self, device, dtype): def help(input, conv, memory_format): @@ -8576,6 +8584,7 @@ def test_ReplicationPad_empty(self, device, dtype): with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 6'): torch._C._nn.replication_pad3d(torch.randn([2]), padding=[]) + @expectedFailureMPS # TODO(hvaara): Investigate as possible bug. def test_ReplicationPad1d_large(self, device): shapes = ([2, 65736, 4], [65736, 2, 4]) pl, pr = 3, 4 @@ -8600,6 +8609,7 @@ def test_ReplicationPad1d_large(self, device): self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1)) self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1)) + @expectedFailureMPS # TODO(hvaara): Investigate as possible bug. def test_ReplicationPad2d_large(self, device): shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4]) pl, pr, pt, pb = 3, 4, 5, 6 @@ -8938,9 +8948,12 @@ def check_rnn_grads(rnn1, rnn2): else: self.assertEqual(hx.grad, hx_device.grad) - def test_BatchNorm_empty(self, device): + @expectedFailureMPS # TypeError: float64 https://github.com/pytorch/pytorch/issues/134423 + @dtypesIfMPS(torch.float) + @dtypes(torch.double) + def test_BatchNorm_empty(self, device, dtype): mod = torch.nn.BatchNorm2d(3).to(device) - inp = torch.randn(0, 3, 2, 2, device=device) + inp = torch.randn(0, 3, 2, 2, device=device, dtype=dtype) _test_module_empty_input(self, mod, inp) if self.device_type == 'cuda' and self.has_cudnn(): with torch.backends.cudnn.flags(enabled=False): @@ -8967,7 +8980,7 @@ def test_linear_empty(self, device): def test_one_hot(self, device): # cuda throws device assert for invalid data # xla ignores out of bound indices - if self.device_type != 'cuda' and self.device_type != 'xla': + if self.device_type not in ('cuda', 'mps', 'xla'): with self.assertRaises(RuntimeError): torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1) @@ -9016,6 +9029,7 @@ def test_one_hot(self, device): with self.assertRaises(RuntimeError): torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2) + @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764 def test_nn_empty(self, device): # One off tests to ensure scalars from nn.yaml are properly applied def verify_scalars(input, output): @@ -9031,6 +9045,7 @@ def verify_scalars(input, output): output = m(input) verify_scalars(input, output) + @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764 def test_nn_scalars(self, device): # One off tests to ensure scalars from nn.yaml are properly applied def verify_scalars(input, output): @@ -9209,6 +9224,9 @@ def func(device): # We don't want to make propagating NaN a hard requirement on ops, but for # these easy ones, we should make them do so. + # MPS: NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764 + # MPS: NotImplementedError: aten::hardshrink.out https://github.com/pytorch/pytorch/issues/77764 + @expectedFailureMPS def test_nonlinearity_propagate_nan(self, device): def test(nonlinearity, *args, **kwargs): x = torch.tensor([nan], device=device) @@ -9241,6 +9259,7 @@ def test(nonlinearity, *args, **kwargs): test('threshold', 3, 2) test('threshold', 3, 2, inplace=True) + @expectedFailureMPS # TypeError: float64 the MPS framework doesn't support float64 @parametrize_test("mode", ["nearest-exact", "nearest"]) def test_upsamplingNearest1d(self, device, mode): # Forward AD does not support XLA because XLA tensors don't have storage @@ -9319,6 +9338,7 @@ def test_upsamplingNearestExact1d_rescale(self, device): expected_out = in_t.repeat_interleave(2, dim=-1) self.assertEqual(out_t, expected_out) + @skipIfMps # Partially passes https://github.com/pytorch/pytorch/issues/134430 @parametrize_test("isize, osize", [(20, 11), (10, 15)]) def test_upsamplingNearestExact1d_correctness(self, device, isize, osize): # Here we check if output matches Scikit-Image/Scipy-like result @@ -9337,6 +9357,7 @@ def test_upsamplingNearestExact1d_correctness(self, device, isize, osize): expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) @parametrize_test("mode", ["nearest", "nearest-exact"]) def test_upsamplingNearest2d(self, device, memory_format, mode): @@ -9425,6 +9446,7 @@ def test_upsamplingNearest2d_correctness(self, device, memory_format, isize, osi expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) + @skipIfMps # Partially passes https://github.com/pytorch/pytorch/issues/134430 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) @parametrize_test("isize, osize", [(20, 11), (10, 15)]) def test_upsamplingNearestExact2d_correctness(self, device, memory_format, isize, osize): @@ -9448,6 +9470,7 @@ def test_upsamplingNearestExact2d_correctness(self, device, memory_format, isize expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) @parametrize_test("mode", ["nearest", "nearest-exact"]) def test_upsamplingNearest3d(self, device, memory_format, mode): @@ -9521,6 +9544,7 @@ def test_upsamplingNearest3d_correctness(self, device, memory_format, isize, osi expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) + @expectedFailureMPS # NotImplementedError: aten::_upsample_nearest_exact3d.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) @parametrize_test("isize, osize", [(20, 11), (10, 15)]) def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize, osize): @@ -9643,6 +9667,7 @@ def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_cha else: _ = F.interpolate(x, (12, 12), mode=mode, antialias=antialias) + @expectedFailureMPS # NotImplementedError: aten::_upsample_bilinear2d_aa.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format): # NOTE: We expand the batch dim such that `b*c` is above the maximum @@ -9663,6 +9688,8 @@ def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format): t_out = F.interpolate(t_in, size=(2, 2), mode="bilinear", align_corners=False, antialias=True) self.assertEqual(expected_out.expand([*shape[:2], 2, 2]), t_out) + # Partially passes. NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764 + @skipIfMps @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) @parametrize_test("mode", ["bilinear", "bicubic"]) @parametrize_test("antialias", [True, False]) @@ -9768,6 +9795,7 @@ def test_upsamplingBiLinear2d_consistency_interp_size_bug(self, device, memory_f ) torch.testing.assert_close(output_f32, output_ui8, atol=1, rtol=0) + @expectedFailureMPS # NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764 def test_upsamplingBicubic2d_correctness(self, device): # test output against known input: align_corners=False result must match opencv in_t = torch.arange(8., device=device).view(1, 2, 2, 2) @@ -9785,6 +9813,7 @@ def test_upsamplingBicubic2d_correctness(self, device): torch.set_printoptions(precision=5) self.assertEqual(out_t, expected_out_t, atol=1e-5, rtol=0) + @expectedFailureMPS # NotImplementedError: aten::_upsample_bicubic2d_aa.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format): t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, 3, 8, 8) @@ -9801,6 +9830,7 @@ def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format): t_out = F.interpolate(t_in, size=(2, 2), mode="bicubic", align_corners=False, antialias=True) self.assertEqual(expected_out, t_out) + @expectedFailureMPS # NotImplementedError: aten::upsample_trilinear3d.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("align_corners", [True, False]) @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) def test_upsamplingTrilinear3d(self, device, align_corners, memory_format): @@ -10483,8 +10513,8 @@ def _test_gumbel_softmax_grad(self, device, dtype): tol = 2 * torch.finfo(dtype).eps self.assertEqual(logits_soft.grad, logits_hard.grad, atol=tol, rtol=0) - @skipIfMps @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypesIfMPS(torch.float) @dtypes(torch.float, torch.double) def test_gumbel_softmax(self, device, dtype): self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=0, count_expected=1) @@ -10512,6 +10542,7 @@ def _test_rnn_retain_variables(self, device, dtype): self.assertEqual(grads, grads2) @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypesIfMPS(torch.half, torch.float) @dtypes(torch.double) def test_rnn_retain_variables(self, device, dtype): self._test_rnn_retain_variables(device, dtype) @@ -10561,6 +10592,7 @@ def flatten_out(mod, inp): # Merge into OpInfo? @skipMeta # LSTM cell reuses output which was resized + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @dtypes(torch.double) def test_LSTM_grad_and_gradgrad(self, device, dtype): hsize = 4 @@ -10570,6 +10602,7 @@ def test_LSTM_grad_and_gradgrad(self, device, dtype): self._test_rnn_mod(mod, inp) @skipMeta # GRU cell reuses output which was resized + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @dtypes(torch.double) def test_GRU_grad_and_gradgrad(self, device, dtype): hsize = 4 @@ -10738,6 +10771,7 @@ def _assertEqual_list(self, expected, list_to_compare, atol=None, rtol=None): for ele in list_to_compare: self.assertEqual(expected, ele, atol=atol, rtol=rtol) + @expectedFailureMPS # NotImplementedError: aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("reduction", ['none', 'mean', 'sum']) @parametrize_test("use_module_form", [True, False]) def test_CTCLoss_no_batch_dim(self, device, reduction, use_module_form): @@ -10868,6 +10902,7 @@ def _test_batchnorm_grad(self, device, dtype=torch.double): _assertGradAndGradgradChecks(self, F.batch_norm, (input, running_mean, running_var, weight, bias, training, 0.1, 0.0001)) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_batchnorm_grad(self, device): self._test_batchnorm_grad(device) @@ -10906,6 +10941,7 @@ def test_layernorm_weight_bias(self): out_zero_bias = torch.layer_norm(input, normalized_shape, data, bias, eps) self.assertEqual(out_none_bias, out_zero_bias) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_hardsigmoid_grad(self, device): inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10 inputs.requires_grad = True @@ -11112,6 +11148,7 @@ def test_grid_sample_nan_inf(self, device, dtype): padding_mode=padding_mode, align_corners=False) self.assertEqual(sample, torch.zeros([1, 1, 1, 2], device=device, dtype=dtype)) + @expectedFailureMPS # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 def test_CTCLoss_empty_target(self, device): target_lengths = [0, 0, 0] input_lengths = [50, 50, 50] @@ -11132,6 +11169,7 @@ def test_CTCLoss_empty_target(self, device): # Merge into OpInfo? @skipCUDAIf(True, """Test is flaky on Linux and Windows, typical error message: https://github.com/pytorch/pytorch/issues/34870""") + @expectedFailureMPS # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 def test_ctc_loss(self, device): batch_size = 64 num_labels = 101 @@ -11234,6 +11272,7 @@ def test_ctc_loss_cudnn_tensor(self, device): grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out) self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0) + @expectedFailureMPS # RuntimeError: LSTM with projections is not currently supported with MPS. @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) @tf32_on_and_off(0.005) @@ -11509,6 +11548,7 @@ def test_nll_loss_empty_tensor_reduction_none(self, device): self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device) self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device) + @expectedFailureMPS # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431 def test_nll_loss_empty_tensor_reduction_mean(self, device): nan = torch.tensor(float('nan'), device=device) self._nll_loss_helper([0, 3], "mean", nan, device) @@ -11517,6 +11557,7 @@ def test_nll_loss_empty_tensor_reduction_mean(self, device): self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device) self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device) + @expectedFailureMPS # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431 def test_nll_loss_empty_tensor_reduction_sum(self, device): zero = torch.tensor(0, device=device) self._nll_loss_helper([0, 3], "sum", zero, device) @@ -11525,6 +11566,7 @@ def test_nll_loss_empty_tensor_reduction_sum(self, device): self._nll_loss_helper([2, 3, 5, 0], "sum", zero, device) self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device) + @expectedFailureMPS # AssertionError: Expected nan but got 0.0. def test_nll_loss_total_weight_is_zero(self, device): def helper(input_size): @@ -11541,6 +11583,7 @@ def helper(input_size): helper([2, 3, 5, 7]) helper([2, 3, 5, 7, 9]) + @expectedFailureMPS # AssertionError: Expected nan but got 0.0. def test_nll_loss_all_ignored(self, device): def helper(input_size): @@ -11722,6 +11765,7 @@ def test_cross_entropy_label_smoothing_errors(self, device): r"label_smoothing must be between 0\.0"): loss(*input_arg) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @set_default_dtype(torch.double) def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, device): N, C = 10, 4 @@ -11877,6 +11921,7 @@ def test_softshrink_negative(self, device): r'lambda must be greater or equal to 0, but found to be -1\.'): m(input) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_fold(self, device): def test_dtype(fn, input, dtype): input = input.detach().clone().to(dtype=dtype).requires_grad_(True) @@ -11915,6 +11960,7 @@ def test_logsigmoid_out(self, device): # Check that clip_grad_norm_ raises an error if the total norm of the # parameters' gradients is non-finite + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_clip_grad_norm_error_if_nonfinite(self, device): norms_pos = [0.1, 1, 2, 3.5, inf] norms_neg = [-0.1, -1, -2, -3.5] @@ -12040,6 +12086,7 @@ def __init__(self) -> None: for p, pe in zip(test_model.parameters(), ref_model.parameters()): self.assertEqual(p.grad.to(devices[0]), pe.grad) + @skipMPSVersionIfLessThan(14, 0) # macOS 13 does not support bfloat16 def test_elu_inplace_overlap(self, device): x = torch.randn((1, 6), dtype=torch.bfloat16, device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): @@ -12082,6 +12129,7 @@ def test_softplus_inplace_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.softplus(x, out=x) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_softplus_low_threshold(self, device): # Ensure gradients are computed correctly with a low threshold. model = torch.nn.Softplus(threshold=1).double() @@ -12103,6 +12151,7 @@ def test_leaky_relu_inplace_overlap(self, device): F.leaky_relu_(x) # Merge into OpInfo? + @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764 def test_leaky_relu_inplace_with_neg_slope(self, device): a = torch.tensor([-1., 1.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), -2) @@ -12115,6 +12164,7 @@ def test_leaky_relu_inplace_with_neg_slope(self, device): b.backward(torch.ones(2, device=device)) # Merge into OpInfo? + @skipMPSVersionIfLessThan(14, 0) # macOS 13 does not support bfloat16 def test_leaky_relu_inplace_with_zero_slope(self, device): a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), 0.0) @@ -12238,15 +12288,13 @@ def cosine_distance(x, y): self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6) self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6) - def test_to_complex(self, device): + @dtypesIfMPS(torch.cfloat, torch.float) + @dtypes(torch.cfloat, torch.cdouble, torch.float) + def test_to_complex(self, device, dtype): m = nn.Linear(3, 5).to(device) self.assertIs(m, m.to(device)) - m.to(torch.cfloat) - self.assertIs(m.weight.dtype, torch.cfloat) - m.to(torch.cdouble) - self.assertIs(m.weight.dtype, torch.cdouble) - m.to(torch.float) - self.assertIs(m.weight.dtype, torch.float) + m.to(dtype) + self.assertIs(m.weight.dtype, dtype) with warnings.catch_warnings(record=True) as w: # Trigger warning m.to(torch.cfloat) @@ -12255,6 +12303,7 @@ def test_to_complex(self, device): self.assertTrue("Complex modules are a new feature" in str(w[-1].message)) @skipMeta + @dtypesIfMPS(torch.float32) @dtypes(torch.float32, torch.float64) def test_module_to_empty(self, device, dtype): class MyModule(nn.Module): @@ -12338,6 +12387,7 @@ def test_skip_init(self, device): self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device) self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight)) + @skipIfMps # TODO(hvaara): Investigate as possible bug. macOS 13 passes, while 14 and 15 fails. @dtypes(torch.float) @dtypesIfCUDA(torch.double, torch.float, torch.half) def test_transformerencoderlayer(self, device, dtype): @@ -12643,6 +12693,7 @@ def perm_fn(x): with cm: _test(activation=activation, batch_first=batch_first, training=training) + @skipIfMps # RuntimeError: foreach=True was passed, but can't use the foreach API on mps tensors @parametrize_test('foreach', (False, True)) def test_clip_grad_value(self, foreach, device): if torch.device(device).type == 'xla' and foreach: @@ -12670,6 +12721,7 @@ def test_clip_grad_value(self, foreach, device): clip_grad_value_([p2], clip_value, foreach=foreach) self.assertEqual(p1.grad, p2.grad) + @skipIfMps # TypeError: the MPS framework doesn't support float64 @parametrize_test('foreach', (False, True)) @parametrize_test('norm_type', (0.5, 1.5, 2, 4, 'inf')) def test_clip_grad_norm(self, norm_type, foreach, device): @@ -12768,6 +12820,7 @@ def test_adaptiveavg_pool1d_shmem(self, device): self.assertEqual(x.grad, x_cpu.grad) @skipMeta + @expectedFailureMPS # NotImplementedError: aten::channel_shuffle https://github.com/pytorch/pytorch/issues/77764 def test_channel_shuffle(self, device): # 3D tensor x = torch.tensor( @@ -12940,7 +12993,8 @@ def __init__(self) -> None: self.assertEqual(list(state_dict.keys()), list(ddp_state_dict.keys())) self.assertEqual(list(state_dict._metadata.keys()), list(ddp_state_dict._metadata.keys())) -instantiate_device_type_tests(TestNNDeviceType, globals()) + +instantiate_device_type_tests(TestNNDeviceType, globals(), allow_mps=True) instantiate_parametrized_tests(TestNN) if __name__ == '__main__': diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 709ce32b3ec814..379a0c24e1105e 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1642,6 +1642,10 @@ def expectedFailureMeta(fn): return skipIfTorchDynamo()(expectedFailure("meta")(fn)) +def expectedFailureMPS(fn): + return expectedFailure("mps")(fn) + + def expectedFailureXLA(fn): return expectedFailure("xla")(fn) From faf363d4a75c6ace880ac18a29d22f4c7e42d972 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Sun, 25 Aug 2024 21:09:30 -0700 Subject: [PATCH 0047/1018] [Inductor UT] Refine test case `test_codegen_upcast_to_fp32_upcast` to pass on XPU. (#134474) [Inductor UT] Refine test case test_codegen_upcast_to_fp32_upcast to pass on XPU. Fix issue: #134476 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134474 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4db4fd0a2f8dfa..be2964009a4aec 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11912,7 +11912,7 @@ def func(a, b): with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32): func_opt = torch._dynamo.optimize("inductor")(func) code = run_and_get_triton_code(func_opt, *inps) - fp32_cast_in_code = "float32" in code + fp32_cast_in_code = "to(tl.float32)" in code self.assertEqual(fp32_cast_in_code, upcast_to_fp32) @config.patch("triton.use_block_ptr", False) From 12fc85a9b23c2d075fa77110bc4fdb2651411036 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 27 Aug 2024 00:03:45 +0000 Subject: [PATCH 0048/1018] [ROCm] Prevent accidental enablement of efficient attention. (#133331) Currently Efficient attention and Flash attention share the same set of GPU kernels on ROCM and have common limitations on head sizes. Fixes https://github.com/pytorch/pytorch/issues/132004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133331 Approved by: https://github.com/malfet, https://github.com/jithunnair-amd --- aten/src/ATen/native/transformers/cuda/sdp_utils.cpp | 7 ++++++- test/test_transformers.py | 7 +++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 5580194a2aa80b..f093878ee3cda3 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -641,7 +641,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { check_all_tensors_on_device, check_mem_efficient_hardware_support, check_tensor_shapes, - check_head_dim_size_mem_efficient); +#ifdef USE_ROCM + check_head_dim_size_flash +#else + check_head_dim_size_mem_efficient +#endif + ); for (auto& constraint : general_constraints) { if (!constraint(params, debug)) { return false; diff --git a/test/test_transformers.py b/test/test_transformers.py index d0e3495a686514..f83f4fdce1eba4 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1623,6 +1623,8 @@ def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend): dtype = torch.float16 make_tensor = partial(torch.rand, device=device, dtype=dtype) size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257) + if TEST_WITH_ROCM: # On ROCM, FA and EA share the backend GPU kernels + size = SdpaShape(2, 2, 3, 257) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -1665,8 +1667,9 @@ def test_unaligned_tensors(self, device): make_tensor = partial(torch.rand, size, device=device, dtype=dtype) q, k, v = make_tensor(), make_tensor(), make_tensor() with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) + ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext() + with ctxmgr: + torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware") From c17ddd3c166d157c00df12a6cc02bd38de732c32 Mon Sep 17 00:00:00 2001 From: CK Luk Date: Tue, 27 Aug 2024 00:07:28 +0000 Subject: [PATCH 0049/1018] Back out "[Traceable FSDPS] Allow tracing through FSDP2 impl in trace_rules.py (#133532)" (#134478) Summary: Original commit changeset: 0215a41433e9 Original Phabricator Diff: D61432583 D61432583 causes FSDP2 stuck in PT2 compilation when applied to FB-FM-v4. With D61432583: https://www.internalfb.com/mast/job/aps-ckluk-745e763d6a After backing out D61432583: https://www.internalfb.com/mast/job/aps-ckluk-f9604ea1f9 Test Plan: hg graft D61774888 scripts/ckluk/aps/mast_joint_arch_exploration_cmf_updated_fbfm_v3_fsdp2_qps.sh Differential Revision: D61802689 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134478 Approved by: https://github.com/yf225 --- torch/_dynamo/trace_rules.py | 2 -- torch/distributed/_composable/fsdp/_fsdp_state.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index fe026ea5037954..2049dc03e0c555 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3212,7 +3212,6 @@ def _module_dir(m: types.ModuleType): # we have to add replicate to LEGACY_MOD_INLINELIST to ensure # the forward_hook won't be ignored. "torch.distributed._composable.replicate", - "torch.distributed._composable.fsdp", } @@ -3264,7 +3263,6 @@ def _module_dir(m: types.ModuleType): MOD_INLINELIST.add("torch.distributed") MOD_INLINELIST.add("torch.distributed._functional_collectives") MOD_INLINELIST.add("torch.distributed._composable.replicate") - MOD_INLINELIST.add("torch.distributed._composable.fsdp") @functools.lru_cache(None) diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index ceb480fd23939f..4b949c64d9f7f8 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -351,14 +351,12 @@ def _register_group_forward_hooks( """ modules_set = set(modules) - @disable_if_config_true @functools.wraps(pre_hook) def wrapped_pre_hook(*args: Any, **kwargs: Any): if len(modules_to_run) == 0: # first to run modules_to_run.update(modules_set) return pre_hook(*args, **kwargs) - @disable_if_config_true def get_wrapped_post_hook(module: nn.Module): @functools.wraps(post_hook) def wrapped_post_hook(*args: Any, **kwargs: Any): From 72918fc7834a9fa38f7932affc1f9fecef032ec9 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 27 Aug 2024 03:56:10 +0800 Subject: [PATCH 0050/1018] [dynamo][itertools] support `itertools.tee` (#133771) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133771 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 16 +++++++++++++ torch/_dynamo/polyfills/itertools.py | 34 ++++++++++++++++++++++++++++ torch/_dynamo/polyfills/loader.py | 5 +++- 3 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 torch/_dynamo/polyfills/itertools.py diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 9689e2c12e05a4..dc9e99f770111f 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10045,6 +10045,22 @@ def fn(l): self.assertEqual(eager, compiled) self.assertEqual(len(counters["graph_break"]), 0) + def test_itertools_tee(self): + counters.clear() + + def fn(l): + a, b = itertools.tee(l) + return list(a), list(b) + + l = [1, 2, 2, 3, 4, 4, 4, 1, 2] + eager = fn(l) + + compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) + compiled = compiled_fn(l) + + self.assertEqual(eager, compiled) + self.assertEqual(len(counters["graph_break"]), 0) + def test_list_iterator_contains(self): def fn(x): it = iter(["my_weight", "not_my_weight"]) diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py new file mode 100644 index 00000000000000..7207acd6504b51 --- /dev/null +++ b/torch/_dynamo/polyfills/itertools.py @@ -0,0 +1,34 @@ +""" +Python polyfills for itertools +""" + +import itertools +from typing import Iterable, Iterator, Tuple, TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = ["tee"] + + +_T = TypeVar("_T") + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee +@substitute_in_graph(itertools.tee) +def tee(iterable: Iterable[_T], n: int = 2, /) -> Tuple[Iterator[_T], ...]: + iterator = iter(iterable) + shared_link = [None, None] + + def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def] + try: + while True: + if link[1] is None: + link[0] = next(iterator) + link[1] = [None, None] + value, link = link + yield value + except StopIteration: + return + + return tuple(_tee(shared_link) for _ in range(n)) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 0a7e4aa628ada1..c891490b68fb32 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -11,7 +11,10 @@ from types import ModuleType -POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ("builtins",) +POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( + "builtins", + "itertools", +) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) for submodule in POLYFILLED_MODULE_NAMES From 593a8eee07e77829d1b3c20debf790e13b503f94 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 27 Aug 2024 03:56:11 +0800 Subject: [PATCH 0051/1018] [dynamo] simplify implementation for `os.fspath` (#133801) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133801 Approved by: https://github.com/anijain2305 ghstack dependencies: #133771 --- torch/_dynamo/polyfills/__init__.py | 12 ---------- torch/_dynamo/polyfills/loader.py | 1 + torch/_dynamo/polyfills/os.py | 36 +++++++++++++++++++++++++++++ torch/_dynamo/variables/misc.py | 14 +---------- 4 files changed, 38 insertions(+), 25 deletions(-) create mode 100644 torch/_dynamo/polyfills/os.py diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 5e51b08ad40fd8..9b465726754bbe 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -154,15 +154,3 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): if isinstance(obj, cls): obj.__init__(*args, **kwargs) return obj - - -def fspath(path): - # Python equivalent of os.fspath - if isinstance(path, (str, bytes)): - return path - elif hasattr(path, "__fspath__"): - return path.__fspath__() - else: - raise TypeError( - f"expected str, bytes or os.PathLike object, not {type(path).__name__}" - ) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index c891490b68fb32..377029fbd63c86 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -14,6 +14,7 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( "builtins", "itertools", + "os", ) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) diff --git a/torch/_dynamo/polyfills/os.py b/torch/_dynamo/polyfills/os.py new file mode 100644 index 00000000000000..013cc81ac30e6f --- /dev/null +++ b/torch/_dynamo/polyfills/os.py @@ -0,0 +1,36 @@ +""" +Python polyfills for os +""" + +from __future__ import annotations + +import os +from typing import AnyStr + +from ..decorators import substitute_in_graph + + +__all__ = ["fspath"] + + +# Copied from os.py in the standard library +@substitute_in_graph(os.fspath) +def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: + if isinstance(path, (str, bytes)): + return path + + path_type = type(path) + try: + path_repr = path_type.__fspath__(path) # type: ignore[arg-type] + except AttributeError: + if hasattr(path_type, "__fspath__"): + raise + raise TypeError( + f"expected str, bytes or os.PathLike object, not {path_type.__name__}", + ) from None + if isinstance(path_repr, (str, bytes)): + return path_repr # type: ignore[return-value] + raise TypeError( + f"expected {path_type.__name__}.__fspath__() to return str or bytes, " + f"not {type(path_repr).__name__}", + ) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index b2f8b9ea6cde2f..5d81b1730ecf5c 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -4,7 +4,6 @@ import functools import inspect import itertools -import os import random import re import sys @@ -15,7 +14,7 @@ import torch._numpy as tnp import torch.utils._pytree as pytree -from .. import config, polyfills, variables +from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented @@ -30,7 +29,6 @@ ) from ..utils import ( check_unspec_or_constant_args, - hashable, identity, is_tensor_base_attr_getter, proxy_args_kwargs, @@ -44,10 +42,6 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator -POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS = { - os.fspath: polyfills.fspath, -} - class NO_SUCH_SUBOBJ: pass @@ -1128,12 +1122,6 @@ def var_getattr(self, tx: "InstructionTranslator", name): else: attr_value = self.value.__dict__[name] - if hashable(attr_value): - if polyfill_fn := POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS.get( - attr_value, None - ): - return variables.UserFunctionVariable(polyfill_fn) - if self.source: new_source = AttrSource(self.source, name) return VariableBuilder(tx, new_source)(attr_value) From 62ce05093c57b4bbba419c9d91e21f455c925148 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 27 Aug 2024 00:16:16 +0000 Subject: [PATCH 0052/1018] [inductor] calibration inductor windows uts (13/N) (#134429) enable Windows inductor UTs for `test/inductor/test_torchinductor_dynamic_shapes.py` Local test pass: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134429 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor_dynamic_shapes.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 5f97c3c67490d9..476419779fbb02 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -27,8 +27,6 @@ ) from torch.testing._internal.common_utils import ( IS_ARM64, - IS_CI, - IS_WINDOWS, parametrize, TEST_CUDA_MEM_LEAK_CHECK, TEST_WITH_ASAN, @@ -37,14 +35,6 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor_dynamic_shapes yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) From 005074106d5f4115c648bdec24616aa6478bb407 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 26 Aug 2024 13:13:07 -0700 Subject: [PATCH 0053/1018] [FlexAttention] Fix bug when checking whether to return LSE (#134495) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134495 Approved by: https://github.com/yanboliang, https://github.com/Chillee, https://github.com/BoyuanFeng --- test/inductor/test_flex_attention.py | 17 +++++++++++++++++ torch/nn/attention/flex_attention.py | 17 +++++++++++------ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a657a371081415..ecc2f7d8e2978a 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1624,6 +1624,23 @@ def mask_mod(b, h, q, kv): self.run_test_with_call(attention, Q_S=S - 1, KV_S=S - 1) + @supported_platform + def test_force_write_lse(self): + make_tensor = functools.partial( + torch.randn, + (2, 2, 128, 16), + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + out_eager, lse_eager = flex_attention(query, key, value, return_lse=True) + + flex_compile = torch.compile(flex_attention, fullgraph=True) + out_compiled, lse_compiled = flex_compile(query, key, value, return_lse=True) + + torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) + @supported_platform def test_causal_block_non_divisible_with_captured_buffer(self): Q_S = S - 3 diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 5b93b1e018f6ae..556d85179cb170 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -824,7 +824,9 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask: ) -def _apply_kernel_options(query, key, value, kernel_options): +def _apply_kernel_options( + query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options +): kernel_options = {} if kernel_options is None else dict(kernel_options) kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) @@ -832,11 +834,13 @@ def _apply_kernel_options(query, key, value, kernel_options): # If foward kernel needs to return logsumexp is decided by this rule internally. assert "OUTPUT_LOGSUMEXP" not in kernel_options - any_inputs_require_grad = ( - query.requires_grad or key.requires_grad or value.requires_grad - ) - output_logsumexp = any_inputs_require_grad and torch.is_grad_enabled() - kernel_options.setdefault("OUTPUT_LOGSUMEXP", output_logsumexp) + kernel_options["OUTPUT_LOGSUMEXP"] = True + if not return_lse: + any_inputs_require_grad = ( + query.requires_grad or key.requires_grad or value.requires_grad + ) + output_logsumexp = any_inputs_require_grad and torch.is_grad_enabled() + kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp return kernel_options @@ -957,6 +961,7 @@ def score_mod( query, key, value, + return_lse, kernel_options, ) From 263d03e651efac36aff7f68681127068a6e180e6 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 27 Aug 2024 00:44:59 +0000 Subject: [PATCH 0054/1018] Revert "[export] enumerate unsupported sympy.Functions (#134271)" This reverts commit ddd71e34797f3bb56a048058e007a2df87c5755f. Reverted https://github.com/pytorch/pytorch/pull/134271 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/134271#issuecomment-2311353460)) --- torch/fx/experimental/symbolic_shapes.py | 31 +++--------------------- torch/fx/passes/runtime_assert.py | 6 ++--- torch/utils/_sympy/functions.py | 7 ------ 3 files changed, 7 insertions(+), 37 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 8098d3b1465cb2..9c7f7ea80ce399 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1267,14 +1267,13 @@ def _is_supported_equivalence(expr): ) return isinstance(expr, sympy.Symbol) -def _has_uninterpretable_sympy_function(expr) -> bool: - """ - Add functions that our sympy interpreter can't reify into FX nodes - """ +def _has_unsupported_sympy_function(expr) -> bool: return expr.has( torch.utils._sympy.functions.ToFloat, torch.utils._sympy.functions.TruncToInt, torch.utils._sympy.functions.CeilToInt, + # add more sympy functions that involve float<->int conversion here + # since our solver does not know what to do with them ) @dataclass(frozen=True) @@ -1676,14 +1675,6 @@ def __init__( # symbols that are marked dynamic self._marked_dynamic = marked_dynamic - # track supported sympy functions and subtract from list of all sympy functions - self._supported_sympy_functions: Set[sympy.Function] = { - Mod, - PythonMod, - FloorDiv, - } - self._enumerate_sympy_functions() - def rewrite_with_congruences(self, s, expr): """ Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. @@ -1750,20 +1741,6 @@ def floor_div_handler(*args): expr = expr.replace(FloorDiv, floor_div_handler) return expr - def _enumerate_sympy_functions(self): - module = torch.utils._sympy.functions - all_functions = set() - for attr in dir(module): - if isinstance(func := getattr(module, attr), sympy.FunctionClass): - all_functions.add(func) - self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions) - - def _has_unsupported_sympy_function(self, expr) -> bool: - """ - Tracks list of sympy.Functions the export solver doesn't know how to handle. - """ - return expr.has(*self._unsupported_sympy_functions) - def add(self, expr) -> bool: """Add an expression to the set of constraints. @@ -1780,7 +1757,7 @@ def add(self, expr) -> bool: # a fix for this issue, we delay raising such failures. See solve(). if orig_reduced == sympy.false: self._inconsistencies.append(f"{orig_expr} is inconsistent!") - if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr): + if isinstance(expr, sympy.Ne) or _has_unsupported_sympy_function(expr): # we're not going to do anything useful with these, so drop them return False free_symbols = expr.free_symbols diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index e2a7c1bee3db1d..7307fb8c835703 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -93,7 +93,7 @@ def insert_deferred_runtime_asserts( import sympy from torch.fx.experimental.symbolic_shapes import ( - _has_uninterpretable_sympy_function, + _has_unsupported_sympy_function, CallMethodKey, cast_symbool_to_symint_guardless, ConvertIntKey, @@ -136,7 +136,7 @@ def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: (val := _get_sym_val(node)) is not None and not isinstance(val, sympy.Number) # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported - and not _has_uninterpretable_sympy_function(val) + and not _has_unsupported_sympy_function(val) and any( isinstance(arg, fx.Node) and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size)) @@ -195,7 +195,7 @@ def add_runtime_asserts(ras): and _is_bound_expr_for_symbol(ra.expr) ) # don't try to reify sympy functions we can't turn into FX nodes - or _has_uninterpretable_sympy_function(ra.expr) + or _has_unsupported_sympy_function(ra.expr) ): continue diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 4ed6b4f6b5542a..d54495047e2713 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -53,20 +53,13 @@ __all__ = [ "FloorDiv", "ModularIndexing", - "Where", - "PythonMod", - "Mod", "CleanDiv", - "CeilToInt", - "FloorToInt", "CeilDiv", "IntTrueDiv", "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", - "TruncToFloat", - "TruncToInt", "RoundToInt", "RoundDecimal", "ToFloat", From b3351980c15ec4d97f6797795fd7f179d7f109a5 Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 26 Aug 2024 12:56:57 -0700 Subject: [PATCH 0055/1018] Unify lowerings for auto_functionalized and triton_kernel_wrapper_functional (#134466) Fixes https://github.com/pytorch/pytorch/issues/134372 The triton_kernel_wrapper_functional lowering was causing problems (it was generating small kernels with nans in it, probably from realizing aten.empty nodes. Instead of having its own manual lowering, we change triton_kernel_wrapper_functional to go the same route as auto_functionalized where we decompose the node into clone + mutation nodes. Test Plan: - new test - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/134466 Approved by: https://github.com/oulgen, https://github.com/eellison ghstack dependencies: #134364 --- test/inductor/test_triton_kernels.py | 52 ++++++++++++++++++++++++++ torch/_inductor/fx_passes/post_grad.py | 32 ++++++++++++++++ torch/_inductor/lowering.py | 38 +------------------ 3 files changed, 85 insertions(+), 37 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index e2a659c0cc190a..14eec775cdec31 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -18,6 +18,7 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU +from torch.testing._internal.logging_utils import logs_to_string # Defines all the kernels for tests from torch.testing._internal.triton_utils import * # noqa: F403 @@ -266,6 +267,57 @@ def call_triton_return_view(x: torch.Tensor): self.assertEqual(2 * t_view, compiled_func(t).view(16)) self.assertEqual(2 * t, compiled_func(t)) + @requires_gpu + def test_no_nan_kernels(self): + @triton.jit + def add_one_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x + 1 + tl.store(out_ptr + offsets, output, mask=mask) + + def add_one(x, out): + n_elements = x.numel() + add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + + class AddOne(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = torch.empty_like(x) + add_one(x, out) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad): + (saved,) = ctx.saved_tensors + out = torch.empty_like(grad) + add_one(saved, out) + return out + + @torch.compile + def f(x): + return AddOne.apply(x) + + log_stream, ctx = logs_to_string("torch._inductor.codecache", "output_code") + + x = torch.randn(3, requires_grad=True, device=GPU_TYPE) + with ctx(): + y = f(x) + + output_code = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() + self.assertTrue(len(output_code) > 0, msg="output code is not empty") + self.assertEqual(output_code.count('float("nan")'), 0) + self.assertEqual(output_code.count("float('nan')"), 0) + @requires_gpu @common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 8d2f0d5c7caad7..1d4029b852320f 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -798,6 +798,13 @@ def remove_noop_ops(graph: torch.fx.Graph): def decompose_auto_functionalized(graph): + """Decomposes auto_functionalized and triton_kernel_wrapper_functional + nodes into clones and the underlying mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ graph_pass = PatternMatcherPass() @register_graph_pattern( @@ -822,11 +829,36 @@ def decomp(*flat_args): match.replace_by_example(decomp, flat_args, run_functional_passes=False) + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_functional_dense, + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return (triton_kernel_wrapper_functional_dense(*args, **kwargs),) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + graph_pass.apply(graph) for node in graph.find_nodes( op="call_function", target=torch.ops.higher_order.auto_functionalized ): raise AssertionError("auto_functionalized was not removed") + for node in graph.find_nodes( + op="call_function", + target=torch.ops.higher_order.triton_kernel_wrapper_functional, + ): + raise AssertionError("triton_kernel_wrapper_functional was not removed") @register_lowering_pattern( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 86872e46a0b4c0..7654792cdd299f 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -18,10 +18,7 @@ import torch.fx import torch.utils._pytree as pytree from torch._higher_order_ops.associative_scan import associative_scan_op -from torch._higher_order_ops.triton_kernel_wrap import ( - triton_kernel_wrapper_functional, - triton_kernel_wrapper_mutation, -) +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation from torch._prims_common import ( canonicalize_dim, canonicalize_dims, @@ -6174,39 +6171,6 @@ def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs): return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} -@register_lowering(triton_kernel_wrapper_functional) -def triton_kernel_wrap( - *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone -): - new_kwargs = {} - for name, value in kwargs.items(): - if isinstance(value, ir.TensorBox): - x = value.data - has_non_rv_views = False - while isinstance(x, ir.BaseView): - if not isinstance(x, ir.ReinterpretView): - has_non_rv_views = True - break - x = x.data - if has_non_rv_views: - # we realize the inputs wrapped into any view which is not - # ReinterpretView to convert them into ReinterpretView during - # realization; all views being ReinterpretView is assumed by - # the downstream code (e.g., preserving ReinterpretView in - # cloning; layout should be available in mutation marking) - value = ir.TensorBox(ir.ExternKernel.realize_input(value)) - if name in tensors_to_clone: - value = clone_preserve_reinterpret_view(value) - new_kwargs[name] = value - - return triton_kernel_wrap_( - kernel_idx=kernel_idx, - constant_args_idx=constant_args_idx, - grid=grid, - kwargs=new_kwargs, - ) - - @register_lowering(torch.ops.higher_order.cond) def cond(pred, true_fn, false_fn, operands): if is_triton(pred) or any(map(is_triton, operands)): From f0e8a5794f5dbc6929fab4d6c7ce9cfe3b527b2d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 24 Aug 2024 17:45:31 -0700 Subject: [PATCH 0056/1018] Improve print stack/locals printing in comptime (#133651) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/133651 Approved by: https://github.com/anijain2305 --- test/dynamo/test_comptime.py | 6 +++--- torch/_dynamo/comptime.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index 61a3a088013ed3..28e8f15c737eb2 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -160,7 +160,7 @@ def f(x): self.assertExpectedInline( FILE.getvalue(), """\ -- TensorVariable() +- FakeTensor(..., size=(2,)) """, ) @@ -186,8 +186,8 @@ def _(ctx): self.assertExpectedInline( FILE.getvalue(), """\ -x = TensorVariable() -y = TensorVariable() +x = FakeTensor(..., size=(2,)) +y = FakeTensor(..., size=(2,)) """, ) diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index f1a91e358c6d96..972d79d48fa8b2 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -233,10 +233,9 @@ def print_value_stack(self, *, file=None, stacklevel=0): NB: Stack grows downwards in our print """ - # TODO: improve printing tx = self.__get_tx(stacklevel) for s in tx.stack: - print(f"- {s}", file=file) + print(f"- {s.debug_repr()}", file=file) def print_locals(self, *, file=None, stacklevel=0): """ @@ -244,10 +243,9 @@ def print_locals(self, *, file=None, stacklevel=0): By default this view is very limited; you can get more information about any individual local using get_local(). """ - # TODO: improve by improving the VariableTracker printing tx = self.__get_tx(stacklevel) for k, v in tx.symbolic_locals.items(): - print(f"{k} = {v}", file=file) + print(f"{k} = {v.debug_repr()}", file=file) def print_bt(self, *, file=None, stacklevel=0): """ From e8ed816a283e02829237aeefd3d743bef1f3c16f Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 24 Aug 2024 21:09:56 -0400 Subject: [PATCH 0057/1018] Remove unnecessary expect_true in split_with_sizes (#133439) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/133439 Approved by: https://github.com/albanD --- torch/_decomp/decompositions.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index e4eb73dc624ed0..bc1ccc6a20bb2d 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1414,13 +1414,9 @@ def split_with_sizes( start_idx = 0 # Avoid importing sympy at a module level - from torch.fx.experimental.symbolic_shapes import expect_true for i in range(num_splits): length = split_sizes[i] - # We know this is true thanks to the sum, but this assertion helps - # out our internal reasoning - expect_true(start_idx + length <= self.shape[dim]) splits.append(self.narrow(dim, start_idx, length)) start_idx += length return splits From ac8e6429598c9fb00f900c154259883d63ce1c24 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 27 Aug 2024 01:48:21 +0000 Subject: [PATCH 0058/1018] [CMake] Remove BUILDING_WITH_TORCH_LIBS (#134434) Since BUILDING_WITH_TORCH_LIBS is not used now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134434 Approved by: https://github.com/ezyang --- CMakeLists.txt | 4 ---- scripts/build_windows.bat | 4 ---- tools/setup_helpers/cmake.py | 1 - 3 files changed, 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 89ef59681bfff4..70ab320adff605 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -496,10 +496,6 @@ if(USE_SYSTEM_LIBS) endif() endif() -# Used when building Caffe2 through setup.py -option(BUILDING_WITH_TORCH_LIBS - "Tell cmake if Caffe2 is being built alongside torch libs" ON) - # /Z7 override option When generating debug symbols, CMake default to use the # flag /Zi. However, it is not compatible with sccache. So we rewrite it off. # But some users don't use sccache; this override is for them. diff --git a/scripts/build_windows.bat b/scripts/build_windows.bat index fbdade94353554..60bfebad08c016 100644 --- a/scripts/build_windows.bat +++ b/scripts/build_windows.bat @@ -22,10 +22,6 @@ if NOT DEFINED BUILD_SHARED_LIBS ( ) ) -IF NOT DEFINED BUILDING_WITH_TORCH_LIBS ( - set BUILDING_WITH_TORCH_LIBS=OFF -) - if NOT DEFINED CAFFE2_STATIC_LINK_CUDA ( set CAFFE2_STATIC_LINK_CUDA=OFF ) diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index d50d9d54739979..4b605fe597505a 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -204,7 +204,6 @@ def generate( "UBSAN_FLAGS", "BLAS", "WITH_BLAS", - "BUILDING_WITH_TORCH_LIBS", "CUDA_HOST_COMPILER", "CUDA_NVCC_EXECUTABLE", "CUDA_SEPARABLE_COMPILATION", From 811021b090b61e32d68658929836bb9ffb65e977 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 27 Aug 2024 10:26:45 +0800 Subject: [PATCH 0059/1018] [BE][Easy][dynamo] ensure `trace_rules.MOD_INLINELIST` in alphabetical order (#134246) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #134246 * #133987 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134246 Approved by: https://github.com/yanboliang --- torch/_dynamo/trace_rules.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 2049dc03e0c555..87db00ad1f9bab 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3218,19 +3218,23 @@ def _module_dir(m: types.ModuleType): # Force inline functions under these modules, even they are in *_SKIPLIST. # We are using python module name instead of file or directory object to avoid circular dependency. # Please keep this sorted alphabetically. -MOD_INLINELIST = { - "torch.utils._python_dispatch", - "torch._refs", - "torch._prims", +MOD_INLINELIST = [ "torch._decomp", "torch._dynamo._trace_wrapped_higher_order_op", "torch._dynamo.comptime", "torch._dynamo.polyfills", - "torch._functorch.vmap", "torch._functorch.autograd_function", - "torch._library.custom_ops", "torch._functorch.eager_transforms", + "torch._functorch.functional_call", + "torch._functorch.vmap", + "torch._higher_order_ops.associative_scan", + "torch._higher_order_ops.strict_mode", + "torch._higher_order_ops.while_loop", "torch._inductor.test_operators", + "torch._library.custom_ops", + "torch._prims", + "torch._refs", + "torch._tensor", "torch.amp.autocast_mode", "torch.ao.nn", "torch.autograd.function", @@ -3245,24 +3249,19 @@ def _module_dir(m: types.ModuleType): "torch.random", "torch.sparse", "torch.testing", - "torch.testing._internal.hypothesis_utils", "torch.utils._content_store", "torch.utils._contextlib", "torch.utils._foreach_utils", + "torch.utils._python_dispatch", "torch.utils._pytree", "torch.utils.hooks", - "torch._tensor", - "torch._higher_order_ops.strict_mode", - "torch._higher_order_ops.while_loop", - "torch._higher_order_ops.associative_scan", - "torch._functorch.functional_call", -} +] +assert sorted(set(MOD_INLINELIST)) == MOD_INLINELIST +MOD_INLINELIST = set(MOD_INLINELIST) if torch.distributed.is_available(): MOD_INLINELIST.add("torch.distributed") - MOD_INLINELIST.add("torch.distributed._functional_collectives") - MOD_INLINELIST.add("torch.distributed._composable.replicate") @functools.lru_cache(None) From 396f26022fcb845107e744f150099123f03f1148 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 27 Aug 2024 02:33:07 +0000 Subject: [PATCH 0060/1018] [inductor] Turn on UT: test_randint_int64_mod (#134510) It fixed by https://github.com/pytorch/pytorch/pull/134229, turn on it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134510 Approved by: https://github.com/ezyang --- test/inductor/test_torchinductor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index be2964009a4aec..767d68dfdbb835 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11022,7 +11022,6 @@ def fn(a, b): actual = torch.compile(fn)(a, b) self.assertEqual(ref, actual) - @skipIfWindows(msg="torch._dynamo.exc.BackendCompilerFailed") # TODO: FIX IT def test_randint_int64_mod(self): # This used to not compile due to a wrong return type of randint64_cpu # See https://github.com/pytorch/pytorch/issues/117435 From ff881ec18a659691f7e716f65baa541c819e2a76 Mon Sep 17 00:00:00 2001 From: Shuai Yang Date: Tue, 27 Aug 2024 02:57:14 +0000 Subject: [PATCH 0061/1018] Back out "[dynamo][exception] Support raise exception from None (#134028)" (#134513) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The original diff is causing the error "attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype ‘float". The context is in: https://fb.workplace.com/groups/1075192433118967/permalink/1491357138169159/ Test Plan: After reverting, the above issue is gone, details are in https://fb.workplace.com/groups/1075192433118967/permalink/1491357138169159/ Differential Revision: D61820520 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134513 Approved by: https://github.com/anijain2305 --- test/dynamo/test_exceptions.py | 33 --------------- torch/_dynamo/symbolic_convert.py | 69 +++++++++++++------------------ 2 files changed, 28 insertions(+), 74 deletions(-) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index cd95d98abf30ae..d9b364443eb132 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -330,39 +330,6 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) - def test_raise_from_None(self): - # Inspired from os.environ - class MyMapping: - def __init__(self, d): - self._d = d - - def __getitem__(self, key): - try: - value = self._d[key] - except KeyError: - raise KeyError(key) from None - return value - - d = MyMapping({"a": 10, "b": 20}) - - def mapping_get(obj, key, value=None): - try: - return obj.__getitem__(key) - except KeyError: - return value - - def fn(x, d, key): - x = torch.sin(x + 1) - return x, mapping_get(d, key) - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - - x = torch.rand(2, 3) - ref = fn(x, d, "m") - res = opt_fn(x, d, "m") - self.assertEqual(ref[0], res[0]) - self.assertEqual(ref[1], res[1]) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 76660b417af882..0b7d68be08b9e5 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1359,48 +1359,34 @@ def FOR_ITER(self, inst): self.push(ConstantVariable.create(None)) self.jump(inst) - def _raise_exception_variable(self, inst): - val = self.pop() - # User can raise exception in 2 ways - # 1) raise exception type - raise NotImplementedError - # 2) raise execption instance - raise NotImplemetedError("foo") - - # 1) when user raises exception type - if isinstance(val, variables.BuiltinVariable): - # Create the instance of the exception type - # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 - val = val.call_function(self, [], {}) # type: ignore[arg-type] - - # Save the exception in a global data structure - self.exn_vt_stack.append(val) - - # 2) when user raises exception instance - if isinstance(val, variables.ExceptionVariable): - if val.exc_type is StopIteration: - # StopIteration is used to find the end of iteration while tracing __next__ - raise exc.ObservedUserStopIteration(f"raised exception {val}") - raise exc.ObservedException(f"raised exception {val}") - unimplemented(f"raise {exc}") - def RAISE_VARARGS(self, inst): if inst.arg == 0: unimplemented("re-raise") elif inst.arg == 1: - self._raise_exception_variable(inst) + val = self.pop() + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + if isinstance(val, variables.BuiltinVariable): + # Create the instance of the exception type + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 + val = val.call_function(self, [], {}) # type: ignore[arg-type] + + # Save the exception in a global data structure + self.exn_vt_stack.append(val) + + # 2) when user raises exception instance + if isinstance(val, variables.ExceptionVariable): + if val.exc_type is StopIteration: + # StopIteration is used to find the end of iteration while tracing __next__ + raise exc.ObservedUserStopIteration(f"raised exception {val}") + raise exc.ObservedException(f"raised exception {val}") + unimplemented(f"raise {exc}") else: - # Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we - # ignore `from None` part. - from_vt = self.pop() - if isinstance(from_vt, ConstantVariable) and from_vt.value is None: - self._raise_exception_variable(inst) unimplemented("raise ... from ...") - def RERAISE(self, inst): - if sys.version_info >= (3, 11): - # RERAISE is currently supported in a narrow case of `raise ... from None` - self._raise_exception_variable(inst) - unimplemented("RERAISE") - def exception_handler(self, raised_exception): if sys.version_info >= (3, 11): exn_tab_entry = self.current_instruction.exn_tab_entry @@ -1414,6 +1400,10 @@ def exception_handler(self, raised_exception): # 2) if 'lasti' is true, then push the offset that the exception was raised at if exn_tab_entry.lasti: + # This is untested. Any test that tests this end-to-end + # requires supporting more bytecodes. Therefore graph + # breaking for now. + unimplemented("lasti=True while exception handling") self.push( variables.ConstantVariable(self.current_instruction.offset) ) @@ -1445,12 +1435,9 @@ def exception_handler(self, raised_exception): # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 self.popn(3) if len(self.block_stack) == 0: - # No handler found in this frame. Bubble the exception to the parent - # instruction translater. - self.stack.clear() - if type(self) is InstructionTranslator: - raise Unsupported("Observed exception") - raise raised_exception + unimplemented( + "exception is raised when block stack " "is empty" + ) block_stack_entry = self.block_stack.pop() if block_stack_entry.inst.opname != "SETUP_FINALLY": From 30f2cd555c8adcf835c027694f045d516290b994 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Tue, 27 Aug 2024 04:55:04 +0000 Subject: [PATCH 0062/1018] [PT2/Profiler] Add Context Info to Torch-Compiled Regions (#132765) Summary: We want to add compile IDs and frames to each Torch-Compiled Region in order to help users cross reference the section they are checking alongside data obtained from tools, such as tlparse. This diff operates on the assumption that each graph section will enter and exit a CompileContext before it is ran to either compile the graph or look it up in the cache. Based on this assuption, we can save the value of the graph section from the exited CompileContext in eval_frame.c using a Python C API. After this, we can create a new interface in cpp shim to wrap around the record_function in order to pass in the new keyword argument for "context". Test Plan: Enhance test_profiler_dynamo_compiled_region to look for kwinputs as well as a name to see that the context is now labeled. Also changed test to run graph with more contexts so that we test a wider range of profiling. Differential Revision: D60803317 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132765 Approved by: https://github.com/anijain2305 --- aten/src/ATen/record_function.h | 2 +- test/dynamo/test_profiler.py | 41 +++++++++++++++++++++++++-------- torch/_C/_dynamo/eval_frame.pyi | 3 ++- torch/_guards.py | 10 ++++++++ torch/csrc/dynamo/cpp_shim.cpp | 21 +++++++++++++++-- torch/csrc/dynamo/cpp_shim.h | 4 +++- torch/csrc/dynamo/eval_frame.c | 22 ++++++++++++++++-- 7 files changed, 87 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 63fbcb55e96d2b..125c1d8c491101 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -319,7 +319,7 @@ struct TORCH_API RecordFunction { if (!isActive()) { return; } - kwinputs_ = *kwargs; + kwinputs_ = std::unordered_map(*kwargs); before(std::move(fn), args, current_sequence_nr); } diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py index 0824a4164b3f81..58395bb7c914fc 100644 --- a/test/dynamo/test_profiler.py +++ b/test/dynamo/test_profiler.py @@ -1,10 +1,12 @@ # Owner(s): ["module: dynamo"] +import logging from unittest.mock import patch import torch import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils +import torch._logging from torch._dynamo.utils import dynamo_timed from torch.testing._internal.common_utils import TemporaryFileName @@ -163,20 +165,41 @@ def fn(x, y, z): ) def test_profiler_dynamo_compiled_region(self): - def fn(x, y, z): - return x @ y + z + torch._logging.set_logs(dynamo=logging.INFO) - opt_fn = torch._dynamo.optimize("eager")(fn) + def fn(x, y): + r = y.sum(dim=1) + print(r.shape) + return x * r - inputs = [torch.rand(4, 4) for _ in range(3)] + fn_c = torch.compile(fn) - for _ in range(2): - opt_fn(*inputs) + with torch.profiler.profile(record_shapes=True) as prof: + fn_c( + torch.randn(10), + torch.randn(10, 10), + ) - with torch.profiler.profile() as prof: - opt_fn(*inputs) + fn_c( + torch.randn(10), + torch.randn(10, 15), + ) - self.assertTrue(any(e.name == "Torch-Compiled Region" for e in prof.events())) + for e in prof.events(): + if e.name == "Torch-Compiled Region": + print(e.kwinputs) + self.assertTrue( + any( + e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "0/0_1" + for e in prof.events() + ) + ) + self.assertTrue( + any( + e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "1/0" + for e in prof.events() + ) + ) if __name__ == "__main__": diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 499a66ed4643d4..6c4e4be3be6378 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import types -from typing import NewType +from typing import NewType, Tuple from torch._dynamo.types import DynamoCallback, DynamoGuardHook @@ -13,6 +13,7 @@ def reset_code(code: types.CodeType) -> None: ... def unsupported(obj1: object, obj2: object) -> object: ... def skip_code(code: types.CodeType) -> None: ... def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... +def set_context_frame(context: Tuple[int, int, int]) -> None: ... class _CacheEntry: def check_fn(self, *args, **kwargs): ... diff --git a/torch/_guards.py b/torch/_guards.py index ccc07735815bf4..012f26c5bb3ba3 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -26,6 +26,7 @@ TypeVar, ) +from torch._C._dynamo.eval_frame import set_context_frame # noqa: F401 from torch.utils import _pytree as pytree from torch.utils._traceback import CapturedTraceback from torch.utils.weak import WeakTensorKeyDictionary @@ -781,6 +782,15 @@ def compile_context(context: Optional[CompileContext]): try: yield context finally: + if context is not None: + if context.compile_id is not None: + set_context_frame( + ( + context.compile_id.frame_id, + context.compile_id.frame_compile_id, + context.attempt, + ) + ) _TLS.compile_context = old_context diff --git a/torch/csrc/dynamo/cpp_shim.cpp b/torch/csrc/dynamo/cpp_shim.cpp index 35c415fe57425a..84c6e0baaf8d95 100644 --- a/torch/csrc/dynamo/cpp_shim.cpp +++ b/torch/csrc/dynamo/cpp_shim.cpp @@ -1,6 +1,5 @@ -#include - #include +#include struct _PytorchRecordFunctionState { at::RecordFunction guard; @@ -14,6 +13,24 @@ _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) { return state; } +static inline _PytorchRecordFunctionState* +_pytorch_record_function_enter_with_kwinputs( + const char* name, + const std::unordered_map* kwargs) { + _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState(); + std::vector args; + state->guard.before(name, &args, kwargs); + return state; +} + +_PytorchRecordFunctionState* _pytorch_record_function_enter_with_context( + const char* name, + const char* context) { + auto map = std::unordered_map(); + map.insert({"context", c10::IValue(context)}); + return _pytorch_record_function_enter_with_kwinputs(name, &map); +} + void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) { if (state == nullptr) { return; diff --git a/torch/csrc/dynamo/cpp_shim.h b/torch/csrc/dynamo/cpp_shim.h index 5baf67805b06c3..b5ec73a3bbfaaa 100644 --- a/torch/csrc/dynamo/cpp_shim.h +++ b/torch/csrc/dynamo/cpp_shim.h @@ -1,5 +1,4 @@ #pragma once - #ifdef __cplusplus extern "C" { #endif @@ -8,6 +7,9 @@ struct _PytorchRecordFunctionState; typedef struct _PytorchRecordFunctionState _PytorchRecordFunctionState; _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name); +_PytorchRecordFunctionState* _pytorch_record_function_enter_with_context( + const char* name, + const char* context); void _pytorch_record_function_exit(_PytorchRecordFunctionState* state); #ifdef __cplusplus diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 35032f07770844..1803a0e57341cb 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -10,9 +10,11 @@ #include #include +#define MAX_COMPILE_CONTEXT_SIZE 100 + PyObject* guard_error_hook = NULL; const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup"; - +static char compile_context[MAX_COMPILE_CONTEXT_SIZE]; static int active_dynamo_threads = 0; static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT; @@ -483,7 +485,8 @@ inline static PyObject* eval_custom_code( PyCodeObject* code, int throw_flag, int free_vars_copied) { - _PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region"); + const char* trace_id = compile_context; + _PytorchRecordFunctionState* rf = _pytorch_record_function_enter_with_context("Torch-Compiled Region", trace_id); PyObject* result = eval_custom_code_impl( tstate, frame, @@ -817,12 +820,27 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) { Py_RETURN_NONE; } +static PyObject* set_context_frame(PyObject* dummy, PyObject* obj) { + int frame_id, frame_compile_id, attempt; + if (!PyArg_ParseTuple(obj, "iii", &frame_id, &frame_compile_id, &attempt)) { + PyErr_SetString(PyExc_TypeError, "Expected three integers"); + return NULL; + } + if (attempt == 0) { + sprintf(compile_context, "%d/%d", frame_id, frame_compile_id); + } else { + sprintf(compile_context, "%d/%d_%d", frame_id, frame_compile_id, attempt); + } + Py_RETURN_NONE; +} + static PyMethodDef _methods[] = { {"set_eval_frame", set_eval_frame_py, METH_O, NULL}, {"reset_code", reset_code, METH_O, NULL}, {"unsupported", unsupported, METH_VARARGS, NULL}, {"skip_code", skip_code, METH_O, NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, + {"set_context_frame", set_context_frame, METH_O, NULL}, {NULL, NULL, 0, NULL}}; static struct PyModuleDef _module = { From 1b3fc91363a9295edc1624e9f5285a8b7338a42a Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 27 Aug 2024 05:43:07 +0000 Subject: [PATCH 0063/1018] [inductor] calibration inductor windows uts (12/N) (#134428) enable Windows inductor UTs for `test/inductor/test_torchinductor_codegen_dynamic_shapes.py` Failed by depends on https://github.com/pytorch/pytorch/pull/134429, need to rebase after https://github.com/pytorch/pytorch/pull/134429 merged. ```cmd 2024-08-25T23:57:23.2747794Z Windows CI does not have necessary dependencies for test_torchinductor_dynamic_shapes yet 2024-08-25T23:57:23.2748541Z Traceback (most recent call last): 2024-08-25T23:57:23.2749593Z File "C:\actions-runner\_work\pytorch\pytorch\test\inductor\test_torchinductor_codegen_dynamic_shapes.py", line 30, in 2024-08-25T23:57:23.2750688Z from inductor.test_torchinductor_dynamic_shapes import ( 2024-08-25T23:57:23.2751877Z File "C:\actions-runner\_work\pytorch\pytorch\test\inductor\test_torchinductor_dynamic_shapes.py", line 46, in 2024-08-25T23:57:23.2752876Z raise unittest.SkipTest("requires sympy/functorch/filelock") 2024-08-25T23:57:23.2753545Z unittest.case.SkipTest: requires sympy/functorch/filelock 2024-08-25T23:57:23.2754077Z Got exit code 1 2024-08-25T23:57:23.2754874Z No stepcurrent file found. Either pytest didn't get to run (e.g. import error) or file got deleted (contact dev infra) ``` Local test pass: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134428 Approved by: https://github.com/jansel --- .../test_torchinductor_codegen_dynamic_shapes.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index e963eabe2571b0..bfadb09344f890 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -2,17 +2,11 @@ import importlib import os import sys -import unittest import torch from torch._inductor.compile_fx import compile_fx from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import ( - IS_CI, - IS_WINDOWS, - TEST_WITH_ASAN, - TEST_WITH_ROCM, -) +from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM from torch.testing._internal.inductor_utils import ( _check_has_dynamic_shape, GPU_TYPE, @@ -21,14 +15,6 @@ ) -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor_codegen_dynamic_shapes yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - importlib.import_module("filelock") # Make the helper files in test/ importable From 136e4435005177642768d6755ba7b4ca520ad48b Mon Sep 17 00:00:00 2001 From: Sathyanarayanan Saravanamuthu Date: Tue, 27 Aug 2024 07:09:39 +0000 Subject: [PATCH 0064/1018] Adding entry-point based support for out-of-tree rendezvous plugins (#132633) Fixes #127519 Currently in torchrun rendezvous, there are only two rendezvous backends supported out of the box: `C10d` and `Etcd`. The changes in this PR enables the distributed elastic users to bring their out-of-tree rendezvous backend implementations as Python packages. #### AUTHORING NEW PLUGIN Any new plugin will be a python package exposing entry-points. For example, the structure of redis plugin is as follows: ``` plugin_root |_ pyproject.toml |_ src |_ redis |_ __init__.py |_ redis_store.py |_ redis_backend.py ``` The contents of the `pyproject.toml` should indicate that this is exposes a torchrun entry-point by mentioning the group name `torchrun.plugins`. The `pyproject.toml` for redis plugin would be as follows: ``` [project] name = "redis" version = "0.0.1" [project.entry-points.'torchrun.plugins'] redis = 'redis' ``` The `src/redis/__init__.py` file would contain functions that return the plugin name and plugin handler. The contents of `__init__.py` for redis would be as follows: ``` def getPluginHandler(): def _create_redis_handler(params: RendezvousParameters): from redis_rendezvous_backend import create_backend backend, store = create_backend(params) return create_handler(store, backend, params) return _create_redis_handler ``` The files `redis_store` and `redis_backend` contain the implementation of [Store](https://github.com/pytorch/pytorch/blob/41189b0da4f793d506c3964d70ea2fe16f06daf6/torch/_C/_distributed_c10d.pyi#L171) and [RendezvousBackend](https://github.com/pytorch/pytorch/blob/e782918b8eeffc465acff40f0c55c8d80bc20ce2/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py#L61) respectively. #### USER EXPERIENCE Before using the plugin for the first time, the user has to install the plugin packages. For example, the published packages can be installed using `pip3 install ` and the plugin is in local file systemcan be installed using `pip3 install -e `. Once installed, the new backend can be used in torchrun as follows: ``` torchrun --rdzv-backend=redis --rdzv-endpoint=redis-container:6379 --nnodes=3 --nproc-per-node=1 --max-restarts=3 --rdzv-id=1 test.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/132633 Approved by: https://github.com/wconstab --- .../rendezvous/out_of_tree_rendezvous_test.py | 38 +++++++++++++++++++ .../out_of_tree_test_package/pyproject.toml | 6 +++ .../src/testbackend/__init__.py | 2 + .../elastic/rendezvous/__init__.py | 3 +- .../elastic/rendezvous/registry.py | 25 ++++++++++++ 5 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py create mode 100644 test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml create mode 100644 test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py diff --git a/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py b/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py new file mode 100644 index 00000000000000..4d304bef1bc234 --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py @@ -0,0 +1,38 @@ +# Owner(s): ["oncall: r2p"] + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import pathlib +import sys +import unittest + +import torch.distributed.elastic.rendezvous as rdvz + + +BACKEND_NAME = "testbackend" +TEST_PACKAGE_PATH = "/out_of_tree_test_package/src" + + +class OutOfTreeRendezvousTest(unittest.TestCase): + def test_out_of_tree_handler_loading(self): + current_path = str(pathlib.Path(__file__).parent.resolve()) + rdvz._register_out_of_tree_handlers() + registry_dict = rdvz.rendezvous_handler_registry._registry + + # test backend should not be registered as a backend + self.assertFalse(BACKEND_NAME in registry_dict) + + # Including testbackend in python path + sys.path.append(current_path + TEST_PACKAGE_PATH) + + # Registering the out of tree handlers again + rdvz._register_out_of_tree_handlers() + + # test backend should be registered as a backend + self.assertTrue(BACKEND_NAME in registry_dict) + + # Removing testbackend from python path + sys.path.remove(current_path + TEST_PACKAGE_PATH) diff --git a/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml b/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml new file mode 100644 index 00000000000000..32e177bf73fe6b --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml @@ -0,0 +1,6 @@ +[project] +name = "testbackend" +version = "0.0.1" + +[project.entry-points.'torchrun.handlers'] +testbackend = 'testbackend:test_handler' \ No newline at end of file diff --git a/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py b/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py new file mode 100644 index 00000000000000..1fbf5b4c2dfb10 --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py @@ -0,0 +1,2 @@ +def test_handler(): + return "" diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index 62a31adab27b01..22ec0c9a0f6758 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -143,10 +143,11 @@ class that implements the rendezvous mechanism described above. It is a backend- RendezvousStoreInfo, RendezvousTimeoutError, ) -from .registry import _register_default_handlers +from .registry import _register_default_handlers, _register_out_of_tree_handlers _register_default_handlers() +_register_out_of_tree_handlers() __all__ = [ diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index 1a91d0a8ff7946..d038ab95eabae6 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +import sys + from .api import ( rendezvous_handler_registry as handler_registry, RendezvousHandler, @@ -12,6 +15,13 @@ from .dynamic_rendezvous import create_handler +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +log = logging.getLogger(__name__) + __all__ = ["get_rendezvous_handler"] @@ -50,6 +60,21 @@ def _register_default_handlers() -> None: handler_registry.register("static", _create_static_handler) +def _register_out_of_tree_handlers() -> None: + discovered_handler_generators = entry_points(group="torchrun.handlers") + + for handler_generator in discovered_handler_generators: + try: + get_handler = discovered_handler_generators[handler_generator.name].load() + handler_registry.register(handler_generator.name, get_handler()) + except Exception: + log.warning( + "Exception while registering out of tree plugin %s: ", + handler_generator.name, + exc_info=True, + ) + + def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: """ Obtain a reference to a :py:class`RendezvousHandler`. From 858c70f1b1a3f8a6f03ddbb5398d61fc5fa8ae45 Mon Sep 17 00:00:00 2001 From: wizzniu Date: Tue, 27 Aug 2024 07:13:34 +0000 Subject: [PATCH 0065/1018] Fix TypeError when itering NoneType in instantiate_device_type_tests() (#134457) Fixes #134454 Fix TypeError introduced by https://github.com/pytorch/pytorch/pull/133082, which uses iter for NoneType of default args ``except_for`` and ``only_for``. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134457 Approved by: https://github.com/shink, https://github.com/albanD --- torch/testing/_internal/common_device_type.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 379a0c24e1105e..eeaa31fad1bc25 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -750,12 +750,16 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo # Replace your privateuse1 backend name with 'privateuse1' if is_privateuse1_backend_available(): privateuse1_backend_name = torch._C._get_privateuse1_backend_name() - except_for = [ - "privateuse1" if x == privateuse1_backend_name else x for x in except_for - ] - only_for = [ - "privateuse1" if x == privateuse1_backend_name else x for x in only_for - ] + except_for = ( + ["privateuse1" if x == privateuse1_backend_name else x for x in except_for] + if except_for is not None + else None + ) + only_for = ( + ["privateuse1" if x == privateuse1_backend_name else x for x in only_for] + if only_for is not None + else None + ) if except_for: device_type_test_bases = filter( From e207d36a7c899ea12d49bc8d49a181fccc692740 Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Tue, 27 Aug 2024 08:21:39 +0000 Subject: [PATCH 0066/1018] Fix the warning for cat operators with same qparams (#133999) Summary: Currently the warning is printed when the cat inputs have same qparam, leading to a flood of warnings. This diff emits the warning only when cat inputs don't have the same qparam. Test Plan: CI Reviewed By: aprotopopov Differential Revision: D60638609 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133999 Approved by: https://github.com/tarun292 --- aten/src/ATen/native/quantized/cpu/TensorShape.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/TensorShape.cpp b/aten/src/ATen/native/quantized/cpu/TensorShape.cpp index 5fcdd22c034a5f..33334bef8f4346 100644 --- a/aten/src/ATen/native/quantized/cpu/TensorShape.cpp +++ b/aten/src/ATen/native/quantized/cpu/TensorShape.cpp @@ -163,10 +163,11 @@ Tensor cat_quantized_cpu(const ITensorListRef& qxs, int64_t dim) { TORCH_CHECK(is_valid_quantization_scheme(materialized[0]), "Only per-tensor quantization is supported in 'cat'!"); - if (all_inputs_sharing_qparams(materialized)) { + if (!all_inputs_sharing_qparams(materialized)) { // TODO: if possible change this warning to an error T194501002 TORCH_WARN("All inputs of this cat operator must share the same quantization parameters. Otherwise large numerical inaccuracies may occur."); - } check_cat_no_zero_dim(materialized); + } + check_cat_no_zero_dim(materialized); dim = legacy_cat_wrap_dim(dim, materialized); double _scale = materialized[0].get().q_scale(); int64_t _zero_point = materialized[0].get().q_zero_point(); From f2f26142e8f16116a81408eeee8427397d029110 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Tue, 27 Aug 2024 08:51:42 +0000 Subject: [PATCH 0067/1018] Remove explicit Amz2023 reference from jobs (#134355) Changes jobs to go back to using the default AMI. Note: This is only a cleanup PR. It does NOT introduce any behavior changes in CI Now that the default variant uses the Amazon 2023 AMI and has been shown to be stable for a week, it's time to remove the explicit amz2023 references and go back to using the default variant. After a week or two, when this is rolled out to most people, we can remove the variants from scale config as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134355 Approved by: https://github.com/jeanschmidt --- .../linux_binary_build_workflow.yml.j2 | 10 +- ...linux-aarch64-binary-manywheel-nightly.yml | 12 -- .../generated-linux-binary-conda-nightly.yml | 64 +++--- ...d-linux-binary-libtorch-cxx11-abi-main.yml | 4 +- ...inux-binary-libtorch-cxx11-abi-nightly.yml | 20 +- ...d-linux-binary-libtorch-pre-cxx11-main.yml | 4 +- ...inux-binary-libtorch-pre-cxx11-nightly.yml | 20 +- .../generated-linux-binary-manywheel-main.yml | 24 +-- ...nerated-linux-binary-manywheel-nightly.yml | 190 +++++++++--------- .github/workflows/inductor.yml | 140 ++++++------- .github/workflows/lint.yml | 14 +- .github/workflows/nightly.yml | 4 +- .github/workflows/periodic.yml | 56 +++--- .github/workflows/pull.yml | 179 ++++++++--------- .github/workflows/slow.yml | 36 ++-- .github/workflows/trunk.yml | 50 ++--- 16 files changed, 405 insertions(+), 422 deletions(-) diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index 1cccee3e234741..918151b486d661 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -68,17 +68,16 @@ jobs: needs: get-label-type with:!{{ upload.binary_env_as_input(config) }} {%- if "aarch64" in build_environment %} - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" {%- elif "s390x" in build_environment %} runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" {%- elif "conda" in build_environment and config["gpu_arch_type"] == "cuda" %} - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral {%- else %} - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" {%- endif %} build_name: !{{ config["build_name"] }} build_environment: !{{ build_environment }} @@ -103,7 +102,6 @@ jobs: build_name: !{{ config["build_name"] }} build_environment: !{{ build_environment }} {%- if "aarch64" in build_environment %} - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" {%- elif "s390x" in build_environment %} @@ -112,10 +110,10 @@ jobs: {%- elif config["gpu_arch_type"] == "rocm" %} runs_on: linux.rocm.gpu {%- elif config["gpu_arch_type"] == "cuda" %} - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu {%- else %} - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge {%- endif %} secrets: diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 90923ffeb912c0..3ce438c9bc5f09 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -59,7 +59,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.9" - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cpu-aarch64 @@ -85,7 +84,6 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: @@ -128,7 +126,6 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.9" - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cuda-aarch64 @@ -174,7 +171,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.10" - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cpu-aarch64 @@ -200,7 +196,6 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: @@ -243,7 +238,6 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.10" - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64 @@ -289,7 +283,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.11" - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cpu-aarch64 @@ -315,7 +308,6 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: @@ -358,7 +350,6 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.11" - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64 @@ -404,7 +395,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.12" - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cpu-aarch64 @@ -430,7 +420,6 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: @@ -473,7 +462,6 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.12" - runner_prefix: amz2023. runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64 diff --git a/.github/workflows/generated-linux-binary-conda-nightly.yml b/.github/workflows/generated-linux-binary-conda-nightly.yml index b8a90c520abfa1..e4451fb1f9b74b 100644 --- a/.github/workflows/generated-linux-binary-conda-nightly.yml +++ b/.github/workflows/generated-linux-binary-conda-nightly.yml @@ -59,7 +59,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: conda-py3_9-cpu build_environment: linux-binary-conda secrets: @@ -82,7 +82,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cpu build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -124,7 +124,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_9-cuda11_8 build_environment: linux-binary-conda @@ -149,7 +149,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda11_8 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -192,7 +192,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_9-cuda12_1 build_environment: linux-binary-conda @@ -217,7 +217,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda12_1 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -260,7 +260,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_9-cuda12_4 build_environment: linux-binary-conda @@ -285,7 +285,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda12_4 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -327,7 +327,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: conda-py3_10-cpu build_environment: linux-binary-conda secrets: @@ -350,7 +350,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cpu build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -392,7 +392,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_10-cuda11_8 build_environment: linux-binary-conda @@ -417,7 +417,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda11_8 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -460,7 +460,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_10-cuda12_1 build_environment: linux-binary-conda @@ -485,7 +485,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda12_1 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -528,7 +528,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_10-cuda12_4 build_environment: linux-binary-conda @@ -553,7 +553,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda12_4 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -595,7 +595,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: conda-py3_11-cpu build_environment: linux-binary-conda secrets: @@ -618,7 +618,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cpu build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -660,7 +660,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_11-cuda11_8 build_environment: linux-binary-conda @@ -685,7 +685,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda11_8 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -728,7 +728,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_11-cuda12_1 build_environment: linux-binary-conda @@ -753,7 +753,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda12_1 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -796,7 +796,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_11-cuda12_4 build_environment: linux-binary-conda @@ -821,7 +821,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda12_4 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -863,7 +863,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: conda-py3_12-cpu build_environment: linux-binary-conda secrets: @@ -886,7 +886,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cpu build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -928,7 +928,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_12-cuda11_8 build_environment: linux-binary-conda @@ -953,7 +953,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda11_8 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -996,7 +996,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_12-cuda12_1 build_environment: linux-binary-conda @@ -1021,7 +1021,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda12_1 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1064,7 +1064,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_12-cuda12_4 build_environment: linux-binary-conda @@ -1089,7 +1089,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda12_4 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml index 748d20503d4eac..ad1098bf7d1702 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml @@ -55,7 +55,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cpu-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -79,7 +79,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml index 552c1e042b13cd..f4c344cbdbc99b 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml @@ -60,7 +60,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cpu-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -84,7 +84,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -128,7 +128,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda11_8-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -153,7 +153,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda11_8-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -198,7 +198,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda12_1-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -223,7 +223,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_1-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -268,7 +268,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -293,7 +293,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -338,7 +338,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-rocm6_1-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -448,7 +448,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.2-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-rocm6_2-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml index 61847b43eaadfc..06c26961e9894b 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml @@ -55,7 +55,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -79,7 +79,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml index e9f99382773c87..23580f30881442 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml @@ -60,7 +60,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -84,7 +84,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -128,7 +128,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda11_8-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -153,7 +153,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda11_8-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -198,7 +198,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda12_1-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -223,7 +223,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_1-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -268,7 +268,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -293,7 +293,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -338,7 +338,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-rocm6_1-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -448,7 +448,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-rocm6_2-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 15ae6d41909cf6..298c0b5e105ccb 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -55,7 +55,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -80,7 +80,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -101,7 +101,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main use_split_build: True DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -127,7 +127,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -147,7 +147,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -172,7 +172,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -193,7 +193,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main use_split_build: True DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -219,7 +219,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -239,7 +239,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -264,7 +264,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -285,7 +285,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main use_split_build: True DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -311,7 +311,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index c769a302d6a604..e88daa947a3d19 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -59,7 +59,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel secrets: @@ -82,7 +82,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -124,7 +124,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -148,7 +148,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -191,7 +191,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -216,7 +216,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -260,7 +260,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main use_split_build: True DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -286,7 +286,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -330,7 +330,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -355,7 +355,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -399,7 +399,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main use_split_build: True DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -425,7 +425,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -469,7 +469,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -494,7 +494,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -538,7 +538,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main use_split_build: True DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -564,7 +564,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -608,7 +608,7 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_1 build_environment: linux-binary-manywheel secrets: @@ -715,7 +715,7 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_2 build_environment: linux-binary-manywheel secrets: @@ -821,7 +821,7 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-xpu build_environment: linux-binary-manywheel secrets: @@ -934,7 +934,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel secrets: @@ -957,7 +957,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -999,7 +999,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -1023,7 +1023,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1066,7 +1066,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -1091,7 +1091,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1135,7 +1135,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main use_split_build: True DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda11_8-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -1161,7 +1161,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1205,7 +1205,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -1230,7 +1230,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1274,7 +1274,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main use_split_build: True DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_1-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -1300,7 +1300,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1345,7 +1345,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_1-full build_environment: linux-binary-manywheel secrets: @@ -1370,7 +1370,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1-full build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1414,7 +1414,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -1439,7 +1439,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1483,7 +1483,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main use_split_build: True DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_4-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -1509,7 +1509,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1553,7 +1553,7 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_1 build_environment: linux-binary-manywheel secrets: @@ -1660,7 +1660,7 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_2 build_environment: linux-binary-manywheel secrets: @@ -1766,7 +1766,7 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-xpu build_environment: linux-binary-manywheel secrets: @@ -1879,7 +1879,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel secrets: @@ -1902,7 +1902,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1944,7 +1944,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -1968,7 +1968,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2011,7 +2011,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2036,7 +2036,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2080,7 +2080,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main use_split_build: True DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda11_8-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2106,7 +2106,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2150,7 +2150,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2175,7 +2175,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2219,7 +2219,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main use_split_build: True DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_1-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2245,7 +2245,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2289,7 +2289,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2314,7 +2314,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2358,7 +2358,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main use_split_build: True DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_4-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2384,7 +2384,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2428,7 +2428,7 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_1 build_environment: linux-binary-manywheel secrets: @@ -2535,7 +2535,7 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_2 build_environment: linux-binary-manywheel secrets: @@ -2641,7 +2641,7 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-xpu build_environment: linux-binary-manywheel secrets: @@ -2754,7 +2754,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel secrets: @@ -2777,7 +2777,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2819,7 +2819,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -2843,7 +2843,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2886,7 +2886,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2911,7 +2911,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2955,7 +2955,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main use_split_build: True DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda11_8-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2981,7 +2981,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3025,7 +3025,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3050,7 +3050,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3094,7 +3094,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main use_split_build: True DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_1-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3120,7 +3120,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3164,7 +3164,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3189,7 +3189,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3233,7 +3233,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main use_split_build: True DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_4-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3259,7 +3259,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3303,7 +3303,7 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_1 build_environment: linux-binary-manywheel secrets: @@ -3410,7 +3410,7 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_2 build_environment: linux-binary-manywheel secrets: @@ -3516,7 +3516,7 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-xpu build_environment: linux-binary-manywheel secrets: @@ -3629,7 +3629,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu build_environment: linux-binary-manywheel secrets: @@ -3652,7 +3652,7 @@ jobs: DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3694,7 +3694,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -3718,7 +3718,7 @@ jobs: DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3761,7 +3761,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3786,7 +3786,7 @@ jobs: DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3830,7 +3830,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main use_split_build: True DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda11_8-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3856,7 +3856,7 @@ jobs: DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda11_8-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3900,7 +3900,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3925,7 +3925,7 @@ jobs: DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3969,7 +3969,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main use_split_build: True DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_1-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3995,7 +3995,7 @@ jobs: DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_1-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -4039,7 +4039,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -4064,7 +4064,7 @@ jobs: DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -4108,7 +4108,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main use_split_build: True DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_4-split build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -4134,7 +4134,7 @@ jobs: DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_4-split build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -4177,7 +4177,7 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-xpu build_environment: linux-binary-manywheel secrets: diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index c9987f0f130bd3..d4c129529675e6 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -35,28 +35,28 @@ jobs: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.12xlarge.nvidia.gpu" }, - { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" }, + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} @@ -80,11 +80,11 @@ jobs: build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_12-gcc9-inductor-test: @@ -103,10 +103,10 @@ jobs: with: build-environment: linux-jammy-py3.12-gcc11 docker-image-name: pytorch-linux-jammy-py3.12-halide - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor-halide", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, + { config: "inductor-halide", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, ]} linux-jammy-cpu-py3_12-inductor-halide-test: @@ -128,11 +128,11 @@ jobs: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} @@ -156,50 +156,50 @@ jobs: with: build-environment: linux-jammy-py3.8-gcc11-build docker-image-name: pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_avx512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "inductor_avx512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_huggingface_amp_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "cpu_inductor_timm_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "cpu_inductor_timm_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "cpu_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "cpu_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.24xl.spr-metal" }, - { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_huggingface_freezing_avx2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_torchbench_freezing_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_torchbench_freezing_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_timm_freezing_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_timm_freezing_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, + { config: "inductor_avx512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "inductor_avx512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_huggingface_amp_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_timm_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_timm_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, + { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_huggingface_freezing_avx2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_torchbench_freezing_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_torchbench_freezing_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_timm_freezing_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_timm_freezing_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c61a7703e7c22c..f86d7270612017 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -28,7 +28,7 @@ jobs: needs: get-label-type with: timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed @@ -45,7 +45,7 @@ jobs: needs: get-label-type with: timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed @@ -60,7 +60,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -95,7 +95,7 @@ jobs: pr-sanity-checks: name: pr-sanity-checks needs: get-label-type - runs-on: [self-hosted, "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.large"] + runs-on: [self-hosted, "${{ needs.get-label-type.outputs.label-type }}linux.large"] # Only run this on pull requests. This check is simple enough to be done without a Docker image if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'skip-pr-sanity-checks') steps: @@ -116,7 +116,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-focal-linter fetch-depth: -1 submodules: true @@ -153,7 +153,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -192,7 +192,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 35f7a3ce116881..cef5841fa9f219 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -31,7 +31,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" build-environment: linux-jammy-py3.8-gcc11 docker-image-name: pytorch-linux-jammy-py3.8-gcc11 @@ -42,7 +42,7 @@ jobs: - docs-build - get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.8-gcc11 docker-image: ${{ needs.docs-build.outputs.docker-image }} push: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || startsWith(github.event.ref, 'refs/tags/v') }} diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 714f1c5d641df8..b12e478518d4c5 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -52,14 +52,14 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-test: name: linux-focal-cuda12.1-py3.10-gcc9 @@ -77,19 +77,19 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.4-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_4-py3_10-gcc9-test: @@ -109,14 +109,14 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: parallelnative-linux-jammy-py3.8-gcc11 docker-image-name: pytorch-linux-jammy-py3.8-gcc11 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} parallelnative-linux-jammy-py3_8-gcc11-test: @@ -135,7 +135,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda11.8-py3.9-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 cuda-arch-list: 8.6 @@ -159,7 +159,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda11.8-py3.10-gcc9-debug docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 build-with-debug: true @@ -250,7 +250,6 @@ jobs: name: buck-build-test uses: ./.github/workflows/_buck-build-test.yml with: - runner_prefix: "amz2023." test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1, runner: "ubuntu-latest" }, @@ -260,7 +259,6 @@ jobs: name: android-emulator-build-test uses: ./.github/workflows/_run_android_tests.yml with: - runner_prefix: "amz2023." test-matrix: | { include: [ { config: 'default', @@ -278,12 +276,12 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-vulkan-focal-py3.11-clang10 docker-image-name: pytorch-linux-focal-py3.11-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} linux-vulkan-focal-py3_11-clang10-test: @@ -300,7 +298,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-rocm6.1-py3.8 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | @@ -329,15 +327,15 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true build-environment: linux-focal-cuda12.1-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build-test: @@ -357,7 +355,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true build-environment: linux-focal-cuda11.8-py3.9-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 32c9a4837e5cd2..ed77a706c2ee89 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -48,20 +48,20 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.8-gcc11 docker-image-name: pytorch-linux-jammy-py3.8-gcc11 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -82,7 +82,6 @@ jobs: uses: ./.github/workflows/_docs.yml needs: linux-jammy-py3_8-gcc11-build with: - runner_prefix: amz2023. build-environment: linux-jammy-py3.8-gcc11 docker-image: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.docker-image }} secrets: inherit @@ -92,7 +91,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.8-gcc11-no-ops docker-image-name: pytorch-linux-jammy-py3.8-gcc11 test-matrix: | @@ -106,7 +105,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.8-gcc11-pch docker-image-name: pytorch-linux-jammy-py3.8-gcc11 test-matrix: | @@ -120,17 +119,17 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-clang15-asan docker-image-name: pytorch-linux-jammy-py3-clang15-asan test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, + { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} sync-tag: asan-build secrets: inherit @@ -154,13 +153,13 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3.8-clang10-onnx docker-image-name: pytorch-linux-focal-py3-clang10-onnx test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -181,20 +180,20 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" build-environment: linux-focal-py3.8-clang10 docker-image-name: pytorch-linux-focal-py3.8-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} linux-focal-py3_8-clang10-test: name: linux-focal-py3.8-clang10 @@ -213,20 +212,20 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3.11-clang10 docker-image-name: pytorch-linux-focal-py3.11-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -247,18 +246,18 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3.12-clang10 docker-image-name: pytorch-linux-focal-py3.12-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -278,14 +277,14 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda11.8-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -307,16 +306,16 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -338,7 +337,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3-clang12-mobile-build docker-image-name: pytorch-linux-jammy-py3-clang15-asan build-generates-artifacts: false @@ -353,7 +352,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12 test-matrix: | @@ -367,7 +366,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3-clang9-mobile-custom-build-static docker-image-name: pytorch-linux-focal-py3-clang9-android-ndk-r21e build-generates-artifacts: false @@ -382,12 +381,12 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3.8-clang9-xla docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.1-lite test-matrix: | { include: [ - { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, + { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, ]} secrets: inherit @@ -425,13 +424,13 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.large" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.large" build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: cpu test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, + { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} secrets: inherit @@ -440,13 +439,13 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.large" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.large" build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: "12.1" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -455,13 +454,13 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.large" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.large" build-environment: linux-focal-cuda12.4-py3.10-gcc9-bazel-test docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-version: "12.4" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -494,7 +493,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.8-gcc111-mobile-lightweight-dispatch-build docker-image-name: pytorch-linux-jammy-py3.8-gcc11 build-generates-artifacts: false @@ -511,7 +510,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-rocm6.1-py3.8 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build @@ -528,17 +527,17 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -559,12 +558,12 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3-clang12-executorch docker-image-name: pytorch-linux-jammy-py3-clang12-executorch test-matrix: | { include: [ - { config: "executorch", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "executorch", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -583,18 +582,18 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: True build-environment: linux-focal-py3.12-clang10 docker-image-name: pytorch-linux-focal-py3.12-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 3, runner: "linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 3, runner: "linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 3, runner: "linux.2xlarge" }, + { config: "dynamo", shard: 1, num_shards: 3, runner: "linux.2xlarge" }, + { config: "dynamo", shard: 2, num_shards: 3, runner: "linux.2xlarge" }, + { config: "dynamo", shard: 3, num_shards: 3, runner: "linux.2xlarge" }, ]} secrets: inherit @@ -614,7 +613,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm75 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '7.5' diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 5f7179b6303267..932e15fe1230a2 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -50,18 +50,18 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3-gcc9-slow-gradcheck docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3-gcc9-slow-gradcheck-test: @@ -81,14 +81,14 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-sm86-test: @@ -107,13 +107,13 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3.8-clang10 docker-image-name: pytorch-linux-focal-py3.8-clang10 test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} linux-focal-py3_8-clang10-test: @@ -132,7 +132,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-rocm6.1-py3.8 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | @@ -160,14 +160,14 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-clang15-asan docker-image-name: pytorch-linux-jammy-py3-clang15-asan test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, + { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} sync-tag: asan-build secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index b9361e0d2eed16..8696ba99957258 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -48,17 +48,17 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_4-py3_10-gcc9-sm86-test: @@ -80,7 +80,7 @@ jobs: build-environment: libtorch-linux-focal-cuda12.1-py3.7-gcc9 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 build-generates-artifacts: false - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: "linux.4xlarge" test-matrix: | { include: [ @@ -93,7 +93,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-no-ops docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | @@ -109,7 +109,7 @@ jobs: build-environment: libtorch-linux-focal-cuda12.4-py3.7-gcc9 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 build-generates-artifacts: false - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: "linux.4xlarge" test-matrix: | { include: [ @@ -122,7 +122,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.4-py3.10-gcc9-no-ops docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | @@ -138,7 +138,7 @@ jobs: docker-image-name: pytorch-linux-focal-py3-clang9-android-ndk-r21e test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, ]} macos-py3-arm64-build: @@ -228,7 +228,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-rocm6.1-py3.8 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build @@ -260,20 +260,20 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true build-environment: linux-focal-cuda12.4-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build-test: @@ -292,15 +292,15 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true build-environment: linux-focal-cuda11.8-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, ]} linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build-test: From eb8f9220c437dd2c3397640f5e811780bf454139 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 26 Aug 2024 18:35:51 -0700 Subject: [PATCH 0068/1018] [FlexAttention] Create new variables for the subgraphs (#134507) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134507 Approved by: https://github.com/yanboliang, https://github.com/BoyuanFeng ghstack dependencies: #134495 --- test/inductor/test_flex_attention.py | 19 +++++++++++++++++++ torch/_inductor/kernel/flex_attention.py | 15 +++++++++------ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index ecc2f7d8e2978a..0a1fe403bdd0cf 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1641,6 +1641,25 @@ def test_force_write_lse(self): torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) + @supported_platform + def test_small_q_kv_len(self): + make_tensor = functools.partial( + torch.ones, + (1, 1, 1, 16), + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} + out_eager, lse_eager = flex_attention(query, key, value, return_lse=True, kernel_options=kernel_options) + + flex_compile = torch.compile(flex_attention, fullgraph=True) + out_compiled, lse_compiled = flex_compile(query, key, value, return_lse=True, kernel_options=kernel_options) + + assert torch.equal(out_eager, out_compiled) + assert torch.equal(lse_eager, lse_compiled) + @supported_platform def test_causal_block_non_divisible_with_captured_buffer(self): Q_S = S - 3 diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index de7ff7ee913610..71b59058e16dfc 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -439,8 +439,11 @@ def forward_block_mn( # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, # which is larger than the actual number of elements. To avoid access memory out of bound, # we need to mask out the elements that are out of Q_LEN & KV_LEN. - offs_m = offs_m % Q_LEN - offs_n = offs_n % KV_LEN + m = offs_m % Q_LEN + n = offs_n % KV_LEN + else: + m = offs_m + n = offs_n {{ modification( subgraph_number=0, @@ -448,8 +451,8 @@ def forward_block_mn( score="qk", b="off_z", h="off_h", - m="offs_m", - n="offs_n", + m="m", + n="n", out="qk" ) | indent_except_first(1) }} @@ -464,8 +467,8 @@ def forward_block_mn( score="qk", b="off_z", h="off_h", - m="offs_m", - n="offs_n", + m="m", + n="n", ) | indent_except_first(2) }} if CHECK_BLOCK_BOUNDARY: From daa3f9902bacec28fff1e983c97dcc9350c8fa85 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 26 Aug 2024 18:35:51 -0700 Subject: [PATCH 0069/1018] [FlexAttention] Remove unused code (#134511) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134511 Approved by: https://github.com/yanboliang ghstack dependencies: #134495, #134507 --- test/inductor/test_flex_attention.py | 7 ++++++- torch/_inductor/kernel/flex_attention.py | 3 --- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 0a1fe403bdd0cf..6e42158ef4ab11 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1648,7 +1648,7 @@ def test_small_q_kv_len(self): (1, 1, 1, 16), device="cuda", dtype=torch.float32, - requires_grad=False, + requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} @@ -1660,6 +1660,11 @@ def test_small_q_kv_len(self): assert torch.equal(out_eager, out_compiled) assert torch.equal(lse_eager, lse_compiled) + grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value)) + grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value)) + + torch.testing.assert_close(grads_eager, grads_compile) + @supported_platform def test_causal_block_non_divisible_with_captured_buffer(self): Q_S = S - 3 diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 71b59058e16dfc..8e90eacaab72b5 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1606,9 +1606,6 @@ def flex_attention_backward(*args, **kwargs): ] ) - if _use_flex_decoding(query, kernel_options): - raise NotImplementedError("Flex decoding backward pass is not implemented. ") - device = query.get_device() dtype = query.get_dtype() Bq, Hq, seq_len_q, qk_head_dim = query.get_size() From f0702bc5cbd3507de5b2f790eb0b91ec9a7a37ef Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 26 Aug 2024 18:35:51 -0700 Subject: [PATCH 0070/1018] [FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134538 Approved by: https://github.com/yanboliang ghstack dependencies: #134495, #134507, #134511 --- test/inductor/test_flex_decoding.py | 38 +++++++++++++++++++++++-- torch/_inductor/kernel/flex_decoding.py | 4 +-- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 3cf3981b0f1782..bc41ff2fe8b16a 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -14,6 +14,7 @@ from torch.nn.attention.flex_attention import ( _create_empty_block_mask, _identity, + BlockMask, create_block_mask, flex_attention, ) @@ -253,7 +254,7 @@ def _check_out( def run_test( self, - score_mod: Callable, + score_mod: Optional[Callable], dtype: torch.dtype = torch.float16, Q_B: int = B, Q_H: int = Hq, @@ -263,7 +264,11 @@ def run_test( KV_H: int = Hkv, KV_S: int = S, V_D: int = D, + block_mask: Optional[BlockMask] = None, ): + assert ( + score_mod is not None or block_mask is not None + ), "Must provide score_mod or block_mask" assert Q_H % KV_H == 0 q = torch.randn( (Q_B, Q_H, Q_S, Q_D), @@ -280,7 +285,6 @@ def run_test( q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) - block_mask = None sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) @@ -943,6 +947,36 @@ def func(q, k, v, score_mod): code[0] ) + @supported_platform + def test_non_sparse_mulitple_block_size(self): + def generate_causal_offset(offset: torch.Tensor): + def causal_offset_mask(b, h, q_idx, kv_idx): + return (offset + q_idx) >= kv_idx + + return causal_offset_mask + + def noop(score, b, h, q_idx, kv_idx): + return score + + mod = generate_causal_offset( + torch.tensor(192, device="cuda", dtype=torch.int32) + ) + block_mask = create_block_mask(mod, 1, 1, 1, 65) + + self.run_test( + score_mod=None, + dtype=torch.float32, + block_mask=block_mask, + Q_B=1, + Q_H=1, + Q_S=1, + Q_D=16, + KV_B=1, + KV_H=1, + KV_S=65, + V_D=16, + ) + @supported_platform def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self): torch._dynamo.reset() diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 3520b618ef0a67..870e6194aec1e1 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -161,7 +161,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me # first kv block we're loading # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K + k_offset, @@ -207,7 +207,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K + k_offset, From c4c140cc7bb0c91453c528615088c44b5a2b5da4 Mon Sep 17 00:00:00 2001 From: FEI Date: Tue, 27 Aug 2024 11:10:43 +0000 Subject: [PATCH 0071/1018] Provide default value None for the attn_bias parameter(#133981) (#133986) Fixes #133981 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133986 Approved by: https://github.com/ezyang --- aten/src/ATen/native/native_functions.yaml | 2 +- tools/autograd/derivatives.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e46cdec768cdff..1f31c5cd198e6a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14763,7 +14763,7 @@ CPU: _scaled_dot_product_flash_attention_cpu tags: nondeterministic_seeded -- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) dispatch: CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable tags: nondeterministic_seeded diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 96c55c666b678c..9f7ea3fbeb4ff4 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2871,7 +2871,7 @@ output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) -- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) From c86074ac4ae290142e0bbe611af1d1c37f4d3fdf Mon Sep 17 00:00:00 2001 From: xinyu-intel Date: Tue, 27 Aug 2024 11:34:35 +0000 Subject: [PATCH 0072/1018] Implementation for rng ops on hpu and xpu (#133068) implementation for high_order_op::run_and_save_rng_state and high_order_op::run_with_rng_state on hpu Pull Request resolved: https://github.com/pytorch/pytorch/pull/133068 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel, https://github.com/anijain2305 --- test/inductor/test_torchinductor.py | 1 - torch/_prims/rng_prims.py | 46 +++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 767d68dfdbb835..512d606920df2d 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -7951,7 +7951,6 @@ def g(x): out = [torch.empty_like(x) for _ in range(2)] y = g(x) - @expectedFailureXPU def test_functionalize_rng_wrappers(self): # Ideally, we would like to use torch.compile for these operators. But # currently the plan is to introduce these operators at the partitioner diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index f94a0d06ad6b4b..a10106cacc9845 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -142,6 +142,10 @@ def get_device(args, kwargs): devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)} if any(dev == "cuda" for dev in devices): return "cuda" + elif any(dev == "xpu" for dev in devices): + return "xpu" + elif any(dev == "hpu" for dev in devices): + return "hpu" elif any(dev == "cpu" for dev in devices): return "cpu" return None @@ -166,9 +170,24 @@ def impl_cuda(op, *args, **kwargs): def impl_cpu(op, *args, **kwargs): return torch.get_rng_state(), op(*args, **kwargs) + @run_and_save_rng_state.py_impl(DispatchKey.HPU) + def impl_hpu(op, *args, **kwargs): + if hasattr(torch, "hpu"): + return torch.hpu.get_rng_state(), op(*args, **kwargs) + raise RuntimeError("functionalize a hpu RNG operator is not supported.") + + @run_and_save_rng_state.py_impl(DispatchKey.XPU) + def impl_xpu(op, *args, **kwargs): + return torch.xpu.get_rng_state(), op(*args, **kwargs) + @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect) def impl_backend_select(op, *args, **kwargs): - impl_map = {"cuda": impl_cuda, "cpu": impl_cpu} + impl_map = { + "cuda": impl_cuda, + "cpu": impl_cpu, + "hpu": impl_hpu, + "xpu": impl_xpu, + } device = get_device(args, kwargs) assert device in impl_map, f"Backend not supported for {device}" impl = impl_map[device] @@ -220,6 +239,24 @@ def impl_cpu(rng_state, op, *args, **kwargs): torch.set_rng_state(current_state) return out + @run_with_rng_state.py_impl(DispatchKey.HPU) + def impl_hpu(rng_state, op, *args, **kwargs): + if hasattr(torch, "hpu"): + current_state = torch.hpu.get_rng_state() + torch.hpu.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.hpu.set_rng_state(current_state) + return out + raise RuntimeError("functionalize a hpu RNG operator is not supported.") + + @run_with_rng_state.py_impl(DispatchKey.XPU) + def impl_xpu(rng_state, op, *args, **kwargs): + current_state = torch.xpu.get_rng_state() + torch.xpu.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.xpu.set_rng_state(current_state) + return out + @run_with_rng_state.py_impl(ProxyTorchDispatchMode) def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs): # TODO: you don't need to do this, the dispatch here already disabled @@ -235,7 +272,12 @@ def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs): @run_with_rng_state.py_impl(DispatchKey.BackendSelect) def impl_backend_select(rng_state, op, *args, **kwargs): - impl_map = {"cuda": impl_cuda, "cpu": impl_cpu} + impl_map = { + "cuda": impl_cuda, + "cpu": impl_cpu, + "hpu": impl_hpu, + "xpu": impl_xpu, + } device = get_device(args, kwargs) assert device in impl_map, f"Backend not supported for {device}" impl = impl_map[device] From b60fa8dbc444e09fc0778b5fc35a19f02871d781 Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 26 Aug 2024 17:57:44 -0700 Subject: [PATCH 0073/1018] Update partitioner's is_fusible heuristic to respect auto_functionalized (#134490) We say Node a is fusible into node b if node b is an auto_functionalized node that may reinplace node a later on. This PR also changes aten.empty to be recomputable w.r.t the Partitioner (it is, like aten.zeros, cheap to recompute and fusible into other ops). Fixes https://github.com/pytorch/pytorch/issues/134468 Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/134490 Approved by: https://github.com/Chillee ghstack dependencies: #134364, #134466 --- test/inductor/test_inplacing_pass.py | 49 +++++++++++++++++++++++++--- test/inductor/test_perf.py | 29 ++++++++++++++++ torch/_functorch/partitioners.py | 20 ++++++++++++ 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_inplacing_pass.py b/test/inductor/test_inplacing_pass.py index c3aa3996053dce..fef78d8cafb2c8 100644 --- a/test/inductor/test_inplacing_pass.py +++ b/test/inductor/test_inplacing_pass.py @@ -9,7 +9,12 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.fx_passes.reinplace import reinplace_inplaceable_ops_core from torch._inductor.test_case import run_tests, TestCase as InductorTestCase -from torch.testing._internal.common_utils import IS_LINUX +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + IS_LINUX, + parametrize, + subtest, +) from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.logging_utils import logs_to_string @@ -25,9 +30,9 @@ def num_reinplacing_failures(): return counters["inductor"]["possibly_missed_reinplacing_opportunities"] -@torch.library.custom_op("_reinplacing::sin", mutates_args={"out"}) -def sin(x: torch.Tensor, out: torch.Tensor) -> None: - out.copy_(x.sin()) +@torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) +def sin(x: torch.Tensor, result: torch.Tensor) -> None: + result.copy_(x.sin()) @torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}) @@ -95,7 +100,7 @@ def test_counters(self): def f(x): out = torch.empty_like(x) - _, new_out = auto_functionalized(sin._opoverload, x=x, out=out) + _, new_out = auto_functionalized(sin._opoverload, x=x, result=out) y = out * new_out return new_out, y @@ -182,6 +187,40 @@ def f(b): # Both list inputs failed to reinplace. So we should have emitted clones for them. self.assertEqual(post_grad_graphs.count("aten.clone"), 2) + @parametrize( + "factory_op", + [ + subtest(torch.ones_like, name="ones_like"), + subtest(torch.empty_like, name="empty_like"), + ], + ) + def test_partitioner_recomputes_factory(self, factory_op): + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = factory_op(x) + sin(x, out) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad): + (saved,) = ctx.saved_tensors + out = factory_op(grad) + sin(saved, out) + return out + + @torch.compile(backend="inductor") + def f(x): + return MySin.apply(x) + + x = torch.randn(3, requires_grad=True, device=device) + y = f(x) + self.assertEqual(num_reinplacing_failures(), 0) + + +instantiate_parametrized_tests(TestReinplacingPassCorrectness) + if __name__ == "__main__": if IS_LINUX and HAS_CUDA: diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index a7530f71523c54..83a61df4167b58 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -858,6 +858,35 @@ def f(a, b): inp = (T(10, 10), TI(2, mx=5)) self.assertExpectedInline(count_numel(f, *inp), """42""") + @requires_cuda + def test_inplace_custom_op_training(self): + @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) + def sin(x: torch.Tensor, result: torch.Tensor) -> None: + result.copy_(x.sin()) + + factory_op = torch.empty_like + + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = factory_op(x) + sin(x, out) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad): + (saved,) = ctx.saved_tensors + out = factory_op(grad) + sin(saved, out) + return out + + def f(x): + return MySin.apply(x) + + x = T(3, grad=True) + self.assertExpectedInline(count_numel_train(f, x), """9""") + @requires_cuda def test_inplace_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 8716f792a4294f..42b14a66dca85c 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -807,11 +807,29 @@ def solve_min_cut( print("Ops banned from rematerialization: ", ops_ignored) print() + def can_fuse_into_auto_functionalized(a, b): + if b.target != torch.ops.higher_order.auto_functionalized: + return False + mutable_op = b.args[0] + mutable_arg_names = ( + torch._higher_order_ops.auto_functionalize.get_mutable_arg_names(mutable_op) + ) + for name in mutable_arg_names: + arg = b.kwargs[name] + if a is arg: + return True + if isinstance(arg, list): + if a in arg: + return True + return False + def is_fusible(a, b): # We can perform "memory fusion" into a cat, but cat cannot be a # producer to a fusion if get_aten_target(b) == aten.cat: return True + if can_fuse_into_auto_functionalized(a, b): + return True return op_types.is_fusible(a) and op_types.is_fusible(b) try: @@ -1273,6 +1291,8 @@ def get_default_op_list() -> OpTypes: aten.full, aten.as_strided, aten.zeros, + aten.empty, + aten.empty_like, aten.argmax, aten.maximum, prims.iota, From 9966472c6f43c4ffc092f953e6a51253dcb2b4d3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 27 Aug 2024 13:05:24 +0000 Subject: [PATCH 0074/1018] Revert "[FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538)" This reverts commit a34320a6f225061a3b5fe130a5a8fe35ed7a40f9. Reverted https://github.com/pytorch/pytorch/pull/134538 on behalf of https://github.com/albanD due to Broke lint due to too long line ([comment](https://github.com/pytorch/pytorch/pull/134507#issuecomment-2312505955)) --- test/inductor/test_flex_decoding.py | 38 ++----------------------- torch/_inductor/kernel/flex_decoding.py | 4 +-- 2 files changed, 4 insertions(+), 38 deletions(-) diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index bc41ff2fe8b16a..3cf3981b0f1782 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -14,7 +14,6 @@ from torch.nn.attention.flex_attention import ( _create_empty_block_mask, _identity, - BlockMask, create_block_mask, flex_attention, ) @@ -254,7 +253,7 @@ def _check_out( def run_test( self, - score_mod: Optional[Callable], + score_mod: Callable, dtype: torch.dtype = torch.float16, Q_B: int = B, Q_H: int = Hq, @@ -264,11 +263,7 @@ def run_test( KV_H: int = Hkv, KV_S: int = S, V_D: int = D, - block_mask: Optional[BlockMask] = None, ): - assert ( - score_mod is not None or block_mask is not None - ), "Must provide score_mod or block_mask" assert Q_H % KV_H == 0 q = torch.randn( (Q_B, Q_H, Q_S, Q_D), @@ -285,6 +280,7 @@ def run_test( q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) + block_mask = None sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) @@ -947,36 +943,6 @@ def func(q, k, v, score_mod): code[0] ) - @supported_platform - def test_non_sparse_mulitple_block_size(self): - def generate_causal_offset(offset: torch.Tensor): - def causal_offset_mask(b, h, q_idx, kv_idx): - return (offset + q_idx) >= kv_idx - - return causal_offset_mask - - def noop(score, b, h, q_idx, kv_idx): - return score - - mod = generate_causal_offset( - torch.tensor(192, device="cuda", dtype=torch.int32) - ) - block_mask = create_block_mask(mod, 1, 1, 1, 65) - - self.run_test( - score_mod=None, - dtype=torch.float32, - block_mask=block_mask, - Q_B=1, - Q_H=1, - Q_S=1, - Q_D=16, - KV_B=1, - KV_H=1, - KV_S=65, - V_D=16, - ) - @supported_platform def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self): torch._dynamo.reset() diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 870e6194aec1e1..3520b618ef0a67 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -161,7 +161,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me # first kv block we're loading # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) K_block_ptr = tl.make_block_ptr( base=K + k_offset, @@ -207,7 +207,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) K_block_ptr = tl.make_block_ptr( base=K + k_offset, From a3b95a501f63f90410c4e83a46496e371cb67ee1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 27 Aug 2024 13:05:25 +0000 Subject: [PATCH 0075/1018] Revert "[FlexAttention] Remove unused code (#134511)" This reverts commit 767c47d3c0ee3fc7804918a08de3f94874143a03. Reverted https://github.com/pytorch/pytorch/pull/134511 on behalf of https://github.com/albanD due to Broke lint due to too long line ([comment](https://github.com/pytorch/pytorch/pull/134507#issuecomment-2312505955)) --- test/inductor/test_flex_attention.py | 7 +------ torch/_inductor/kernel/flex_attention.py | 3 +++ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 6e42158ef4ab11..0a1fe403bdd0cf 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1648,7 +1648,7 @@ def test_small_q_kv_len(self): (1, 1, 1, 16), device="cuda", dtype=torch.float32, - requires_grad=True, + requires_grad=False, ) query, key, value = make_tensor(), make_tensor(), make_tensor() kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} @@ -1660,11 +1660,6 @@ def test_small_q_kv_len(self): assert torch.equal(out_eager, out_compiled) assert torch.equal(lse_eager, lse_compiled) - grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value)) - grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value)) - - torch.testing.assert_close(grads_eager, grads_compile) - @supported_platform def test_causal_block_non_divisible_with_captured_buffer(self): Q_S = S - 3 diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 8e90eacaab72b5..71b59058e16dfc 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1606,6 +1606,9 @@ def flex_attention_backward(*args, **kwargs): ] ) + if _use_flex_decoding(query, kernel_options): + raise NotImplementedError("Flex decoding backward pass is not implemented. ") + device = query.get_device() dtype = query.get_dtype() Bq, Hq, seq_len_q, qk_head_dim = query.get_size() From b1728e27461816f0960b6156f7e0648fb5163684 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 27 Aug 2024 13:05:25 +0000 Subject: [PATCH 0076/1018] Revert "[FlexAttention] Create new variables for the subgraphs (#134507)" This reverts commit 4d0a44d34a46af6dcc764d55269b30ac537822a0. Reverted https://github.com/pytorch/pytorch/pull/134507 on behalf of https://github.com/albanD due to Broke lint due to too long line ([comment](https://github.com/pytorch/pytorch/pull/134507#issuecomment-2312505955)) --- test/inductor/test_flex_attention.py | 19 ------------------- torch/_inductor/kernel/flex_attention.py | 15 ++++++--------- 2 files changed, 6 insertions(+), 28 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 0a1fe403bdd0cf..ecc2f7d8e2978a 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1641,25 +1641,6 @@ def test_force_write_lse(self): torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) - @supported_platform - def test_small_q_kv_len(self): - make_tensor = functools.partial( - torch.ones, - (1, 1, 1, 16), - device="cuda", - dtype=torch.float32, - requires_grad=False, - ) - query, key, value = make_tensor(), make_tensor(), make_tensor() - kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} - out_eager, lse_eager = flex_attention(query, key, value, return_lse=True, kernel_options=kernel_options) - - flex_compile = torch.compile(flex_attention, fullgraph=True) - out_compiled, lse_compiled = flex_compile(query, key, value, return_lse=True, kernel_options=kernel_options) - - assert torch.equal(out_eager, out_compiled) - assert torch.equal(lse_eager, lse_compiled) - @supported_platform def test_causal_block_non_divisible_with_captured_buffer(self): Q_S = S - 3 diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 71b59058e16dfc..de7ff7ee913610 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -439,11 +439,8 @@ def forward_block_mn( # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, # which is larger than the actual number of elements. To avoid access memory out of bound, # we need to mask out the elements that are out of Q_LEN & KV_LEN. - m = offs_m % Q_LEN - n = offs_n % KV_LEN - else: - m = offs_m - n = offs_n + offs_m = offs_m % Q_LEN + offs_n = offs_n % KV_LEN {{ modification( subgraph_number=0, @@ -451,8 +448,8 @@ def forward_block_mn( score="qk", b="off_z", h="off_h", - m="m", - n="n", + m="offs_m", + n="offs_n", out="qk" ) | indent_except_first(1) }} @@ -467,8 +464,8 @@ def forward_block_mn( score="qk", b="off_z", h="off_h", - m="m", - n="n", + m="offs_m", + n="offs_n", ) | indent_except_first(2) }} if CHECK_BLOCK_BOUNDARY: From ffa9f9e52e737cc512752135a2620bfb9df9e71a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 27 Aug 2024 14:44:10 +0000 Subject: [PATCH 0077/1018] Do not use `` on Linux (#134494) Because right now it leads to symbol conflict from binary builds. Use of `std::filesystem::file_exists` was introduced by https://github.com/pytorch/pytorch/pull/126601 and in this PR it is replaced with a very straightforward implementation that calls `stat` on the given path, which is a classic C-way of checking for the file existence. This PR should be reverted once one figures out how to keep `std::filesystem` methods linked into the binary private Fixes symptoms of https://github.com/pytorch/pytorch/issues/133437 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134494 Approved by: https://github.com/atalman, https://github.com/d4l3k --- .../c10d/control_plane/WorkerServer.cpp | 21 +++++++++++++++++-- .../aoti_runner/model_container_runner.cpp | 17 ++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index c4d6b2f61fcbbc..047459b965589a 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -8,6 +7,15 @@ #include #include +// NS: TODO: Use `std::filesystem` regardless of OS when it's possible +// to use it without leaking symbols on PRECXX11 ABI Linux OSes +// See https://github.com/pytorch/pytorch/issues/133437 for more details +#ifdef _WIN32 +#include +#else +#include +#endif + namespace c10d::control_plane { namespace { @@ -70,6 +78,15 @@ std::string jsonStrEscape(const std::string& str) { } return ostream.str(); } + +bool file_exists(const std::string& path) { +#ifdef _WIN32 + return std::filesystem::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} } // namespace WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { @@ -144,7 +161,7 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { // using unix sockets server_.set_address_family(AF_UNIX); - if (std::filesystem::exists(hostOrFile)) { + if (file_exists(hostOrFile)) { throw std::runtime_error(fmt::format("{} already exists", hostOrFile)); } diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index 4941e4e3210a45..c7c09ac3ea43d7 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -14,6 +14,21 @@ namespace fs = std::filesystem; namespace fs = std::experimental::filesystem; #endif +#ifndef _WIN32 +#include +#endif + +namespace { +bool file_exists(std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} +} // namespace + namespace torch::inductor { AOTIModelContainerRunner::AOTIModelContainerRunner( @@ -60,7 +75,7 @@ AOTIModelContainerRunner::AOTIModelContainerRunner( size_t lastindex = model_so_path.find_last_of("."); std::string json_filename = model_so_path.substr(0, lastindex) + ".json"; - if (fs::exists(json_filename)) { + if (file_exists(json_filename)) { proxy_executor_ = std::make_unique( json_filename, device_str == "cpu"); proxy_executor_handle_ = From 9a8fc2477e73b5400826ad9d9e93555e4abe36c1 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 26 Aug 2024 19:19:32 -0700 Subject: [PATCH 0078/1018] Make torch.serialization.set_default_mmap_options usable as a context manager (#134371) As title Pull Request resolved: https://github.com/pytorch/pytorch/pull/134371 Approved by: https://github.com/albanD --- test/test_serialization.py | 26 ++++++++++++++++++++------ torch/serialization.py | 38 +++++++++++++++++++++++++------------- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index f663188a959da8..a041473d195b19 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -4052,7 +4052,7 @@ def test_serialization_warning_s390x(self): @parametrize('path_type', (str, Path)) @parametrize('weights_only', (True, False)) @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") - def test_serialization_mmap_loading(self, weights_only, path_type): + def test_serialization_mmap_loading_options(self, weights_only, path_type): class DummyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -4101,7 +4101,7 @@ def forward(self, input): for v in result.values(): self.assertTrue(v.is_cuda) - def test_serialization_mmap_loading_options(self): + def test_serialization_mmap_loading(self): if IS_WINDOWS: with self.assertRaisesRegex(RuntimeError, "Changing the default mmap options is currently not supported"): torch.serialization.set_default_mmap_options(2) @@ -4111,22 +4111,36 @@ def test_serialization_mmap_loading_options(self): with tempfile.NamedTemporaryFile() as f: torch.save(sd, f) # with MmapVisibility.MAP_PRIVATE, should not be able to modify file - sd_loaded = torch.load(f.name, mmap=True) + sd_loaded = torch.load(f.name, mmap=True, weights_only=True) sd_loaded['weight'][0][0] = 0 - sd_loaded2 = torch.load(f.name, mmap=True) + sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True) self.assertEqual(sd_loaded2['weight'], sd['weight']) # with MmapVisibility.MAP_SHARED, should be able to modify file torch.serialization.set_default_mmap_options(MAP_SHARED) try: - sd_loaded = torch.load(f.name, mmap=True) + sd_loaded = torch.load(f.name, mmap=True, weights_only=True) sd_loaded['weight'][0][0] = 0 - sd_loaded2 = torch.load(f.name, mmap=True) + sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True) self.assertNotEqual(sd_loaded2['weight'], sd['weight']) self.assertEqual(sd_loaded2['weight'][0][0].item(), 0) self.assertEqual(sd_loaded2['weight'], sd_loaded['weight']) finally: torch.serialization.set_default_mmap_options(MAP_PRIVATE) + @unittest.skipIf(IS_WINDOWS, "mmap ctx doesn't work on Windows") + def test_serialization_mmap_loading_ctx(self): + sd = torch.nn.Linear(3, 5).state_dict() + with tempfile.NamedTemporaryFile() as f: + torch.save(sd, f) + with torch.serialization.set_default_mmap_options(MAP_SHARED): + sd_loaded = torch.load(f.name, mmap=True, weights_only=True) + sd_loaded['weight'][0][0] = 0 + sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True) + self.assertNotEqual(sd_loaded2['weight'], sd['weight']) + self.assertEqual(sd_loaded2['weight'][0][0].item(), 0) + self.assertEqual(sd_loaded2['weight'], sd_loaded['weight']) + self.assertTrue(torch.serialization.get_default_mmap_options() == MAP_PRIVATE) + @parametrize('dtype', (torch.float8_e5m2, torch.float8_e4m3fn, torch.complex32)) @parametrize('weights_only', (True, False)) def test_serialization_dtype(self, dtype, weights_only): diff --git a/torch/serialization.py b/torch/serialization.py index 36e5c035ff1516..2ac36ec371fe57 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -54,6 +54,8 @@ "LoadEndianness", "get_default_load_endianness", "set_default_load_endianness", + "get_default_mmap_options", + "set_default_mmap_options", "clear_safe_globals", "get_safe_globals", "add_safe_globals", @@ -163,9 +165,9 @@ def get_default_mmap_options() -> int: return _default_mmap_options -def set_default_mmap_options(flags: int): +class set_default_mmap_options: """ - Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. + Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported. Please open an issue if you need any other option to be added here. @@ -176,17 +178,27 @@ def set_default_mmap_options(flags: int): Args: flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` """ - global _default_mmap_options - if IS_WINDOWS: - raise RuntimeError( - "Changing the default mmap options is currently not supported for Windows" - ) - if flags != MAP_PRIVATE and flags != MAP_SHARED: - raise ValueError( - "Invalid argument in function set_default_mmap_options, " - f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}" - ) - _default_mmap_options = flags + + def __init__(self, flags: int) -> None: + if IS_WINDOWS: + raise RuntimeError( + "Changing the default mmap options is currently not supported for Windows" + ) + if flags != MAP_PRIVATE and flags != MAP_SHARED: + raise ValueError( + "Invalid argument in function set_default_mmap_options, " + f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}" + ) + global _default_mmap_options + self.prev = _default_mmap_options + _default_mmap_options = flags + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global _default_mmap_options + _default_mmap_options = self.prev def clear_safe_globals() -> None: From c69e50341c7aa596c6b8e58fa8a6c155ecab24c1 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 26 Aug 2024 14:28:18 -0700 Subject: [PATCH 0079/1018] Clarify error messages for NEWOBJ and BUILD in weights_only unpickler (#134346) Clarify that `add_safe_globals` will allow types for these instructions Some types do not appear as `GLOBAL` and are only caught in `BUILD`, example from hf slack is `numpy.dtypes.UInt32DType` ```python import torch import numpy as np from tempfile import TemporaryDirectory from pathlib import Path from codecs import encode torch.serialization.add_safe_globals([encode, np.dtype, np.core.multiarray._reconstruct, np.ndarray]) with TemporaryDirectory() as tempdir: p = Path(tempdir) r2 = np.random.get_state() torch.save(r2, p / "r2.pkl") torch.load(p / "r2.pkl", weights_only=True) ``` Yields (error comes from BUILD) ``` UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Can only build Tensor, parameter or OrderedDict objects, but got ``` The reasoning is that `numpy.dtypes.UInt32DType` is constructed via `REDUCE` with `func =` and `args= ('u4', False, True)`, clarify the error message that doing `add_safe_globals` on these will also allow them After this PR error message becomes ``` _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`, but got ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134346 Approved by: https://github.com/albanD --- torch/_weights_only_unpickler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index c7fff5b98c7b45..063b57d859e75d 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -261,7 +261,8 @@ def load(self): self.append(cls.__new__(cls, *args)) else: raise UnpicklingError( - f"Trying to instantiate unsupported class {cls}" + "Can only create new object for nn.Parameter or classes allowlisted " + f"via `add_safe_globals` but got {cls}" ) elif key[0] == REDUCE[0]: args = self.stack.pop() @@ -291,7 +292,8 @@ def load(self): inst.__dict__.update(state) else: raise UnpicklingError( - f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}" + "Can only build Tensor, Parameter, OrderedDict or types allowlisted " + f"via `add_safe_globals`, but got {type(inst)}" ) # Stack manipulation elif key[0] == APPEND[0]: From 8462340be7f23f498176693a447b6b0025e03845 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 26 Aug 2024 21:53:44 +0000 Subject: [PATCH 0080/1018] Fix flaky GroupNorm ModuleInfo test (#133899) Fixes https://github.com/pytorch/pytorch/issues/98677 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133899 Approved by: https://github.com/albanD --- torch/testing/_internal/common_modules.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index bdd0375850b438..43f1059eb91b36 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1718,14 +1718,6 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, constructor_input=FunctionInput(3, 6, 1e-3), forward_input=FunctionInput(make_input((4, 6, 2, 3))), desc='2d_affine'), - ModuleInput( - constructor_input=FunctionInput(3, 6, 1e-3), - forward_input=FunctionInput(make_input((4, 6, 28, 28))), - desc='2d_affine_large_feature'), - ModuleInput( - constructor_input=FunctionInput(3, 51, 1e-5, False), - forward_input=FunctionInput(make_input((2, 51, 28, 28))), - desc='2d_no_affine_large_feature'), ModuleInput( constructor_input=FunctionInput(3, 3, 1e-3, False), forward_input=FunctionInput(make_input((4, 3, 2, 3))), From c6246c91fe3b424f5ee36fabf12911d95e653580 Mon Sep 17 00:00:00 2001 From: wz337 Date: Tue, 27 Aug 2024 14:51:19 +0000 Subject: [PATCH 0081/1018] [DeviceMesh] Add get_all_submeshes in _MeshEnv (#134275) Adding a private helper method for Shampoo HSDP use cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134275 Approved by: https://github.com/XilunWu --- test/distributed/test_device_mesh.py | 11 +++++++++++ torch/distributed/device_mesh.py | 29 ++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 09db4ac06dab35..972c0ae89028be 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -692,6 +692,17 @@ def test_get_mesh_dim_by_name(self): self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "DP"), 0) self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1) + @with_comms + def test_get_all_submeshes(self): + mesh_2d = init_device_mesh( + self.device_type, (2, 4), mesh_dim_names=("replicate", "shard") + ) + all_submeshes = _mesh_resources._get_all_submeshes(mesh_2d, "replicate") + self.assertEqual(len(all_submeshes), 4) + self.assertEqual( + all(submesh.mesh.numel() == 2 for submesh in all_submeshes), True + ) + class DeviceMeshCollectiveTest(DTensorTestBase): @property diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 14aa38be42b34c..0ad17536ef1832 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -335,6 +335,35 @@ def _get_slice_mesh_dims( return slice_mesh_dims + def _get_all_submeshes( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> List["DeviceMesh"]: + """ + Return all the submeshes of a given mesh dimension of the device mesh. + """ + mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name) + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( + -1, device_mesh.mesh.size(mesh_dim) + ) + + cur_rank = device_mesh.get_rank() + res_submeshes = [] + for mesh_1d in pg_ranks_by_dim: + submesh = DeviceMesh( + device_mesh.device_type, + mesh_1d, + mesh_dim_names=(mesh_dim_name,), + _init_backend=False, + ) + submesh._dim_group_infos = ( + [device_mesh._dim_group_infos[mesh_dim]] + if cur_rank in mesh_1d + else [] + ) + res_submeshes.append(submesh) + + return res_submeshes + _mesh_resources: _MeshEnv = _MeshEnv() def _get_device_handle(device_type: str = "cuda"): From 78c2b9e92c3b3b1a2097ddaafb19371c12ceec50 Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 26 Aug 2024 17:57:45 -0700 Subject: [PATCH 0082/1018] Update partitioner's is_fusible heuristic to respect triton kernels (#134491) mutated arguments to triton kernels are fusible into the triton kernel. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/134491 Approved by: https://github.com/Chillee ghstack dependencies: #134364, #134466, #134490 --- test/inductor/test_inplacing_pass.py | 44 +++++++++++++++++++++++--- test/inductor/test_perf.py | 47 ++++++++++++++++++++++++++++ torch/_functorch/partitioners.py | 12 +++++++ 3 files changed, 99 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_inplacing_pass.py b/test/inductor/test_inplacing_pass.py index fef78d8cafb2c8..bce4fa329fe497 100644 --- a/test/inductor/test_inplacing_pass.py +++ b/test/inductor/test_inplacing_pass.py @@ -15,7 +15,7 @@ parametrize, subtest, ) -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU from torch.testing._internal.logging_utils import logs_to_string @@ -41,6 +41,35 @@ def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> No out_cos.copy_(x.cos()) +if HAS_GPU: + import triton + import triton.language as tl + + @triton.jit + def sin_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = tl.sin(x) + tl.store(out_ptr + offsets, output, mask=mask) + + def sin_triton(x, out): + n_elements = x.numel() + sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + +else: + + def sin_triton(x, out): + return + + class TestReinplacingPassCorrectness(InductorTestCase): def setUp(self): counters.clear() @@ -194,12 +223,19 @@ def f(b): subtest(torch.empty_like, name="empty_like"), ], ) - def test_partitioner_recomputes_factory(self, factory_op): + @parametrize( + "sin_op", + [ + subtest(sin, name="sin_op"), + subtest(sin_triton, name="sin_triton"), + ], + ) + def test_partitioner_recomputes_factory(self, factory_op, sin_op): class MySin(torch.autograd.Function): @staticmethod def forward(ctx, x): out = factory_op(x) - sin(x, out) + sin_op(x, out) ctx.save_for_backward(out) return out @@ -207,7 +243,7 @@ def forward(ctx, x): def backward(ctx, grad): (saved,) = ctx.saved_tensors out = factory_op(grad) - sin(saved, out) + sin_op(saved, out) return out @torch.compile(backend="inductor") diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 83a61df4167b58..14694e54161d10 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -32,6 +32,9 @@ if HAS_CUDA: + import triton + import triton.language as tl + from torch.testing._internal.triton_utils import add_kernel aten = torch.ops.aten @@ -858,6 +861,50 @@ def f(a, b): inp = (T(10, 10), TI(2, mx=5)) self.assertExpectedInline(count_numel(f, *inp), """42""") + @requires_cuda + def test_inplace_triton_kernel_training(self): + @triton.jit + def sin_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = tl.sin(x) + tl.store(out_ptr + offsets, output, mask=mask) + + def sin_triton(x, out): + n_elements = x.numel() + sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + + factory_op = torch.empty_like + + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = factory_op(x) + sin_triton(x, out) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad): + (saved,) = ctx.saved_tensors + out = factory_op(grad) + sin_triton(saved, out) + return out + + def f(x): + return MySin.apply(x) + + x = T(3, grad=True) + self.assertExpectedInline(count_numel_train(f, x), """9""") + @requires_cuda def test_inplace_custom_op_training(self): @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 42b14a66dca85c..df6ca4e0e998a0 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -823,6 +823,16 @@ def can_fuse_into_auto_functionalized(a, b): return True return False + def can_fuse_into_triton_kernel_wrapper_functional(a, b): + if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional: + return False + mutable_arg_names = b.kwargs["tensors_to_clone"] + for name in mutable_arg_names: + arg = b.kwargs["kwargs"][name] + if a is arg: + return True + return False + def is_fusible(a, b): # We can perform "memory fusion" into a cat, but cat cannot be a # producer to a fusion @@ -830,6 +840,8 @@ def is_fusible(a, b): return True if can_fuse_into_auto_functionalized(a, b): return True + if can_fuse_into_triton_kernel_wrapper_functional(a, b): + return True return op_types.is_fusible(a) and op_types.is_fusible(b) try: From 281c5d2ff4499c002e853a903ad4ac573174076c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 26 Aug 2024 14:29:12 -0700 Subject: [PATCH 0083/1018] [1/N] Move NaN check onto NCCL stream (#134300) So that the tensor's lifetime management is the same as the management built for the NCCL, pre and post kernels. Also so that on visualizers, they show up in the NCCL stream line. Otherwise if they show up in the compute line, user may get confused (my code does not have these kernels). The check is thus moved after the point where we depend NCCL stream from the last compute kernel. Also moved declaration of `checkForNan` from Utils.hpp to NCCLUtils.hpp, and renamed Utils.cu to NCCLUtils.cu. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134300 Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab --- BUILD.bazel | 4 ++-- build_variables.bzl | 2 +- .../distributed/c10d/{Utils.cu => NCCLUtils.cu} | 6 +++--- torch/csrc/distributed/c10d/NCCLUtils.hpp | 6 ++++++ torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 14 ++++++++------ torch/csrc/distributed/c10d/Utils.hpp | 2 -- 6 files changed, 20 insertions(+), 14 deletions(-) rename torch/csrc/distributed/c10d/{Utils.cu => NCCLUtils.cu} (86%) diff --git a/BUILD.bazel b/BUILD.bazel index f079c76a7f7204..523f2ad2ce60a4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -574,7 +574,7 @@ cu_library( name = "torch_cuda", srcs = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NCCLUtils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], copts = torch_cuda_half_options, @@ -722,7 +722,7 @@ cc_library( "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NCCLUtils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], )) + torch_sources, diff --git a/build_variables.bzl b/build_variables.bzl index 98b721617b609c..bf6d6cb14bd364 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -689,7 +689,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NCCLUtils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] diff --git a/torch/csrc/distributed/c10d/Utils.cu b/torch/csrc/distributed/c10d/NCCLUtils.cu similarity index 86% rename from torch/csrc/distributed/c10d/Utils.cu rename to torch/csrc/distributed/c10d/NCCLUtils.cu index ae2017efdf8d97..5dcf01fb98b153 100644 --- a/torch/csrc/distributed/c10d/Utils.cu +++ b/torch/csrc/distributed/c10d/NCCLUtils.cu @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include @@ -20,7 +20,7 @@ __global__ void checkForNaN(T* data, size_t size) { } // CHECK if a Tensor contains NAN in any of its element -void checkForNan(const at::Tensor& tensor) { +void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { // skip check for non float types if (!torch::is_floating_point(tensor)) { return; @@ -40,7 +40,7 @@ void checkForNan(const at::Tensor& tensor) { tensor.scalar_type(), "checkForNaN", [&] { - checkForNaN<<>>( + checkForNaN<<>>( tensor.data_ptr(), tensor.numel()); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 4b0a197dbd536a..c80742d90a4db5 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -713,6 +714,11 @@ struct NCCLTraceBuffer { bool includeStackTraces, bool onlyActive); }; + +// Check for NaNs in a tensor on a given stream. If any are found, throw a +// device-side error. +void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream); + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3b308a360fb733..be42c74c4814f9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2659,9 +2659,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( OpType opType, const char* profilingTitle, bool avoidRecordStreams) { - if (enableNanCheck_) { - checkForNan(input); - } // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; c10::cuda::CaptureStatus capture_status = @@ -2717,6 +2714,10 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; + if (enableNanCheck_) { + checkForNan(input, ncclStream); + } + // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { work->ncclStartEvent_->record(ncclStream); @@ -3018,9 +3019,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( PreProcess pre, PostProcess post, const char* profilingTitle) { - if (enableNanCheck_) { - checkForNan(tensor); - } // avoidRecordStreams_ note: // send, recv, and irecv should be ok with avoidRecordStreams, // However, for isend, I don't think the API requires the user @@ -3149,6 +3147,10 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; + if (enableNanCheck_) { + checkForNan(tensor, ncclStream); + } + if (!coalescing_state_) { // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 5c7365393404fc..ea4a4653bc35fc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -611,8 +611,6 @@ using SizeType = uint64_t; // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) -void checkForNan(const at::Tensor& tensor); - namespace tcputil { // Send and receive From c4e80b4ddcaa590c690bf33f9f71acc112497bfc Mon Sep 17 00:00:00 2001 From: "Colin L. Rice" Date: Tue, 27 Aug 2024 16:07:33 +0000 Subject: [PATCH 0084/1018] Create a JustknobConfig for use in config (#134161) This is designed to be a more ergonomic interface on top of justknob_feature (see https://github.com/pytorch/pytorch/pull/134151 for just the PR with the base commits). The idea is that people stop having to think about this as much, and can just do JustkobsConfig("//the:thing", "FORCE_THING") and it'll do the right thing. Primarily sending this to see how people feel about the API, and using it for new config changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134161 Approved by: https://github.com/ezyang --- test/test_utils_internal.py | 138 ++++++++++++++++++++++++++++++++++++ torch/_utils_internal.py | 105 +++++++++++++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 test/test_utils_internal.py diff --git a/test/test_utils_internal.py b/test/test_utils_internal.py new file mode 100644 index 00000000000000..155afe5f9f6cd2 --- /dev/null +++ b/test/test_utils_internal.py @@ -0,0 +1,138 @@ +# Owner(s): ["module: unknown"] + +import os + +from torch._utils_internal import justknobs_feature, JustKnobsConfig +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + load_tests, +) + + +# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for +# sharding on sandcastle. This line silences flake warnings +load_tests = load_tests + +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestJustKnob(TestCase): + def test_justknob_config(self): + with self.subTest("Returns True"): + a = JustKnobsConfig() + self.assertTrue(a.get()) + with self.subTest("Returns False"): + a = JustKnobsConfig(name="fake_name", default=False) + self.assertFalse(a.get()) + with self.subTest("Returns True via config"): + a = JustKnobsConfig(name="fake_name", default=False) + a.set(True) + self.assertTrue(a.get()) + with self.subTest("Returns True via env"): + os.environ["FAKE_FEATURE"] = "1" + a = JustKnobsConfig( + name="fake_name", env_name="FAKE_FEATURE", default=False + ) + self.assertTrue(a.get()) + with self.subTest("Returns same value consistently"): + a = JustKnobsConfig(name="fake_name", default=False) + a.set(True) + self.assertTrue(a.get()) + a.set(False) + self.assertTrue(a.get()) + + def test_justknob_feature(self): + with self.subTest("OSS is True"): + self.assertTrue(justknobs_feature("testname")) + with self.subTest("OSS default=True"): + self.assertTrue(justknobs_feature("testname", default=True)) + with self.subTest("OSS default=False"): + self.assertFalse(justknobs_feature("testname", default=False)) + with self.subTest("OSS config=True, default=False"): + self.assertTrue( + justknobs_feature("testname", config_value=True, default=False) + ) + with self.subTest("OSS config=None, default=False"): + self.assertFalse( + justknobs_feature("testname", config_value=None, default=False) + ) + with self.subTest("OSS config=False, default=True"): + self.assertFalse( + justknobs_feature("testname", config_value=False, default=True) + ) + with self.subTest("OSS env is missing, config=False, default=True"): + self.assertFalse( + justknobs_feature( + "testname", config_value=False, env_name="NOTDEFINED", default=False + ) + ) + with self.subTest("OSS env is missing, default=False"): + self.assertFalse( + justknobs_feature("testname", env_name="NOTDEFINED", default=False) + ) + with self.subTest( + "OSS config overrides env, config=True, env=False, default=False" + ): + os.environ["FEATURE_ENV"] = "0" + self.assertTrue( + justknobs_feature( + "testname", + config_value=True, + env_name="FEATURE_ENV", + default=False, + ) + ) + with self.subTest("OSS env overrides default, , default=False"): + os.environ["FEATURE_ENV"] = "1" + self.assertTrue( + justknobs_feature("testname", env_name="FEATURE_ENV", default=False) + ) + with self.subTest("OSS env truthy, config=False, default=False"): + os.environ["FEATURE_ENV"] = "1" + self.assertTrue( + justknobs_feature( + "testname", + env_name="FEATURE_ENV", + default=False, + ) + ) + os.environ["FEATURE_ENV"] = "true" + self.assertTrue( + justknobs_feature( + "testname", + env_name="FEATURE_ENV", + default=False, + ) + ) + os.environ["FEATURE_ENV"] = "TRUE" + self.assertTrue( + justknobs_feature( + "testname", + env_name="FEATURE_ENV", + default=False, + ) + ) + os.environ["FEATURE_ENV"] = "very weird true" + self.assertTrue( + justknobs_feature( + "testname", + env_name="FEATURE_ENV", + default=False, + ) + ) + with self.subTest("OSS env false, default=True"): + os.environ["FEATURE_ENV"] = "0" + self.assertFalse( + justknobs_feature("testname", env_name="FEATURE_ENV", default=True) + ) + os.environ["FEATURE_ENV"] = "false" + self.assertFalse( + justknobs_feature("testname", env_name="FEATURE_ENV", default=True) + ) + os.environ["FEATURE_ENV"] = "FALSE" + self.assertFalse( + justknobs_feature("testname", env_name="FEATURE_ENV", default=True) + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index cb6230618bf04b..59b579913ae9c3 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -155,6 +155,111 @@ def capture_pre_autograd_graph_using_training_ir() -> bool: return False +class JustKnobsConfig: + """Represents a lazily loaded config + + This is designed to be used to specify a value in a config. + + i.e. foo.bar = JustknobsConfig(name="//foo:bar", env_name="FORCE_FOO_BAR") + + Call .get() in order to access the value + i.e. if foo.bar.get(): + + Note that the value is fetched once, and then not allowed to change. This + means less suprises, at the downside that you may have to restart a job + to pick up an update. + + It can also be set explicitly via set - i.e. + foo.bar = JustknobsConfig(name="//foo:bar") + foo.bar.set(True) + + Note that this does allow for no JK name (so that you can use this to replace old configurations). + """ + + def __init__( + self, *, name: Optional[str] = None, env_name=None, default: bool = True + ): + self.name = name + self.env_name = env_name + self.default = default + self.value: Optional[bool] = None + self.executed_value = None + + def set(self, value: bool): + self.value = value + + def get(self): + if self.executed_value is None: + self.executed_value = justknobs_feature( + self.name, + config_value=self.value, + env_name=self.env_name, + default=self.default, + ) + return self.executed_value + + def __str__(self): + v = bool(self) + return f"JustknobsConfig(name={self.name}, env_name={self.env_name}, default={self.default} - evals_to={v})" + + +def justknobs_feature( + name: Optional[str], config_value=None, env_name=None, default: bool = True +): + """Returns whether or not a specific justknob feature is enabled. + + This is a slightly higher level API then justknobs_check, designed to make it "easy" to do the right thing. + The primary thing it does, is allow configuration to override JK by default, while retaining some features to force this + the other way during sevs. + + The preference order (i.e. who wins first) in OSS (and FB) is + - Config if specified + - Environment Variable if specified + - JK (FB), or default (OSS) + + + Quickstart + Have a config variable + Make a JK which is set to your "enabled" value (generally true). + Use this feature to check it (if you set the JK to be false, change the default). + If you have an env variable, also use the function to check it. + + Arguments: + name - This should correspond 1:1 to a JK name internally to FB. + env_name - If this is set, we'll try and read the value from environment variables + config_value - If this is set to anything other than None, we'll use this value by + default. Note that within FB, there is some functionality to force override these + configs + default - This is the value to return in OSS. This avoids having to write weird double + negatives within justknobs and the config code, if you just want to have the + killswitch work by having feature return True to turn off features + + Requirements: + Don't use this at import time - Simply pass in the existing config + """ + if config_value is not None: + return config_value + if env_name is not None and ((env := os.getenv(env_name)) is not None): + env = env.upper() + if env in ("1", "TRUE"): + return True + if env in ("0", "FALSE"): + return False + log.error( + "Difficulty parsing env variable %s=%s for feature %s - Assuming env variable means true and returning True", + env_name, + env, + name, + ) + # We could return default here, but that was confusing to log. + return True + if name is None: + return True + if not default: + return not justknobs_check(name) + return justknobs_check(name) + + def justknobs_check(name: str) -> bool: """ This function can be used to killswitch functionality in FB prod, From 12ebcfc70dd3124c70bb60e52dd1f208820254ff Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 26 Aug 2024 15:03:49 -0700 Subject: [PATCH 0085/1018] [2/N] Add flag to control which rank should perform NaN check (#134345) Fixes https://github.com/pytorch/pytorch/issues/134062. For example, in case of broadcast / scatter, only the root rank should perform the NaN check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134345 Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab ghstack dependencies: #134300 --- test/distributed/test_c10d_nccl.py | 22 +++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 39 +++++++++++++------ .../distributed/c10d/ProcessGroupNCCL.hpp | 6 ++- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 64c51f791fddd0..55b785a3552b5a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -366,6 +366,28 @@ def test_nan_assert(self, type): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_nan_p2p(self): + # Putting NaN at recv buffer, program should not fail as NaN checker + # should not check on receive buffer + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + if self.rank == 0: + c10d.send(t, 1) + elif self.rank == 1: + # Putting NaN at recv buffer + t[1, 1] = float("nan") + c10d.recv(t, 0) + c10d.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index be42c74c4814f9..94f23d07b6322e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2658,9 +2658,12 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PostProcess post, OpType opType, const char* profilingTitle, - bool avoidRecordStreams) { + bool avoidRecordStreams, + bool nanCheck) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; + nanCheck &= enableNanCheck_; + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2714,7 +2717,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { + if (nanCheck) { checkForNan(input, ncclStream); } @@ -3147,7 +3150,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { + // Only check for NaN for send ops, for recv ops `tensor` can be a random + // placeholder + if (enableNanCheck_ && opType == OpType::SEND) { checkForNan(tensor, ncclStream); } @@ -3244,7 +3249,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( Fn fn, OpType opType, const char* profilingTitle, - bool avoidRecordStreams) { + bool avoidRecordStreams, + bool nanCheck) { return collective( input, output, @@ -3255,7 +3261,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( c10::intrusive_ptr& work) {}, opType, profilingTitle, - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } template @@ -3505,6 +3512,9 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( // avoidRecordStreams_ note: collective() will stash tensors. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( tensor, tensor, @@ -3512,7 +3522,6 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; return ncclBcast( input.data_ptr(), input.numel(), @@ -3523,7 +3532,8 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( }, OpType::BROADCAST, "nccl:broadcast", - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } // _broadcast_oop adds an out-of-place broadcast in PGNCCL @@ -3543,6 +3553,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( "Tensor input and output of _broadcast_oop must have the same number of elements "); } + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( inputTensor, outputTensor, @@ -3550,7 +3563,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; return ncclBroadcast( input.data_ptr(), output.data_ptr(), @@ -3561,7 +3573,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, - "nccl:_broadcast_oop"); + "nccl:_broadcast_oop", + /*avoidRecordStreams=*/false, + nanCheck); } c10::intrusive_ptr ProcessGroupNCCL::reduce( @@ -4500,6 +4514,9 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // inputs, which == inputTensors[0] on the root rank where it matters. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank; + bool nanCheck = (rank_ == root); + return collective( outputTensor, inputs[0], // just to fit the collective interface @@ -4507,7 +4524,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank; if (getRank() == root) { if (!avoidRecordStreams) { for (auto input : inputs) { @@ -4521,7 +4537,8 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( }, OpType::SCATTER, "nccl:scatter", - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 28615055a72bb4..e099a75de58879 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -762,7 +762,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool avoidRecordStreams = false, + bool nanCheck = true); template c10::intrusive_ptr collective( @@ -773,7 +774,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PostProcess post, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool avoidRecordStreams = false, + bool nanCheck = true); template c10::intrusive_ptr collectiveCoalesced( From 7dde90d50e0582b201464b68f31d43db4f7ba1ba Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 26 Aug 2024 15:03:49 -0700 Subject: [PATCH 0086/1018] [3/N] Set correct device to CUDA guards (#134357) In `collective()`, `pointToPoint()` and `collectiveCoalesced()`, CUDA guards were created with an unset (default) CUDA device. This is the reason for the IMA facing the NaN checker in issue https://github.com/pytorch/pytorch/issues/134062. With this fix, `torch.cuda.set_device(device)` is not needed to work around the IMA. Also refactored a couple places where the guard is created -- preferably we create the guard with a known device, rather than setting the device later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134357 Approved by: https://github.com/wconstab, https://github.com/shuqiangzhang ghstack dependencies: #134300, #134345 --- test/distributed/test_c10d_nccl.py | 19 +++++++++++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 18 ++++++++++-------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 55b785a3552b5a..3a575729cdff78 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -350,6 +350,7 @@ def test_close_pg(self): @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16]) @skip_if_rocm def test_nan_assert(self, type): + # Expecting a device-side error when NaN is detected os.environ["TORCH_NCCL_NAN_CHECK"] = "1" store = c10d.FileStore(self.file_name, self.world_size) pg = self._create_process_group_nccl(store, self.opts()) @@ -388,6 +389,24 @@ def test_nan_p2p(self): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_nan_check(self): + # Not expecting an error, NaN check should not make legit code fail + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank + t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + c10d.broadcast(x, src=0) + c10d.all_reduce(t) + c10d.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 94f23d07b6322e..0d3b16eab127ad 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1162,6 +1162,10 @@ void ProcessGroupNCCL::abortCommsFromMap( at::cuda::OptionalCUDAGuard gpuGuard; at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); if (deviceIndex >= 0) { + // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in + // the map could be non deviceIndex, but rank to rank numbers. So we + // indeed need to check if deviceIndex >= 0 + // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` gpuGuard.set_index(deviceIndex); } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " @@ -2162,7 +2166,9 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( bool batchP2P = ncclActiveGroupCounter_ > 0; bool singleP2POp = isP2POp(opType, batchP2P); - at::cuda::OptionalCUDAGuard gpuGuard; + // Get the device index + auto deviceIndex = device.index(); + at::cuda::OptionalCUDAGuard gpuGuard(device); // [Group Start/End Note] This is used to ensure that nccl communicator will // be created before communication primitives are called. Let's look at this @@ -2202,10 +2208,6 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( rank = p2pRank; } - // Get the device index - auto deviceIndex = device.index(); - gpuGuard.set_index(deviceIndex); - #ifdef NCCL_HAS_COMM_SPLIT if (options_->split_from) { TORCH_CHECK( @@ -2715,7 +2717,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( work->stashed_for_allocator_safety_->push_back(input); } - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); if (nanCheck) { checkForNan(input, ncclStream); @@ -2880,7 +2882,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) @@ -3148,7 +3150,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); // Only check for NaN for send ops, for recv ops `tensor` can be a random // placeholder From 586d06bb8ae11c7f2ae69313cd83cd59f1b04ef5 Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Mon, 26 Aug 2024 23:48:46 +0000 Subject: [PATCH 0087/1018] [aotinductor][UserDefinedTritonKernel] fix case with non-constexpr params declared after autotuned params (#134520) ## Context In some user Triton kernels, we have this set-up for whatever reason. ``` @triton.jit def mykernel( param0, param1, param2, param3: tl.constexpr, # autotuned param4, # non-constexpr ): ... ``` This is an edge case because it's a general practice to declare all constexprs params at the end. And this will be an issue for AOTI because it fails to codegen all 4 params. That will surface as a device-side error: CUDA IMA, invalid argument... ``` > void* kernel_args_var_0[] = {&var_0, &var_1, &var_2}; --- < CUdeviceptr var_3; < AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(buf0, reinterpret_cast(&var_3))); < void* kernel_args_var_0[] = {&var_0, &var_1, &var_2, &var_3}; ``` ## Root-cause * `kernel.constexpr` from the Kernel side-table contains the indices for all `constexpr` params that includes autotuned params. * `raw_args`, that gets passed to wrapper codegen, excludes autotuned args. * In the wrapper codegen, we try to find non-constexpr args using `kernel.constexpr` & `raw_args`. This is okay unless there's a `raw_arg` after an autotuned param in the function signature. https://github.com/pytorch/pytorch/blob/79b7fff18820e4f62a02534a961d5930040e3475/torch/_inductor/codegen/cpp_wrapper_cuda.py#L118-L126 ## Fix We try to fix this, by calculating the right constexprs wrt `raw_args`. An illustration ``` raw_args: [arg0, arg1, arg2, arg4] kernel.arg_names: [param0, param1, param2, param3, param4] kernel.constexprs: [3] # param3 is autotuned; this is correct wrt kernel.arg_names constexpr_indices: [] # this is correct wrt raw_args ``` Differential Revision: [D61831625](https://our.internmc.facebook.com/intern/diff/D61831625) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134520 Approved by: https://github.com/oulgen --- test/inductor/test_aot_inductor.py | 22 ++++++++++++++++++++++ torch/_inductor/ir.py | 11 ++++++++++- torch/testing/_internal/triton_utils.py | 25 +++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 03c80c074cb5f5..4e19b2e6ed74f9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -51,6 +51,7 @@ add_kernel, add_kernel_2d_autotuned, add_kernel_autotuned, + add_kernel_autotuned_weird_param_order, add_kernel_with_optional_param, add_kernel_with_scaling, mul2_inplace_kernel, @@ -2236,6 +2237,27 @@ def forward(self, x, y): dynamic_shapes=dynamic_shapes, ) + def test_triton_kernel_weird_param_order(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + out = torch.empty_like(x) + add_kernel_autotuned_weird_param_order[16,]( + in_ptr0=x, + in_ptr1=x, + n_elements=x.numel(), + out_ptr=out, + ) + return out + + x = torch.randn(16, 16, device=self.device) + self.check_model(Model(), (x,)) + def test_shifted_constraint_ranges(self): class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 7d08b62f51b143..ec33f94e70d95a 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5103,10 +5103,19 @@ def codegen(self, wrapper): raw_args = [ self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel ] + + # NOTE: raw_args doesn't include autotuned args. + # But, kernel.constexprs includes indices of autotuned args. + # So, let's recalculate constexpr indices wrt to raw_args. + constexpr_indices = [] + for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel): + if kernel.arg_names.index(kwarg) in kernel.constexprs: + constexpr_indices.append(idx) + # Call to kernel self.codegen_comment(wrapper) wrapper.generate_user_defined_triton_kernel( - new_name, raw_args, self.grid, configs, triton_meta, kernel.constexprs + new_name, raw_args, self.grid, configs, triton_meta, constexpr_indices ) def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 1ac2dfec09d6bd..d3a8065f294047 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -78,6 +78,31 @@ def add_kernel_autotuned( output = x + y tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2), + ], + key=[], + ) + @triton.jit + def add_kernel_autotuned_weird_param_order( + in_ptr0, + in_ptr1, + n_elements, + BLOCK_SIZE: "tl.constexpr", + out_ptr, + ): + # out_ptr is after an autotuned param that's declared as tl.constexpr. + # This param ordering can create bugs if not handled correctly. + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( configs=[ triton.Config( From 1fa44d7313b3b7d0a32f17fc19eed57f28f115ba Mon Sep 17 00:00:00 2001 From: Jessica Vandebon Date: Tue, 27 Aug 2024 17:31:11 +0000 Subject: [PATCH 0088/1018] Integrate device agnostic APIs in FSDP library [1/n] (#134337) Summary: For MTIA FSDP support, we need to ensure the FSDP library code handles accelerator devices not limited to CUDA. Test Plan: CI Reviewed By: hanzlfs Differential Revision: D60587415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134337 Approved by: https://github.com/LucasLLC, https://github.com/awgu --- torch/distributed/fsdp/_common_utils.py | 14 ++++++++++---- torch/distributed/fsdp/_limiter_utils.py | 8 ++++---- .../fsdp/fully_sharded_data_parallel.py | 2 +- torch/distributed/fsdp/sharded_grad_scaler.py | 1 + 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index d722d5b9825999..396f058c45a1ae 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -87,13 +87,15 @@ def __init__(self, device: torch.device, backend: Any = None): @classmethod def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle": """ - Return an device handle corresponding to the device, and through this handle, + Return a device handle corresponding to the device, and through this handle, operations with the same semantics as CUDA can be performed on the device. Just return torch.cuda if the device is cuda to make attribute-access faster. Custom backend must first register a module with the same name with {device.type} on torch. """ if device.type == "cuda": return cast(_FSDPDeviceHandle, torch.cuda) + elif device.type == "mtia": + return cast(_FSDPDeviceHandle, torch.mtia) return cls(device) def __getattr__(self, __name: str) -> Any: @@ -142,7 +144,7 @@ def __init__(self) -> None: self._gradient_postdivide_factor: int = 0 self._comm_hook: Optional[Callable] = None self._comm_hook_state: Optional[Any] = None - self._unshard_event: Optional[torch.cuda.Event] = None + self._unshard_event: Optional[torch.Event] = None # Abstract device handle for fsdp compute device. For now, # the compute device must implement cuda semantics used by fsdp self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle() @@ -533,8 +535,12 @@ def forward_post_hook(module, args, output): def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None: - # FIXME record_stream doesn't work with non-cuda tensors - if tensor.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]: + # FIXME record_stream doesn't work with non-cuda/mtia tensors + if tensor.device.type not in [ + "cuda", + "mtia", + torch._C._get_privateuse1_backend_name(), + ]: return if torch.distributed._functional_collectives.is_torchdynamo_compiling(): diff --git a/torch/distributed/fsdp/_limiter_utils.py b/torch/distributed/fsdp/_limiter_utils.py index efb5b3ba5ae1f2..5cc56b29f84d4a 100644 --- a/torch/distributed/fsdp/_limiter_utils.py +++ b/torch/distributed/fsdp/_limiter_utils.py @@ -12,20 +12,20 @@ class _FreeEventQueue: """ def __init__(self) -> None: - self._queue: Deque[torch.cuda.Event] = collections.deque() + self._queue: Deque[torch.Event] = collections.deque() self._max_num_inflight_all_gathers = 2 # empirically chosen - def enqueue(self, free_event: torch.cuda.Event) -> None: + def enqueue(self, free_event: torch.Event) -> None: """Enqueues a free event.""" self._queue.append(free_event) - def dequeue_if_needed(self) -> Optional[torch.cuda.Event]: + def dequeue_if_needed(self) -> Optional[torch.Event]: """Dequeues a single event if the limit is reached.""" if len(self._queue) >= self._max_num_inflight_all_gathers: return self._dequeue() return None - def _dequeue(self) -> Optional[torch.cuda.Event]: + def _dequeue(self) -> Optional[torch.Event]: """Dequeues a free event if possible.""" if self._queue: event = self._queue.popleft() diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 92fa78e13ae9a1..bd8049b50ebf24 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -2044,7 +2044,7 @@ class UnshardHandle: def __init__( self, flat_param_handle: Optional[FlatParamHandle], - unshard_event: torch.cuda.Event, + unshard_event: torch.Event, ): self._flat_param_handle = flat_param_handle self._unshard_event = unshard_event diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 7c1b2f83528683..c4f1f1dae4e82c 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -21,6 +21,7 @@ def _is_supported_device(tensor: torch.Tensor) -> bool: "xla", "cpu", "hpu", + "mtia", torch._C._get_privateuse1_backend_name(), ) From a111fe571b10d103f34b8601bc4c6386f46062bc Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 27 Aug 2024 18:09:20 +0000 Subject: [PATCH 0089/1018] Use newer `toAccumulateType` signature in `Normalization.cpp` (#134540) Which fixes BatchNorm behavior for if called with empty tensors on MPS backed. Removed `expectedFailureMPS` in test_nn.py, deleted expected failure in `test_mps.py` and adjusted `skipIfMPS` to `expectedFailureMPS` in BatchNorm2d OpInfo decorator, but restrict it only to the memory format tests Test Plan: CI + `python3 -c "import torch; print(torch.nn.BatchNorm2d(3, device='mps')(torch.rand(0, 3, 2, 2, device='mps')))"` Fixes https://github.com/pytorch/pytorch/issues/134423 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134540 Approved by: https://github.com/Skylion007, https://github.com/albanD --- aten/src/ATen/native/Normalization.cpp | 2 +- test/test_mps.py | 3 --- test/test_nn.py | 1 - torch/testing/_internal/common_modules.py | 10 +++------- 4 files changed, 4 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 1fb64c32370509..bedf6eda03cedf 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -552,7 +552,7 @@ std::tuple _batch_norm_impl_index( if (input.sym_numel() == 0) { Tensor reserve = at::empty({0}, input.options().dtype(kByte)); auto options = input.options().dtype( - at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda())); + at::toAccumulateType(input.scalar_type(), input.device().type())); auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options); auto save_invstd = at::empty_symint(c10::SymIntArrayRef({std::move(num_features)}), options); diff --git a/test/test_mps.py b/test/test_mps.py index f4aa05177a7929..cd28de9dcae3c1 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -832,9 +832,6 @@ def mps_ops_modifier(ops): # Unsupported dtypes # bmm is not supported for integral types 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - # Cannot convert a MPS Tensor to float64 dtype. The tensors - # input data is created with double in common_methods_invocations.py - 'nn.functional.batch_norm': [torch.float32], 'ones_like': None, 'zeros_like': None, diff --git a/test/test_nn.py b/test/test_nn.py index 4bda984d736bc6..c9171f5bd653ce 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8948,7 +8948,6 @@ def check_rnn_grads(rnn1, rnn2): else: self.assertEqual(hx.grad, hx_device.grad) - @expectedFailureMPS # TypeError: float64 https://github.com/pytorch/pytorch/issues/134423 @dtypesIfMPS(torch.float) @dtypes(torch.double) def test_BatchNorm_empty(self, device, dtype): diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 43f1059eb91b36..d3629abde959f5 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -15,7 +15,7 @@ from torch.testing._internal.common_dtype import ( floating_types, floating_and_complex_types_and, get_all_fp_dtypes) from torch.testing._internal.common_device_type import ( - _TestParametrizer, _update_param_kwargs, toleranceOverride, tol, + _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol, skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, skipMPSVersionIfLessThan, skipCUDAVersionIn) from torch.testing._internal.common_methods_invocations import DecorateInfo @@ -3430,9 +3430,6 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad train_and_eval_differ=True, module_inputs_func=module_inputs_torch_nn_BatchNorm1d, skips=( - # test fails on MPS backend and is being investigated. - # See https://github.com/pytorch/pytorch/issues/100914 - DecorateInfo(skipMPS), # tracking here rather than in the list in test_aotdispatch.py as eval mode passes # RuntimeError: tried to get Double out of SymInt DecorateInfo( @@ -3451,9 +3448,8 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad train_and_eval_differ=True, module_inputs_func=module_inputs_torch_nn_BatchNorm2d, skips=( - # test fails on MPS backend and is being investigated. - # See https://github.com/pytorch/pytorch/issues/100914 - DecorateInfo(skipMPS), + # See https://github.com/pytorch/pytorch/issues/134580 + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'), # tracking here rather than in the list in test_aotdispatch.py as eval mode passes # RuntimeError: tried to get Double out of SymInt DecorateInfo( From 0e151c583a3872429d523025714bc5c1fe198d9f Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 26 Aug 2024 16:19:38 -0700 Subject: [PATCH 0090/1018] [dtensor] add test for local_map decorator (#127752) **Summary** This PR is a follow-up of #126924 to address reviewer's comments: 1) add a test case to show the use of `local_map` as a function decorator. 2) simplify the logic of handling different data types of `out_placements`. 3) correct variable naming in test cases to match math formulas. **Test** see #126924 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127752 Approved by: https://github.com/wanchaol --- .../_tensor/experimental/test_local_map.py | 113 ++++++++++-------- .../_tensor/experimental/func_map.py | 18 +-- 2 files changed, 75 insertions(+), 56 deletions(-) diff --git a/test/distributed/_tensor/experimental/test_local_map.py b/test/distributed/_tensor/experimental/test_local_map.py index b483194d6c3a82..391118b73b0897 100644 --- a/test/distributed/_tensor/experimental/test_local_map.py +++ b/test/distributed/_tensor/experimental/test_local_map.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] +from functools import partial import torch import torch.distributed._functional_collectives as funcol @@ -22,6 +23,11 @@ funcol_py = torch.ops.c10d_functional +row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh +col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh +replicate = [Replicate()] # replicate placements on 1-d mesh + + def equal_allgather_forward(device_mesh, X, Y): eq = torch.tensor([torch.equal(X, Y)], device=X.device) eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh) @@ -42,6 +48,16 @@ def mm_allreduce_forward(device_mesh, A, B): return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() +@partial( + local_map, + out_placements=replicate, + in_placements=(None, col_wise, row_wise), +) +def mm_allreduce_forward_decorated(device_mesh, A, B): + partial_sum_tensor = torch.mm(A, B) + return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() + + def mul_forward(X, scalar): # no device mesh needed since we don't do collective return torch.mul(X, scalar) @@ -59,20 +75,19 @@ def test_local_map_correctness(self): ) comm_mode = CommDebugMode() - # Y = W @ X - W = torch.randn(12, 8, device=self.device_type, requires_grad=False) - X = torch.randn(8, 16, device=self.device_type, requires_grad=False) - Y = torch.mm(W, X) + # Y = X @ W + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + Y = torch.mm(X, W) - row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh - col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh - replicate = [Replicate()] - W_dt = distribute_tensor( - W, device_mesh, col_wise - ) # col-wisely sharded W tensor X_dt = distribute_tensor( - X, device_mesh, row_wise - ) # row-wisely sharded X tensor + X, device_mesh, col_wise + ) # col-wisely sharded X tensor + W_dt = distribute_tensor( + W, device_mesh, row_wise + ) # row-wisely sharded W tensor + + # Test 1: use the function returned from calling local_map # get the function wrapped with DTensor/Tensor convertion # mm_allreduce_forward is a function that applies to Tensors with manual collective # local_mm_allreduce_forward is the function that does the same but applies to @@ -84,7 +99,19 @@ def test_local_map_correctness(self): device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) + + # output redistribution to Replicate + self.assertEqual(comm_mode.get_total_counts(), 1) + # check output placements + for placement in Y_dt.placements: + self.assertTrue(placement.is_replicate()) + # check output value + self.assertEqual(Y_dt.to_local(), Y) + + # Test 2: use the local_map decorator + with comm_mode: + Y_dt = mm_allreduce_forward_decorated(device_mesh, X_dt, W_dt) # output redistribution to Replicate self.assertEqual(comm_mode.get_total_counts(), 1) @@ -106,7 +133,6 @@ def test_local_map_out_placements(self): # X.equal(Y) X = torch.randn(8, 8, device=self.device_type, requires_grad=False) Y = torch.randn(8, 8, device=self.device_type, requires_grad=False) - row_wise = [Shard(0)] X_dt = distribute_tensor(X, device_mesh, row_wise) Y_dt = distribute_tensor(Y, device_mesh, row_wise) local_equal_allgather_forward = local_map( @@ -122,7 +148,6 @@ def test_local_map_out_placements(self): # Test 2: directly return out if no argument is DTensor # matmul in DDP - replicate = [Replicate()] X = torch.randn( 4 // self.world_size, 4, device=self.device_type, requires_grad=False ) @@ -151,17 +176,15 @@ def test_local_map_in_placements(self): ) comm_mode = CommDebugMode() - # Y = W @ X - W = torch.randn(12, 8, device=self.device_type, requires_grad=False) - X = torch.randn(8, 16, device=self.device_type, requires_grad=False) - Y = torch.mm(W, X) + # Y = X @ W + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + Y = torch.mm(X, W) - row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh - replicate = [Replicate()] # replicate placements on 1-d mesh - W_dt = distribute_tensor( - W, device_mesh, row_wise - ) # row-wisely sharded W tensor - X_dt = distribute_tensor(X, device_mesh, replicate) # replicate X tensor + X_dt = distribute_tensor( + X, device_mesh, row_wise + ) # row-wisely sharded X tensor + W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor # Test 1: explicitly pass `in_placements` local_mm_forward = local_map( @@ -171,7 +194,7 @@ def test_local_map_in_placements(self): device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_forward(W_dt, X_dt) + Y_dt = local_mm_forward(X_dt, W_dt) # no communication should occur in this case self.assertEqual(comm_mode.get_total_counts(), 0) @@ -186,7 +209,7 @@ def test_local_map_in_placements(self): device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_forward(W_dt, X_dt) + Y_dt = local_mm_forward(X_dt, W_dt) self.assertEqual(comm_mode.get_total_counts(), 0) for placement in Y_dt.placements: @@ -194,15 +217,16 @@ def test_local_map_in_placements(self): self.assertEqual(Y_dt.full_tensor(), Y) # Test 3: `None` placements for non-Tensor input argument + # Y = X * 2.0 local_mul_forward = local_map( mul_forward, in_placements=(row_wise, None), out_placements=row_wise, device_mesh=device_mesh, ) - Y = torch.mul(W, 2.0) + Y = torch.mul(X, 2.0) with comm_mode: - Y_dt = local_mul_forward(W_dt, 2.0) + Y_dt = local_mul_forward(X_dt, 2.0) self.assertEqual(comm_mode.get_total_counts(), 0) for placement in Y_dt.placements: @@ -210,12 +234,6 @@ def test_local_map_in_placements(self): self.assertEqual(Y_dt.full_tensor(), Y) # Test 4: `None` placements for Tensor input argument - X = torch.randn(16, 8, device=self.device_type, requires_grad=False) - W = torch.randn(8, 12, device=self.device_type, requires_grad=False) - X_dt = distribute_tensor( - X, device_mesh, row_wise - ) # row-wisely sharded X tensor - W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor local_mm_forward = local_map( mm_forward, out_placements=None, @@ -265,20 +283,17 @@ def test_local_map_redistribute(self): ) comm_mode = CommDebugMode() - # Y = W @ X - W = torch.randn(12, 8, device=self.device_type, requires_grad=False) - X = torch.randn(8, 16, device=self.device_type, requires_grad=False) - Y = torch.mm(W, X) + # Y = X @ W + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + Y = torch.mm(X, W) - row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh - col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh - replicate = [Replicate()] - W_dt = distribute_tensor( - W, device_mesh, row_wise - ) # row-wisely sharded W tensor which will be redistributed X_dt = distribute_tensor( - X, device_mesh, col_wise - ) # col-wisely sharded X tensor which will be redistributed + X, device_mesh, row_wise + ) # row-wisely sharded X tensor which will be redistributed + W_dt = distribute_tensor( + W, device_mesh, col_wise + ) # col-wisely sharded W tensor which will be redistributed # Test 1: allow input redistribution local_mm_allreduce_forward = local_map( @@ -289,7 +304,7 @@ def test_local_map_redistribute(self): redistribute_inputs=True, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) # 2 for input redistribution and 1 for output self.assertEqual(comm_mode.get_total_counts(), 3) @@ -306,7 +321,7 @@ def test_local_map_redistribute(self): redistribute_inputs=False, ) with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"): - Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) if __name__ == "__main__": diff --git a/torch/distributed/_tensor/experimental/func_map.py b/torch/distributed/_tensor/experimental/func_map.py index 667fbb484fb799..de417f2d8fbfe6 100644 --- a/torch/distributed/_tensor/experimental/func_map.py +++ b/torch/distributed/_tensor/experimental/func_map.py @@ -194,13 +194,17 @@ def wrapped(*args, **kwargs): flat_out, out_spec = pytree.tree_flatten(out) flat_dist_out = [] - for idx, out in enumerate(flat_out): - spec = ( - out_placements[idx] - if isinstance(out_placements, tuple) - else out_placements - ) - + out_placements_tuple = ( + out_placements + if isinstance(out_placements, tuple) + else (out_placements,) + ) + assert len(flat_out) == len(out_placements_tuple), ( + "local_map requires one PlacementType be provided for each output value," + f" received {len(out_placements_tuple)} out_placements but" + f" {len(flat_out)} is expected!" + ) + for out, spec in zip(flat_out, out_placements_tuple): if isinstance(out, torch.Tensor): assert not isinstance( out, DTensor From 345e25ee0e77a0396daf0c4ed482d42bebf29a5c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 27 Aug 2024 18:23:43 +0000 Subject: [PATCH 0091/1018] Revert "fix stuck floordiv (#134150)" This reverts commit 92c4771853892193d73d87bd60eca4dc7efc51d8. Reverted https://github.com/pytorch/pytorch/pull/134150 on behalf of https://github.com/anijain2305 due to compile time regression internal ([comment](https://github.com/pytorch/pytorch/pull/134150#issuecomment-2313230404)) --- .../benchmarks/sum_floordiv_benchmark.py | 39 ------------------- torch/utils/_sympy/functions.py | 10 ++--- 2 files changed, 3 insertions(+), 46 deletions(-) delete mode 100644 benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py deleted file mode 100644 index 227c0bd911ada3..00000000000000 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py +++ /dev/null @@ -1,39 +0,0 @@ -import sys - -from benchmark_base import BenchmarkBase - -import torch - - -class Benchmark(BenchmarkBase): - N = 100 - - def name(self): - return "sum_floordiv_regression" - - def description(self): - return "information at https://github.com/pytorch/pytorch/issues/134133" - - def prepare_once(self): - class M(torch.nn.Module): - def forward(self, x): - total = sum(t.item() for t in x) - return total // 2 - - self.m = M() - self.input = [torch.tensor(i + 2) for i in range(self.N)] - - def prepare(self): - torch._dynamo.reset() - - def work(self): - torch.export.export(self.m, (self.input,)) - - -def main(): - result_path = sys.argv[1] - Benchmark().enable_instruction_count().collect_all().append_results(result_path) - - -if __name__ == "__main__": - main() diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index d54495047e2713..0998b57678610a 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -186,8 +186,7 @@ def eval(cls, base, divisor): # Expands (x + y) // b into x // b + y // b. # This only works if floor is an identity, i.e. x / b is an integer. - base_args = sympy.Add.make_args(base) - for term in base_args: + for term in sympy.Add.make_args(base): quotient = term / divisor if quotient.is_integer and isinstance(divisor, sympy.Integer): # NB: this is correct even if the divisor is not an integer, but it @@ -196,11 +195,8 @@ def eval(cls, base, divisor): return FloorDiv(base - term, divisor) + quotient try: - # sympy.gcd tends to blow up on large sums, so use it on each summand instead - gcd, *gcds_ = (sympy.gcd(term, divisor) for term in base_args) - if not equal_valued(gcd, 1) and all( - equal_valued(gcd, gcd_) for gcd_ in gcds_ - ): + gcd = sympy.gcd(base, divisor) + if not equal_valued(gcd, 1): return FloorDiv( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) ) From 9c975b0bf18ef712983ec796febf26a544ad0b41 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 27 Aug 2024 18:24:27 +0000 Subject: [PATCH 0092/1018] [ROCM] Properly disable Flash Attention/Efficient Attention with environment variables (#133866) Now `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py` can compile correctly Fixes #125230 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133866 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily, https://github.com/malfet --- CMakeLists.txt | 10 ++++++++++ aten/src/ATen/native/transformers/cuda/sdp_utils.cpp | 11 +++++++++++ cmake/Dependencies.cmake | 4 ---- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 70ab320adff605..0318fcb4d1ec04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -889,6 +889,16 @@ cmake_dependent_option( Will be disabled if not supported by the platform" ON "USE_CUDA OR USE_ROCM" OFF) +# +# Cannot be put into Dependencies.cmake due circular dependency: +# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake +# +if(USE_ROCM) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) + include(cmake/External/aotriton.cmake) + endif() +endif() + if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo") diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index f093878ee3cda3..66178a8d8791bd 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -25,7 +25,10 @@ #include #if USE_ROCM +#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION) #include +#define USE_AOTRITON 1 +#endif #endif /** @@ -208,6 +211,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM +#if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -217,6 +221,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug } return false; } +#else + return false; +#endif #else auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { @@ -239,6 +246,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM +#if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -248,6 +256,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } +#else + return false; +#endif #else auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 8abea841fcf61c..21ff5f197ea836 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1103,10 +1103,6 @@ if(USE_ROCM) message(STATUS "Disabling Kernel Assert for ROCm") endif() - include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) - if(USE_CUDA) - caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) - endif() else() caffe2_update_option(USE_ROCM OFF) endif() From f9764f2dc72935fa26f16c6a017eea7f13990eb8 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 21 Aug 2024 13:39:17 -0700 Subject: [PATCH 0093/1018] [AOTI] Introduce DeferredCudaGridLine for cuda cpp wrapper (#129268) Summary: Similar to https://github.com/pytorch/pytorch/pull/129135, use DeferredCudaGridLine to create a deferred grid computation line when generating cpp wrapper. Differential Revision: [D61800622](https://our.internmc.facebook.com/intern/diff/D61800622) Pull Request resolved: https://github.com/pytorch/pytorch/pull/129268 Approved by: https://github.com/angelayi --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 13 +- torch/_inductor/codegen/cpp_wrapper_cuda.py | 191 ++++++++++++++------ torch/_inductor/codegen/wrapper.py | 11 +- 3 files changed, 153 insertions(+), 62 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index be69d0429f731b..0ae915cd927d8d 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -71,7 +71,7 @@ def __init__(self): def generate_kernel_call( self, - name, + kernel_name: str, call_args, grid=None, device_index=None, @@ -81,6 +81,7 @@ def generate_kernel_call( raw_args=None, grid_fn: str = "grid", triton_meta=None, + autotune_configs=None, grid_extra_kwargs="", ): """ @@ -94,14 +95,18 @@ def generate_kernel_call( """ if cuda: return super().generate_kernel_call( - name, + kernel_name, call_args, grid, device_index, cuda, triton, arg_types, + raw_args, grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, ) else: if config.abi_compatible: @@ -119,9 +124,9 @@ def generate_kernel_call( else: # arg is a scalar new_args.append(arg) - self.writeline(self.wrap_kernel_call(name, new_args)) + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) else: - self.writeline(self.wrap_kernel_call(name, call_args)) + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) def write_constant(self, name, hashed): # include a hash so our code cache gives different constants different files diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index d0143844f432ca..cc2baaa72c2c21 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -16,7 +16,7 @@ from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header -from .cpp_utils import DTYPE_TO_CPP +from .cpp_utils import cexpr, DTYPE_TO_CPP from .cpp_wrapper_cpu import CppWrapperCpu from .wrapper import SymbolicCallArg @@ -61,6 +61,98 @@ def _new_line(self, line): return DeferredCudaKernelLine(self.kernel_name, line, self.keys) +class DeferredCudaDefaultGrid: + """ + A marker to + """ + + def __init__( + self, + kernel_name: str, + grid, + grid_callable: Optional[Callable[..., Any]] = None, + **grid_extra_kwargs, + ): + self.kernel_name = kernel_name + self.grid = grid + self.grid_callable = grid_callable + self.grid_extra_kwargs = grid_extra_kwargs + + def __call__(self): + grid = self.grid + assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" + grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid] + grid_callable = self.grid_callable or default_grid + if not self.grid_extra_kwargs: + grid_fn = grid_callable(*grid) + else: + grid_fn = grid_callable(*grid, **self.grid_extra_kwargs) + + params = CudaKernelParamCache.get(self.kernel_name) + assert ( + params is not None + ), f"{self.kernel_name} not found in CudaKernelParamCache" + block_cfg = { + "XBLOCK": params["x_block"], + "YBLOCK": params["y_block"], + "ZBLOCK": params["z_block"], + } + return grid_fn(block_cfg) + + +class DeferredCudaGridLine(DeferredLineBase): + """ + When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred line to backfill those information + """ + + def __init__( + self, + kernel_name: str, + grid_var: str, + grid, + autotune_configs, + ): + super().__init__("") + self.kernel_name = kernel_name + self.grid_var = grid_var + self.grid = grid + self.autotune_configs = autotune_configs + + def __call__(self): + params = CudaKernelParamCache.get(self.kernel_name) + assert ( + params is not None + ), f"{self.kernel_name} not found in CudaKernelParamCache" + + if self.autotune_configs is not None: + # This indicates the Triton kernel is a user-defined one. + grid = None + if len(self.grid) == 1: + grid = self.grid[0] + else: + for i, c in enumerate(self.autotune_configs): + if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): + grid = self.grid[i] + break + assert grid is not None + elif isinstance(self.grid, DeferredCudaDefaultGrid): + grid = self.grid() + else: + grid = self.grid + + assert len(grid) != 0, "Grid can't be empty" + grid_args_str = ", ".join( + [cexpr(V.graph.sizevars.simplify(item)) for item in grid] + ) + return f"Grid {self.grid_var} = Grid({grid_args_str});" + + def _new_line(self, line): + return DeferredCudaGridLine( + self.kernel_name, self.grid_var, self.grid, self.autotune_configs + ) + + class CppWrapperCuda(CppWrapperCpu): """ Generates cpp wrapper for running on GPU and calls CUDA kernels @@ -116,7 +208,13 @@ def generate(self, is_inference): return super().generate(is_inference) def generate_user_defined_triton_kernel( - self, kernel_name, raw_args, grid, configs, triton_meta, constexprs + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, ): # in C++ wrapper, we don't pass constexpr args, as they don't # get added as parameters to the PTX code compiled from the @@ -124,20 +222,6 @@ def generate_user_defined_triton_kernel( raw_args = [ raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs ] - - assert len(grid) != 0 - if len(grid) == 1: - grid_decision = grid[0] - else: - meta = CudaKernelParamCache.get(kernel_name) - assert meta is not None - grid_decision = None - for i, c in enumerate(configs): - if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()): - grid_decision = grid[i] - break - assert grid_decision is not None - args = [self.val_to_arg_str(v) for v in raw_args] arg_types = [ arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg) @@ -147,10 +231,12 @@ def generate_user_defined_triton_kernel( kernel_name, args, arg_types=arg_types, - grid=grid_decision, + raw_args=raw_args, + grid=grid, cuda=True, triton=True, triton_meta=triton_meta, + autotune_configs=configs, ) @functools.lru_cache(None) # noqa: B019 @@ -228,7 +314,7 @@ def generate_args_decl(self, call_args, arg_types): def generate_default_grid( self, - name: str, + kernel_name: str, grid: List[Any], cuda: bool = True, grid_callable: Optional[Callable[..., Any]] = None, @@ -236,31 +322,19 @@ def generate_default_grid( ): """ Generate grid configs for launching a CUDA kernel using the grid - function from triton_heuristics. + function from triton_heuristics. Because its computation needs + to read kernel config after autotune, it is done in a deferred way + using DeferredCudaDefaultGrid. """ if not cuda: return grid - assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" - grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid] - grid_callable = grid_callable or default_grid - if not grid_extra_kwargs: - grid_fn = grid_callable(*grid) - else: - grid_fn = grid_callable(*grid, **grid_extra_kwargs) - params = CudaKernelParamCache.get(name) - assert ( - params is not None - ), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}" - block_cfg = { - "XBLOCK": params["x_block"], - "YBLOCK": params["y_block"], - "ZBLOCK": params["z_block"], - } - return grid_fn(block_cfg) + return DeferredCudaDefaultGrid( + kernel_name, grid, grid_callable, **grid_extra_kwargs + ) def generate_kernel_call( self, - kernel_name, + kernel_name: str, call_args, grid=None, device_index=None, @@ -270,6 +344,7 @@ def generate_kernel_call( raw_args=None, grid_fn: str = "grid", triton_meta=None, + autotune_configs=None, grid_extra_kwargs="", ): assert arg_types is not None and len(call_args) == len( @@ -279,7 +354,18 @@ def generate_kernel_call( if not cuda: # Even in CppWrapperCuda, we may see cpp kernels return super().generate_kernel_call( - kernel_name, call_args, grid, device_index, cuda, triton, arg_types + kernel_name, + call_args, + grid, + device_index, + cuda, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, ) device_index, call_args = self.prepare_triton_kernel_call( @@ -307,33 +393,26 @@ def generate_kernel_call( if V.graph.aot_mode else self.write_get_raw_stream(device_index, V.graph) ) - grid_name = f"{kernel_name}_grid_{next(self.grid_id)}" - assert isinstance( - grid, (list, tuple) - ), f"expected grid to be a list or tuple but got: {grid=}" - - grid = [V.graph.sizevars.simplify(item) for item in grid] - grid_uses_symbolic_shapes = any(item.free_symbols for item in grid) - grid_args = [self.expr_printer(item) for item in grid] - grid_args_str = ", ".join(grid_args) - self.writeline(f"Grid {grid_name} = Grid({grid_args_str});") - - if grid_uses_symbolic_shapes: - self.writeline(f"if ({grid_name}.is_non_zero()) {{") + + grid_var = f"{kernel_name}_grid_{next(self.grid_id)}" + self.writeline( + DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs) + ) + kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + self.writeline(f"if ({grid_var}.is_non_zero()) {{") self.writeline( DeferredCudaKernelLine( kernel_name, r"launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( kernel_var_name, - f"{grid_name}.grid_x", - f"{grid_name}.grid_y", - f"{grid_name}.grid_z", + f"{grid_var}.grid_x", + f"{grid_var}.grid_y", + f"{grid_var}.grid_z", kernel_args_var, stream, ), ("num_warps", "shared_mem"), ), ) - if grid_uses_symbolic_shapes: - self.writeline("}") + self.writeline("}") diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index fd260e9cb09964..09924822023d56 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -788,7 +788,13 @@ def generate_extern_kernel_out( self.writeline(f"{kernel}({', '.join(args)})") def generate_user_defined_triton_kernel( - self, kernel_name, raw_args, grid, configs, triton_meta, constexprs + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, ): grid_fn, code = user_defined_kernel_grid_fn_code( kernel_name, configs, grid, wrapper=self @@ -1541,7 +1547,7 @@ def generate_save_uncompiled_kernels(self): def generate_default_grid( self, - name: str, + kernel_name: str, grid: List[Any], cuda: bool = True, grid_callable: Optional[Callable[..., Any]] = None, @@ -1632,6 +1638,7 @@ def generate_kernel_call( raw_args=None, grid_fn: str = "grid", triton_meta=None, + autotune_configs=None, grid_extra_kwargs="", ): """ From 4842f1e78f295e5c805dd6b50850eff1aae5915d Mon Sep 17 00:00:00 2001 From: atalman Date: Tue, 27 Aug 2024 19:31:44 +0000 Subject: [PATCH 0094/1018] [CD] Fix docker builds by installing setuptools (#134595) Seeing failures like this: ``` #49 844.6 //build_scripts/manylinux1-check.py:6: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives ..... [python 3/3] RUN bash build_scripts/build.sh && rm -r build_scripts: 846.9 ...it did, yay. 846.9 + for PYTHON in '/opt/python/*/bin/python' 846.9 + /opt/python/cpython-3.12.0/bin/python build_scripts/manylinux1-check.py 847.0 Traceback (most recent call last): 847.0 File "//build_scripts/manylinux1-check.py", line 55, in 847.0 if is_manylinux1_compatible(): 847.0 ^^^^^^^^^^^^^^^^^^^^^^^^^^ 847.0 File "//build_scripts/manylinux1-check.py", line 6, in is_manylinux1_compatible 847.0 from distutils.util import get_platform 847.0 ModuleNotFoundError: No module named 'distutils' ------ ``` PR: https://github.com/pytorch/pytorch/pull/134455 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134595 Approved by: https://github.com/kit1980, https://github.com/seemethere, https://github.com/malfet --- .ci/docker/manywheel/build_scripts/build.sh | 2 ++ .github/workflows/build-manywheel-images.yml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/.ci/docker/manywheel/build_scripts/build.sh b/.ci/docker/manywheel/build_scripts/build.sh index 1708b71a19b5eb..fccdea91737d55 100644 --- a/.ci/docker/manywheel/build_scripts/build.sh +++ b/.ci/docker/manywheel/build_scripts/build.sh @@ -117,6 +117,8 @@ find /opt/_internal \ -print0 | xargs -0 rm -f for PYTHON in /opt/python/*/bin/python; do + # install setuptools since python 3.12 is required to use distutils + $PYTHON -m pip install setuptools==68.2.2 # Smoke test to make sure that our Pythons work, and do indeed detect as # being manylinux compatible: $PYTHON $MY_DIR/manylinux1-check.py diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index 6ee9395cae675f..8e64d612d7bb83 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -12,11 +12,13 @@ on: - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ paths: - '.ci/docker/manywheel/*' + - '.ci/docker/manywheel/build_scripts/*' - '.ci/docker/common/*' - .github/workflows/build-manywheel-images.yml pull_request: paths: - '.ci/docker/manywheel/*' + - '.ci/docker/manywheel/build_scripts/*' - '.ci/docker/common/*' - .github/workflows/build-manywheel-images.yml From ec30eaf9195cfbf8981b220d9f6e53747bb85ae7 Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Tue, 27 Aug 2024 19:37:30 +0000 Subject: [PATCH 0095/1018] [MPS] Remove superfluous label/link (#134090) This was probably intended to be a comment. I removed it since the issue is already linked in the warning below. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134090 Approved by: https://github.com/albanD --- aten/src/ATen/native/mps/operations/Indexing.mm | 1 - 1 file changed, 1 deletion(-) diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index e911281114802c..a13e660b9c857e 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -304,7 +304,6 @@ static Tensor nonzero_fallback(const Tensor& self) { if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && (self.numel() >= nonZeroMaxSize || self.is_complex())) { - https: // github.com/pytorch/pytorch/issues/122916 TORCH_WARN_ONCE("MPS: nonzero op is not natively supported for the provided input on MacOS14", "Falling back on CPU. This may have performance implications.", "See github.com/pytorch/pytorch/issues/122916 for further info"); From cd17b88ac28f1c0d7d662101307e8c82fa2633f4 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 27 Aug 2024 19:40:56 +0000 Subject: [PATCH 0096/1018] [inductor] calibration inductor windows uts (14/N) (#134585) skip UT for `test/dynamo/test_exc.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134585 Approved by: https://github.com/jansel --- test/dynamo/test_exc.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index c729cb1d3f1b2c..11d240924eae37 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -10,7 +10,12 @@ from torch._dynamo.comptime import comptime from torch._dynamo.exc import Unsupported from torch.testing._internal.common_device_type import skipIf -from torch.testing._internal.common_utils import IS_FBCODE, munge_exc, TEST_Z3 +from torch.testing._internal.common_utils import ( + IS_FBCODE, + munge_exc, + skipIfWindows, + TEST_Z3, +) from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test @@ -200,6 +205,10 @@ def fn001(x): translation_validation=True, translation_validation_no_bisect=True, ) + @skipIfWindows( + msg='AssertionError: "tran[551 chars]s1 s2 s3) s0)\n ==> (<= (+ s1 s2) (+ s0 (* -1[511 chars][0])' # noqa: PLR0133 + != 'tran[551 chars]s1 s2) (+ s0 (* -1 s3)))\n ==> (<= (+ s1 s2) [483 chars][0])"' + ) def test_trigger_on_error(self): from torch.fx.experimental.validator import ValidationException From fb7ad77f088c8e38fc3dac8f9a30530b85db442b Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 27 Aug 2024 19:41:15 +0000 Subject: [PATCH 0097/1018] [inductor] calibration inductor windows uts (17/N) (#134588) skip UTs for `test/inductor/test_minifier_isolate.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134588 Approved by: https://github.com/jansel --- test/inductor/test_minifier_isolate.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/inductor/test_minifier_isolate.py b/test/inductor/test_minifier_isolate.py index 34b9c9f3383666..5681498ef9cc6f 100644 --- a/test/inductor/test_minifier_isolate.py +++ b/test/inductor/test_minifier_isolate.py @@ -7,6 +7,7 @@ IS_JETSON, IS_MACOS, skipIfRocm, + skipIfWindows, TEST_WITH_ASAN, ) from torch.testing._internal.inductor_utils import HAS_CUDA @@ -33,6 +34,9 @@ def inner(x): @unittest.skipIf(IS_JETSON, "Fails on Jetson") @inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "runtime_error") + @skipIfWindows( + msg="Build Failed: fatal error C1083: Cannot open include file: 'Python.h': No such file or directory" + ) def test_after_aot_cpu_runtime_error(self): self._test_after_aot_runtime_error("cpu", "") From 3f456505b49e4d14e5c7646af367d1a9c41d359d Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 27 Aug 2024 19:45:02 +0000 Subject: [PATCH 0098/1018] [inductor] calibration inductor windows uts (16/N) (#134587) skip UT for `test/inductor/test_compiled_autograd.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134587 Approved by: https://github.com/jansel --- test/inductor/test_compiled_autograd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index f1abcb9677f29c..b31c5bea6a2435 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -22,6 +22,7 @@ from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA from torch.testing._internal.logging_utils import logs_to_string @@ -2341,6 +2342,7 @@ def fn(x, obj): sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs) ) + @skipIfWindows(msg="AssertionError: Scalars are not equal!") def test_verbose_logs_cpp(self): torch._logging.set_logs(compiled_autograd_verbose=True) From 760575f81112d8b0eb32d7f7cb1a669c2144dc3b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 20 Aug 2024 16:10:31 -0700 Subject: [PATCH 0099/1018] Update pt2e numeric debugger to use node.meta["custom"] field (#134040) Summary: With https://github.com/pytorch/pytorch/pull/131912 we now have a "custom" field in node.meta that can be preserved in * copy/deepcopy * run_decompositions() * serialization * re-exporting So we refactored numeric debugger to use this. Test Plan: python test/test_quantization.py TestNumericDebugger Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/134040 Approved by: https://github.com/tarun292 --- docs/source/quantization-support.rst | 1 + .../pt2e/test_numeric_debugger.py | 40 +++++++++++++++---- torch/ao/quantization/__init__.py | 2 + torch/ao/quantization/fx/convert.py | 15 ++++--- .../ao/quantization/pt2e/_numeric_debugger.py | 20 +++++++--- torch/ao/quantization/pt2e/prepare.py | 12 ++++-- 6 files changed, 69 insertions(+), 21 deletions(-) diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index 44597d867b4970..a15e901bfd54c9 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -154,6 +154,7 @@ PT2 Export (pt2e) Numeric Debugger :template: classtemplate.rst generate_numeric_debug_handle + CUSTOM_KEY NUMERIC_DEBUG_HANDLE_KEY prepare_for_propagation_comparison extract_results_from_loggers diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 2a67f6e23097fd..094193f6cf00b4 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -9,6 +9,7 @@ from torch._export import capture_pre_autograd_graph from torch.ao.quantization import ( compare_results, + CUSTOM_KEY, extract_results_from_loggers, generate_numeric_debug_handle, NUMERIC_DEBUG_HANDLE_KEY, @@ -27,8 +28,13 @@ def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]: debug_handle_map: Dict[torch.fx.Node, int] = {} for node in model.graph.nodes: - if NUMERIC_DEBUG_HANDLE_KEY in node.meta: - debug_handle_map[str(node)] = node.meta[NUMERIC_DEBUG_HANDLE_KEY] + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ + NUMERIC_DEBUG_HANDLE_KEY + ] return debug_handle_map @@ -47,8 +53,8 @@ def test_simple(self): unique_ids = set() count = 0 for n in m.graph.nodes: - if NUMERIC_DEBUG_HANDLE_KEY in n.meta: - unique_ids.add(n.meta[NUMERIC_DEBUG_HANDLE_KEY]) + if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]: + unique_ids.add(n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]) count += 1 self.assertEqual(len(unique_ids), count) @@ -69,9 +75,9 @@ def test_quantize_pt2e_preserve_handle(self): m = prepare_pt2e(m, quantizer) debug_handle_map = _extract_debug_handles(m) res_counter = Counter(debug_handle_map.values()) - repeated_debug_handle_ids = [2, 3, 6] + repeated_debug_handle_ids = [3, 4, 7] # 3 ids were repeated because we copy over the id from node to its output observer - # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim, torch.ops.aten.conv1d.default + # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default for dh_id in repeated_debug_handle_ids: self.assertEqual(res_counter[dh_id], 2) @@ -81,7 +87,7 @@ def test_quantize_pt2e_preserve_handle(self): res_counter = Counter(debug_handle_map.values()) # same set of ids where repeated, because we copy over the id from observer/fake_quant to # dequantize node - repeated_debug_handle_ids = [2, 3, 6] + repeated_debug_handle_ids = [3, 4, 7] for dh_id in repeated_debug_handle_ids: self.assertEqual(res_counter[dh_id], 2) @@ -123,6 +129,26 @@ def test_re_export_preserve_handle(self): self.assertEqual(debug_handle_map, debug_handle_map_ref) + @unittest.skip( + "All nodes' meta are preserved but the first arg for the first node seems to be dropped" + ) + def test_run_decompositions_preserve_handle(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + m = torch.export.export(m, example_inputs) + generate_numeric_debug_handle(m) + + debug_handle_map_ref = _extract_debug_handles(m) + + m_copy = copy.copy(m) + m_copy = m_copy.run_decompositions() + debug_handle_map = _extract_debug_handles(m_copy) + + # checking the map still has the same ids, the node may change + self.assertEqual( + set(debug_handle_map.values()), set(debug_handle_map_ref.values()) + ) + def test_prepare_for_propagation_comparison(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index adf70f79e20443..840d2715c7717f 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -11,6 +11,7 @@ from .observer import * # noqa: F403 from .pt2e._numeric_debugger import ( # noqa: F401 compare_results, + CUSTOM_KEY, extract_results_from_loggers, generate_numeric_debug_handle, NUMERIC_DEBUG_HANDLE_KEY, @@ -162,6 +163,7 @@ "swap_module", "weight_observer_range_neg_127_to_127", "generate_numeric_debug_handle", + "CUSTOM_KEY", "NUMERIC_DEBUG_HANDLE_KEY", "prepare_for_propagation_comparison", "extract_results_from_loggers", diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 4502d180cfc655..83487374845330 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch -from torch.ao.quantization import NUMERIC_DEBUG_HANDLE_KEY +from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY from torch.ao.quantization.backend_config import ( BackendConfig, get_native_backend_config, @@ -229,10 +229,15 @@ def add_dequantize_op_kwargs(dequantize_op, input_node): node.replace_all_uses_with(dequantized_node) # propagate numeric debug handle from observer/fake_quant node to dequantize node - if NUMERIC_DEBUG_HANDLE_KEY in node.meta: - dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ - NUMERIC_DEBUG_HANDLE_KEY - ] + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in dequantized_node.meta: + dequantized_node.meta[CUSTOM_KEY] = {} + dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] graph.erase_node(node) elif is_dynamic: # uint8/int8/fp16 dynamic quantization diff --git a/torch/ao/quantization/pt2e/_numeric_debugger.py b/torch/ao/quantization/pt2e/_numeric_debugger.py index 1bcc442ca441a9..cff3b6d2aa3216 100644 --- a/torch/ao/quantization/pt2e/_numeric_debugger.py +++ b/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -9,7 +9,8 @@ from torch.nn import functional as F -NUMERIC_DEBUG_HANDLE_KEY = "_numeric_debug_handle" +NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" +CUSTOM_KEY = "custom" log = logging.getLogger(__name__) @@ -20,8 +21,14 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None: """ unique_id = 0 for node in graph_module.graph.nodes: - if node.op != "placeholder" and NUMERIC_DEBUG_HANDLE_KEY not in node.meta: - node.meta[NUMERIC_DEBUG_HANDLE_KEY] = unique_id + if node.op in ["output", "placehodler"]: + continue + + if CUSTOM_KEY not in node.meta: + node.meta[CUSTOM_KEY] = {} + + if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]: + node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id unique_id += 1 @@ -98,9 +105,12 @@ def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: # don't change the original model model = copy.deepcopy(model) for n in model.graph.nodes: - if NUMERIC_DEBUG_HANDLE_KEY not in n.meta: + if ( + CUSTOM_KEY not in n.meta + or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY] + ): continue - numeric_debug_handle = n.meta[NUMERIC_DEBUG_HANDLE_KEY] + numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] _insert_logger(model, n, numeric_debug_handle) model.recompile() diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index d758870017ca59..88ab2be9ff5a83 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -4,6 +4,7 @@ import torch from torch._subclasses import FakeTensor from torch.ao.quantization import ( + CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY, ObserverOrFakeQuantize, QConfigMapping, @@ -459,11 +460,14 @@ def _maybe_insert_output_observer_for_node( if ( isinstance(node, Node) and isinstance(new_output, Node) - and NUMERIC_DEBUG_HANDLE_KEY in node.meta + and CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] ): - new_output.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ - NUMERIC_DEBUG_HANDLE_KEY - ] + if CUSTOM_KEY not in new_output.meta: + new_output.meta[CUSTOM_KEY] = {} + new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] return new_output return None From 6ede9795a135ab67383eb4c5661a0a9e94d243df Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 27 Aug 2024 19:52:44 +0000 Subject: [PATCH 0100/1018] [Doc] Fix rendering of the unicode characters (#134597) https://github.com/pytorch/pytorch/pull/124771 introduced unicode escape sequences inside raw strings, which were not rendered correctly. Also fix typo in `\uue0 ` escape sequence (should have been `\u00e0`) Fix it by relying on [string literal concatenation](https://docs.python.org/3/reference/lexical_analysis.html#string-literal-concatenation) to join raw and regular strings together during lexical analysis stage Fixes https://github.com/pytorch/pytorch/issues/134422 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134597 Approved by: https://github.com/aorenste, https://github.com/Skylion007 --- torch/nn/modules/adaptive.py | 4 ++-- torch/nn/modules/conv.py | 36 +++++++++++++++++++++++++----------- torch/nn/modules/fold.py | 8 ++++---- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 2a3b24b9b280bf..cfcc447625ea3b 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -18,13 +18,13 @@ class AdaptiveLogSoftmaxWithLoss(Module): - r"""Efficient softmax approximation. + """Efficient softmax approximation. As described in `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou `__. - +""" r""" Adaptive softmax is an approximate strategy for training models with large output spaces. It is most effective when the label distribution is highly imbalanced, for example in natural language modelling, where the word diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index ccb628dff6a313..8501d3f79d2df6 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -241,11 +241,13 @@ class Conv1d(_ConvNd): * :attr:`padding` controls the amount of padding applied to the input. It can be either a string {{'valid', 'same'}} or a tuple of ints giving the amount of implicit padding applied on both sides. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also - known as the \uue0 trous algorithm. It is harder to describe, but this `link`_ + known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - +""" + r""" {groups_note} Note: @@ -404,10 +406,13 @@ class Conv2d(_ConvNd): * :attr:`padding` controls the amount of padding applied to the input. It can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the amount of implicit padding applied on both sides. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" {groups_note} @@ -574,9 +579,12 @@ class Conv3d(_ConvNd): * :attr:`padding` controls the amount of padding applied to the input. It can be either a string {{'valid', 'same'}} or a tuple of ints giving the amount of implicit padding applied on both sides. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" {groups_note} @@ -834,10 +842,12 @@ class ConvTranspose1d(_ConvTransposeNd): * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. - +""" + r""" {groups_note} Note: @@ -996,10 +1006,12 @@ class ConvTranspose2d(_ConvTransposeNd): * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. - +""" + r""" {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` @@ -1184,10 +1196,12 @@ class ConvTranspose3d(_ConvTransposeNd): * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. - +""" + r""" {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index 54d97d5a58bf10..58397caec32c8c 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -42,10 +42,10 @@ class Fold(Module): * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension before reshaping. - +""" """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - +""" r""" Args: output_size (int or tuple): the shape of the spatial dimensions of the output (i.e., ``output.sizes()[2:]``) @@ -194,10 +194,10 @@ class Unfold(Module): * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension before reshaping. - +""" """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - +""" r""" Args: kernel_size (int or tuple): the size of the sliding blocks dilation (int or tuple, optional): a parameter that controls the From aac615ebef283fd5079d26f3835c931fb063d3fc Mon Sep 17 00:00:00 2001 From: Ivan Duka <39438519+ivanduka@users.noreply.github.com> Date: Tue, 27 Aug 2024 19:56:05 +0000 Subject: [PATCH 0101/1018] Links to contributors' GitHub accounts (#133787) Maintainers have the links to their GitHub profiles, but the major contributors do not have them. I added the links to the contributors' GitHub accounts in case anyone wants to follow them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133787 Approved by: https://github.com/albanD --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d2945520cd36f5..d7567419acb77b 100644 --- a/README.md +++ b/README.md @@ -471,7 +471,7 @@ To learn more about making a contribution to Pytorch, please see our [Contributi PyTorch is a community-driven project with several skillful engineers and researchers contributing to it. PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. -A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito. +A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jamesb93), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. From 2830791350f48505227115a159024e96b8cb154f Mon Sep 17 00:00:00 2001 From: xpfjmj Date: Tue, 27 Aug 2024 19:57:55 +0000 Subject: [PATCH 0102/1018] Add a cpu_dispatch_key parameter to the cpu_fallback function (#134321) Fixes #134322 Add a cpu_dispatch_key parameter to the cpu_fallback function to support fallback, for example, to SparseCPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134321 Approved by: https://github.com/albanD --- aten/src/ATen/native/CPUFallback.cpp | 8 ++++++-- aten/src/ATen/native/CPUFallback.h | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/CPUFallback.cpp b/aten/src/ATen/native/CPUFallback.cpp index c1dc1f3a5eec6b..78222317a889ac 100644 --- a/aten/src/ATen/native/CPUFallback.cpp +++ b/aten/src/ATen/native/CPUFallback.cpp @@ -87,7 +87,11 @@ static bool validate_tensor_list(const c10::List& tensorlist) { return flag; } -void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) { +void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views, + c10::DispatchKey cpu_dispatch_key) { + TORCH_CHECK(c10::BackendComponent::CPUBit == c10::toBackendComponent(cpu_dispatch_key), + "Expected CPU backend DispatchKey but got ", + c10::toString(cpu_dispatch_key)); auto& schema_args = op.schema().arguments(); const auto num_arguments = schema_args.size(); auto arguments = torch::jit::last(stack, num_arguments); @@ -143,7 +147,7 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool } // Step 2: Call the underlying CPU implementation of the operator - op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack); + op.redispatchBoxed(c10::DispatchKeySet(cpu_dispatch_key), stack); // Step 3: We need to take special care to handle mutable aliases properly: // If any input tensors are mutable aliases, we need to diff --git a/aten/src/ATen/native/CPUFallback.h b/aten/src/ATen/native/CPUFallback.h index 606901fe1926fb..44cb534b8db2c9 100644 --- a/aten/src/ATen/native/CPUFallback.h +++ b/aten/src/ATen/native/CPUFallback.h @@ -11,7 +11,8 @@ namespace at::native { // This function implements a boxed fallback to CPU. // External backends can add their own custom logging on top if it to customize their own CPU fallbacks. -TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false); +TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false, + c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU); // This is a helper function that backends can use to directly call their boxed CPU fallback // TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands. From 733d4d3597f17d37934fe34c5022d68d3a5fa216 Mon Sep 17 00:00:00 2001 From: Moritz Hennen Date: Tue, 27 Aug 2024 20:09:26 +0000 Subject: [PATCH 0103/1018] Do not compute unnecessary `tensor!=0` for bool tensors in `count_nonzero` (#134254) Updated aten/src/ATen/native/TensorAdvancedIndexing.cpp to only reduce non-bool tensors before computing a sum Since I have no expertise for MPS, I did leave the MPS backend untouched. Also, in `count_nonzero_impl` for CPU, I assumed the comparison can be optimized by the compiler for boolean values? https://github.com/pytorch/pytorch/blob/90c821814ed8bfafb7724cab6c7b7e8416150dee/aten/src/ATen/native/TensorAdvancedIndexing.cpp#L2262-L2264 Fixes #133983 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134254 Approved by: https://github.com/albanD --- aten/src/ATen/native/TensorAdvancedIndexing.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 105636c0a6b746..a3e11077a6fd5d 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -2284,12 +2284,20 @@ int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) { } Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){ - return (self != 0).sum(dims); + auto reduce = self; + if (reduce.scalar_type() != kBool) { + reduce = reduce != 0; + } + return reduce.sum(dims); } Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){ if (!dims.empty()) { - return (self != 0).sum(dims); + auto reduce = self; + if (reduce.scalar_type() != kBool) { + reduce = reduce != 0; + } + return reduce.sum(dims); } // Optimized all-reduce From ae98256eccfdce400d1ce505a2adc640dfbe80d0 Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Tue, 27 Aug 2024 20:15:09 +0000 Subject: [PATCH 0104/1018] [easy] rm duplicate definition for inductor in TORCH_LOGS documentation (#134480) already defined in https://github.com/pytorch/pytorch/blob/2eb9339b71b0d43477ba33e235c3c1a2fad562cf/torch/_logging/_internal.py#L286-L287 Test Plan: Sandcastle run Differential Revision: D61806088 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134480 Approved by: https://github.com/eellison, https://github.com/mlazos --- torch/_logging/_internal.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 1cf37a2a57b084..f079eb86804282 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -322,9 +322,6 @@ def set_logs( aot_joint_graph (:class:`bool`): Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` - inductor (:class:`Optional[int]`): - Whether to log information from inductor cudagraphs. Default: ``logging.WARN`` - ddp_graphs (:class:`bool`): Whether to emit graphs generated by DDPOptimizer. Default: ``False`` From c4debae8185c7fc250fe855ee08dadd2d40d2ce3 Mon Sep 17 00:00:00 2001 From: Yuanhao Ji Date: Tue, 27 Aug 2024 20:28:27 +0000 Subject: [PATCH 0105/1018] [Fix] Check name when registering privateuse1 backend (#134071) do some checks when registering privateuse1 backend to avoid using in-tree deivce names Pull Request resolved: https://github.com/pytorch/pytorch/pull/134071 Approved by: https://github.com/albanD --- c10/core/DeviceType.cpp | 8 ++++++++ c10/test/core/Device_test.cpp | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index 1543db0a39bcf9..70d3d7489e548d 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -145,6 +146,13 @@ void register_privateuse1_backend(const std::string& backend_name) { "torch.register_privateuse1_backend() has already been set! Current backend: ", privateuse1_backend_name); + static const std::array types = { + "cpu", "cuda", "hip", "mps", "xpu", "mtia"}; + TORCH_CHECK( + std::find(types.begin(), types.end(), backend_name) == types.end(), + "Cannot register privateuse1 backend with in-tree device name: ", + backend_name); + privateuse1_backend_name = backend_name; // Invariant: once this flag is set, privateuse1_backend_name is NEVER written // to. diff --git a/c10/test/core/Device_test.cpp b/c10/test/core/Device_test.cpp index 0524ad08e8ec37..3874ea98f2d379 100644 --- a/c10/test/core/Device_test.cpp +++ b/c10/test/core/Device_test.cpp @@ -54,3 +54,12 @@ TEST(DeviceTypeTest, PrivateUseOneDeviceType) { ASSERT_EQ(c10::get_privateuse1_backend(true), "my_privateuse1_backend"); ASSERT_EQ(c10::get_privateuse1_backend(false), "MY_PRIVATEUSE1_BACKEND"); } + +TEST(DeviceTypeTest, PrivateUseOneRegister) { + ASSERT_THROW(c10::register_privateuse1_backend("cpu"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("cuda"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("hip"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("mps"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("xpu"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("mtia"), c10::Error); +} From a3f8bf29ca1b209dde3d0e1b3355c39bd7770169 Mon Sep 17 00:00:00 2001 From: Bo Li <110066325+BLOrange-AMD@users.noreply.github.com> Date: Tue, 27 Aug 2024 20:40:53 +0000 Subject: [PATCH 0106/1018] Exclude test_transformers and unit tests which require recent GPU arch (#132895) This PR is to exclude test_transformers on ROCm temporarily and skip some unit tests which require recent GPU arch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132895 Approved by: https://github.com/jithunnair-amd, https://github.com/pruthvistony, https://github.com/malfet --- .../tensor/parallel/test_micro_pipeline_tp.py | 4 ++++ test/run_test.py | 1 + torch/testing/_internal/common_utils.py | 17 +++++++++++++++++ 3 files changed, 22 insertions(+) diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index 327359a5807cb5..951e77188364c3 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -29,8 +29,10 @@ ) from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] instantiate_parametrized_tests, + MI300_ARCH, parametrize, run_tests, + runOnRocmArch, TestCase, ) from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule @@ -228,6 +230,7 @@ def func(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor: self.assertIn("fused_all_gather_matmul", code) self.assertNotIn("all_gather_into_tensor", code) + @runOnRocmArch(MI300_ARCH) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("gather_dim", [0, 1, 2]) @@ -324,6 +327,7 @@ def func(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: self.assertIn("fused_matmul_reduce_scatter", code) self.assertNotIn("reduce_scatter_tensor", code) + @runOnRocmArch(MI300_ARCH) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("scatter_dim", [0, 1, 2]) diff --git a/test/run_test.py b/test/run_test.py index 231a1b2b7ca012..b459b577947693 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -183,6 +183,7 @@ def __contains__(self, item): "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", "distributed/_tensor/test_attention", + "test_transformers", ] XPU_BLOCKLIST = [ diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 764772cc767989..4fb13abb5b96d3 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -103,6 +103,10 @@ except ImportError: has_pytest = False + +MI300_ARCH = ("gfx940", "gfx941", "gfx942") + + def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -1759,6 +1763,19 @@ def wrapper(*args, **kwargs): raise unittest.SkipTest("test currently only works on the ROCm stack") return wrapper +def runOnRocmArch(arch: Tuple[str, ...]): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] not in arch: + reason = f"skipIfRocm: test only runs on {arch}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"): def dec_fn(fn): reason = f"skipIfXpu: {msg}" From 5c6dbe3a1eefffe1908598ae39f3df2731b89c6e Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 27 Aug 2024 09:47:02 -0700 Subject: [PATCH 0107/1018] [dynamo][dicts] Support hasattr on dicts (#134590) Fixes - https://github.com/pytorch/pytorch/issues/134577 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134590 Approved by: https://github.com/Skylion007 ghstack dependencies: #134039 --- test/dynamo/test_functions.py | 16 ++++++++++++++++ torch/_dynamo/variables/dicts.py | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 67ab47dfac9093..f5cbfb8760c307 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1683,6 +1683,22 @@ def test_dict_sorted(x): tmp = {1: "D", 10: "B", 3: "E", 0: "F"} return x + 1, sorted(tmp), sorted(tmp, reverse=True) + def test_dict_hasattr(self): + def fn(x): + if hasattr(x, "to"): + return x.to("cpu") + if hasattr(x, "items"): + return torch.cos(x["a"]) + return x + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = dict(a=torch.randn(3)) + self.assertEqual(fn(x), opt_fn(x)) + + x = torch.randn(4) + self.assertEqual(fn(x), opt_fn(x)) + @make_test def test_list_clear(a, b): tmp = [a + 1, a + 2] diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 2d4c015bf3ac1b..7bcae6887cf6cc 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -355,6 +355,13 @@ def call_method( def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] + def call_hasattr(self, tx, name): + # dict/OrderedDict do not allow setting arbitrary attributes. To check for hasattr, we can just check the + # __dict__ of the cls. + if name in self.user_cls.__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: From 10463467cbf775841e50d95d68c057524b983105 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 27 Aug 2024 20:44:01 +0000 Subject: [PATCH 0108/1018] Remove unused imported names in python files (#134438) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134438 Approved by: https://github.com/zou3519 --- torch/_awaits/__init__.py | 2 +- torch/_deploy.py | 2 -- torch/_lazy/__init__.py | 1 - torch/_prims/__init__.py | 7 ++----- torch/_prims_common/__init__.py | 3 +-- torch/_refs/linalg/__init__.py | 3 +-- torch/linalg/__init__.py | 2 -- 7 files changed, 5 insertions(+), 15 deletions(-) diff --git a/torch/_awaits/__init__.py b/torch/_awaits/__init__.py index 3770557d6f263d..b08067bdcf45a1 100644 --- a/torch/_awaits/__init__.py +++ b/torch/_awaits/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import cast, Callable, Generic, Type, TypeVar +from typing import Generic, TypeVar import torch diff --git a/torch/_deploy.py b/torch/_deploy.py index 3f8adc420672db..0443a2447d00dd 100644 --- a/torch/_deploy.py +++ b/torch/_deploy.py @@ -24,10 +24,8 @@ def persistent_id(obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case - storage = obj._untyped_storage dtype = obj.dtype else: - storage = obj dtype = torch.uint8 serialized_storages.append(obj) diff --git a/torch/_lazy/__init__.py b/torch/_lazy/__init__.py index c074abd143723e..8d90efa40e5884 100644 --- a/torch/_lazy/__init__.py +++ b/torch/_lazy/__init__.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import threading import torch._C._lazy from torch.utils._pytree import tree_flatten, tree_unflatten diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 977d664754982e..26b3f298e36187 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1,16 +1,13 @@ # mypy: allow-untyped-defs -import contextlib -import itertools import operator -import weakref from enum import Enum from functools import partial, reduce -from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union +from typing import Callable, List, Optional, Sequence, Tuple, Type, Union import torch import torch._prims_common as utils import torch.library -from torch import sym_float, Tensor, TypedStorage +from torch import sym_float, Tensor from torch._C import _get_default_device from torch._higher_order_ops.effects import new_token_tensor from torch._library.utils import is_functional_schema diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 065ec2eff2384d..61d0ba13b88f15 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -3,10 +3,9 @@ import operator import warnings -import weakref from contextlib import nullcontext from enum import Enum -from functools import cmp_to_key, reduce +from functools import reduce from typing import ( Any, Callable, diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 77a635ea9ef69b..6585f57e3d643b 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch._prims as prims @@ -15,7 +15,6 @@ DimsType, ELEMENTWISE_TYPE_PROMOTION_KIND, IntLike, - NumberType, TensorLikeType, ) from torch._prims_common.wrappers import ( diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 3377d30c665477..f62569f7a9557e 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -1,5 +1,3 @@ -import sys - import torch from torch._C import _add_docstr, _linalg # type: ignore[attr-defined] From 19676afdc847bff9dc2d517087a0798832a79a5b Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 27 Aug 2024 07:57:33 -0700 Subject: [PATCH 0109/1018] Update label_to_label with oncall: pt2 hierarchy. (#134582) Test Plan: - None Pull Request resolved: https://github.com/pytorch/pytorch/pull/134582 Approved by: https://github.com/clee2000 --- .github/label_to_label.yml | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/.github/label_to_label.yml b/.github/label_to_label.yml index e6c66a5e56cf68..9f3ed07a544c70 100644 --- a/.github/label_to_label.yml +++ b/.github/label_to_label.yml @@ -1,13 +1,46 @@ # Use this to auto apply labels based on other labels. Applies to both PRs and # issues. Currently only supports any and all - any: - - "module: custom operators" + - "module: opcheck" + then: + - "module: custom-operators" +- any: + - "module: custom-operators" + - "module: functionalization" - "module: aotdispatch" + - "module: higher order operators" + - "module: fakeTensor" + - "module: ProxyTensor" + - "module: library" + - "module: reinplacing" then: - "module: pt2-dispatcher" +- any: + - "module: vmap" + then: + - "module: functorch" +- any: + - "module: reinplacing" + then: + - "module: inductor" +- any: + - "module: pt2 optimizer" + then: + - "module: dynamo" +- any: + - "module: flex attention" + then: + - "module: higher order operators" - any: - "module: dynamo" - "module: pt2-dispatcher" - "module: inductor" + - "module: aotinductor" + - "module: cudagraphs" + - "oncall: export" + - "module: startup-tracing-compile" + - "module: compiled autograd" + - "module: flex attention" + - "module: dynamic shapes" then: - "oncall: pt2" From 84f8420e75a7881ec4f0fe576e093209a718e344 Mon Sep 17 00:00:00 2001 From: David Berard Date: Mon, 26 Aug 2024 11:15:58 -0700 Subject: [PATCH 0110/1018] Move attention kernels back from fake_impls to meta_registrations (#134288) See #121528 for additional context. In #120682, we moved the attention kernels from meta_registrations to fake_impls with the intent of fixing the device handling for seed/offset: these are typically on CPU. We needed to put the registrations in fake_impls to do this because meta_registrations doesn't have a way to specify device, whereas fake_impls does. But when we tried to actually fix the device types (#120839), we had to revert the PR because it broke cudagraph handling (during which seed/offset _are_ on CUDA). Now, we want to put the registrations back in meta_registrations so that we can call these kernels with meta tensors. The use case is later in this stack - we want to be able to use the flop counter with these kernels. Also - I specifically skip the `compare_tensor_meta()` check in test_fake / test_fake_autocast tests for the `_efficient_attention_forward` and `_flash_attention_forward` kernels, which fails because of the device mismatch from the seed/offset tensors. Then we can un-skip these opinfos. I verified that the efficient_attention_forward bug (#120842) is now caught by these opinfos if I revert the fix from this PR. Differential Revision: [D61687369](https://our.internmc.facebook.com/intern/diff/D61687369) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134288 Approved by: https://github.com/drisspg --- test/test_ops.py | 13 +- torch/_meta_registrations.py | 258 +++++++++++++++ torch/_subclasses/fake_impls.py | 313 ------------------ torch/_subclasses/fake_utils.py | 6 + .../_internal/common_methods_invocations.py | 67 ++-- 5 files changed, 301 insertions(+), 356 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index e3627d62ab94b3..e1b5ebfe87d248 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2476,9 +2476,16 @@ def map_to_fake(e): # if you see a shape exception here, you may need to add # a `dynamic_output_shape` tag to an operator - # prims/decomps must correctly model strides, - # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325 - prims.utils.compare_tensor_meta(fake_out, real_out, True) + if op.op not in [ + torch.ops.aten._efficient_attention_forward, + torch.ops.aten._flash_attention_forward, + ]: + # prims/decomps must correctly model strides, + # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325 + + # note: the excluded ops have intentionally incorrect device; + # see "Note [Seed and Offset]" (_meta_registrations.py) + prims.utils.compare_tensor_meta(fake_out, real_out, True) if name not in aliasing_failures: fake_aliasing = outputs_alias_inputs( diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 80641f53b7a72a..022e8010265ba0 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5041,6 +5041,106 @@ def meta_scatter_(self, dim, index, src_or_value, reduce=None): return self +@register_meta([aten._scaled_dot_product_flash_attention]) +def meta__scaled_dot_product_flash_attention( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + scale: Optional[float] = None, +): + batch_size = query.size(0) + num_heads = query.size(1) + max_seqlen_batch_q = query.size(2) + head_dim = query.size(3) + max_seqlen_batch_k = key.size(2) + + query_t = query.transpose(1, 2) + attention = torch.empty_like(query_t).transpose(1, 2) + logsumexp = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device=query.device, + ) + else: + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + + # Note [Seed and Offset]: device for seed and offset below depends on whether we are + # capturing or not, but at the time of tracing we don't know if we + # are going to use cudagraphs or not, so we return meta tensors here + # it's possible we'll need to have some special handling in inductor for sdpa + + return ( + attention, + logsumexp, + None, + None, + max_seqlen_batch_q, + max_seqlen_batch_k, + torch.empty((), dtype=torch.long, device="meta"), + torch.empty((), dtype=torch.long, device="meta"), + debug_mask, + ) + + +@register_meta([aten._scaled_dot_product_cudnn_attention]) +def meta__scaled_dot_product_cudnn_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + compute_log_sumexp: bool, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + scale: Optional[float] = None, +): + B = query.size(0) + H = query.size(1) + S_Q = query.size(2) + S_KV = key.size(2) + D_QK = query.size(-1) + D_V = value.size(-1) + + res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device) + logsum_exp = torch.empty( + (B, H, S_Q), + dtype=torch.float, + device=query.device, + ) + + # See Note [Seed and Offset] + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return ( + res, + logsum_exp, + None, + None, + S_Q, + S_KV, + seed, + offset, + None, + ) + + @register_meta( [ aten._scaled_dot_product_flash_attention_backward, @@ -5155,6 +5255,46 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( return grad_q, grad_k, grad_v +@register_meta([aten._scaled_dot_product_efficient_attention]) +def meta__scaled_dot_product_efficient_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + compute_log_sumexp: bool, + dropout_p=0.0, + is_causal: bool = False, + scale: Optional[float] = None, +): + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) + + logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 + logsum_exp = torch.empty( + (B, num_heads, logsumexp_dim), + dtype=torch.float, + device=query.device, + ) + + res = res.transpose(1, 2) + + # See Note [Seed and Offset]: + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return res, logsum_exp, seed, offset + + @register_meta( [ aten._scaled_dot_product_efficient_attention_backward, @@ -5244,6 +5384,71 @@ def meta__scaled_dot_product_cudnn_backward( return grad_q, grad_k, grad_v +@register_meta( + [ + aten._flash_attention_forward, + ] +) +def meta__flash_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + cum_seq_q: Optional[Tensor], + cum_seq_k: Optional[Tensor], + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + return_debug_mask: bool, + scale: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + seqused_k: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, +): + # NB: there are two underlying paths: + # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) + # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total + # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total + batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 + max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q + max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k + num_heads = query.size(-2) + head_dim = query.size(-1) + + # Cuda Path + attention = torch.empty_like(query) + logsumexp = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device=query.device, + ) + else: + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + + # See Note [Seed and Offset]: + return ( + attention, + logsumexp, + torch.empty((), dtype=torch.long, device="meta"), + torch.empty((), dtype=torch.long, device="meta"), + debug_mask, + ) + + @register_meta( [ aten._flash_attention_backward, @@ -5275,6 +5480,59 @@ def meta__flash_attention_backward( return grad_query, grad_key, grad_value +@register_meta( + [ + aten._efficient_attention_forward, + ] +) +def meta__efficient_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + custom_mask_type: int, + compute_log_sumexp: bool = False, + scale: Optional[float] = None, + causal_diagonal: Optional[Tensor] = None, + seqlen_k: Optional[Tensor] = None, + window_size: Optional[int] = None, +): + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) + + logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B + actual_max_seqlen_q = M + if cu_seqlens_q is not None: + assert max_seqlen_q is not None + actual_max_seqlen_q = max_seqlen_q + actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N + logsumexp_dim = ( + math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 + ) + logsum_exp = torch.empty( + (logsumexp_batch_dim, num_heads, logsumexp_dim), + dtype=torch.float, + device=query.device, + ) + + # See Note [Seed and Offset]: + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k + + @register_meta( [ aten._efficient_attention_backward, diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 8c6a8521d56719..b0c9ee1a716aa3 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -2,7 +2,6 @@ import functools import itertools -import math import sys from typing import Callable, Union @@ -704,318 +703,6 @@ def convert(t, mem_fmt): ) -@register_op_impl(aten._scaled_dot_product_flash_attention.default) -def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - return_debug_mask = kwargs["return_debug_mask"] - # unused: value, dropout_p, is_causal, scale - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - batch_size = query.size(0) - num_heads = query.size(1) - max_seqlen_batch_q = query.size(2) - head_dim = query.size(3) - max_seqlen_batch_k = key.size(2) - - query_t = query.transpose(1, 2) - # empty_like already returns a fake tensor so we don't need to convert it - attention = torch.empty_like(query_t).transpose(1, 2) - logsumexp = convert_tensor( - torch.empty( - (batch_size, num_heads, max_seqlen_batch_q), - dtype=torch.float, - device="meta", - ), - device=query.device, - ) - - if return_debug_mask: - blocksize_c = 128 if head_dim > 64 else 256 - max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) - if max_seqlen_batch_k <= 128: - max_seqlen_k = 128 - elif max_seqlen_batch_k <= 256: - max_seqlen_k = 256 - debug_mask = convert_tensor( - torch.empty( - (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), - dtype=query.dtype, - device="meta", - ), - device=query.device, - ) - else: - debug_mask = convert_tensor( - torch.empty(0, dtype=query.dtype, device="meta"), - query.device, - ) - - # Note [Seed and Offset]: device for seed and offset below depends on whether we are - # capturing or not, but at the time of tracing we don't know if we - # are going to use cudagraphs or not, so we return meta tensors here - # it's possible we'll need to have some special handling in inductor for sdpa - - return ( - attention, - logsumexp, - None, - None, - max_seqlen_batch_q, - max_seqlen_batch_k, - convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), - convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), - debug_mask, - ) - - -@register_op_impl(aten._scaled_dot_product_cudnn_attention.default) -def meta__scaled_dot_product_cudnn(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - compute_log_sumexp = kwargs["compute_log_sumexp"] - # unused: attn_bias, dropout_p, is_causal, return_debug_mask, scale - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - B = query.size(0) - H = query.size(1) - S_Q = query.size(2) - S_KV = key.size(2) - D_QK = query.size(-1) - D_V = value.size(-1) - - res = convert_tensor( - torch.empty(B, H, S_Q, D_V, dtype=query.dtype, device="meta"), - query.device, - ) - - logsum_exp = convert_tensor( - torch.empty( - (B, H, S_Q), - dtype=torch.float, - device="meta", - ), - query.device, - ) - - # See Note [Seed and Offset]: - seed = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - offset = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - return ( - res, - logsum_exp, - None, - None, - S_Q, - S_KV, - seed, - offset, - None, - ) - - -@register_op_impl(aten._scaled_dot_product_efficient_attention.default) -def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - compute_log_sumexp = kwargs["compute_log_sumexp"] - # unused: attn_bias, dropout_p, is_causal, scale - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - B = query.size(0) - M = query.size(1) - N = key.size(1) - num_heads = query.size(-2) - K = query.size(-1) - Kv = value.size(-1) - - res = convert_tensor( - torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), - query.device, - ) - - logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 - logsum_exp = convert_tensor( - torch.empty( - (B, num_heads, logsumexp_dim), - dtype=torch.float, - device="meta", - ), - query.device, - ) - - res = res.transpose(1, 2) - - # See Note [Seed and Offset]: - seed = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - offset = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - - return res, logsum_exp, seed, offset - - -@register_op_impl(aten._flash_attention_forward.default) -def meta__flash_attention_forward(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - cum_seq_q = kwargs["cum_seq_q"] - cum_seq_k = kwargs["cum_seq_k"] - max_q = kwargs["max_q"] - max_k = kwargs["max_k"] - return_debug_mask = kwargs["return_debug_mask"] - # unused: value, dropout_p, is_causal, scale - # unused: seqused_k, alibi_slopes, window_size_left, window_size_right - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - # NB: there are two underlying paths: - # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) - # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total - # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total - batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 - max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q - max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k - num_heads = query.size(-2) - head_dim = query.size(-1) - - # Cuda Path - # note: empty_like already returns a fake tensor, we don't need to wrap it - attention = torch.empty_like(query) - logsumexp = convert_tensor( - torch.empty( - (batch_size, num_heads, max_seqlen_batch_q), - dtype=torch.float, - device="meta", - ), - device=query.device, - ) - - if return_debug_mask: - blocksize_c = 128 if head_dim > 64 else 256 - max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) - if max_seqlen_batch_k <= 128: - max_seqlen_k = 128 - elif max_seqlen_batch_k <= 256: - max_seqlen_k = 256 - debug_mask = convert_tensor( - torch.empty( - (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), - dtype=query.dtype, - device="meta", - ), - query.device, - ) - else: - debug_mask = convert_tensor( - torch.empty(0, dtype=query.dtype, device="meta"), - query.device, - ) - - # See Note [Seed and Offset]: - return ( - attention, - logsumexp, - convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), - convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), - debug_mask, - ) - - -@register_op_impl(aten._efficient_attention_forward.default) -def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - cu_seqlens_q = kwargs["cu_seqlens_q"] - max_seqlen_q = kwargs["max_seqlen_q"] - max_seqlen_k = kwargs["max_seqlen_k"] - compute_log_sumexp = kwargs["compute_log_sumexp"] - # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, seqlen_k - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - B = query.size(0) - M = query.size(1) - N = key.size(1) - num_heads = query.size(-2) - K = query.size(-1) - Kv = value.size(-1) - - res = convert_tensor( - torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), - query.device, - ) - - logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B - actual_max_seqlen_q = M - if cu_seqlens_q is not None: - assert max_seqlen_q is not None - actual_max_seqlen_q = max_seqlen_q - actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N - logsumexp_dim = ( - math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 - ) - logsum_exp = convert_tensor( - torch.empty( - (logsumexp_batch_dim, num_heads, logsumexp_dim), - dtype=torch.float, - device="meta", - ), - query.device, - ) - - # See Note [Seed and Offset]: - seed = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - offset = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - - return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k - - @register_op_impl(torch.ops.aten._pack_padded_sequence.default) def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first): if ( diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index d0a2c13b395df3..28fc7a4028917f 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -66,6 +66,12 @@ def is_sdpa_error(func, idx, e): and "Devices" in repr(e) ): return True + if ( + func is aten._scaled_dot_product_cudnn_attention.default + and idx in (6, 7) + and "Devices" in repr(e) + ): + return True return False diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index d3c7bd33c9d711..3049cd8d8f5753 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -28,7 +28,7 @@ skipCPUIfNoMklSparse, toleranceOverride, tol) from torch.testing._internal.common_cuda import ( - PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, _get_torch_rocm_version, ) @@ -16050,12 +16050,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # skip for sm < 80 DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace', - device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', - device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', - device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), # FIXME DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'), 'TestCompositeCompliance', 'test_cow_input', @@ -16069,24 +16063,17 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): 'TestFakeTensor', 'test_fake_crossref_backward_no_amp', device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), - # registered in fake_impls.py instead of _meta_registrations.py, so meta kernels will fail. - # However, for implementations that fall back to the constituent ops, the meta kernels may not - # fail. Fused kernels will fail, whereas unfused kernels will not fail. - # All fused kernels support bf16 and fp16 - so if fused attention is supported, the test will fail. - # mem_eff_attention also supports fp32 - so if it is supported the test will fail. - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=(torch.bfloat16, torch.float16), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION), - # TODO: float32 support in ROCM efficient attention - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=(torch.bfloat16, torch.float16), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.bfloat16, torch.float16,), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION), + # for element 1, was torch.Size([4, 4, 0]) but real shape was torch.Size([16, 3, 0]) + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16, torch.float32], + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16, torch.float32], + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),), + device_type="cuda", dtypes=[torch.float32], + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),), ), OpInfo( 'torch.ops.aten._flash_attention_forward', @@ -16102,14 +16089,14 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): check_batched_forward_grad=False, decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")], skips=( - # Device mismatch due to philox seed and offset - DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'), - # meta implementation is in fake_impls.py instead of being a meta registration - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), + # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + # Checking the scalar value of the philox seed and offset # Checking the scalar value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), @@ -16137,14 +16124,14 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], skips=( - # Device mismatch due to philox seed and offset - DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'), - # meta implementation is in fake_impls.py instead of being a meta registration - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), + # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16, torch.float32], + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), # Checking the scaler value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), From f35f1dfccdeb7a8d62cf40c2ea4dea6b2f46936a Mon Sep 17 00:00:00 2001 From: PaliC Date: Tue, 27 Aug 2024 21:12:56 +0000 Subject: [PATCH 0111/1018] [split build] fix distributed problems (#134502) Should fix the issue where USE_C10D_NCCL was not getting propagated to libtorch_python.so Pull Request resolved: https://github.com/pytorch/pytorch/pull/134502 Approved by: https://github.com/yifuwang --- torch/CMakeLists.txt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 9a91b26d54cfb4..b8dfb8b706ba1a 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -340,11 +340,17 @@ endif() # in case of the split build we need to add compile definitions if(BUILD_LIBTORCHLESS) + + if(USE_UCC) + target_link_libraries(torch_python PRIVATE __caffe2_ucc) + target_compile_definitions(torch_python PRIVATE USE_UCC) + endif() + if(USE_UCC AND USE_C10D_UCC) target_compile_definitions(torch_python PRIVATE USE_C10D_UCC) endif() - if(USE_UCC AND USE_C10D_NCCL) + if(USE_NCCL AND USE_C10D_NCCL) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() From 712acb8c24251facf48f6203151a9c706ca796f9 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Tue, 27 Aug 2024 09:50:31 -0700 Subject: [PATCH 0112/1018] Add test case test_arange_length_with_float32_dtype (#134415) Adding a test as a followup from https://github.com/pytorch/pytorch/pull/134296 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134415 Approved by: https://github.com/ezyang --- test/dynamo/test_misc.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index dc9e99f770111f..2bb11b7109799f 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1616,6 +1616,22 @@ def _fn(a, b, func=func): expected_ops_dynamic=ifdynstaticdefault(1, 5), ) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_arange_length_with_float32_dtype(self): + @torch.compile(fullgraph=True) + def f(x): + y = x.item() + torch._check_is_size(y) + r = torch.arange(y, dtype=torch.float32) + + if r.size(0) == y: + return r + 1 + + return r + + x = torch.tensor([300]) + r = f(x) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_torch_check(self): cnts = torch._dynamo.testing.CompileCounter() From 8c034c2af7b8f8039b91900de982d6f31a374392 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 27 Aug 2024 21:51:03 +0000 Subject: [PATCH 0113/1018] [EZ] Restore `test_unicode_comments` (#134589) This reverts changes introduced by test_jit.py by https://github.com/pytorch/pytorch/commit/43737bd78a4c7a120d8175941d54e8348e706068 and adds lint suppression for this it As test name suggests it should have an unicode comment to make sure our parser can handle it Part of the fix for https://github.com/pytorch/pytorch/issues/134422 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134589 Approved by: https://github.com/aorenste, https://github.com/Skylion007 --- test/test_jit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_jit.py b/test/test_jit.py index 5b3a40973acac3..c3af8bc9f48fcf 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15636,7 +15636,7 @@ def fn(string): def test_unicode_comments(self): @torch.jit.script def test(self, a): - # shrug + # 🤷🤷🤷🤷 return torch.nn.functional.relu(a) def test_get_set_state_with_tensors(self): From 93ea4eb18aa9921b34d875d00bf67b94bfca4737 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Aug 2024 10:39:13 -0700 Subject: [PATCH 0114/1018] [FlexAttention] Create new variables for the subgraphs (#134507) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134507 Approved by: https://github.com/yanboliang, https://github.com/BoyuanFeng --- test/inductor/test_flex_attention.py | 23 +++++++++++++++++++++++ torch/_inductor/kernel/flex_attention.py | 15 +++++++++------ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index ecc2f7d8e2978a..e0c7882a9a1ee1 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1641,6 +1641,29 @@ def test_force_write_lse(self): torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) + @supported_platform + def test_small_q_kv_len(self): + make_tensor = functools.partial( + torch.ones, + (1, 1, 1, 16), + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} + out_eager, lse_eager = flex_attention( + query, key, value, return_lse=True, kernel_options=kernel_options + ) + + flex_compile = torch.compile(flex_attention, fullgraph=True) + out_compiled, lse_compiled = flex_compile( + query, key, value, return_lse=True, kernel_options=kernel_options + ) + + assert torch.equal(out_eager, out_compiled) + assert torch.equal(lse_eager, lse_compiled) + @supported_platform def test_causal_block_non_divisible_with_captured_buffer(self): Q_S = S - 3 diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index de7ff7ee913610..71b59058e16dfc 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -439,8 +439,11 @@ def forward_block_mn( # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, # which is larger than the actual number of elements. To avoid access memory out of bound, # we need to mask out the elements that are out of Q_LEN & KV_LEN. - offs_m = offs_m % Q_LEN - offs_n = offs_n % KV_LEN + m = offs_m % Q_LEN + n = offs_n % KV_LEN + else: + m = offs_m + n = offs_n {{ modification( subgraph_number=0, @@ -448,8 +451,8 @@ def forward_block_mn( score="qk", b="off_z", h="off_h", - m="offs_m", - n="offs_n", + m="m", + n="n", out="qk" ) | indent_except_first(1) }} @@ -464,8 +467,8 @@ def forward_block_mn( score="qk", b="off_z", h="off_h", - m="offs_m", - n="offs_n", + m="m", + n="n", ) | indent_except_first(2) }} if CHECK_BLOCK_BOUNDARY: From 4c1449dea65debdd81c2e1ab982a4d25bcc21a56 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Aug 2024 10:39:13 -0700 Subject: [PATCH 0115/1018] [FlexAttention] Remove unused code (#134511) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134511 Approved by: https://github.com/yanboliang ghstack dependencies: #134507 --- test/inductor/test_flex_attention.py | 7 ++++++- torch/_inductor/kernel/flex_attention.py | 3 --- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index e0c7882a9a1ee1..a5ee0525a9818a 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1648,7 +1648,7 @@ def test_small_q_kv_len(self): (1, 1, 1, 16), device="cuda", dtype=torch.float32, - requires_grad=False, + requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} @@ -1664,6 +1664,11 @@ def test_small_q_kv_len(self): assert torch.equal(out_eager, out_compiled) assert torch.equal(lse_eager, lse_compiled) + grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value)) + grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value)) + + torch.testing.assert_close(grads_eager, grads_compile) + @supported_platform def test_causal_block_non_divisible_with_captured_buffer(self): Q_S = S - 3 diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 71b59058e16dfc..8e90eacaab72b5 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1606,9 +1606,6 @@ def flex_attention_backward(*args, **kwargs): ] ) - if _use_flex_decoding(query, kernel_options): - raise NotImplementedError("Flex decoding backward pass is not implemented. ") - device = query.get_device() dtype = query.get_dtype() Bq, Hq, seq_len_q, qk_head_dim = query.get_size() From e89544bfe3c932bf1888fc9c7bcf83420af1e446 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 27 Aug 2024 10:39:14 -0700 Subject: [PATCH 0116/1018] [FlexAttention] Fix Sparse block multiple to ceildiv instead for floor div (#134538) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134538 Approved by: https://github.com/yanboliang ghstack dependencies: #134507, #134511 --- test/inductor/test_flex_decoding.py | 38 ++++++++++++++++++++++-- torch/_inductor/kernel/flex_attention.py | 8 ++--- torch/_inductor/kernel/flex_decoding.py | 4 +-- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 3cf3981b0f1782..bc41ff2fe8b16a 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -14,6 +14,7 @@ from torch.nn.attention.flex_attention import ( _create_empty_block_mask, _identity, + BlockMask, create_block_mask, flex_attention, ) @@ -253,7 +254,7 @@ def _check_out( def run_test( self, - score_mod: Callable, + score_mod: Optional[Callable], dtype: torch.dtype = torch.float16, Q_B: int = B, Q_H: int = Hq, @@ -263,7 +264,11 @@ def run_test( KV_H: int = Hkv, KV_S: int = S, V_D: int = D, + block_mask: Optional[BlockMask] = None, ): + assert ( + score_mod is not None or block_mask is not None + ), "Must provide score_mod or block_mask" assert Q_H % KV_H == 0 q = torch.randn( (Q_B, Q_H, Q_S, Q_D), @@ -280,7 +285,6 @@ def run_test( q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) - block_mask = None sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) @@ -943,6 +947,36 @@ def func(q, k, v, score_mod): code[0] ) + @supported_platform + def test_non_sparse_mulitple_block_size(self): + def generate_causal_offset(offset: torch.Tensor): + def causal_offset_mask(b, h, q_idx, kv_idx): + return (offset + q_idx) >= kv_idx + + return causal_offset_mask + + def noop(score, b, h, q_idx, kv_idx): + return score + + mod = generate_causal_offset( + torch.tensor(192, device="cuda", dtype=torch.int32) + ) + block_mask = create_block_mask(mod, 1, 1, 1, 65) + + self.run_test( + score_mod=None, + dtype=torch.float32, + block_mask=block_mask, + Q_B=1, + Q_H=1, + Q_S=1, + Q_D=16, + KV_B=1, + KV_H=1, + KV_S=65, + V_D=16, + ) + @supported_platform def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self): torch._dynamo.reset() diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 8e90eacaab72b5..24ba8ba1ec1191 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -239,7 +239,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK kv_indices = KV_IDX + sparse_kv_idx_offset kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K, @@ -278,7 +278,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK kv_indices = FULL_KV_IDX + sparse_kv_idx_offset kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K, @@ -1195,7 +1195,7 @@ def bwd_dq_inner( # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N2, 1)) + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) if not IS_DIVISIBLE: if hi >= 1: for start_n in range(0, hi - 1): @@ -1376,7 +1376,7 @@ def bwd_dkdv_inner( do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(Q_LEN // BLOCK_M1, 1)) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) if not IS_DIVISIBLE: if hi >= 1: diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 3520b618ef0a67..870e6194aec1e1 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -161,7 +161,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me # first kv block we're loading # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K + k_offset, @@ -207,7 +207,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K + k_offset, From 5c5705daa3595cf1a6f807824375cd81dfc5f619 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 27 Aug 2024 22:52:52 +0000 Subject: [PATCH 0117/1018] Revert "[dynamo][dicts] Support hasattr on dicts (#134590)" This reverts commit d23c0150f3ba5fd1162358e9e7b0e72e7308c87e. Reverted https://github.com/pytorch/pytorch/pull/134590 on behalf of https://github.com/anijain2305 due to causing trunk CI failures ([comment](https://github.com/pytorch/pytorch/pull/134590#issuecomment-2313705582)) --- test/dynamo/test_functions.py | 16 ---------------- torch/_dynamo/variables/dicts.py | 7 ------- 2 files changed, 23 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index f5cbfb8760c307..67ab47dfac9093 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1683,22 +1683,6 @@ def test_dict_sorted(x): tmp = {1: "D", 10: "B", 3: "E", 0: "F"} return x + 1, sorted(tmp), sorted(tmp, reverse=True) - def test_dict_hasattr(self): - def fn(x): - if hasattr(x, "to"): - return x.to("cpu") - if hasattr(x, "items"): - return torch.cos(x["a"]) - return x - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - - x = dict(a=torch.randn(3)) - self.assertEqual(fn(x), opt_fn(x)) - - x = torch.randn(4) - self.assertEqual(fn(x), opt_fn(x)) - @make_test def test_list_clear(a, b): tmp = [a + 1, a + 2] diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 7bcae6887cf6cc..2d4c015bf3ac1b 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -355,13 +355,6 @@ def call_method( def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] - def call_hasattr(self, tx, name): - # dict/OrderedDict do not allow setting arbitrary attributes. To check for hasattr, we can just check the - # __dict__ of the cls. - if name in self.user_cls.__dict__: - return ConstantVariable.create(True) - return ConstantVariable.create(False) - class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: From cfd629c9811a3e91ecdc16ded9fb097b707a1ead Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Tue, 27 Aug 2024 23:00:26 +0000 Subject: [PATCH 0118/1018] [export] don't duck size for DIM.AUTO (#134486) Summary: apparently DIM.AUTO leads to duck sizing, I didn't catch this. Doing the least intrusive fix possible by using `torch._dynamo.maybe_mark_dynamic()` under the hood. Test Plan: added test Differential Revision: D61809344 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134486 Approved by: https://github.com/avikchaudhuri --- test/export/test_export.py | 23 +++++++++++++++++++++++ torch/export/dynamic_shapes.py | 4 ++++ 2 files changed, 27 insertions(+) diff --git a/test/export/test_export.py b/test/export/test_export.py index 1a491cb8a3c37a..2393443b5f0b0f 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2492,6 +2492,29 @@ def forward(self, x, y): gm(torch.randn(33, 4), torch.randn(32, 4)) gm(torch.randn(128, 4), torch.randn(128, 4)) + def test_dont_duck_size_for_auto_dynamic(self): + # for this use case, mark_dynamic() and AUTO should have same effect. + # check that same symbol gets allocated to both dims without raising constraint violation. + from torch.export.dynamic_shapes import DIM + + AUTO, STATIC = DIM.AUTO, DIM.STATIC + + class Foo(torch.nn.Module): + def forward(self, x, y): + # x: [s0, s1], y: [s0 + 1, 4] + assert y.shape[1] == 4 + assert x.shape[0] == y.shape[0] - 1 + return x * 2, y * 2 + + # duck sizing would make all static based on these sample inputs + inputs = (torch.randn(4, 4), torch.randn(5, 4)) + shapes = { + "x": (AUTO, AUTO), + "y": (AUTO, AUTO), + } + ep = export(Foo(), inputs, dynamic_shapes=shapes) + ep.module()(torch.randn(6, 3), torch.randn(7, 4)) + @testing.expectedFailureRetraceability # T183144629 def test_map(self): class Module(torch.nn.Module): diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 77b2c1e34a9772..48438ee8cb2a55 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -835,6 +835,8 @@ def _marked_dynamic(tensor, i): if _marked_dynamic(tensor, i) or dim == DIM.AUTO: # don't have to specify anything if dynamic # None also works, since assume_static_by_default=False + if dim == DIM.AUTO: + torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing continue elif isinstance(dim, _Dim): out[i] = dim @@ -851,6 +853,8 @@ def _marked_dynamic(tensor, i): for i, val in enumerate(tensor.shape): dim = shape[i] if _marked_dynamic(tensor, i) or dim == DIM.AUTO: + if dim == DIM.AUTO: + torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing out.append(None) elif isinstance(dim, _Dim): out.append(dim) From 16c2413e9fbc0ee5eb03f4aa0d8a27755fa87436 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 27 Aug 2024 23:25:34 +0000 Subject: [PATCH 0119/1018] [ROCm][Inductor][CK] Fix codegen after ck signature change (#134483) MakeArgument signature was changed in https://github.com/ROCm/composable_kernel/pull/1453 adding splitK argument to universal gemm templates which are used to codegen addmm and matmul (part of the series started at #125453 ) # Testing `pytest test/inductor/test_ck_backend.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134483 Approved by: https://github.com/ColinPeppler --- torch/_inductor/codegen/rocm/ck_universal_gemm_template.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py index fff132a9ab7c71..d247103a9d4019 100644 --- a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -57,6 +57,7 @@ class CKGemmTemplate(CKTemplate): LDB, std::array{ {{'LDD' if has_bias else ''}} }, LDC, + 1, // kBatch PassThrough {}, // a_elementwise_op PassThrough {}, // b_elementwise_op {{epilogue}} // c_elementwise_op From c059e153695e4f28c12ef7a21ff9c7708b27cc16 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 27 Aug 2024 11:17:01 -0700 Subject: [PATCH 0120/1018] Add compile time instruction count metric (#133834) PYTHONPATH=$(pwd) python benchmarks/update_hint_benchmark.py out as of this diff, compile_time_instruction_count counts the number of instruction from within convert_frame.compile_inner ``` update_hint_regression,compile_time_instruction_count,10522459165 ``` will add result from CI once populated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133834 Approved by: https://github.com/aorenste --- .../pr_time_benchmarks/benchmark_base.py | 78 +++++++++++++++---- .../benchmarks/update_hint_benchmark.py | 13 ++-- torch/_C/_instruction_counter.pyi | 4 + torch/_dynamo/config.py | 4 + torch/_dynamo/convert_frame.py | 4 +- torch/_dynamo/utils.py | 39 ++++++++++ 6 files changed, 122 insertions(+), 20 deletions(-) create mode 100644 torch/_C/_instruction_counter.pyi diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py index e9da5b3a45c519..225f51c5b314d7 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py @@ -4,6 +4,8 @@ from fbscribelogger import make_scribe_logger import torch._C._instruction_counter as i_counter +import torch._dynamo.config as config +from torch._dynamo.utils import CompileTimeInstructionCounter scribe_log_torch_benchmark_compile_time = make_scribe_logger( @@ -51,10 +53,19 @@ class BenchmarkBase(ABC): - _instruction_count = False + # measure total number of instruction spent in _work. + _enable_instruction_count = False + + # measure total number of instruction spent in convert_frame.compile_inner + # TODO is there other parts we need to add ? + _enable_compile_time_instruction_count = False def enable_instruction_count(self): - self._instruction_count = True + self._enable_instruction_count = True + return self + + def enable_compile_time_instruction_count(self): + self._enable_compile_time_instruction_count = True return self def name(self): @@ -64,29 +75,44 @@ def description(self): return "" @abstractmethod - def prepare(self): + def _prepare(self): pass @abstractmethod - def work(self): + def _work(self): pass - def prepare_once(self): # noqa: B027 + def _prepare_once(self): # noqa: B027 pass - def count_instructions(self): + def _count_instructions(self): print(f"collecting instruction count for {self.name()}") - self.prepare_once() - results = [] for i in range(10): - self.prepare() + self._prepare() id = i_counter.start() - self.work() + self._work() count = i_counter.end(id) print(f"instruction count for iteration {i} is {count}") - if i != 0: - results.append(count) + results.append(count) + return min(results) + + def _count_compile_time_instructions(self): + print(f"collecting compile time instruction count for {self.name()}") + config.record_compile_time_instruction_count = True + + results = [] + for i in range(10): + self._prepare() + # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner + # hence this will only count instruction count spent in compile_inner. + CompileTimeInstructionCounter.clear() + self._work() + count = CompileTimeInstructionCounter.value() + print(f"compile time instruction count for iteration {i} is {count}") + results.append(count) + + config.record_compile_time_instruction_count = False return min(results) def append_results(self, path): @@ -102,12 +128,36 @@ def print(self): print(f"{entry[0]},{entry[1]},{entry[2]}") def collect_all(self): + self._prepare_once() self.results = [] - if self._instruction_count: - r = self.count_instructions() + if ( + self._enable_instruction_count + and self._enable_compile_time_instruction_count + ): + raise RuntimeError( + "not supported until we update the logger, both logs to the same field now" + ) + + if self._enable_instruction_count: + r = self._count_instructions() self.results.append((self.name(), "instruction_count", r)) scribe_log_torch_benchmark_compile_time( name=self.name(), instruction_count=r, ) + if self._enable_compile_time_instruction_count: + r = self._count_compile_time_instructions() + + self.results.append( + ( + self.name(), + "compile_time_instruction_count", + r, + ) + ) + # TODO add a new field compile_time_instruction_count to the logger. + scribe_log_torch_benchmark_compile_time( + name=self.name(), + instruction_count=r, + ) return self diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py index 92a83c609a1f2d..694bcea5985771 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py @@ -15,17 +15,17 @@ def name(self): def description(self): return "information at https://github.com/pytorch/pytorch/pull/129893" - def prepare_once(self): + def _prepare_once(self): torch._dynamo.config.capture_scalar_outputs = True random.seed(42) self.splits = torch.randint(10, (self.N,)) sz = self.splits.sum().item() self.input = torch.randn(sz) - def prepare(self): + def _prepare(self): torch._dynamo.reset() - def work(self): + def _work(self): @torch.compile(fullgraph=True) def f(a, b): xs = b.tolist() @@ -34,12 +34,15 @@ def f(a, b): torch._check(x <= self.N) return a.split(xs) - f(self.input, self.splits) + for i in range(1000): + f(self.input, self.splits) def main(): result_path = sys.argv[1] - Benchmark().enable_instruction_count().collect_all().append_results(result_path) + Benchmark().enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) if __name__ == "__main__": diff --git a/torch/_C/_instruction_counter.pyi b/torch/_C/_instruction_counter.pyi new file mode 100644 index 00000000000000..4e3c27567eb228 --- /dev/null +++ b/torch/_C/_instruction_counter.pyi @@ -0,0 +1,4 @@ +# Defined in torch/csrc/instruction_counter/Module.cpp + +def start() -> int: ... +def end(id: int) -> int: ... diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 88081a87290070..8885ca33cc0765 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -374,6 +374,10 @@ def _get_optimize_ddp_mode(): # Inline inbuilt nn modules inline_inbuilt_nn_modules = not is_fbcode() +# When set, total compile time instruction count is recorded using +# torch._dynamo.utilsCompileTimeInstructionCounter. +record_compile_time_instruction_count = False + def default_debug_dir_root(): # [@compile_ignored: debug] diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 2f0da2be866db1..c1fadedd614405 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -27,6 +27,7 @@ import torch._logging from torch._C._dynamo.guards import GlobalStateGuard from torch._dynamo.distributed import get_compile_pg +from torch._dynamo.utils import CompileTimeInstructionCounter from torch._guards import compile_context, CompileContext, CompileId, tracing from torch._logging import structured from torch._utils_internal import ( @@ -652,7 +653,8 @@ def compile_inner( transform: Callable[[List[Instruction], Dict[str, Any]], Any], ) -> Optional[GuardedCode]: with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"): - return _compile_inner(code, one_graph, hooks, transform) + with CompileTimeInstructionCounter.record(): + return _compile_inner(code, one_graph, hooks, transform) @compile_time_strobelight_meta(phase_name="compile_inner") @maybe_cprofile diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index c8b2ac7a953af5..ae78f22374665a 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -64,6 +64,7 @@ from torch import fx from torch._C import ( _get_function_stack_at, + _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, _push_on_torch_function_stack, @@ -3203,3 +3204,41 @@ def get_user_object_from_id(obj_id): def store_user_object_weakref(obj): obj_id = id(obj) user_obj_id_to_weakref[obj_id] = weakref.ref(obj) + + +class CompileTimeInstructionCounter: + _counter: int = 0 + _id: int = -1 + _depth = 0 + + @classmethod + def start(cls) -> None: + cls._depth = cls._depth + 1 + if cls._depth == 1: + cls._id = _instruction_counter.start() + + @classmethod + def end(cls) -> None: + cls._depth = cls._depth - 1 + if cls._depth == 0: + cls._counter += _instruction_counter.end(cls._id) + cls._id = -1 + + @classmethod + def clear(cls) -> None: + cls._counter = 0 + + @classmethod + def value(cls) -> int: + return cls._counter + + @classmethod + @contextmanager + def record(cls): + try: + if config.record_compile_time_instruction_count: + cls.start() + yield + finally: + if config.record_compile_time_instruction_count: + cls.end() From 2ff76a61ed36c3652c48d3e4c6468ab0df77b2a4 Mon Sep 17 00:00:00 2001 From: albanD Date: Tue, 27 Aug 2024 16:12:01 -0400 Subject: [PATCH 0121/1018] Add device daemon (#131814) Base implementation aiming towards https://github.com/pytorch/rfcs/pull/64 Details of the implementation and next steps in https://github.com/pytorch/pytorch/blob/gh/albanD/3/head/test/cpp_extensions/open_registration_extension/README.md Pull Request resolved: https://github.com/pytorch/pytorch/pull/131814 Approved by: https://github.com/ezyang --- .flake8 | 2 +- .../open_registration_extension/README.md | 23 ++- .../pytorch_openreg/__init__.py | 33 +--- .../pytorch_openreg/_aten_impl.py | 153 ++++++++++++++++ .../pytorch_openreg/_device_daemon.py | 168 ++++++++++++++++++ .../pytorch_openreg/_meta_parser.py | 104 +++++++++++ .../pytorch_openreg/csrc/OpenReg.h | 1 + .../pytorch_openreg/csrc/OpenRegHooks.cpp | 12 +- .../pytorch_openreg/csrc/OpenRegMem.cpp | 125 +++++++++++++ .../test/test_openreg.py | 34 +++- 10 files changed, 614 insertions(+), 41 deletions(-) create mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py create mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py create mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py create mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp diff --git a/.flake8 b/.flake8 index c789819f477bae..4e1cb4642d418f 100644 --- a/.flake8 +++ b/.flake8 @@ -57,7 +57,7 @@ per-file-ignores = torch/distributed/_tensor/_collective_utils.py: TOR901 # This is a full package that happen to live within the test # folder, so ok to skip - test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py: TOR901 + test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py: TOR901 optional-ascii-coding = True exclude = ./.git, diff --git a/test/cpp_extensions/open_registration_extension/README.md b/test/cpp_extensions/open_registration_extension/README.md index f304a96ecb1a87..07f1f98d915a76 100644 --- a/test/cpp_extensions/open_registration_extension/README.md +++ b/test/cpp_extensions/open_registration_extension/README.md @@ -1,4 +1,4 @@ -This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend in core. +This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core. ## How to use Install as standalone with `python setup.py develop` (or install) from this folder. @@ -8,6 +8,23 @@ You can run test via `python test/test_openreg.py`. For simplicity anything that can be implemented from python is done so. A real implementation will most likely want to call these different APIs from c++ directly. -The current version send everything back to python and is missing most implementations in python. The only one available is the one used by the autograd engine to check how many workers to spawn. +The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing. -Next step is to create the device daemon so we can actually provide and allocator and create memory, then start using features and re-route all missing methods to daemon as appropriate. +The codebase is split as follows: +- `pytorch_openreg/__init__.py` imports torch to get core state initialized, imports `._aten_impl` to register our aten op implementations to torch, imports `.C` to load our c++ extension that registers more ops, allocator and hooks and finally renames the PrivateUse1 backend and register our python-side module. +- `pytorch_openreg/_aten_impl.py` does two main things. Use the `_register_same_name()` function to register hooks from c++ (like getDevice, getStream, etc) and send them to our device daemon. Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation +- `pytorch_openreg/_device_daemon.py` contains the Allocator (responsible for allocating memory on the device side, as int8 buffers, and recreating nice looking Tensors on the device side to be able to use aten ops to run code there), `run_op` that is the logic running on the device side to perform compute (for simplicity of coverage, we are re-building full blown Tensors here and calling aten ops on them). It also contains the Daemon responsible for the device worker process and sending data back and forth. +- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor. + +## Next steps + +Currently, the autograd test is disabled because it's missing the getStream implementation. +The main next step would be to: +- Split the daemon into a proper user-process driver vs device-process executor. The main goal would be to better mimick which information is held on the user-process side and when we're actually communicating with the device. In particular current device or stream should be user-process informations. +- Add Stream/Event system. Most likely by having multiple requests queue that go to the device from the driver. +- Add RNG Generator. +- Add Pinned memory and HostAllocator. + +Longer term: +- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this. +- Build this module in the CI environment and enable Device-generic tests on this device. diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py index fa2a286f22c91b..fa231cff5b9d36 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py @@ -1,26 +1,13 @@ import torch - -# Global properties of our device -NUM_DEVICES = 7 - # Create our python implementation dict so that the C++ module # can access it during its initialization -_IMPL_REGISTRY = {} - -# Load the C++ Module -import pytorch_openreg._C # noqa: F401 - - -# Define all the implementations in the registry -def register(fn): - _IMPL_REGISTRY[fn.__name__[1:]] = fn - return fn +# Also register aten impls +from ._aten_impl import _IMPL_REGISTRY as _IMPL_REGISTRY # noqa: F401 -@register -def _deviceCount(): - return NUM_DEVICES +# Load the C++ Module +import pytorch_openreg._C # noqa: F401 # usort: skip # Module used for our backend @@ -31,15 +18,3 @@ class _OpenRegMod: # Set all the appropriate state on PyTorch torch.utils.rename_privateuse1_backend("openreg") torch._register_device_module("openreg", _OpenRegMod()) - -_openreg_lib = torch.library.Library("_", "IMPL") # ignore TOR901 - - -def _openreg_kernel_fallback(op, *args, **kwargs): - print("Calling ", op) - assert op is torch.ops.aten.empty.memory_format - # FIXME: this returns a cpu Tensor which is NOT ok. - return torch.empty(args[0]) - - -_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py new file mode 100644 index 00000000000000..534b810f9999c5 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -0,0 +1,153 @@ +import logging + +import torch +from torch.utils._pytree import tree_any + + +log = logging.getLogger(__name__) + +from ._device_daemon import daemon +from ._meta_parser import prepare_for_sending, to_device_no_copy + + +_IMPL_REGISTRY = {} + + +# Define all the implementations in the registry +def _register_same_name(name, with_log=False): + def _(*args, **kwargs): + if with_log: + log.info("Calling hook %s", name) + return daemon.exec(name, *args, **kwargs) + + _IMPL_REGISTRY[name] = _ + + +_register_same_name("deviceCount") +_register_same_name("getDevice") +_register_same_name("uncheckedSetDevice") +_register_same_name("exchangeDevice") +_register_same_name("malloc", True) +_register_same_name("free", True) + +_openreg_lib = torch.library.Library("_", "IMPL") + + +def _openreg_kernel_fallback(op, *args, **kwargs): + log.info("Calling kernel %s", op) + + # Special ops needed to avoid infinite recursion + if op is torch.ops.aten._copy_from.default: + from_, to_ = args + if from_.device.type == to_.device.type: + assert from_.device.type == "openreg" + op = torch.ops.aten.copy_.default + # handled below as a regular copy + elif from_.device.type == "openreg": + args, _ = prepare_for_sending((from_,), {}) + host_mem = daemon.exec("send_data", *args) + return to_.copy_(host_mem) + elif to_.device.type == "openreg": + args, _ = prepare_for_sending((to_,), {}) + daemon.exec("recv_data", from_, *args) + return to_ + else: + raise RuntimeError("Should not happen") + elif op is torch.ops.aten.set_.source_Tensor: + return torch.ops.aten.set_.source_Storage_storage_offset( + args[0], + args[1].untyped_storage(), + args[1].storage_offset(), + args[1].size(), + args[1].stride(), + ) + elif op is torch.ops.aten._local_scalar_dense.default: + args, _ = prepare_for_sending(args, {}) + host_mem = daemon.exec("send_data", *args) + return host_mem.item() + + op_name = None + post_process = None + if "out" in op._overloadname: + # Note that all structured native op will call here + if isinstance(kwargs["out"], tuple): + raise RuntimeError(f"out= variant {op} with tuple out= not supported") + if kwargs["out"].nelement() == 0: + # Out variant that needs a resize, convert to an out of place + # and handle generically below + orig_out = kwargs["out"] + del kwargs["out"] + if op._overloadname != "out": + raise RuntimeError( + "Cannot retranslate non-default out= variant form 0 size" + ) + op = op.overloadpacket.default + + def _post_process(): + nonlocal real_res + orig_out.set_(real_res) + real_res = orig_out + + post_process = _post_process + + else: + # No metadata update to do, just run the op on the device + op_name = op.overloadpacket._qualified_op_name + real_res = kwargs["out"] + elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)): + # No Tensor argument means factory function + # They should decompose and be handled in our c++ side directly + raise RuntimeError(f"{op} not handled yet.") + elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: + # Only handle inplace ops returning their first arg + assert len(args) >= 1, f"Inplace {op} needs at least one arg" + assert ( + len(op._schema.returns) == 1 + ), f"NYI Inplace {op} with more than one return" + op_name = op.overloadpacket._qualified_op_name + real_res = args[0] + elif any(r.alias_info is not None for r in op._schema.returns): + # View ops + if op is torch.ops.aten.view.default: + return torch.ops.aten._unsafe_view(*args, **kwargs) + raise RuntimeError(f"{op} view op is not handled yet") + + if op_name is None: + # 1. Compute updated metadata + if torch.Tag.dynamic_output_shape not in op.tags: + # Usual case: run the meta op to see the output metadata + meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs) + meta_res = op(*meta_args, **meta_kwargs) + + # 2. Allocate the output + real_res, _ = to_device_no_copy("openreg", meta_res, {}) + else: + # Slow version for data-dependent functions: + # Run the op on the device just to get the output shape + args_, kwargs_ = prepare_for_sending(args, kwargs) + shape = daemon.exec( + "get_op_output_shape", + op.overloadpacket._qualified_op_name, + args_, + kwargs_, + ) + + # 2. Allocate the output + real_res = args[0].new(shape) + + # 3. Move to out variant + kwargs["out"] = real_res + # Let overload resolution find the out= overload + op_name = op.overloadpacket._qualified_op_name + + # 4. Run the compute and populate the output on the device + args, kwargs = prepare_for_sending(args, kwargs) + daemon.exec("run_op", op_name, args, kwargs) + + if post_process is not None: + post_process() + + return real_res + + +_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py new file mode 100644 index 00000000000000..8cf727cfa2c826 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py @@ -0,0 +1,168 @@ +import logging + +import torch + +from ._meta_parser import ( + OpenRegTensorData, + receive_after_sending, + safe_str, + validate_send_queue_args, +) + + +log = logging.getLogger(__name__) +mp_context = torch.multiprocessing.get_context("spawn") + +# Constant properties of our device +NUM_DEVICES = 7 + +# Global state of our driver +CURR_DEVICE_IDX = 0 +CURR_STREAM = 0 + + +# Our allocator +class Allocator: + def __init__(self): + self.allocated = {} + + def malloc(self, size): + new_data = torch.empty(size, dtype=torch.uint8) + ptr = new_data.data_ptr() + self.allocated[ptr] = new_data + return ptr + + def free(self, ptr): + if ptr not in self.allocated: + return False + else: + del self.allocated[ptr] + return True + + def tensor_from_meta(self, meta): + # Usual case, we're receiving a known Tensor + found_base = self.allocated.get(meta.data_ptr, None) + + # Might be a rewrap of another storage at a different offset + # Slow path to try and find the corresponding storage + if found_base is None: + for tag, t in self.allocated.items(): + # t is always a 1D uint8 storage! + if meta.data_ptr > tag and meta.data_ptr < tag + t.nelement(): + # Blame @ngimel for this + slice_size = t.nelement() - (meta.data_ptr - tag) + found_base = torch.tensor((), dtype=torch.uint8).set_( + t.untyped_storage()[meta.data_ptr - tag :], + size=(slice_size,), + stride=(1,), + storage_offset=0, + ) + + # This pointer is not allocated here, segfault ! + if found_base is None: + log.info("Currently allocated blocks:\n %s", safe_str(self.allocated)) + log.info("Trying to access %s", meta) + raise RuntimeError("SEGFAULT!") + + # Raw 1d uint8 data + raw = found_base + # Slice the right storage part + raw_slice = raw.narrow(0, 0, meta.nelem_in_bytes) + # Reinterpret cast in the right dtype + as_dtype = raw_slice.view(dtype=meta.dtype) + # View to the right shape/stride/offset + view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset) + return view + + +def run_op(allocator, op_name, args, kwargs): + op, _ = torch._C._jit_get_operation(op_name) + args, kwargs = receive_after_sending(allocator, args, kwargs) + return op(*args, **kwargs) + + +class _Daemon: + def __init__(self): + super().__init__() + self.is_initialized = False + + def _lazy_init(self): + if self.is_initialized: + return + self.req_queue = mp_context.Queue() + self.ans_queue = mp_context.Queue() + + self.runner = mp_context.Process( + target=self.run_forever, args=(self.req_queue, self.ans_queue), daemon=True + ) + self.runner.start() + self.is_initialized = True + + def exec(self, cmd, *args): + self._lazy_init() + log.info("Main process launched: %s(*%s)", cmd, safe_str(args)) + validate_send_queue_args(cmd, args) + self.req_queue.put((cmd,) + args) + res = self.ans_queue.get() + log.info("Main process result for %s received: %s", cmd, safe_str(res)) + if res == "ERROR": + raise RuntimeError(f"Error in daemon while executing {cmd}, see logs") + else: + return res + + @staticmethod + def run_forever(req_queue, ans_queue): + # Initialize our device + global CURR_DEVICE_IDX + empty_res = object() + allocator = Allocator() + + # Serve all requests + while True: + cmd, *args = req_queue.get() + log.info("Worker executing: %s", cmd) + res = empty_res + if cmd == "deviceCount": + assert len(args) == 0 + res = NUM_DEVICES + elif cmd == "getDevice": + res = CURR_DEVICE_IDX + elif cmd == "uncheckedSetDevice": + assert len(args) == 1 + CURR_DEVICE_IDX = int(args[0]) + res = None + elif cmd == "exchangeDevice": + assert len(args) == 1 + res = CURR_DEVICE_IDX + CURR_DEVICE_IDX = int(args[0]) + elif cmd == "malloc": + res = allocator.malloc(*args) + elif cmd == "free": + res = allocator.free(*args) + elif cmd == "run_op": + op_name, args, kwargs = args + run_op(allocator, op_name, args, kwargs) + res = None + elif cmd == "send_data": + assert len(args) == 1 + res = OpenRegTensorData.from_meta(allocator, args[0]) + elif cmd == "recv_data": + assert len(args) == 2 + host_tensor, dev_mem = args + dev_tensor = OpenRegTensorData.from_meta(allocator, dev_mem) + dev_tensor.copy_(host_tensor) + res = None + elif cmd == "get_op_output_shape": + op_name, args, kwargs = args + res = run_op(allocator, op_name, args, kwargs).size() + else: + log.warning("Bad command in worker") + res = "ERROR" + + if res == empty_res: + raise RuntimeError("Bad impl didn't return anything") + log.info("Worker answering to: %s", cmd) + ans_queue.put(res) + + +daemon = _Daemon() diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py new file mode 100644 index 00000000000000..18b3c1842fc027 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py @@ -0,0 +1,104 @@ +import pprint + +import torch +from torch.utils._pytree import tree_map, tree_map_only + + +class OpenRegTensorMeta: + def __init__(self, tensor, checked=True): + if checked and not tensor.device.type == "openreg": + raise RuntimeError( + "Creating OpenRegTensorMeta is only for Tensors on openreg device" + ) + self.data_ptr = tensor.untyped_storage().data_ptr() + self.size = tensor.size() + self.stride = tensor.stride() + self.storage_offset = tensor.storage_offset() + self.dtype = tensor.dtype + self.nelem_in_bytes = tensor.nelement() * tensor.element_size() + + def __repr__(self): + return ( + f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, " + f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})" + ) + + +class OpenRegTensorData(torch.Tensor): + @staticmethod + def from_meta(allocator, tensor_meta): + return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta)) + + +VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float} + +VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str} + + +def safe_str(args): + def convert(obj): + if isinstance(obj, torch.Tensor): + return str(OpenRegTensorMeta(obj, checked=False)) + else: + return obj + + new_args = tree_map(convert, args) + return pprint.pformat(new_args) + + +def validate_send_queue_args(cmd, args): + def check(obj): + if type(obj) not in VALID_QUEUE_TYPES_OUT: + if ( + cmd == "recv_data" + and type(obj) is torch.Tensor + and obj.device.type == "cpu" + ): + # Only HtoD copy command can send cpu Tensors over + return + raise RuntimeError( + f"Trying to send invalid object through queue: {type(obj)}" + ) + + tree_map(check, args) + + +def prepare_for_sending(args, kwargs): + def convert(obj): + if type(obj) not in VALID_QUEUE_TYPES_IN: + raise RuntimeError( + f"Cannot send object of type {type(obj)} " "over openreg device pipe." + ) + + if isinstance(obj, torch.Tensor): + return OpenRegTensorMeta(obj) + else: + return obj + + return tree_map(convert, (args, kwargs)) + + +def receive_after_sending(allocator, args, kwargs): + def convert(obj): + if type(obj) not in VALID_QUEUE_TYPES_OUT: + raise RuntimeError( + f"Received invalid object of type {type(obj)} " + "over openreg device pipe." + ) + + if isinstance(obj, OpenRegTensorMeta): + return allocator.tensor_from_meta(obj) + else: + return obj + + return tree_map(convert, (args, kwargs)) + + +def to_device_no_copy(device, args, kwargs): + def safe_to(t): + if device == "meta": + return t.to(device=device) + else: + return torch.empty_like(t, device=device) + + return tree_map_only(torch.Tensor, safe_to, (args, kwargs)) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h index 68bfeef1fdc889..6cf2b429bea076 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h @@ -6,5 +6,6 @@ namespace openreg { void set_impl_registry(PyObject* registry); + py::function get_method(const char* name); } \ No newline at end of file diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp index 238ed8fddaf4ba..dae7b3b1db9608 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp @@ -11,10 +11,6 @@ namespace { // Python dictionary where real implementations can be found PyObject* py_registry; -py::function get_method(const char* name) { - return py::cast(py_registry)[name]; -} - // C++ hooks implementation struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {}; @@ -243,4 +239,12 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); void set_impl_registry(PyObject* registry) { py_registry = registry; } + +py::function get_method(const char* name) { + auto dict = py::cast(py_registry); + TORCH_CHECK(dict.contains(name), "OpenReg registry does not contain ", + "an implementation for '", name, "' make sure to add it in the __init__.py " + "file and register it.") + return dict[name]; +} } // openreg \ No newline at end of file diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp new file mode 100644 index 00000000000000..91b2dfcec4b4cd --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp @@ -0,0 +1,125 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace openreg { + +namespace { + +using openreg_ptr_t = uint64_t; + +// A dummy allocator for our custom device, that secretly uses the CPU +struct OpenRegAllocator final : at::Allocator { + OpenRegAllocator() = default; + + at::DataPtr allocate(size_t nbytes) override { + py::gil_scoped_acquire acquire; + auto curr_device_idx = get_method("getDevice")().cast(); + auto curr_device = c10::Device(c10::DeviceType::PrivateUse1, curr_device_idx); + void* data = nullptr; + if (nbytes > 0) { + data = reinterpret_cast(get_method("malloc")(nbytes).cast()); + TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on openreg device."); + } + return {data, data, &ReportAndDelete, curr_device}; + } + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + py::gil_scoped_acquire acquire; + TORCH_CHECK( + get_method("free")(reinterpret_cast(ptr)).cast(), + "Failed to free memory pointer at ", ptr + ); + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + py::gil_scoped_acquire acquire; + get_method("copy_data")(reinterpret_cast(dest), reinterpret_cast(src), count); + } +}; + +// Register our dummy allocator +static OpenRegAllocator global_openreg_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); + + +// Empty op needs C++ code and cannot be handled by python side fallback +at::Tensor empty_openreg( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK(c10::layout_or_default(layout_opt) == c10::Layout::Strided, "Non strided layout not supported"); + TORCH_CHECK(!c10::pinned_memory_or_default(pin_memory_opt), "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_generic( + size, &global_openreg_alloc, pu1_dks, dtype, memory_format_opt); +} + +at::Tensor empty_strided_openreg( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK(c10::layout_or_default(layout_opt) == c10::Layout::Strided, "Non strided layout not supported"); + TORCH_CHECK(!c10::pinned_memory_or_default(pin_memory_opt), "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_strided_generic( + size, stride, &global_openreg_alloc, pu1_dks, dtype); +} + +at::Tensor as_strided_openreg( + const at::Tensor& self, + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional storage_offset_) { + // Metadata-only change so we re-use the cpu impl + return at::cpu::as_strided(self, size, stride, storage_offset_); +} + +at::Tensor& set_openreg( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::cpu::set_(result, storage, storage_offset, size, stride); +} + + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("empty.memory_format", empty_openreg); + m.impl("empty_strided", empty_strided_openreg); + m.impl("as_strided", as_strided_openreg); + m.impl("set_.source_Storage_storage_offset", set_openreg); +} + +} // anonymous namspaces + +} // openreg diff --git a/test/cpp_extensions/open_registration_extension/test/test_openreg.py b/test/cpp_extensions/open_registration_extension/test/test_openreg.py index 87c35edf695ca3..18484e61bef04b 100644 --- a/test/cpp_extensions/open_registration_extension/test/test_openreg.py +++ b/test/cpp_extensions/open_registration_extension/test/test_openreg.py @@ -7,17 +7,17 @@ import pytorch_openreg import torch -from torch.testing._internal.common_utils import IS_LINUX, run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, TestCase class TestOpenReg(TestCase): def test_initializes(self): self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg") - @unittest.skipIf(not IS_LINUX, "Only works on linux") + @unittest.SkipTest def test_autograd_init(self): # Make sure autograd is initialized - torch.rand(2, requires_grad=True, device="openreg").sum().backward() + torch.ones(2, requires_grad=True, device="openreg").sum().backward() pid = os.getpid() task_path = f"/proc/{pid}/task" @@ -30,9 +30,35 @@ def test_autograd_init(self): thread_name = file.read().strip() all_thread_names.add(thread_name) - for i in range(pytorch_openreg.NUM_DEVICES): + for i in range(pytorch_openreg._device_daemon.NUM_DEVICES): self.assertIn(f"pt_autograd_{i}", all_thread_names) + def test_factory(self): + a = torch.empty(50, device="openreg") + self.assertEqual(a.device.type, "openreg") + + a.fill_(3.5) + + self.assertTrue(a.eq(3.5).all()) + + def test_printing(self): + a = torch.ones(20, device="openreg") + # Does not crash! + str(a) + + def test_cross_device_copy(self): + a = torch.rand(10) + b = a.to(device="openreg").add(2).to(device="cpu") + self.assertEqual(b, a + 2) + + def test_data_dependent_output(self): + cpu_a = torch.randn(10) + a = cpu_a.to(device="openreg") + mask = a.gt(0) + out = torch.masked_select(a, mask) + + self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0))) + if __name__ == "__main__": run_tests() From 3f4714a82649ad93a92bfce5c08087634cd5e3ea Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 28 Aug 2024 00:34:38 +0000 Subject: [PATCH 0122/1018] [export] enumerate unsupported sympy.Functions (#134271) (#134598) Summary: There's 2 concepts of unsupported sympy.Functions in symbolic_shapes: 1) unsupported by the export solver, meaning the solver doesn't know how to provide useful fixes for those functions 2) unsupported by the sympy interpreter - meaning we can't reify them into FX nodes because the functions aren't present in PythonReferenceAnalysis This splits the current call into a call for each version, with the Export solver the only user of 1). For 1), we enumerate the functions in _sympy/functions.py, and subtract the functions we know we can support. For 2) there's only 3 functions we've seen pop up in test cases. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Differential Revision: D61863394 Pulled By: pianpwk Pull Request resolved: https://github.com/pytorch/pytorch/pull/134598 Approved by: https://github.com/angelayi --- torch/fx/experimental/symbolic_shapes.py | 34 ++++++++++++++++++++---- torch/fx/passes/runtime_assert.py | 6 ++--- torch/utils/_sympy/functions.py | 7 +++++ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 9c7f7ea80ce399..bc6645c617a479 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -63,7 +63,7 @@ from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._sympy.functions import ( - FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt + Application, FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt ) from torch.utils._sympy.solve import try_solve from torch.utils._sympy.numbers import int_oo @@ -1267,13 +1267,14 @@ def _is_supported_equivalence(expr): ) return isinstance(expr, sympy.Symbol) -def _has_unsupported_sympy_function(expr) -> bool: +def _has_uninterpretable_sympy_function(expr) -> bool: + """ + Add functions that our sympy interpreter can't reify into FX nodes + """ return expr.has( torch.utils._sympy.functions.ToFloat, torch.utils._sympy.functions.TruncToInt, torch.utils._sympy.functions.CeilToInt, - # add more sympy functions that involve float<->int conversion here - # since our solver does not know what to do with them ) @dataclass(frozen=True) @@ -1675,6 +1676,15 @@ def __init__( # symbols that are marked dynamic self._marked_dynamic = marked_dynamic + # track supported sympy functions and subtract from list of all sympy functions + self._supported_sympy_functions: Set[sympy.Function] = { + Application, + Mod, + PythonMod, + FloorDiv, + } + self._enumerate_sympy_functions() + def rewrite_with_congruences(self, s, expr): """ Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. @@ -1741,6 +1751,20 @@ def floor_div_handler(*args): expr = expr.replace(FloorDiv, floor_div_handler) return expr + def _enumerate_sympy_functions(self): + module = torch.utils._sympy.functions + all_functions = set() + for attr in dir(module): + if isinstance(func := getattr(module, attr), sympy.FunctionClass): + all_functions.add(func) + self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions) + + def _has_unsupported_sympy_function(self, expr) -> bool: + """ + Tracks list of sympy.Functions the export solver doesn't know how to handle. + """ + return expr.has(*self._unsupported_sympy_functions) + def add(self, expr) -> bool: """Add an expression to the set of constraints. @@ -1757,7 +1781,7 @@ def add(self, expr) -> bool: # a fix for this issue, we delay raising such failures. See solve(). if orig_reduced == sympy.false: self._inconsistencies.append(f"{orig_expr} is inconsistent!") - if isinstance(expr, sympy.Ne) or _has_unsupported_sympy_function(expr): + if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr): # we're not going to do anything useful with these, so drop them return False free_symbols = expr.free_symbols diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 7307fb8c835703..e2a7c1bee3db1d 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -93,7 +93,7 @@ def insert_deferred_runtime_asserts( import sympy from torch.fx.experimental.symbolic_shapes import ( - _has_unsupported_sympy_function, + _has_uninterpretable_sympy_function, CallMethodKey, cast_symbool_to_symint_guardless, ConvertIntKey, @@ -136,7 +136,7 @@ def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: (val := _get_sym_val(node)) is not None and not isinstance(val, sympy.Number) # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported - and not _has_unsupported_sympy_function(val) + and not _has_uninterpretable_sympy_function(val) and any( isinstance(arg, fx.Node) and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size)) @@ -195,7 +195,7 @@ def add_runtime_asserts(ras): and _is_bound_expr_for_symbol(ra.expr) ) # don't try to reify sympy functions we can't turn into FX nodes - or _has_unsupported_sympy_function(ra.expr) + or _has_uninterpretable_sympy_function(ra.expr) ): continue diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 0998b57678610a..f9602dbd63976f 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -53,13 +53,20 @@ __all__ = [ "FloorDiv", "ModularIndexing", + "Where", + "PythonMod", + "Mod", "CleanDiv", + "CeilToInt", + "FloorToInt", "CeilDiv", "IntTrueDiv", "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", + "TruncToFloat", + "TruncToInt", "RoundToInt", "RoundDecimal", "ToFloat", From b6aa10c938336b86105d712b123b4d43f586e905 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 27 Aug 2024 21:02:45 +0000 Subject: [PATCH 0123/1018] using new device-agnostic api instead of old api like torch.cpu or torch.cuda (#134448) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134448 Approved by: https://github.com/guangyey, https://github.com/shink, https://github.com/albanD --- test/test_autocast.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_autocast.py b/test/test_autocast.py index 24f87944990d8f..4c3f031a88c153 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -45,9 +45,9 @@ def cast(val, to_type): if add_kwargs is None: add_kwargs = {} - self.assertFalse(torch.is_autocast_cpu_enabled()) - with torch.cpu.amp.autocast(dtype=amp_dtype): - self.assertTrue(torch.is_autocast_cpu_enabled()) + self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) + with torch.amp.autocast(device_type="cpu", dtype=amp_dtype): + self.assertTrue(torch.is_autocast_enabled(device_type="cpu")) out_type = out_type if out_type is not None else run_as_type output = output_method = None @@ -94,8 +94,8 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with torch.cpu.amp.autocast(enabled=False): - self.assertFalse(torch.is_autocast_cpu_enabled()) + with torch.amp.autocast(device_type="cpu", enabled=False): + self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) if module is not None and hasattr(module, op): control = getattr(module, op)( @@ -108,8 +108,8 @@ def compare(first, second): self.assertTrue(type(output_to_compare) == type(control)) comparison = compare(output_to_compare, control) self.assertTrue(comparison, f"torch.{op} result did not match control") - self.assertTrue(torch.is_autocast_cpu_enabled()) - self.assertFalse(torch.is_autocast_cpu_enabled()) + self.assertTrue(torch.is_autocast_enabled(device_type="cpu")) + self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) def args_maybe_kwargs(self, op_with_args): if len(op_with_args) == 2: @@ -237,7 +237,7 @@ def test_autocast_rnn(self): m(x, (hx, cx)) # Should be able to run the below case with autocast - with torch.cpu.amp.autocast(): + with torch.amp.autocast(device_type="cpu"): m(x, (hx, cx)) def test_autocast_disabled_with_fp32_dtype(self): @@ -249,7 +249,7 @@ def test_generic_autocast(self): op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) with torch.amp.autocast(device_type="cpu"): generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) - with torch.cpu.amp.autocast(): + with torch.amp.autocast(device_type="cpu"): cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) self.assertEqual(generic_autocast_output, cpu_autocast_output) @@ -346,8 +346,8 @@ def test_cache_disabled(self): class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self): - gpu_fast_dtype = torch.get_autocast_gpu_dtype() - cpu_fast_dtype = torch.get_autocast_cpu_dtype() + gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda") + cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu") self.assertEqual(gpu_fast_dtype, torch.half) self.assertEqual(cpu_fast_dtype, torch.bfloat16) From 822b1cd8900e78daecd29db5ea1bb2e02e4ef9c1 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Mon, 26 Aug 2024 23:54:10 -0700 Subject: [PATCH 0124/1018] allow conv_bn mixed dtype folding in post-grad (#133968) This PR relaxes the condition to allow conv_bn mixed dtype folding in post-grad. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133968 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel --- test/inductor/test_inductor_freezing.py | 61 +++++++++++++++++++++ torch/_inductor/fx_passes/binary_folding.py | 5 +- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 2f4faad8f9f063..9ea952ca02bd2a 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -10,6 +10,7 @@ import torch from torch import nn +from torch._dynamo.utils import counters from torch._inductor import config from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import override_lowering, run_and_get_code @@ -79,6 +80,17 @@ def forward(self, x): return self.bn(self.conv(x)) +class ConvBNHardswish(torch.nn.Module): + def __init__(self, in_channels, out_channels, bias=False, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) + self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) + self.hardswish = nn.Hardswish(inplace=True) + + def forward(self, x): + return self.hardswish(self.bn(self.conv(x))) + + class ConvFunctionalBN(torch.nn.Module): def __init__( self, @@ -437,6 +449,54 @@ def test_folded_conv_bn(self): x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) + torch._dynamo.reset() + counters.clear() + + @torch.compile() + def foo(mod, x): + return mod(x) + + # TODO - bias is separate kernel right now, we should only unfuse it + # from conv if it can be fused + + with torch.no_grad(): + out_eager = mod(x) + out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) + + # we unfuse the conv bias, but it should only have one constant in the kernel + if self.device == "cuda": + FileCheck().check_not(".run(").check("conv").check(".run(").check_same( + "frozen_param" + ).check_not("frozen_param").check_next("return").run(code[0]) + + self.assertEqual( + out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 + ) + self.assertEqual(counters["inductor"]["binary_folding"], 4) + + @torch._inductor.config.patch(layout_optimization=False) + def test_folded_conv_bn_hardswish(self): + for use_bias, dtype in itertools.product( + [True, False], [torch.float16, torch.bfloat16, torch.float32] + ): + if self.device == "cpu" and dtype == torch.float16: + continue + + if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + continue + + mod = ( + ConvBNHardswish(3, 32, bias=use_bias, kernel_size=3, stride=2) + .eval() + .to(self.device) + .to(dtype) + ) + + x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) + + torch._dynamo.reset() + counters.clear() + @torch.compile() def foo(mod, x): return mod(x) @@ -457,6 +517,7 @@ def foo(mod, x): self.assertEqual( out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 ) + self.assertEqual(counters["inductor"]["binary_folding"], 4) @torch._inductor.config.patch(layout_optimization=False) def test_folded_conv_bn_with_module_sharing(self): diff --git a/torch/_inductor/fx_passes/binary_folding.py b/torch/_inductor/fx_passes/binary_folding.py index c360ae9ec6d87f..bd1e0736d510fa 100644 --- a/torch/_inductor/fx_passes/binary_folding.py +++ b/torch/_inductor/fx_passes/binary_folding.py @@ -34,10 +34,7 @@ def mark_mixed_dtype_conv(conv): conv_user = next(iter(conv_user.users.keys())) - if not ( - conv_user.target == prims.convert_element_type.default - and conv_user.args[1] == conv_dtype - ): + if conv_user.target != prims.convert_element_type.default: return conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype From 4400af34acc204f422c6820d200d25941faec1c1 Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 28 Aug 2024 01:17:41 +0000 Subject: [PATCH 0125/1018] [CD] Fix docker builds by installing setuptools after python build (#134631) Follow up after https://github.com/pytorch/pytorch/pull/134595 Same error happens silently before the error addressed in the above PR (and build continues and builds invalid Docker): ``` #47 457.5 Traceback (most recent call last): #47 457.5 File "", line 1, in #47 457.5 File "/opt/_internal/cpython-3.12.0/lib/python3.12/site-packages/wheel/pep425tags.py", line 3, in #47 457.5 import distutils.util #47 457.5 ModuleNotFoundError: No module named 'distutils' #47 457.5 + local abi_tag= #47 457.5 + ln -s /opt/_internal/cpython-3.12.0 /opt/python/ #47 457.5 + rm -f Python-3.12.0.tgz ``` The fix in https://github.com/pytorch/pytorch/pull/134595 is no longer needed since we will install setuptools right after python installation. Link: https://github.com/pytorch/pytorch/actions/runs/10584642913/job/29329366729#step:6:6041 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134631 Approved by: https://github.com/kit1980 --- .ci/docker/common/install_cpython.sh | 3 ++- .ci/docker/manywheel/build_scripts/build.sh | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.ci/docker/common/install_cpython.sh b/.ci/docker/common/install_cpython.sh index 1bd25fb2de91f0..5e0c0f2bb2815d 100755 --- a/.ci/docker/common/install_cpython.sh +++ b/.ci/docker/common/install_cpython.sh @@ -58,7 +58,8 @@ function do_cpython_build { if [ -e ${prefix}/bin/pip3 ] && [ ! -e ${prefix}/bin/pip ]; then ln -s pip3 ${prefix}/bin/pip fi - ${prefix}/bin/pip install wheel==0.34.2 + # install setuptools since python 3.12 is required to use distutils + ${prefix}/bin/pip install wheel==0.34.2 setuptools==68.2.2 local abi_tag=$(${prefix}/bin/python -c "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag; print('{0}{1}-{2}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag()))") ln -s ${prefix} /opt/python/${abi_tag} } diff --git a/.ci/docker/manywheel/build_scripts/build.sh b/.ci/docker/manywheel/build_scripts/build.sh index fccdea91737d55..1708b71a19b5eb 100644 --- a/.ci/docker/manywheel/build_scripts/build.sh +++ b/.ci/docker/manywheel/build_scripts/build.sh @@ -117,8 +117,6 @@ find /opt/_internal \ -print0 | xargs -0 rm -f for PYTHON in /opt/python/*/bin/python; do - # install setuptools since python 3.12 is required to use distutils - $PYTHON -m pip install setuptools==68.2.2 # Smoke test to make sure that our Pythons work, and do indeed detect as # being manylinux compatible: $PYTHON $MY_DIR/manylinux1-check.py From fcf196e937603e8f2d83c5d8e6b8627e1309a108 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Mon, 26 Aug 2024 15:36:04 +0000 Subject: [PATCH 0126/1018] Refactor caching device allocator utils (#130923) # Motivation Following [[RFC] Intel GPU Runtime Upstreaming for Allocator ](https://github.com/pytorch/pytorch/issues/116322), this PR aims to refactor caching device allocator utils to improve code reuse usage. This is the first PR, we could prepare some follow-up PRs continuing to refactor the device caching allocator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130923 Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD, https://github.com/eqy --- c10/core/CachingDeviceAllocator.h | 131 +++++++++++++++++ c10/cuda/CUDACachingAllocator.cpp | 163 +++++++-------------- c10/cuda/CUDACachingAllocator.h | 78 +--------- c10/cuda/CUDAMallocAsyncAllocator.cpp | 2 + torch/csrc/cuda/CUDAPluggableAllocator.cpp | 4 +- torch/csrc/cuda/CUDAPluggableAllocator.h | 2 +- torch/csrc/cuda/Module.cpp | 8 +- 7 files changed, 198 insertions(+), 190 deletions(-) create mode 100644 c10/core/CachingDeviceAllocator.h diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h new file mode 100644 index 00000000000000..8724ecf88ae0f9 --- /dev/null +++ b/c10/core/CachingDeviceAllocator.h @@ -0,0 +1,131 @@ +#pragma once + +#include +#include + +#include + +namespace c10::CachingDeviceAllocator { + +struct Stat { + void increase(size_t amount) { + current += static_cast(amount); + peak = std::max(current, peak); + allocated += static_cast(amount); + } + + void decrease(size_t amount) { + current -= static_cast(amount); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + current >= 0, + "Negative tracked stat in device allocator (likely logic error)."); + freed += static_cast(amount); + } + + void reset_accumulated() { + allocated = 0; + freed = 0; + } + + void reset_peak() { + peak = current; + } + + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +using StatArray = std::array(StatType::NUM_TYPES)>; +using StatTypes = std::array(StatType::NUM_TYPES)>; + +template +void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { + for (const auto stat_type : c10::irange(stat_types.size())) { + if (stat_types[stat_type]) { + f(stat_type); + } + } +} + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from device memory allocation. + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via device memory deallocation) + StatArray inactive_split; + + // SUM: bytes allocated by this memory alocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to device malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to device memory allocation + // after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of device memory allocation calls. This includes both + // mapped and malloced memory. + int64_t num_device_alloc = 0; + + // COUNT: total number of device memory deallocation calls. This includes both + // un-mapped and free memory. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + +// Size pretty-printer +inline std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (static_cast(size) / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << static_cast(size) / 1048576.0; + os << " MiB"; + } else { + os << static_cast(size) / 1073741824.0; + os << " GiB"; + } + return os.str(); +} + +} // namespace c10::CachingDeviceAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 4a3a2c7545c54c..a38f202394e5a3 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include @@ -27,7 +26,6 @@ #include #include #include -#include #include #include #include @@ -44,6 +42,8 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace cuda::CUDACachingAllocator { +using namespace c10::CachingDeviceAllocator; + // Included here as this is externally used in CUDAAllocatorConfig const size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks @@ -134,47 +134,13 @@ namespace { using stream_set = ska::flat_hash_set; -using StatTypes = std::array(StatType::NUM_TYPES)>; - -void increase_stat(Stat& stat, size_t amount) { - stat.current += static_cast(amount); - stat.peak = std::max(stat.current, stat.peak); - stat.allocated += static_cast(amount); -} - -void decrease_stat(Stat& stat, size_t amount) { - stat.current -= static_cast(amount); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - stat.current >= 0, - "Negative tracked stat in CUDA allocator (likely logic error)."); - stat.freed += static_cast(amount); -} - -void reset_accumulated_stat(Stat& stat) { - stat.allocated = 0; - stat.freed = 0; -} - -void reset_peak_stat(Stat& stat) { - stat.peak = stat.current; -} - -template -void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { - for (const auto stat_type : c10::irange(stat_types.size())) { - if (stat_types[stat_type]) { - f(stat_type); - } - } -} - void decrease_stat_array( StatArray& stat_array, size_t amount, const StatTypes& stat_types) { for_each_selected_stat_type( stat_types, [&stat_array, amount](size_t stat_type) { - decrease_stat(stat_array[stat_type], amount); + stat_array[stat_type].decrease(amount); }); } @@ -1424,16 +1390,16 @@ class DeviceCachingAllocator { // A new split inactive block is being created from a previously unsplit // block, size remaining->size bytes. for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - increase_stat(stats.inactive_split_bytes[stat_type], remaining->size); - increase_stat(stats.inactive_split[stat_type], 1); + stats.inactive_split_bytes[stat_type].increase(remaining->size); + stats.inactive_split[stat_type].increase(1); }); } } else if (already_split && !block->expandable_segment_) { // An already-split block is becoming active for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - decrease_stat(stats.inactive_split_bytes[stat_type], block->size); - decrease_stat(stats.inactive_split[stat_type], 1); + stats.inactive_split_bytes[stat_type].decrease(block->size); + stats.inactive_split[stat_type].decrease(1); }); } @@ -1454,14 +1420,14 @@ class DeviceCachingAllocator { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - increase_stat(stats.allocation[stat_type], 1); - increase_stat(stats.allocated_bytes[stat_type], block->size); - increase_stat(stats.active[stat_type], 1); - increase_stat(stats.active_bytes[stat_type], block->size); - increase_stat(stats.requested_bytes[stat_type], block->requested_size); + stats.allocation[stat_type].increase(1); + stats.allocated_bytes[stat_type].increase(block->size); + stats.active[stat_type].increase(1); + stats.active_bytes[stat_type].increase(block->size); + stats.requested_bytes[stat_type].increase(block->requested_size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - increase_stat(stats.oversize_allocations, 1); + stats.oversize_allocations.increase(1); c10::reportMemoryUsageToProfiler( block->ptr, @@ -1487,8 +1453,8 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.allocation[stat_type], 1); - decrease_stat(stats.allocated_bytes[stat_type], block->size); + stats.allocation[stat_type].decrease(1); + stats.allocated_bytes[stat_type].decrease(block->size); }); record_trace( @@ -1500,7 +1466,7 @@ class DeviceCachingAllocator { context ? context : block->context_when_allocated); if (block->size >= CUDAAllocatorConfig::max_split_size()) - decrease_stat(stats.oversize_allocations, 1); + stats.oversize_allocations.decrease(1); if (!block->stream_uses.empty()) { if (C10_UNLIKELY(!captures_underway.empty())) { @@ -1628,15 +1594,15 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - reset_accumulated_stat(stats.allocation[statType]); - reset_accumulated_stat(stats.segment[statType]); - reset_accumulated_stat(stats.active[statType]); - reset_accumulated_stat(stats.inactive_split[statType]); - reset_accumulated_stat(stats.allocated_bytes[statType]); - reset_accumulated_stat(stats.reserved_bytes[statType]); - reset_accumulated_stat(stats.active_bytes[statType]); - reset_accumulated_stat(stats.inactive_split_bytes[statType]); - reset_accumulated_stat(stats.requested_bytes[statType]); + stats.allocation[statType].reset_accumulated(); + stats.segment[statType].reset_accumulated(); + stats.active[statType].reset_accumulated(); + stats.inactive_split[statType].reset_accumulated(); + stats.allocated_bytes[statType].reset_accumulated(); + stats.reserved_bytes[statType].reset_accumulated(); + stats.active_bytes[statType].reset_accumulated(); + stats.inactive_split_bytes[statType].reset_accumulated(); + stats.requested_bytes[statType].reset_accumulated(); } stats.num_alloc_retries = 0; @@ -1644,8 +1610,8 @@ class DeviceCachingAllocator { stats.num_sync_all_streams = 0; stats.num_device_alloc = 0; stats.num_device_free = 0; - reset_accumulated_stat(stats.oversize_allocations); - reset_accumulated_stat(stats.oversize_segments); + stats.oversize_allocations.reset_accumulated(); + stats.oversize_segments.reset_accumulated(); } /** Resets the historical peak stats for the device **/ @@ -1654,18 +1620,18 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - reset_peak_stat(stats.allocation[statType]); - reset_peak_stat(stats.segment[statType]); - reset_peak_stat(stats.active[statType]); - reset_peak_stat(stats.inactive_split[statType]); - reset_peak_stat(stats.allocated_bytes[statType]); - reset_peak_stat(stats.reserved_bytes[statType]); - reset_peak_stat(stats.active_bytes[statType]); - reset_peak_stat(stats.inactive_split_bytes[statType]); - reset_peak_stat(stats.requested_bytes[statType]); + stats.allocation[statType].reset_peak(); + stats.segment[statType].reset_peak(); + stats.active[statType].reset_peak(); + stats.inactive_split[statType].reset_peak(); + stats.allocated_bytes[statType].reset_peak(); + stats.reserved_bytes[statType].reset_peak(); + stats.active_bytes[statType].reset_peak(); + stats.inactive_split_bytes[statType].reset_peak(); + stats.requested_bytes[statType].reset_peak(); } - reset_peak_stat(stats.oversize_allocations); - reset_peak_stat(stats.oversize_segments); + stats.oversize_allocations.reset_peak(); + stats.oversize_segments.reset_peak(); } /* Checkpoint the state of a private pool necessary to return it to its @@ -2277,7 +2243,7 @@ class DeviceCachingAllocator { total_allocated_memory += mapped_range.size; StatTypes stat_types = get_stat_types_for_pool(*to_map->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - increase_stat(stats.reserved_bytes[stat_type], mapped_range.size); + stats.reserved_bytes[stat_type].increase(mapped_range.size); }); stats.num_device_alloc++; @@ -2384,27 +2350,23 @@ class DeviceCachingAllocator { // inactive_split if (!block->expandable_segment_) { if (net_change_inactive_split_blocks > 0) { - increase_stat( - stats.inactive_split[stat_type], + stats.inactive_split[stat_type].increase( static_cast(net_change_inactive_split_blocks)); } else if (net_change_inactive_split_blocks < 0) { - decrease_stat( - stats.inactive_split[stat_type], + stats.inactive_split[stat_type].decrease( static_cast(-net_change_inactive_split_blocks)); } if (net_change_inactive_split_size > 0) { - increase_stat( - stats.inactive_split_bytes[stat_type], + stats.inactive_split_bytes[stat_type].increase( static_cast(net_change_inactive_split_size)); } else if (net_change_inactive_split_size < 0) { - decrease_stat( - stats.inactive_split_bytes[stat_type], + stats.inactive_split_bytes[stat_type].decrease( static_cast(-net_change_inactive_split_size)); } } - decrease_stat(stats.active[stat_type], 1); - decrease_stat(stats.active_bytes[stat_type], original_block_size); - decrease_stat(stats.requested_bytes[stat_type], requested_size); + stats.active[stat_type].decrease(1); + stats.active_bytes[stat_type].decrease(original_block_size); + stats.requested_bytes[stat_type].decrease(requested_size); }); } @@ -2711,11 +2673,11 @@ class DeviceCachingAllocator { total_allocated_memory += size; p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { - increase_stat(stats.segment[stat_type], 1); - increase_stat(stats.reserved_bytes[stat_type], size); + stats.segment[stat_type].increase(1); + stats.reserved_bytes[stat_type].increase(size); }); if (size >= CUDAAllocatorConfig::max_split_size()) - increase_stat(stats.oversize_segments, 1); + stats.oversize_segments.increase(1); // p.block came from new, not cudaMalloc. It should not be nullptr here. TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); @@ -2846,12 +2808,12 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.segment[stat_type], 1); - decrease_stat(stats.reserved_bytes[stat_type], block->size); + stats.segment[stat_type].decrease(1); + stats.reserved_bytes[stat_type].decrease(block->size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - decrease_stat(stats.oversize_segments, 1); + stats.oversize_segments.decrease(1); pool->blocks.erase(block); delete block; } @@ -2903,7 +2865,7 @@ class DeviceCachingAllocator { total_allocated_memory -= unmapped.size; StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.reserved_bytes[stat_type], unmapped.size); + stats.reserved_bytes[stat_type].decrease(unmapped.size); }); if (block->pool->owner_PrivatePool) { @@ -3732,25 +3694,6 @@ void local_raw_delete(void* ptr) { } } // namespace Native -// Size pretty-printer -std::string format_size(uint64_t size) { - std::ostringstream os; - os.precision(2); - os << std::fixed; - if (size <= 1024) { - os << size << " bytes"; - } else if (size <= 1048576) { - os << (static_cast(size) / 1024.0); - os << " KiB"; - } else if (size <= 1073741824ULL) { - os << static_cast(size) / 1048576.0; - os << " MiB"; - } else { - os << static_cast(size) / 1073741824.0; - os << " GiB"; - } - return os.str(); -} namespace CudaMallocAsync { // If this is put in its own header file, it gets incorrectly renamed in HIPify. diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 72617bcaf3a944..9d1a631f3750ba 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -50,73 +50,6 @@ namespace c10::cuda::CUDACachingAllocator { extern const size_t kLargeBuffer; -struct Stat { - int64_t current = 0; - int64_t peak = 0; - int64_t allocated = 0; - int64_t freed = 0; -}; - -enum struct StatType : uint64_t { - AGGREGATE = 0, - SMALL_POOL = 1, - LARGE_POOL = 2, - NUM_TYPES = 3 // remember to update this whenever a new stat type is added -}; - -typedef std::array(StatType::NUM_TYPES)> StatArray; - -// Struct containing memory allocator summary statistics for a device. -struct DeviceStats { - // COUNT: allocations requested by client code - StatArray allocation; - // COUNT: number of allocated segments from cudaMalloc(). - StatArray segment; - // COUNT: number of active memory blocks (allocated or used by stream) - StatArray active; - // COUNT: number of inactive, split memory blocks (unallocated but can't be - // released via cudaFree) - StatArray inactive_split; - - // SUM: bytes allocated by this memory alocator - StatArray allocated_bytes; - // SUM: bytes reserved by this memory allocator (both free and used) - StatArray reserved_bytes; - // SUM: bytes within active memory blocks - StatArray active_bytes; - // SUM: bytes within inactive, split memory blocks - StatArray inactive_split_bytes; - // SUM: bytes requested by client code - StatArray requested_bytes; - - // COUNT: total number of failed calls to CUDA malloc necessitating cache - // flushes. - int64_t num_alloc_retries = 0; - - // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) - int64_t num_ooms = 0; - - // COUNT: total number of oversize blocks allocated from pool - Stat oversize_allocations; - - // COUNT: total number of oversize blocks requiring malloc - Stat oversize_segments; - - // COUNT: total number of synchronize_and_free_events() calls - int64_t num_sync_all_streams = 0; - - // COUNT: total number of CUDA allocation calls. This includes both cuMemMap - // and cudaMalloc. - int64_t num_device_alloc = 0; - - // COUNT: total number of CUDA free calls. This includes both cuMemUnmap - // and cudaFree. - int64_t num_device_free = 0; - - // SIZE: maximum block size that is allowed to be split. - int64_t max_split_size = 0; -}; - typedef std::shared_ptr (*CreateContextFn)(); // Struct containing info of an allocation block (i.e. a fractional part of a @@ -247,9 +180,6 @@ enum struct RecordContext { ALL = 3, // additionally record stacks for when something is freed }; -// Size pretty-printer -std::string format_size(uint64_t size); - using OutOfMemoryObserver = std::functionrecordStream(dataPtr, stream); } -inline DeviceStats getDeviceStats(c10::DeviceIndex device) { +inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) { return get()->getDeviceStats(device); } diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 3fe414b55e30a2..3a7b485ebce223 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -11,6 +11,8 @@ namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync { +using namespace c10::CachingDeviceAllocator; + #if CUDA_VERSION >= 11040 // CUDA device allocator that uses cudaMallocAsync to implement // the same interface as CUDACachingAllocator.cpp. diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index c6af163481481e..5220e86233bd67 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -210,8 +210,8 @@ void CUDAPluggableAllocator::recordStream( } } -c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: - getDeviceStats(c10::DeviceIndex device) { +c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats( + c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getDeviceStats. " diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 70014b2fa0b37b..8652ef0f2bfde8 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -115,7 +115,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator void recordStream(const c10::DataPtr&, streamType stream) override; - c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 1cc6f7378e80c4..0106eebb1cd6ea 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -554,10 +554,10 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const auto device_index = THPUtils_unpackDeviceIndex(arg); - using c10::cuda::CUDACachingAllocator::DeviceStats; - using c10::cuda::CUDACachingAllocator::Stat; - using c10::cuda::CUDACachingAllocator::StatArray; - using c10::cuda::CUDACachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + using c10::CachingDeviceAllocator::Stat; + using c10::CachingDeviceAllocator::StatArray; + using c10::CachingDeviceAllocator::StatType; const auto statToDict = [](const Stat& stat) { py::dict dict; From 0323141fed065e2d6dbdaab41c93ea979ce8004e Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 28 Aug 2024 01:40:42 +0000 Subject: [PATCH 0127/1018] c10d/logging: add C10D_LOCK_GUARD (#134131) This adds logs if we can't acquire locks in NCCLUtils and ProcessGroupNCCL for 30s. This is motivated by some deadlocks were seeing and it's unclear if it's in NCCL or on the PyTorch side of things. This required replacing most `std::mutex` with `std::timed_mutex` and `std::condition_variable_any` as appropriate. Test plan: existing CI for regressions will add unit tests on `C10D_LOCK_GUARD` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134131 Approved by: https://github.com/c-p-i-o, https://github.com/fduwjj --- build_variables.bzl | 1 + caffe2/CMakeLists.txt | 2 + test/cpp/c10d/CMakeLists.txt | 2 + test/cpp/c10d/LoggingTest.cpp | 54 +++++++++++++++++++ test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 4 +- torch/csrc/distributed/c10d/LockGuard.cpp | 29 ++++++++++ torch/csrc/distributed/c10d/LockGuard.hpp | 32 +++++++++++ torch/csrc/distributed/c10d/NCCLUtils.cpp | 10 ++-- torch/csrc/distributed/c10d/NCCLUtils.hpp | 21 ++++---- .../distributed/c10d/ProcessGroupGloo.cpp | 2 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 50 ++++++++--------- .../distributed/c10d/ProcessGroupNCCL.hpp | 18 +++---- torch/csrc/distributed/c10d/Work.cpp | 13 ++--- torch/csrc/distributed/c10d/Work.hpp | 4 +- 14 files changed, 183 insertions(+), 59 deletions(-) create mode 100644 test/cpp/c10d/LoggingTest.cpp create mode 100644 torch/csrc/distributed/c10d/LockGuard.cpp create mode 100644 torch/csrc/distributed/c10d/LockGuard.hpp diff --git a/build_variables.bzl b/build_variables.bzl index bf6d6cb14bd364..85e7031d3c9098 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -495,6 +495,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/Functional.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", "torch/csrc/distributed/c10d/GroupRegistry.cpp", + "torch/csrc/distributed/c10d/LockGuard.cpp", "torch/csrc/distributed/c10d/Ops.cpp", "torch/csrc/distributed/c10d/ParamCommsUtils.cpp", "torch/csrc/distributed/c10d/PrefixStore.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 5ee9968ba484fa..d44a8da210462f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -927,6 +927,7 @@ elseif(USE_CUDA) set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_cuda) # see cmake/public/utils.cmake target_compile_definitions(torch_cuda PRIVATE USE_CUDA) + target_link_libraries(torch_cuda PRIVATE fmt::fmt-header-only) if(USE_CUFILE) target_link_libraries(torch_cuda PRIVATE torch::cufile) @@ -1333,6 +1334,7 @@ if(USE_ROCM) ${ROCM_SOURCE_DIR}/rocblas/include ${ROCM_SOURCE_DIR}/hipsparse/include ) + target_link_libraries(torch_hip PRIVATE fmt::fmt-header-only) if(USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) endif() diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index 0874852517e33b..c64adb4ad0521e 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -10,12 +10,14 @@ function(c10d_add_test test_src) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) target_link_libraries(${test_name} ${ARGN}) + target_link_libraries(${test_name} fmt::fmt-header-only) if(NOT WIN32) target_link_libraries(${test_name} pthread) endif() add_test(NAME ${test_name} COMMAND $) endfunction() +c10d_add_test(LoggingTest.cpp torch_cpu gtest_main) c10d_add_test(BackoffTest.cpp torch_cpu gtest_main) c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main) c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main) diff --git a/test/cpp/c10d/LoggingTest.cpp b/test/cpp/c10d/LoggingTest.cpp new file mode 100644 index 00000000000000..1b8664fb4c780c --- /dev/null +++ b/test/cpp/c10d/LoggingTest.cpp @@ -0,0 +1,54 @@ +#include + +#include +#include + +#include +#include + +TEST(LockGuard, basic) { + std::timed_mutex mutex; + + { + C10D_LOCK_GUARD(lock, mutex); + + // already locked + ASSERT_FALSE(mutex.try_lock()); + } + + ASSERT_TRUE(mutex.try_lock()); + mutex.unlock(); +} + +TEST(LockGuard, logging) { + // set log level to INFO + FLAGS_caffe2_log_level = 0; + + std::timed_mutex mutex; + + mutex.lock(); + + auto loggingThread = std::async(std::launch::async, [&]() { + std::unique_lock name{mutex, std::defer_lock}; + ::c10d::detail::lockWithLogging( + name, std::chrono::milliseconds(10), "my lock", __FILE__, __LINE__); + }); + + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10); + while (true) { + ASSERT_LT(std::chrono::system_clock::now(), deadline); + + testing::internal::CaptureStderr(); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + std::string output = testing::internal::GetCapturedStderr(); + + if (output.find("my lock: waiting for lock for 10ms") != + std::string::npos) { + break; + } + } + + mutex.unlock(); + + loggingThread.get(); +} diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index d416847f7911a5..54a929ab982514 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -180,7 +180,7 @@ class ProcessGroupNCCLNoHeartbeatCaught : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts), hasMonitorThreadCaughtError_(false) {} - std::mutex& getWatchdogMutex() { + std::timed_mutex& getWatchdogMutex() { return workMetaListMutex_; } @@ -413,7 +413,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { work = pg.allreduce(tensors_); { // Now run all reduce with errors. - std::lock_guard lock(pg.getWatchdogMutex()); + std::lock_guard lock(pg.getWatchdogMutex()); LOG(INFO) << "Lock watchdog thread."; // Wait long enough before monitor thread throws exceptions. std::this_thread::sleep_for( diff --git a/torch/csrc/distributed/c10d/LockGuard.cpp b/torch/csrc/distributed/c10d/LockGuard.cpp new file mode 100644 index 00000000000000..6ce3ea736f3b3f --- /dev/null +++ b/torch/csrc/distributed/c10d/LockGuard.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +namespace c10d::detail { + +void lockWithLogging( + std::unique_lock& lock, + std::chrono::milliseconds log_interval, + const char* desc, + const char* file, + int line) { + while (!lock.try_lock_for(log_interval)) { + C10D_WARNING( + "{}:{} {}: waiting for lock for {}ms", + file, + line, + desc, + log_interval.count()); + } +} + +} // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/LockGuard.hpp b/torch/csrc/distributed/c10d/LockGuard.hpp new file mode 100644 index 00000000000000..d2a4dd9b3743d8 --- /dev/null +++ b/torch/csrc/distributed/c10d/LockGuard.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Meta Platforms, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +#include + +namespace c10d::detail { + +// logWithLogging is a wrapper around std::unique_lock +// that automatically logs if the lock cannot be acquired within a given +// timeout. +TORCH_API void lockWithLogging( + std::unique_lock& lock, + std::chrono::milliseconds log_interval, + const char* desc, + const char* file, + int line); + +} // namespace c10d::detail + +// TODO: use std::source_location() when we can use C++20 +#define C10D_LOCK_GUARD(name, mutex) \ + std::unique_lock name{mutex, std::defer_lock}; \ + ::c10d::detail::lockWithLogging( \ + name, std::chrono::seconds(30), #mutex, __FILE__, __LINE__) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 16cb91015ae63d..46a2426fd5f4c0 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -21,7 +21,7 @@ constexpr int64_t kCommInitBusyWaitMillis = 10; namespace c10d { ncclComm_t NCCLComm::getNcclComm() { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); if (aborted_) { auto commFailureMsg = commFailureReason_ != std::nullopt ? c10::str(" Original reason for failure was: ", *commFailureReason_) @@ -391,7 +391,7 @@ std::optional NCCLTraceBuffer::record( } auto traceback = torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - std::lock_guard guard(mutex_); + C10D_LOCK_GUARD(guard, mutex_); auto te = Entry{ id_, @@ -448,7 +448,7 @@ void NCCLTraceBuffer::record_pg_ranks( if (!enabled_) { return; } - std::lock_guard guard(mutex_); + C10D_LOCK_GUARD(guard, mutex_); pg_name_to_ranks_[pg_name] = ranks; } @@ -468,7 +468,7 @@ void NCCLTraceBuffer::update_state(Entry& r) { } std::vector NCCLTraceBuffer::dump_entries() { - std::lock_guard guard(mutex_); + C10D_LOCK_GUARD(guard, mutex_); std::vector result; result.reserve(entries_.size()); result.insert(result.end(), entries_.begin() + next_, entries_.end()); @@ -493,7 +493,7 @@ void NCCLTraceBuffer::retire_id( Event* endEvent = nullptr; std::optional duration = std::nullopt; - std::unique_lock guard(mutex_); + C10D_LOCK_GUARD(guard, mutex_); Entry* entry = &entries_.at(*id % max_entries_); if (entry->id_ == *id) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index c80742d90a4db5..e45d9d09b25379 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -271,7 +272,7 @@ class NCCLComm { ~NCCLComm() noexcept { // Add lock in this destructor, as aborted_ needs to be read after memory // barrier here. - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); if (ncclComm_ && initialized_ && !aborted_) { #ifdef ENABLE_NCCL_ERROR_CHECKING // Use ncclCommAbort instead of ncclCommDestroy here since @@ -363,7 +364,7 @@ class NCCLComm { NCCLComm(NCCLComm&& other) { // Using other's lock, as it reads other's states // Can not use this.mutex_, as this object is being constructed. - std::unique_lock lock(other.mutex_); + C10D_LOCK_GUARD(lock, other.mutex_); std::swap(ncclComm_, other.ncclComm_); std::swap(aborted_, other.aborted_); std::swap(ncclAsyncErr_, other.ncclAsyncErr_); @@ -373,13 +374,13 @@ class NCCLComm { ncclComm_t getNcclComm(); std::optional getNcclCommFailureReason() const { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); return commFailureReason_; } void ncclCommAbort( std::optional commFailureReason = std::nullopt) { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_ && !initialized_) { // Should not abort twice. @@ -427,7 +428,7 @@ class NCCLComm { } bool isAborted() const { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); return aborted_; } @@ -436,7 +437,7 @@ class NCCLComm { } ncclResult_t checkForNcclError() { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (ncclAsyncErr_ != ncclSuccess) { return ncclAsyncErr_; @@ -451,7 +452,7 @@ class NCCLComm { } ncclResult_t registerSegment(void* ptr, size_t size) { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always @@ -482,7 +483,7 @@ class NCCLComm { } ncclResult_t deregisterSegment(void* ptr) { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); #ifdef NCCL_HAS_COMM_REGISTER TORCH_CHECK( registeredSegmentHandles_.count(ptr) == 1, @@ -519,7 +520,7 @@ class NCCLComm { bool aborted_; uint64_t ncclCommSplitCounter_{0}; ncclResult_t ncclAsyncErr_; - mutable std::mutex mutex_; + mutable std::timed_mutex mutex_; // Rank that this communicator corresponds to. int rank_; // Optional reason for communicator failure, provided by ProcessGroupNCCL for @@ -638,7 +639,7 @@ struct NCCLTraceBuffer { bool enabled_ = false; bool capture_cpp_stack_ = false; - std::mutex mutex_; + std::timed_mutex mutex_; std::vector entries_; size_t max_entries_ = 0; size_t next_ = 0; diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 51fa248ec403b9..96ea93f7413b6c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -602,7 +602,7 @@ uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const { } int ProcessGroupGloo::RecvWork::sourceRank() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); return srcRank_; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 0d3b16eab127ad..ce2f60f4bc0a95 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -301,7 +302,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { // hooks are called outside the scope of any PG, thus we need traverse // communicators in all PGs. static std::unordered_map, int> ncclCommDevIdxMap; -static std::mutex ncclCommDevIdxMapMutex; +static std::timed_mutex ncclCommDevIdxMapMutex; static bool allocatorHooksAttached = false; std::atomic ProcessGroupNCCL::shouldDump_(false); @@ -314,7 +315,7 @@ void cacheAllocatorRegisterHook( return; } - std::lock_guard lock(ncclCommDevIdxMapMutex); + C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex); for (auto& it : ncclCommDevIdxMap) { auto& ncclComm = it.first; auto& devIdx = it.second; @@ -332,7 +333,7 @@ void cacheAllocatorDeregisterHook( return; } - std::lock_guard lock(ncclCommDevIdxMapMutex); + C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex); for (auto& it : ncclCommDevIdxMap) { auto& ncclComm = it.first; auto& devIdx = it.second; @@ -551,7 +552,7 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { } auto exception_ptr = checkForNCCLErrors(); - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); exception_ = exception_ptr; if (exception_) { LOG(ERROR) << logPrefix() << "Collective " << *this @@ -567,7 +568,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { void ProcessGroupNCCL::WorkNCCL::setException( std::exception_ptr exception_ptr) { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); exception_ = exception_ptr; } @@ -776,12 +777,12 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( bool timing) { auto deleter = [this, timing](at::cuda::CUDAEvent* event) { - std::lock_guard lock(this->cacheMutex_); + C10D_LOCK_GUARD(lock, this->cacheMutex_); this->eventsArray_[timing ? 1 : 0].push_back(event); }; at::cuda::CUDAEvent* event = nullptr; { - std::lock_guard lock(cacheMutex_); + C10D_LOCK_GUARD(lock, cacheMutex_); auto events = eventsArray_[timing ? 1 : 0]; if (!events.empty()) { event = events.back(); @@ -1086,8 +1087,9 @@ void ProcessGroupNCCL::waitForPendingWorks() { while (true) { { std::lock(workMetaListMutex_, completedWorkListMutex_); - std::lock_guard lockWork(workMetaListMutex_, std::adopt_lock); - std::lock_guard lockHook( + std::lock_guard lockWork( + workMetaListMutex_, std::adopt_lock); + std::lock_guard lockHook( completedWorkListMutex_, std::adopt_lock); if (workMetaList_.empty() && completedWorkList_.empty()) { @@ -1207,7 +1209,7 @@ bool ProcessGroupNCCL::abort(std::optional abortReason) { } ncclCommDevIdxMapMutex.unlock(); - std::lock_guard lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); abortCommsFromMap(devNCCLCommMap_, abortReason); abortCommsFromMap(inInitializationCommMap_, abortReason); return true; @@ -1276,8 +1278,8 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { // Serialize all calls to this function to avoid corrupting data, but allow // multiple calls in one runtime. User is responsible for preserving the // output file from an earlier call before a later call overwrites it. - static std::mutex writeDebugInfoMutex; - std::lock_guard lock(writeDebugInfoMutex); + static std::timed_mutex writeDebugInfoMutex; + C10D_LOCK_GUARD(lock, writeDebugInfoMutex); LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; if (ncclTraceBufferSize_ > 0) { // We dump nccl trace into local disk by default and users can register @@ -1356,7 +1358,7 @@ void ProcessGroupNCCL::heartbeatMonitor() { // This won't have any lock since this lock is only used here. // Please be aware that mutex `monitorMutex_` should not be used // somewhere else to avoid the deadlock. - std::unique_lock lock(monitorMutex_); + C10D_LOCK_GUARD(lock, monitorMutex_); if (monitorWakeUpCV_.wait_for( lock, std::chrono::milliseconds(monitorPollInterval), [&] { return terminateHeartbeatMonitorThread_.load(); @@ -1681,7 +1683,7 @@ const std::vector& ProcessGroupNCCL::groupRanks() const { void ProcessGroupNCCL::addEphemeralTimeout( const std::chrono::milliseconds& timeout) { - std::lock_guard timeoutLock(mtxTimeoutExtension_); + C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); ephemeralTimeoutActive_ += timeout; } @@ -1704,7 +1706,7 @@ void ProcessGroupNCCL::watchdogHandler() { std::list completedWorkList; while (!done || !terminateProcessGroup_.load()) { - std::unique_lock lock(workMetaListMutex_); + C10D_LOCK_GUARD(lock, workMetaListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. workMetaListCV_.wait_for( @@ -1876,7 +1878,7 @@ void ProcessGroupNCCL::watchdogHandler() { if (work.isCompleted()) { { // Reset the timeout and first work if the work is completed. - std::lock_guard timeoutLock(mtxTimeoutExtension_); + C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); if (work.ownedEphermeralTimeout_.count() > 0) { ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; @@ -1891,7 +1893,7 @@ void ProcessGroupNCCL::watchdogHandler() { // Move Work object to completedWorkList_ to be consumed by the hook // thread { - const std::lock_guard lock(completedWorkListMutex_); + C10D_LOCK_GUARD(lock, completedWorkListMutex_); completedWorkList_.splice( completedWorkList_.end(), workMetaList_, it++); } @@ -1919,7 +1921,7 @@ void ProcessGroupNCCL::runHookLoop() { bool done = false; while (!done || !terminateProcessGroup_.load()) { - std::unique_lock lock(completedWorkListMutex_); + C10D_LOCK_GUARD(lock, completedWorkListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. completedWorkListCV_.wait_for( @@ -2092,7 +2094,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( } void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { - std::lock_guard lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { TORCH_INTERNAL_ASSERT( false, @@ -2140,7 +2142,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( usedDeviceIdxs_.insert(device.index()); { - std::lock_guard lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { // Reuse the cached communicator if there is one. return devNCCLCommMap_[deviceKey]; @@ -2214,7 +2216,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( options_->split_color != 0, "Must specify a non-zero color when splitting"); // Find a valid, healthy communicator to split from if possible. - std::lock_guard lock(options_->split_from->mutex_); + C10D_LOCK_GUARD(lock, options_->split_from->mutex_); auto& other_comms = options_->split_from->devNCCLCommMap_; auto dit = other_comms.find(getKeyFromDevice(device)); if (dit != other_comms.end()) { @@ -2268,7 +2270,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( options_->is_high_priority_stream || force_high); { - std::lock_guard lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); inInitializationCommMap_.emplace(deviceKey, ncclComm); } @@ -2518,7 +2520,7 @@ void ProcessGroupNCCL::assignTimeoutToWork( const c10::intrusive_ptr& work, const c10::intrusive_ptr& option) { std::chrono::milliseconds timeout = option->timeout; - std::lock_guard timeoutLock(mtxTimeoutExtension_); + C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); if (ephemeralTimeoutActive_.count() > 0) { timeout += ephemeralTimeoutActive_; } @@ -2531,7 +2533,7 @@ void ProcessGroupNCCL::assignTimeoutToWork( void ProcessGroupNCCL::workEnqueue( c10::intrusive_ptr work) { if (!terminateProcessGroup_.load()) { - std::lock_guard lock(workMetaListMutex_); + C10D_LOCK_GUARD(lock, workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. // View tensors' destruction invokes autograd_meta, which // needs to be destructed in user thread. Otherwise will diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index e099a75de58879..b6635157d75e1b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -449,7 +449,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { static CUDAEventCache& get(); private: - std::mutex cacheMutex_; + std::timed_mutex cacheMutex_; // NOTE: We intentionaly store raw pointers so that // we do not attempt to destroy the event objects on process exit, // because cuda may be gone. @@ -920,7 +920,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // ephemeralTimeoutActive_/ephemeralTimeoutInflight_. // TODO(fduwjj): We need to have an audit on all mutexes we are adding here. // And consolidate them if possible. - std::mutex mtxTimeoutExtension_; + std::timed_mutex mtxTimeoutExtension_; // The ephemeral timeout added on top of existing timeout for works issued // before first work finishes. @@ -980,7 +980,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { inInitializationCommMap_; // Mutex to guard maps like devNCCLCommMap_. - std::mutex mutex_; + std::timed_mutex mutex_; // Heartbeat of watchdog thread. std::atomic_uint64_t heartbeat_; @@ -1041,18 +1041,18 @@ class TORCH_API ProcessGroupNCCL : public Backend { static std::atomic shouldDump_; // Mutex to Guard workMetaList_ - std::mutex workMetaListMutex_; + std::timed_mutex workMetaListMutex_; // Mutex to Guard monitorWakeUpCV_ - std::mutex monitorMutex_; + std::timed_mutex monitorMutex_; bool writeDebugInfo_ = false; // Condition Variable for watchdog thread sleep - std::condition_variable workMetaListCV_; + std::condition_variable_any workMetaListCV_; // Condition Variable for monitor thread to wake up early - std::condition_variable monitorWakeUpCV_; + std::condition_variable_any monitorWakeUpCV_; // Vector to Store WorkNCCL pointers std::list workMetaList_; @@ -1060,10 +1060,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::chrono::time_point lastWorkListUpdateTime_; // Mutex to Guard workMetaList_ - std::mutex completedWorkListMutex_; + std::timed_mutex completedWorkListMutex_; // Condition Variable for watchdog thread sleep - std::condition_variable completedWorkListCV_; + std::condition_variable_any completedWorkListCV_; std::list completedWorkList_; diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index 8beb8f29362080..11453f0454357a 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,5 +1,6 @@ #include +#include #include #include @@ -45,17 +46,17 @@ OpType Work::retrieveOpType() const { Work::~Work() = default; bool Work::isCompleted() { - std::lock_guard lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); return completed_; } bool Work::isSuccess() const { - std::lock_guard lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); return !exception_; } std::exception_ptr Work::exception() const { - std::lock_guard lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); return exception_; } @@ -73,7 +74,7 @@ std::vector Work::result() { void Work::synchronize() {} bool Work::wait(std::chrono::milliseconds timeout) { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); if (timeout == kNoTimeout) { // This waits without a timeout. cv_.wait(lock, [&] { return completed_; }); @@ -103,7 +104,7 @@ c10::intrusive_ptr Work::getFuture() { } void Work::finish(std::exception_ptr exception) { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { @@ -115,7 +116,7 @@ void Work::finish(std::exception_ptr exception) { } void Work::finishAndThrow(std::exception_ptr exception) { - std::unique_lock lock(mutex_); + C10D_LOCK_GUARD(lock, mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index c10e5007b9f544..ea5853aefb4049 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -126,8 +126,8 @@ class TORCH_API Work : public torch::CustomClassHolder { // provided by the user. void finishAndThrow(std::exception_ptr exception); - mutable std::mutex mutex_; - std::condition_variable cv_; + mutable std::timed_mutex mutex_; + std::condition_variable_any cv_; bool completed_ = false; std::exception_ptr exception_; From 77d88d7c8c06a1fdda63c7f1dfbc6c64dc142842 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 23 Aug 2024 12:32:02 -0700 Subject: [PATCH 0128/1018] [BE] no need to print stream in comm abort (#134362) Strictly speaking, NCCL communicator has nothing to do with CUDA streams. Thus, we don't need to print stream in comm abort's message. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134362 Approved by: https://github.com/fduwjj, https://github.com/wconstab --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index ce2f60f4bc0a95..7b877de3d6de35 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1183,15 +1183,8 @@ void ProcessGroupNCCL::abortCommsFromMap( // their responsibility to destroy the process group and recreate // it to recover from errors. - c10::StreamId streamId = -1; - if (ncclStreams_.find(devName) != ncclStreams_.end()) { - auto stream = ncclStreams_.at(devName); - streamId = stream.id(); - } - LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed " - << " communicator on CUDA device: " << devName - << " with stream: " << streamId; + << " communicator on CUDA device: " << devName; } } From 93f148926d15f00e639ce3aee5bc038ca28c5231 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 28 Aug 2024 03:22:01 +0000 Subject: [PATCH 0129/1018] [21/N] Fix clang-tidy warnings in jit (#134537) Follows #133399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134537 Approved by: https://github.com/Skylion007 --- torch/csrc/jit/api/function_impl.h | 1 - .../csrc/jit/codegen/fuser/cuda/fused_kernel.cpp | 16 ++++------------ torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h | 9 ++++----- torch/csrc/jit/codegen/fuser/partition_desc.h | 3 +-- torch/csrc/jit/runtime/jit_trace.cpp | 10 ++++++---- torch/csrc/jit/runtime/jit_trace.h | 2 +- torch/csrc/jit/runtime/register_ops_utils.cpp | 6 +++--- torch/csrc/jit/runtime/register_prim_ops.cpp | 2 +- torch/csrc/jit/runtime/script_profile.cpp | 2 +- .../csrc/jit/runtime/symbolic_shape_registry.cpp | 4 ++-- torch/csrc/jit/runtime/symbolic_shape_registry.h | 2 +- torch/csrc/jit/serialization/pickler.h | 2 -- 12 files changed, 24 insertions(+), 35 deletions(-) diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h index 01e7a3c98e3024..b5d336db2b6ec2 100644 --- a/torch/csrc/jit/api/function_impl.h +++ b/torch/csrc/jit/api/function_impl.h @@ -7,7 +7,6 @@ namespace torch::jit { struct TORCH_API GraphFunction : public Function { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GraphFunction( c10::QualifiedName name, std::shared_ptr graph, diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp index 8ff804c4601522..fd9fca230b9668 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp @@ -16,13 +16,9 @@ #include #include #include -#include #include -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { +namespace torch::jit::fuser::cuda { // See NOTE [ USE OF NVRTC AND DRIVER API ] const at::cuda::NVRTC& nvrtc() { @@ -85,7 +81,6 @@ void codegenOutputQuery( } // Compiles the specified kernel and stores the metadata required to run it -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FusedKernelCUDA::FusedKernelCUDA( at::DeviceIndex device, std::string name, @@ -114,6 +109,7 @@ FusedKernelCUDA::FusedKernelCUDA( // Acquires device and NVRTC properties (for compile arch and occupancy // calculations) + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) prop_ = at::cuda::getCurrentDeviceProperties(); int major = 0, minor = 0; bool compile_to_sass = false; @@ -197,8 +193,7 @@ static int ceilDiv(const int a, const int b) { void FusedKernelCUDA::launch_raw( const uint32_t numel, std::vector& arguments) const { - // NOLINTNEXTLINE(bugprone-unused-raii) - at::cuda::CUDAGuard{device_}; + at::cuda::CUDAGuard guard{device_}; // Hacked at::DeviceGuard (see note above) const auto prior_device = at::cuda::current_device(); at::cuda::set_device(device_); @@ -269,7 +264,4 @@ static std::shared_ptr createFusionKernel( RegisterFusionBackend reg(DeviceType::CUDA, createFusionKernel); -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser::cuda diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h index 5593ad570bbb21..d635049e758a2c 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -50,11 +49,11 @@ struct TORCH_CUDA_CU_API FusedKernelCUDA // Note: per device to store device properties and compute launch heuristics // Acquiring these values at launch time would be too slow at::DeviceIndex device_; - int maxBlocks_; - cudaDeviceProp* prop_; + int maxBlocks_{}; + cudaDeviceProp* prop_{}; std::vector ptx_; - CUmodule module_; - CUfunction function_; + CUmodule module_{}; + CUfunction function_{}; }; } // namespace torch::jit::fuser::cuda diff --git a/torch/csrc/jit/codegen/fuser/partition_desc.h b/torch/csrc/jit/codegen/fuser/partition_desc.h index ad42609e0ddc54..964e1821364a50 100644 --- a/torch/csrc/jit/codegen/fuser/partition_desc.h +++ b/torch/csrc/jit/codegen/fuser/partition_desc.h @@ -28,8 +28,7 @@ struct TORCH_API PartitionDesc { // so dim - 1 is no longer contiguous cont[dim_ - 1] = false; } - // NOLINTNEXTLINE(modernize-make-shared) - subTensorDesc_.reset(new TensorDesc(_desc.scalar_type, cont)); + subTensorDesc_ = std::make_shared(_desc.scalar_type, cont); } bool isNoop() const { diff --git a/torch/csrc/jit/runtime/jit_trace.cpp b/torch/csrc/jit/runtime/jit_trace.cpp index cff9d5f954e813..b25088b32ecae2 100644 --- a/torch/csrc/jit/runtime/jit_trace.cpp +++ b/torch/csrc/jit/runtime/jit_trace.cpp @@ -56,8 +56,8 @@ Node* traceNode(Node* node, TracingData& td, Stack& stack) { } void eraseAllOutputs(Node* opt_pn) { - // NOLINTNEXTLINE - for (int i = opt_pn->outputs().size() - 1; i >= 0; i--) { + for (auto i = static_cast(opt_pn->outputs().size()) - 1; i >= 0; + i--) { opt_pn->eraseOutput(i); } } @@ -275,10 +275,12 @@ void insertTracingNodes(Block* block, ProfilingRecord* pr, TracingData& td) { // nodes and the outputs of the node in the scripted graph. // There are a few subtleties with tracing Ifs and Loops // discussed above -std::shared_ptr TraceGraph(std::shared_ptr graph, Stack& stack) { +std::shared_ptr TraceGraph( + const std::shared_ptr& graph, + Stack& stack) { TracingData td; GRAPH_DUMP("Before Inline:", graph); - Inline(*graph.get()); + Inline(*graph); EliminateDeadCode(graph); GRAPH_DUMP("After Inline:", graph); auto pr = ProfilingRecord::instrumentGraph(graph); diff --git a/torch/csrc/jit/runtime/jit_trace.h b/torch/csrc/jit/runtime/jit_trace.h index 12be844e35a91a..9b29501eeb3f91 100644 --- a/torch/csrc/jit/runtime/jit_trace.h +++ b/torch/csrc/jit/runtime/jit_trace.h @@ -3,6 +3,6 @@ namespace torch::jit { TORCH_API std::shared_ptr TraceGraph( - std::shared_ptr graph, + const std::shared_ptr& graph, Stack& stack); } // namespace torch::jit diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index a057367af81c96..abbdf44ec60518 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -287,12 +287,12 @@ void listAdd(Stack& stack) { c10::List ret = make_result_list(a.elementType()); if (a.use_count() == 1) { - ret = std::move(a); + ret = a; } else { ret = a.copy(); } - ret.append(std::move(b)); + ret.append(b); push(stack, std::move(ret)); } @@ -300,7 +300,7 @@ void listAdd(Stack& stack) { void listInplaceAdd(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); - a.append(std::move(b)); + a.append(b); push(stack, std::move(a)); } diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index ed56e7d4322f61..d20c5d6a0fec54 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -830,7 +830,7 @@ static const std::vector opGenArgs{ ss << i; } drop(stack, num_inputs); - ss << std::endl; + ss << '\n'; auto* handler = getPrintHandler(); TORCH_INTERNAL_ASSERT(handler); handler(ss.str()); diff --git a/torch/csrc/jit/runtime/script_profile.cpp b/torch/csrc/jit/runtime/script_profile.cpp index c31f27223b8b7b..3ad4716d32b598 100644 --- a/torch/csrc/jit/runtime/script_profile.cpp +++ b/torch/csrc/jit/runtime/script_profile.cpp @@ -147,7 +147,7 @@ const ScriptProfile::SourceMap& ScriptProfile::dumpStats() { for (const auto& datapoint : datapoints_) { if (const auto& source = datapoint->sourceRange.source()) { if (auto fileLineCol = datapoint->sourceRange.file_line_col()) { - auto it = sourceMap_.find(*source.get()); + auto it = sourceMap_.find(*source); if (it == sourceMap_.end()) { it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first; } diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 837ce21095adeb..74f87e46757eaf 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -219,7 +219,7 @@ void checkInputAndOutputTypes( void transformShapeFunction( const FunctionSchema* schema_string, - std::shared_ptr graph) { + const std::shared_ptr& graph) { Inline(*graph); // ATEN operators can return multiple unboxed values, this in contrast to @@ -411,7 +411,7 @@ TORCH_API std::optional boundedGraphsForSchema( void RegisterShapeComputeGraphForSchema( const FunctionSchema& schema, - std::shared_ptr g) { + const std::shared_ptr& g) { std::lock_guard guard(lock); if (cached_schema_to_graph.empty()) { loadFunctions(); diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.h b/torch/csrc/jit/runtime/symbolic_shape_registry.h index a14d327aab4291..7222fd8bca3269 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.h +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.h @@ -52,7 +52,7 @@ struct BoundedShapeGraphs { TORCH_API void RegisterShapeComputeGraphForSchema( const FunctionSchema& schema, - std::shared_ptr g); + const std::shared_ptr& g); TORCH_API std::optional> shapeComputeGraphForSchema( const FunctionSchema& schema); diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index ffe0c8330708e1..9be9b0fb2d8c1d 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -94,7 +94,6 @@ enum class PickleOpCode : char { using ::c10::IValue; -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct WriteableTensorData { const char* data() const { return static_cast(tensor_.storage().data()); @@ -140,7 +139,6 @@ class TORCH_API Pickler { memoized_class_types_(memoized_class_types), get_tensor_id_(std::move(get_tensor_id)), tag_aggregates_(tag_aggregates) {} - // NOLINTNEXTLINE(bugprone-exception-escape) ~Pickler(); // Push protocol onto the stack From 8ea10ffa986ae7fbf667c901b3a7fa59d5a417b9 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Wed, 28 Aug 2024 03:45:46 +0000 Subject: [PATCH 0130/1018] Back out "[pytorch][PR] [export] Schematize nn_module_stack serialization" (#134628) Summary: Breaking backward compatibilities for serialization and deserialization Differential Revision: D61888223 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134628 Approved by: https://github.com/angelayi --- test/export/test_serialize.py | 30 ------------------- torch/_export/serde/serialize.py | 51 +++++++++++++++++++------------- 2 files changed, 31 insertions(+), 50 deletions(-) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 9a2290fcba1164..1c31f64d2e0a2e 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -441,36 +441,6 @@ def forward(self, x): if "aten.sum.dim_IntList" in node.target: self.assertEqual(node.inputs[1].arg.type, "as_ints") - def test_nn_module_stack_serde_with_commas(self) -> None: - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 2) - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(self.linear(x)) - - # export the model - ep = torch.export.export(M(), (torch.randn(2, 2),)) - # modify the nn_module_stack to contain comma(,) and semicolon(;) - for node in ep.graph.nodes: - if nn_module_stack := node.meta.get("nn_module_stack"): - for k, (p, t) in nn_module_stack.items(): - nn_module_stack[k] = (p + ";semicolon", t + ",comma") - node.meta["nn_module_stack"] = nn_module_stack - # serialize and deserialize the model - buffer = io.BytesIO() - save(ep, buffer) - buffer.seek(0) - loaded_ep = load(buffer) - # check that the output is the same - inp = (torch.randn(2, 2),) - exp_out = ep.module()(*inp) - actual_out = loaded_ep.module()(*inp) - self.assertEqual(exp_out, actual_out) - self.assertEqual(exp_out.requires_grad, actual_out.requires_grad) - @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index f77d081d5aa85c..3b8803b776f6f9 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -10,6 +10,7 @@ import logging import math import operator +import re import typing import traceback @@ -578,23 +579,21 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: ret["stack_trace"] = stack_trace if nn_module_stack := node.meta.get("nn_module_stack"): - keys, paths, tys = [], [], [] - for k, v in nn_module_stack.items(): - keys.append(k) - assert isinstance(v, tuple) and len(v) == 2 - path, ty = v + + def export_nn_module_stack(val): + assert isinstance(val, tuple) and len(val) == 2 + path, ty = val + assert isinstance(path, str) assert isinstance(ty, str) - paths.append(path) - tys.append(ty) - nn_module_stack_dict = { - "key_list": keys, - "path_list": paths, - "type_list": tys - } + return path + "," + ty - ret["nn_module_stack"] = json.dumps(nn_module_stack_dict) + # Serialize to "key,orig_path,type_str" + nn_module_list = [ + f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items() + ] + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) if source_fn_st := node.meta.get("source_fn_stack"): source_fn_list = [ @@ -2189,13 +2188,25 @@ def deserialize_meta_func(serialized_target: str): if nn_module_stack_str := metadata.get("nn_module_stack"): # Originally serialized to "key,orig_path,type_str" - nn_module_stack_dict = json.loads(nn_module_stack_str) - nn_module_stack = { - key: (path, ty) - for key, path, ty in zip( - nn_module_stack_dict["key_list"], - nn_module_stack_dict["path_list"], - nn_module_stack_dict["type_list"])} + def import_nn_module_stack(key, path, ty): + return key, (path, ty) + + # Helper function that splits strings by commas except for those + # encapsulated by parens, which are valid traces. + # TODO: Currently this is needed due to indexing Sequential + # layers introducing names in the form "layer.slice(1, None, None)". + # If that naming is improved, this fancier splitting can probably be + # reverted to a simple split by comma. + def metadata_split(metadata): + # Remove the parentheses and commas inside them + metadata = re.sub(r'\(.*?\)', '', metadata) + # Split the string by comma, except for those inside parentheses + return re.split(r'(? Date: Wed, 28 Aug 2024 04:05:06 +0000 Subject: [PATCH 0131/1018] [Reland] [1/N] Fix clang-tidy warnings in inductor (#134544) Reland #131979 and exclude aoti_torch_index_put_out changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134544 Approved by: https://github.com/ColinPeppler --- .../inductor/aoti_eager/kernel_holder.cpp | 59 ++++++++----------- .../inductor/aoti_eager/kernel_meta_info.cpp | 37 +++++------- .../inductor/aoti_eager/kernel_meta_info.h | 7 +-- .../aoti_runner/model_container_runner.cpp | 6 +- .../model_container_runner_cpu.cpp | 2 +- .../model_container_runner_cuda.cpp | 2 +- .../inductor/aoti_torch/mkldnn_tensor.cpp | 6 +- .../inductor/aoti_torch/tensor_converter.cpp | 6 +- torch/csrc/inductor/inductor_ops.cpp | 6 +- torch/csrc/inductor/resize_storage_bytes.cpp | 6 +- 10 files changed, 58 insertions(+), 79 deletions(-) diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index aa9566acd95712..c364daa82e18b3 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -64,8 +64,8 @@ std::vector unpack_tensors( const c10::Device& device) { std::vector inputs; for (size_t idx = 0; idx < stack.size(); idx++) { - auto ivalue = stack[idx]; - auto ivalue_arg = arguments[idx]; + const auto& ivalue = stack[idx]; + const auto& ivalue_arg = arguments[idx]; if (ivalue.isTensor()) { unpack_tensor_ivalue(ivalue, device, inputs); } else if (ivalue.isTensorList()) { @@ -117,12 +117,10 @@ std::vector unpack_input_parameters( if (stack[idx].isScalar()) { // Beyond c10::Scalar, the floating value and interger value are also // represented as Scalar. - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toScalar(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toScalar(), arg_order); } else if (stack[idx].isTensorList()) { // tensor list - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toTensorList().vec(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toTensorList().vec(), arg_order); } else if (stack[idx].isOptionalTensorList()) { // optional tensor list: std::vector> std::vector tensor_list; @@ -131,27 +129,23 @@ std::vector unpack_input_parameters( tensor_list.push_back(item.toOptional().value()); } } - inputs_metadata.push_back(ParameterMetadata(tensor_list, arg_order)); + inputs_metadata.emplace_back(tensor_list, arg_order); } else if ( *arguments[idx].real_type() == *c10::getTypePtr>()) { // optional tensor if (stack[idx].toOptional().has_value()) { - inputs_metadata.push_back(ParameterMetadata( - stack[idx].toOptional().value(), arg_order)); + inputs_metadata.emplace_back( + stack[idx].toOptional().value(), arg_order); } } else if (stack[idx].isTensor()) { - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toTensor(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toTensor(), arg_order); } else if (stack[idx].isString()) { - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toStringRef(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toStringRef(), arg_order); } else if (stack[idx].isBool()) { - inputs_metadata.push_back( - ParameterMetadata(c10::Scalar(stack[idx].toBool()), arg_order)); + inputs_metadata.emplace_back(c10::Scalar(stack[idx].toBool()), arg_order); } else if (stack[idx].isDevice()) { - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toDevice(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toDevice(), arg_order); } else { TORCH_CHECK_NOT_IMPLEMENTED( false, @@ -239,7 +233,7 @@ void AOTIPythonKernelHolder::cache_hit( auto outputs = aoti_kernel_metadata.kernel_runner_->run(inputs); for (auto& output : outputs) { - stack->push_back(output); + stack->emplace_back(output); } } @@ -343,8 +337,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { auto tensor_metadata = build_tensor_metadata(metadata); test_list_metadata.push_back(tensor_metadata); } - parameter_metadata_list.push_back( - ParameterMetadata(test_list_metadata, arg_idx)); + parameter_metadata_list.emplace_back(test_list_metadata, arg_idx); } else if (is_scalar) { // Scalar auto metadata = item_metadata.cast(); @@ -367,14 +360,12 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { dtype_value); } - parameter_metadata_list.push_back( - ParameterMetadata(c10::Scalar(scalar), arg_idx)); + parameter_metadata_list.emplace_back(c10::Scalar(scalar), arg_idx); } else if (is_string) { // String auto metadata = item_metadata.cast(); auto str_value = metadata["string_value"].cast(); - parameter_metadata_list.push_back( - ParameterMetadata(str_value, arg_idx)); + parameter_metadata_list.emplace_back(str_value, arg_idx); } else if (is_dtype) { // Dtype auto metadata = item_metadata.cast(); @@ -382,8 +373,8 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_value_obj.ptr())); auto dtype_value = reinterpret_cast(dtype_value_obj.ptr())->scalar_type; - parameter_metadata_list.push_back(ParameterMetadata( - c10::Scalar(static_cast(dtype_value)), arg_idx)); + parameter_metadata_list.emplace_back( + c10::Scalar(static_cast(dtype_value)), arg_idx); } else if (is_device) { // Device auto metadata = item_metadata.cast(); @@ -395,21 +386,20 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { metadata["device_index_value"].cast(); device.set_index(device_index_value); } - parameter_metadata_list.push_back(ParameterMetadata(device, arg_idx)); + parameter_metadata_list.emplace_back(device, arg_idx); } else if (is_layout) { auto metadata = item_metadata.cast(); auto layout_value_obj = metadata["layout_value"].cast(); TORCH_INTERNAL_ASSERT(THPLayout_Check(layout_value_obj.ptr())); auto layout_value = reinterpret_cast(layout_value_obj.ptr())->layout; - parameter_metadata_list.push_back(ParameterMetadata( - c10::Scalar(static_cast(layout_value)), arg_idx)); + parameter_metadata_list.emplace_back( + c10::Scalar(static_cast(layout_value)), arg_idx); } else { // Tensor auto metadata = item_metadata.cast(); auto tensor_metadata = build_tensor_metadata(metadata); - parameter_metadata_list.push_back( - ParameterMetadata(tensor_metadata, arg_idx)); + parameter_metadata_list.emplace_back(tensor_metadata, arg_idx); } } @@ -480,9 +470,12 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib( schema.overload_name().empty() ? "default" : schema.overload_name(); auto pos = qualified_name.find("::"); TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); - std::string ns_str(qualified_name.begin(), qualified_name.begin() + pos); + std::string ns_str( + qualified_name.begin(), + qualified_name.begin() + static_cast(pos)); std::string func_name( - qualified_name.begin() + pos + strlen("::"), qualified_name.end()); + qualified_name.begin() + static_cast(pos + strlen("::")), + qualified_name.end()); py::gil_scoped_acquire gil; py::handle op_py_func = op.getPythonOp(pyinterpreter_, [&]() -> PyObject* { diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp index a5ca876d45392f..6562447bcba4fa 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp @@ -1,6 +1,7 @@ #if !defined(C10_MOBILE) && !defined(ANDROID) #include #include +#include namespace torch::inductor { @@ -25,8 +26,8 @@ TensorMetadata::TensorMetadata( dtype_(dtype), device_(device), dispatch_key_set_(dispatch_key_set), - sizes_(sizes), - strides_(strides), + sizes_(std::move(sizes)), + strides_(std::move(strides)), requires_grad_(requires_grad) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( !is_symbolic_, "Not support symbolic shape now"); @@ -94,25 +95,24 @@ bool TensorMetadata::operator==(const TensorMetadata& other) const { std::ostream& operator<<( std::ostream& stream, const TensorMetadata& tensor_metadata) { - stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << std::endl; - stream << "dtype_: " << tensor_metadata.dtype_ << std::endl; - stream << "device_: " << tensor_metadata.device_ << std::endl; + stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << '\n'; + stream << "dtype_: " << tensor_metadata.dtype_ << '\n'; + stream << "device_: " << tensor_metadata.device_ << '\n'; stream << "sizes_: "; for (const auto& size : tensor_metadata.sizes_) { stream << size << " "; } - stream << std::endl; + stream << '\n'; stream << "strides_: "; for (const auto& stride : tensor_metadata.strides_) { stream << stride << " "; } - stream << "requires_grad_: " << tensor_metadata.requires_grad_ << std::endl; - stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_ - << std::endl; + stream << "requires_grad_: " << tensor_metadata.requires_grad_ << '\n'; + stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_ << '\n'; stream << "tensor_check_: " << tensor_metadata.tensor_check_.has_value() - << std::endl; - stream << std::endl; + << '\n'; + stream << '\n'; return stream; } @@ -138,8 +138,9 @@ ParameterMetadata::ParameterMetadata( uint64_t input_order) : tag_(TENSOR_LIST), order_(input_order) { std::vector tensor_metadata_list; + tensor_metadata_list.reserve(tensor_list.size()); for (const auto& tensor : tensor_list) { - tensor_metadata_list.push_back(TensorMetadata(tensor)); + tensor_metadata_list.emplace_back(tensor); } value_ = tensor_metadata_list; } @@ -147,23 +148,17 @@ ParameterMetadata::ParameterMetadata( ParameterMetadata::ParameterMetadata( const c10::Scalar& scalar, uint64_t input_order) - : tag_(SCALAR), order_(input_order) { - value_ = scalar; -} + : tag_(SCALAR), value_(scalar), order_(input_order) {} ParameterMetadata::ParameterMetadata( const std::string& str, uint64_t input_order) - : tag_(STRING), order_(input_order) { - value_ = str; -} + : tag_(STRING), value_(str), order_(input_order) {} ParameterMetadata::ParameterMetadata( const c10::Device& device, uint64_t input_order) - : tag_(DEVICE), order_(input_order) { - value_ = device; -} + : tag_(DEVICE), value_(device), order_(input_order) {} bool ParameterMetadata::operator==(const ParameterMetadata& other) const { // Same type diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h index d84ea9b477ef35..24d3c05bc3505c 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h @@ -10,8 +10,8 @@ namespace torch::inductor { // Regarding a aten operation implemented by AOTI, the metadata of the input -// tensors will be cached on the disk to acclerate next run. TensorMetada -// structure is to represent the metadata of each input tensor. it includes +// tensors will be cached on the disk to accelerate next run. TensorMetada +// structure is to represent the metadata of each input tensor. It includes // whether the tensor is symbolic, the dtype, the device, the sizes and the // strides of the tensor. When the metadata of the input tensors is the same as // the cached metadata, the cached kernel library will be loaded and executed. @@ -51,7 +51,6 @@ struct TensorMetadata { TensorMetadata() : is_symbolic_(false), - dtype_(c10::ScalarType::Undefined), device_(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES), sizes_({}), strides_({}) {} @@ -116,7 +115,7 @@ struct ParameterMetadata { // same tag. For example, an operation with two input tensors, the first // tensor is a optional tensor and the second tensor is a tensor. The first // tensor will have the order 0 and the second tensor will have the order 1. - uint64_t order_; + uint64_t order_{}; ParameterMetadata() : tag_(INVALID) {} ParameterMetadata(TensorMetadata tensor_metadata, uint64_t input_order); diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index c7c09ac3ea43d7..37c69ccfa813d0 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -72,7 +72,7 @@ AOTIModelContainerRunner::AOTIModelContainerRunner( model_so_->sym("AOTInductorModelContainerGetCallSpec")); // Hack to find the json file name from the model so file - size_t lastindex = model_so_path.find_last_of("."); + size_t lastindex = model_so_path.find_last_of('.'); std::string json_filename = model_so_path.substr(0, lastindex) + ".json"; if (file_exists(json_filename)) { @@ -189,8 +189,8 @@ void AOTIModelContainerRunner::swap_constant_buffer() { } std::vector AOTIModelContainerRunner::get_call_spec() { - const char* in_spec; - const char* out_spec; + const char* in_spec = nullptr; + const char* out_spec = nullptr; AOTI_RUNTIME_ERROR_CODE_CHECK( get_call_spec_func_(container_handle_, &in_spec, &out_spec)); return {in_spec, out_spec}; diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp index 40eb7407ecd317..25decff00c458d 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp @@ -10,7 +10,7 @@ AOTIModelContainerRunnerCpu::AOTIModelContainerRunnerCpu( size_t num_models) : AOTIModelContainerRunner(model_so_path, num_models, "cpu", "") {} -AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() {} +AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() = default; std::vector AOTIModelContainerRunnerCpu::run( std::vector& inputs) { diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp index 705a59eb3f394f..596b436bb703d1 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp @@ -14,7 +14,7 @@ AOTIModelContainerRunnerCuda::AOTIModelContainerRunnerCuda( device_str, cubin_dir) {} -AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() {} +AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() = default; std::vector AOTIModelContainerRunnerCuda::run( std::vector& inputs) { diff --git a/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp b/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp index 7f0811f0d88b50..23f15043241788 100644 --- a/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp +++ b/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp @@ -6,8 +6,7 @@ #include #endif -namespace torch { -namespace aot_inductor { +namespace torch::aot_inductor { #if AT_MKLDNN_ENABLED() @@ -45,5 +44,4 @@ at::Tensor mkldnn_tensor_from_data_ptr( #endif -} // namespace aot_inductor -} // namespace torch +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_torch/tensor_converter.cpp b/torch/csrc/inductor/aoti_torch/tensor_converter.cpp index d61fa20bab8789..b53a1d8811d811 100644 --- a/torch/csrc/inductor/aoti_torch/tensor_converter.cpp +++ b/torch/csrc/inductor/aoti_torch/tensor_converter.cpp @@ -1,8 +1,7 @@ #include #include -namespace torch { -namespace aot_inductor { +namespace torch::aot_inductor { std::vector unsafe_alloc_new_handles_from_tensors( std::vector& tensors) { @@ -45,5 +44,4 @@ std::vector alloc_tensors_by_stealing_from_handles( return result; } -} // namespace aot_inductor -} // namespace torch +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/inductor_ops.cpp b/torch/csrc/inductor/inductor_ops.cpp index ec9205f130249a..7d0e9b612343b8 100644 --- a/torch/csrc/inductor/inductor_ops.cpp +++ b/torch/csrc/inductor/inductor_ops.cpp @@ -10,8 +10,7 @@ #include -namespace torch { -namespace inductor { +namespace torch::inductor { using namespace at; Tensor _mm_plus_mm_out( @@ -111,5 +110,4 @@ TORCH_LIBRARY_FRAGMENT(inductor, m) { {at::Tag::pt2_compliant_tag}); } -} // namespace inductor -} // namespace torch +} // namespace torch::inductor diff --git a/torch/csrc/inductor/resize_storage_bytes.cpp b/torch/csrc/inductor/resize_storage_bytes.cpp index 94522b7df8c82d..018acb1a0fc5a2 100644 --- a/torch/csrc/inductor/resize_storage_bytes.cpp +++ b/torch/csrc/inductor/resize_storage_bytes.cpp @@ -7,8 +7,7 @@ #include #endif -namespace torch { -namespace inductor { +namespace torch::inductor { using namespace at; // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -63,5 +62,4 @@ TORCH_LIBRARY_IMPL(inductor, Functionalize, m) { "resize_storage_bytes_", TORCH_FN(resize_storage_bytes__functionalize)); } -} // namespace inductor -} // namespace torch +} // namespace torch::inductor From 8045787528883ee8b308b6e154698fbd5b4818b2 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Wed, 28 Aug 2024 04:16:23 +0000 Subject: [PATCH 0132/1018] [audio hash update] update the pinned audio hash (#134632) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134632 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 3973bc933bf981..d4276f969ac8ca 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -b3f6f511f2a1082bd56b13a3f6794e7fc3ba4862 +97ed7b36b7a741253d4e41e4da3c901d83294503 From 0197f419246c113b958fbd70d4e4e8c9b842f9cd Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Wed, 28 Aug 2024 04:16:40 +0000 Subject: [PATCH 0133/1018] [executorch hash update] update the pinned executorch hash (#133386) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133386 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index f6867de3907803..3c7d077a795db0 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -5e9bab8c5956249e75a0f187bf8075df97ca2555 +69472e5c43481324ad923ceb29392ab72830acee From f151cfd87f110af0c462ec4e21472a386cba8069 Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Wed, 28 Aug 2024 04:48:00 +0000 Subject: [PATCH 0134/1018] [AOTI][Tooling][4/n] Add `torch.save()` for individual intermediate tensor (#133871) Differential Revision: D61415304 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133871 Approved by: https://github.com/ColinPeppler --- build_variables.bzl | 1 + test/inductor/test_aot_inductor.py | 8 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 20 +-- torch/_inductor/codegen/debug_utils.py | 122 +++++++++++++----- torch/_inductor/codegen/simd.py | 7 - torch/_inductor/codegen/wrapper.py | 2 +- torch/_inductor/config.py | 14 +- torch/csrc/inductor/aoti_torch/c/shim.h | 8 ++ .../csrc/inductor/aoti_torch/shim_common.cpp | 43 +++++- 9 files changed, 167 insertions(+), 58 deletions(-) diff --git a/build_variables.bzl b/build_variables.bzl index 85e7031d3c9098..bf7ffda0abda63 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -473,6 +473,7 @@ inductor_core_resources = [ "torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp", "torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp", "torch/csrc/inductor/inductor_ops.cpp", + "torch/csrc/jit/serialization/pickle.cpp", ] libtorch_core_sources = sorted( diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 4e19b2e6ed74f9..20d03fc40738e7 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3324,8 +3324,8 @@ def forward(self, a): ] ) - # test the default debug printing codegen - with config.patch({"aot_inductor.debug_intermediate_value_printer": 1}): + # test default debug printing all tensor values codegen + with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): result, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) @@ -3345,11 +3345,11 @@ def forward(self, a): count, ).run(code) - # test the filtered kernel names printing codegen + # test printing selected kernel's tensor values codegen filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out" with config.patch( { - "aot_inductor.debug_intermediate_value_printer": 1, + "aot_inductor.debug_intermediate_value_printer": "2", "aot_inductor.filtered_kernel_names": filtered_kernel_name, } ): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 0ae915cd927d8d..f8baacfacd9c0c 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -12,6 +12,7 @@ import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._ops +from torch._inductor.codegen.debug_utils import IntermediateValueDebuggingLevel from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes from .. import config, ir @@ -1205,10 +1206,12 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): self.allow_stack_allocation = False wrapped_args = [] - args_to_print = None - enable_debug_printer = config.aot_inductor.debug_intermediate_value_printer - if enable_debug_printer: - args_to_print = [] + + args_to_print_or_save = None + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_level = debug_printer_manager.debug_printer_level + if debug_printer_level != IntermediateValueDebuggingLevel.OFF: + args_to_print_or_save = [] for x in args: pieces = x.split(", ") @@ -1224,14 +1227,15 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): ): # TODO: The current way to find a 'tensor' type arg is hacky also as mentioned above # Find a more reliable way to detect tensor kernel args for extern kernel calls - if enable_debug_printer: + if debug_printer_level != IntermediateValueDebuggingLevel.OFF: if piece.startswith(("buf", "arg")): - args_to_print.append(piece) + args_to_print_or_save.append(piece) piece = f"convert_arrayref_tensor_to_tensor({piece})" wrapped_args.append(piece) - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.set_printer_args(args_to_print, kernel, None, None) + debug_printer_manager.set_printer_args( + args_to_print_or_save, kernel, None, None + ) with debug_printer_manager: shim_fn = self.get_c_shim_func_name(kernel) self.writeline( diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index a44fa5d1d8b1fa..11fe548155d514 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -2,90 +2,154 @@ from __future__ import annotations import functools +from enum import Enum from typing import List, Optional from .. import config from ..virtualized import V from .common import TensorArg +from .multi_kernel import MultiKernel -class DebugPrinterManager: - DEBUG_FILTER_DEFAULT_PRINT_ALL = "default" +# AOTI debug printing related configs +class IntermediateValueDebuggingLevel(Enum): + # OFF: No intermediate tensor value debug info will be printed or saved. + OFF = "0" + # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed. + SAVE_ONLY = "1" + # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed. + PRINT_ONLY = "2" + +class DebugPrinterManager: def __init__( self, - enable_debug_printer: bool, - args_to_print: Optional[List[str]] = None, + debug_printer_level, + args_to_print_or_save: Optional[List[str]] = None, kernel_name: str = "", kernel=None, arg_signatures: Optional[List[type]] = None, ): - self.enable_debug_printer = enable_debug_printer - if args_to_print is None: - args_to_print = [] - self.args_to_print = args_to_print + self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level) + if args_to_print_or_save is None: + args_to_print_or_save = [] + self.args_to_print_or_save = args_to_print_or_save self.kernel_name = kernel_name self.arg_signatures: Optional[List[type]] = None self.kernel = kernel self.filtered_kernel_names_to_print = self.get_debug_filtered_kernel_names() def __enter__(self): - if self.enable_debug_printer: - V.graph.all_codegen_kernel_names.add(self.kernel_name) - self.codegen_intermediate_tensor_value_printer( - self.args_to_print, + self._perform_debug_print_or_save_helper( + self.args_to_print_or_save, + self.kernel_name, + before_launch=True, + arg_signatures=self.arg_signatures, + ) + + def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures): + self._perform_debug_print_or_save_helper( + args_to_print_or_save, + kernel_name, + before_launch=False, + arg_signatures=arg_signatures, + ) + + def _perform_debug_print_or_save_helper( + self, + args_to_print_or_save, + kernel_name, + before_launch, + arg_signatures: Optional[List[type]] = None, + ): + if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF: + return + if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY: + # by default save all the tensor values before launch + self.codegen_intermediate_tensor_value_save( + self.args_to_print_or_save, self.kernel_name, - before_launch=True, + before_launch, arg_signatures=self.arg_signatures, ) - - def __exit__(self, args_to_print, kernel_name, arg_signatures): - if self.enable_debug_printer: - self.codegen_intermediate_tensor_value_printer( - self.args_to_print, + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # by default print all the tensor values before launch + self.codegen_intermediate_tensor_value_print( + self.args_to_print_or_save, self.kernel_name, - before_launch=False, + before_launch, arg_signatures=self.arg_signatures, ) def set_printer_args( self, - args_to_print: List[str], + args_to_print_or_save: List[str], kernel_name: str, arg_signatures: Optional[List[type]], kernel, ): - self.args_to_print = args_to_print + # Note: MultiKernel debug printing is not supported for now + if isinstance(kernel, MultiKernel): + self.debug_printer_level = IntermediateValueDebuggingLevel.OFF + self.args_to_print_or_save = args_to_print_or_save self.kernel_name = kernel_name self.arg_signatures = arg_signatures self.kernel = kernel @functools.lru_cache # noqa: B019 def get_debug_filtered_kernel_names(self) -> List[str]: + if config.aot_inductor.filtered_kernel_names is None: + return [] return [ x.strip() for x in config.aot_inductor.filtered_kernel_names.lower().split(",") ] - def codegen_intermediate_tensor_value_printer( + def codegen_intermediate_tensor_value_save( self, - args_to_print, + args_to_save, kernel_name, before_launch=True, arg_signatures: Optional[List[type]] = None, ) -> None: - for i, arg in enumerate(args_to_print): + for i, arg in enumerate(args_to_save): if arg_signatures is not None and not isinstance( arg_signatures[i], TensorArg ): continue - if ( - len(self.filtered_kernel_names_to_print) > 0 - and self.filtered_kernel_names_to_print[0] - != self.DEBUG_FILTER_DEFAULT_PRINT_ALL - and kernel_name not in self.filtered_kernel_names_to_print + launch_prefix = "before_launch" if before_launch else "after_launch" + if V.graph.cpp_wrapper: + if config.abi_compatible: + V.graph.wrapper_code.writeline( + f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' + ) + else: + # TODO: add non-abi compatible mode debug printing info + pass + else: + # currently, not cpp wrapper codegen mode not supported. + pass + + def codegen_intermediate_tensor_value_print( + self, + args_to_print, + kernel_name, + before_launch=True, + arg_signatures: Optional[List[type]] = None, + ) -> None: + for i, arg in enumerate(args_to_print): + if arg_signatures is not None and not isinstance( + arg_signatures[i], TensorArg ): continue + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # when debug printing is enabled i.e. DEFAULT_PRINT level, check if filtered kernel name list is provided + if ( + len(self.filtered_kernel_names_to_print) > 0 + and kernel_name not in self.filtered_kernel_names_to_print + ): + continue + launch_prefix = "before_launch" if before_launch else "after_launch" if V.graph.cpp_wrapper: if config.abi_compatible: diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 9a2c7982fdbcb8..7f50ad1a98f0d1 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1394,19 +1394,12 @@ def _node_has_sort(node): self.codegen_comment(node_schedule) - # debug printing values of intermediate tensors - # Note: MultiKernel debug printing is not supported for now - enable_debug_printer = ( - config.aot_inductor.debug_intermediate_value_printer - and not isinstance(final_kernel, MultiKernel) - ) _, call_args, arg_signatures, _ = ( final_kernel.args.python_argdefs() if not isinstance(final_kernel, MultiKernel) else [None, [], None, None] ) debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.enable_debug_printer = enable_debug_printer debug_printer_manager.set_printer_args( call_args, kernel_name, arg_signatures, final_kernel ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 09924822023d56..7626c39aec20b4 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -534,7 +534,7 @@ def add_import_once(line: str) -> None: # intermediate tensor value printing utility self.debug_printer = DebugPrinterManager( - enable_debug_printer=config.aot_inductor.debug_intermediate_value_printer + debug_printer_level=config.aot_inductor.debug_intermediate_value_printer ) def write_constant(self, name: str, hashed: str) -> None: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d8af1934eea8a1..c9757f62a95dc3 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -918,15 +918,17 @@ class aot_inductor: os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1" ) - # enable debug mode for aot inductor and it will print out more information including the intermediate tensor values, etc - # for debugging purpose - debug_intermediate_value_printer = ( - os.environ.get("AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0") == "1" + # option for debug printing/saving for intermediate tensor values for aot inductor + # 0: disable debug dumping + # 1: enable saving intermediate tensor values + # 2: enable printing intermediate tensor values + debug_intermediate_value_printer = os.environ.get( + "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0" ) - # filtered nodes to be printed for debug values. If not set, it will dump all debug tensor value info by default + # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2 filtered_kernel_names = os.environ.get( - "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", "default" + "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None ) # Serialized tree spec for flattening inputs diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 5c3869c9a4f953..e6e1c6aa770890 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -573,6 +573,14 @@ AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( AtenTensorHandle self, const char* msg); +// When AOTI debug printer option is enabled, this function will be invoked to +// torch pickle save the intermediate tensor for debugging purpose. +AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( + AtenTensorHandle self, + const char* tensor_name, + const char* launch_prefix, + const char* kernel_name); + #ifdef USE_CUDA struct CUDAGuardOpaque; diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 8e9c1a4b18a3e8..77bba40c782cf1 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -10,8 +10,10 @@ #include #include #include +#include #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -45,6 +47,14 @@ #endif +#if __has_include("filesystem") +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + using namespace torch::aot_inductor; namespace { @@ -986,13 +996,40 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( }); } +AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( + AtenTensorHandle self, + const char* tensor_name, + const char* launch_prefix, + const char* kernel_name) { + at::Tensor* t = tensor_handle_to_tensor_pointer(self); +#ifndef C10_MOBILE + // Save tensor to tmp .pt file for tensors and can be torch.load'ed later + std::string cwd = fs::current_path().string(); + std::string tmp_folder = cwd + "/tmp/aoti_torch/"; + if (!fs::exists(tmp_folder)) { + std::cout << "Path does not exist, creating it..." << tmp_folder + << std::endl; + + if (!fs::create_directories(tmp_folder)) { + std::cout << "Error creating directory: " << tmp_folder << std::endl; + return; + } + } + std::string tensor_filepath_to_save = tmp_folder + launch_prefix + "_" + + kernel_name + "_" + tensor_name + "_" + t->device().str() + ".pt"; + + auto bytes = torch::jit::pickle_save(c10::IValue(*t)); + std::ofstream fout(tensor_filepath_to_save, std::ios::out | std::ios::binary); + fout.write(bytes.data(), bytes.size()); + fout.close(); +#endif // !defined(C10_MOBILE) +} + AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( AtenTensorHandle self, const char* msg) { at::Tensor* t = tensor_handle_to_tensor_pointer(self); - auto device = t->device(); - // Display message std::cout << "["; if (msg) { @@ -1014,7 +1051,7 @@ AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( std::cout << "Min value: " << t->min().item() << std::endl; std::cout << "Max value: " << t->max().item() << std::endl; } - std::cout << "Device: " << device << std::endl; + std::cout << "Device: " << t->device() << std::endl; std::cout << "Size: " << t->sizes() << std::endl; std::cout << "Stride: " << t->strides() << std::endl; std::cout << "Dtype: " << t->dtype() << std::endl; From cad7cc1fca4d1ed4065178446b1d7ba5f65b6ae0 Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Tue, 27 Aug 2024 09:29:50 +0800 Subject: [PATCH 0135/1018] fix torch.prod vectorized path for bool (#128009) Fix https://github.com/pytorch/pytorch/issues/127866. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128009 Approved by: https://github.com/jgong5, https://github.com/albanD --- aten/src/ATen/cpu/vec/vec256/vec256.h | 11 +++++++++++ aten/src/ATen/cpu/vec/vec512/vec512.h | 12 ++++++++++++ aten/src/ATen/cpu/vec/vec_base.h | 11 +++++++++++ test/test_reductions.py | 8 +++++++- 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 6f7abf193b77c2..da38b9d26d2188 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -314,6 +314,17 @@ inline Vectorized flip(const Vectorized & v) { return flip8(v); } +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m256i* self_ = reinterpret_cast(self.as_bytes()); + const __m256i* other_ = reinterpret_cast(other.as_bytes()); + __m256i out = _mm256_and_si256(*self_, *other_); + Vectorized ret; + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + #endif // (defined(CPU_CAPABILITY_AVX2) }} // namepsace at::vec::CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec512/vec512.h b/aten/src/ATen/cpu/vec/vec512/vec512.h index c7fa23b23a6074..d593d184c3190a 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512.h @@ -274,6 +274,18 @@ inline Vectorized flip(const Vectorized & v) { return flip8(v); } +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m512i* self_ = reinterpret_cast(self.as_bytes()); + const __m512i* other_ = reinterpret_cast(other.as_bytes()); + __m512i out = _mm512_and_si512(*self_, *other_); + Vectorized ret; + // We do not have a constructer that takes __m512i, so we need to memcpy + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + #endif // defined(CPU_CAPABILITY_AVX512) }}} diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index af4da53793666f..86aaba490a8635 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -947,6 +947,17 @@ inline Vectorized fmsub(const Vectorized& a, const Vectorized& b, const return a * b - c; } +template +Vectorized inline operator&&( + const Vectorized& a, + const Vectorized& b) { + Vectorized ret; + for (int i = 0; i != Vectorized::size(); i++) { + ret[i] = a[i] && b[i]; + } + return ret; +} + template std::enable_if_t> inline gather(T const* base_addr, const Vectorized>& vindex) { diff --git a/test/test_reductions.py b/test/test_reductions.py index 8e07692cb58931..2375324d94f853 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -1498,7 +1498,13 @@ def test_prod_lowp(self, device, dtype): self.assertEqual(res1, res2.to(dtype=dtype)) def test_prod_bool(self, device): - vals = [[True, True], [True, False], [False, False], []] + vals = [ + [True, True], + [True, False], + [False, False], + [], + [False] * 256, # https://github.com/pytorch/pytorch/issues/127866 + ] for val in vals: result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item() expect = np.prod(np.array(val), dtype=bool) From 7e5be6290ba823850b24edbefacb1f6c0d13badd Mon Sep 17 00:00:00 2001 From: xingyuan li <108672484+hoshibara@users.noreply.github.com> Date: Wed, 28 Aug 2024 06:17:43 +0000 Subject: [PATCH 0136/1018] [Inductor UT] Generalize inductor UT for intel GPU (#133309) [Inductor UT] Generalize Inductor test case for Intel GPU. - Reuse `test/inductor/test_decompose_mem_bound_mm.py` - Reuse `test/inductor/test_inplacing_pass.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133309 Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/etaf --- test/inductor/test_decompose_mem_bound_mm.py | 81 +++++++++++--------- test/inductor/test_inplacing_pass.py | 6 +- 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index 93798fc5aedf3e..68997635f3cfb9 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -1,7 +1,6 @@ # Owner(s): ["module: inductor"] import logging -import unittest import torch import torch._inductor @@ -13,15 +12,13 @@ instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CUDA - - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +from torch.testing._internal.triton_utils import requires_gpu class MyModule(torch.nn.Module): def __init__( - self, n_input: int, n_output: int, has_bias: bool, device="cuda" + self, n_input: int, n_output: int, has_bias: bool, device=GPU_TYPE ) -> None: super().__init__() self.linear = torch.nn.Linear(n_input, n_output, bias=has_bias) @@ -48,7 +45,7 @@ def forward(self, input1, input2): return output -@requires_cuda +@requires_gpu @torch._inductor.config.patch( post_grad_fusion_options={ "decompose_mm_pass": {}, @@ -89,12 +86,12 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): ) def test_decompose_bmm(self, b, m, n, k, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) - mat1 = torch.randn(b, m, k, device="cuda").requires_grad_(True) - mat2 = torch.randn(b, k, n, device="cuda").requires_grad_(True) + mat1 = torch.randn(b, m, k, device=GPU_TYPE).requires_grad_(True) + mat2 = torch.randn(b, k, n, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule2().to("cuda") + module = MyModule2().to(GPU_TYPE) traced = torch.compile(module) input = [mat1, mat2] ref = module(*input) @@ -102,7 +99,7 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -113,7 +110,7 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 3 if should_decompose else 0 + expected_val = 3 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -127,11 +124,11 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): @parametrize("has_bias", [True, False]) def test_decompose_linear(self, m, n, k, has_bias, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) - input = torch.randn(m, k, device="cuda").requires_grad_(True) + input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule(k, n, has_bias).to("cuda") + module = MyModule(k, n, has_bias).to(GPU_TYPE) traced = torch.compile(module) input = [input] ref = module(*input) @@ -139,7 +136,7 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -172,13 +169,13 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): def test_decompose_linear_mixed_precision( self, m, n, k, has_bias, should_decompose ): - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16): torch._logging.set_logs(inductor=logging.DEBUG) - input = torch.randn(m, k, device="cuda").requires_grad_(True) + input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule(k, n, has_bias).to("cuda") + module = MyModule(k, n, has_bias).to(GPU_TYPE) traced = torch.compile(module) input = [input] ref = module(*input) @@ -186,7 +183,7 @@ def test_decompose_linear_mixed_precision( self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -218,12 +215,12 @@ def test_decompose_linear_mixed_precision( @parametrize("has_bias", [True, False]) def test_decompose_mm(self, m, n, k, has_bias, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) - mat1 = torch.randn(m, k, device="cuda").requires_grad_(True) - mat2 = torch.randn(k, n, device="cuda").requires_grad_(True) + mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) + mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule3().to("cuda") + module = MyModule3().to(GPU_TYPE) traced = torch.compile(module) input = [mat1, mat2] ref = module(*input) @@ -231,7 +228,7 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -243,7 +240,7 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -256,14 +253,14 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): ) @parametrize("has_bias", [True, False]) def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose): - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16): torch._logging.set_logs(inductor=logging.DEBUG) - mat1 = torch.randn(m, k, device="cuda").requires_grad_(True) - mat2 = torch.randn(k, n, device="cuda").requires_grad_(True) + mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) + mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule3().to("cuda") + module = MyModule3().to(GPU_TYPE) traced = torch.compile(module) input = [mat1, mat2] ref = module(*input) @@ -271,7 +268,7 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -283,7 +280,7 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -294,11 +291,11 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) @parametrize("has_bias", [True, False]) def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) - input = torch.randn(m, k, device="cuda").requires_grad_(True) + input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule(k, n, has_bias).to("cuda") + module = MyModule(k, n, has_bias).to(GPU_TYPE) traced = torch.compile(module, dynamic=True) input = [input] ref = module(*input) @@ -306,7 +303,7 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -319,9 +316,13 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) + expected_val = 0 + if HAS_CUDA: + expected_val = 1 if has_bias else 2 + self.assertEqual( counters["inductor"]["decompose_mm"], - 1 if has_bias else 2, + expected_val, ) counters.clear() @@ -330,8 +331,8 @@ def test_realize_input(self): k = 5 n = 2 torch._logging.set_logs(inductor=logging.DEBUG) - input1 = torch.randn(m, k, device="cuda").T.contiguous() - input2 = torch.randn(k, n, device="cuda") + input1 = torch.randn(m, k, device=GPU_TYPE).T.contiguous() + input2 = torch.randn(k, n, device=GPU_TYPE) @torch.compile() def foo(x, y): @@ -339,8 +340,12 @@ def foo(x, y): out, code = run_and_get_code(foo, input1, input2) - # two kernels generated - FileCheck().check_count(".run(", 2, exactly=True).run(code[0]) + if GPU_TYPE == "xpu": + # only 1 kernel generated on the XPU stack + FileCheck().check_count(".run(", 1, exactly=True).run(code[0]) + else: + # two kernels generated + FileCheck().check_count(".run(", 2, exactly=True).run(code[0]) if __name__ == "__main__": diff --git a/test/inductor/test_inplacing_pass.py b/test/inductor/test_inplacing_pass.py index bce4fa329fe497..6164b4e975ee87 100644 --- a/test/inductor/test_inplacing_pass.py +++ b/test/inductor/test_inplacing_pass.py @@ -15,7 +15,7 @@ parametrize, subtest, ) -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.logging_utils import logs_to_string @@ -23,7 +23,7 @@ const = torch.tensor(0.0) -device = "cuda" +device = GPU_TYPE def num_reinplacing_failures(): @@ -259,5 +259,5 @@ def f(x): if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_GPU: run_tests(needs="filelock") From add9f827cca1d1248a785859495258874150638c Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 27 Aug 2024 21:06:31 -0700 Subject: [PATCH 0137/1018] [dynamo][exceptions] Use exception subclass whenever possible (#134610) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134610 Approved by: https://github.com/drisspg, https://github.com/jansel --- test/dynamo/test_exceptions.py | 23 +++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 7 ++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index d9b364443eb132..33c6e3d0f6121b 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -267,6 +267,29 @@ def forward(self, x): x = torch.ones(4) self.assertEqual(mod(x), opt_mod(x)) + def test_attribute_error_from_getattr(self): + class Mock: + def __init__(self): + self.a = 5 + + def __getattr__(self, name): + if name != "a": + raise AttributeError("missing") + return self.__dict__["a"] + + mock = Mock() + + def fn(x): + if hasattr(mock, "b"): + return torch.cos(x) + return torch.sin(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + def test_stop_iteration(self): def zip_longest(*iterables, fillvalue=None): # Get the iterators for each iterable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 0b7d68be08b9e5..7e8040b09a03e4 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1379,9 +1379,10 @@ def RAISE_VARARGS(self, inst): # 2) when user raises exception instance if isinstance(val, variables.ExceptionVariable): - if val.exc_type is StopIteration: - # StopIteration is used to find the end of iteration while tracing __next__ - raise exc.ObservedUserStopIteration(f"raised exception {val}") + if observed_exception_type := exc.observed_exception_map.get( + val.exc_type + ): + raise observed_exception_type(f"raised exception {val}") raise exc.ObservedException(f"raised exception {val}") unimplemented(f"raise {exc}") else: From eb231d8bfc363349b168bd917e905c74caa9534d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 27 Aug 2024 21:06:31 -0700 Subject: [PATCH 0138/1018] [dynamo][dicts] Support hasattr on dicts (#134590) Fixes - https://github.com/pytorch/pytorch/issues/134577 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134590 Approved by: https://github.com/Skylion007 ghstack dependencies: #134610 --- test/dynamo/test_functions.py | 16 ++++++++++++++++ torch/_dynamo/variables/dicts.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 67ab47dfac9093..f5cbfb8760c307 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1683,6 +1683,22 @@ def test_dict_sorted(x): tmp = {1: "D", 10: "B", 3: "E", 0: "F"} return x + 1, sorted(tmp), sorted(tmp, reverse=True) + def test_dict_hasattr(self): + def fn(x): + if hasattr(x, "to"): + return x.to("cpu") + if hasattr(x, "items"): + return torch.cos(x["a"]) + return x + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = dict(a=torch.randn(3)) + self.assertEqual(fn(x), opt_fn(x)) + + x = torch.randn(4) + self.assertEqual(fn(x), opt_fn(x)) + @make_test def test_list_clear(a, b): tmp = [a + 1, a + 2] diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 2d4c015bf3ac1b..f9b74bb2f202f4 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -355,6 +355,15 @@ def call_method( def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] + def call_hasattr(self, tx, name): + # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict. + # OrderedDict though requires side effects tracking because it supports arbitrary setattr. + if self.user_cls is dict: + if name in self.user_cls.__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + unimplemented(f"hasattr on {self.user_cls} is not supported") + class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: From 7b320bd4b4e2c7ab6c5fd837c54e3bfa7833176c Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 27 Aug 2024 21:06:32 -0700 Subject: [PATCH 0139/1018] [raland][dynamo][exceptions] Support raise from None (#134621) The PR was reverted because this PR traced more code and surfaced a latent bug. Resubmitting w/o any changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134621 Approved by: https://github.com/jansel ghstack dependencies: #134610, #134590 --- test/dynamo/test_exceptions.py | 33 +++++++++++++++ torch/_dynamo/symbolic_convert.py | 69 ++++++++++++++++++------------- 2 files changed, 73 insertions(+), 29 deletions(-) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 33c6e3d0f6121b..5efa344c7ea2ed 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -353,6 +353,39 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + def test_raise_from_None(self): + # Inspired from os.environ + class MyMapping: + def __init__(self, d): + self._d = d + + def __getitem__(self, key): + try: + value = self._d[key] + except KeyError: + raise KeyError(key) from None + return value + + d = MyMapping({"a": 10, "b": 20}) + + def mapping_get(obj, key, value=None): + try: + return obj.__getitem__(key) + except KeyError: + return value + + def fn(x, d, key): + x = torch.sin(x + 1) + return x, mapping_get(d, key) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = torch.rand(2, 3) + ref = fn(x, d, "m") + res = opt_fn(x, d, "m") + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 7e8040b09a03e4..e605e57b6aa49c 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1359,35 +1359,47 @@ def FOR_ITER(self, inst): self.push(ConstantVariable.create(None)) self.jump(inst) + def _raise_exception_variable(self, inst): + val = self.pop() + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + if isinstance(val, variables.BuiltinVariable): + # Create the instance of the exception type + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 + val = val.call_function(self, [], {}) # type: ignore[arg-type] + + # Save the exception in a global data structure + self.exn_vt_stack.append(val) + + # 2) when user raises exception instance + if isinstance(val, variables.ExceptionVariable): + if observed_exception_type := exc.observed_exception_map.get(val.exc_type): + raise observed_exception_type(f"raised exception {val}") + raise exc.ObservedException(f"raised exception {val}") + unimplemented(f"raise {exc}") + def RAISE_VARARGS(self, inst): if inst.arg == 0: unimplemented("re-raise") elif inst.arg == 1: - val = self.pop() - # User can raise exception in 2 ways - # 1) raise exception type - raise NotImplementedError - # 2) raise execption instance - raise NotImplemetedError("foo") - - # 1) when user raises exception type - if isinstance(val, variables.BuiltinVariable): - # Create the instance of the exception type - # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 - val = val.call_function(self, [], {}) # type: ignore[arg-type] - - # Save the exception in a global data structure - self.exn_vt_stack.append(val) - - # 2) when user raises exception instance - if isinstance(val, variables.ExceptionVariable): - if observed_exception_type := exc.observed_exception_map.get( - val.exc_type - ): - raise observed_exception_type(f"raised exception {val}") - raise exc.ObservedException(f"raised exception {val}") - unimplemented(f"raise {exc}") + self._raise_exception_variable(inst) else: + # Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we + # ignore `from None` part. + from_vt = self.pop() + if isinstance(from_vt, ConstantVariable) and from_vt.value is None: + self._raise_exception_variable(inst) unimplemented("raise ... from ...") + def RERAISE(self, inst): + if sys.version_info >= (3, 11): + # RERAISE is currently supported in a narrow case of `raise ... from None` + self._raise_exception_variable(inst) + unimplemented("RERAISE") + def exception_handler(self, raised_exception): if sys.version_info >= (3, 11): exn_tab_entry = self.current_instruction.exn_tab_entry @@ -1401,10 +1413,6 @@ def exception_handler(self, raised_exception): # 2) if 'lasti' is true, then push the offset that the exception was raised at if exn_tab_entry.lasti: - # This is untested. Any test that tests this end-to-end - # requires supporting more bytecodes. Therefore graph - # breaking for now. - unimplemented("lasti=True while exception handling") self.push( variables.ConstantVariable(self.current_instruction.offset) ) @@ -1436,9 +1444,12 @@ def exception_handler(self, raised_exception): # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 self.popn(3) if len(self.block_stack) == 0: - unimplemented( - "exception is raised when block stack " "is empty" - ) + # No handler found in this frame. Bubble the exception to the parent + # instruction translater. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise raised_exception block_stack_entry = self.block_stack.pop() if block_stack_entry.inst.opname != "SETUP_FINALLY": From f8709b4016536cc1c4f818d5925320edf45f85c8 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 27 Aug 2024 21:06:32 -0700 Subject: [PATCH 0140/1018] [dynamo] Graph break on FSDP flat_param inconsistent tensor and grad dtype (#134614) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134614 Approved by: https://github.com/awgu, https://github.com/yf225 ghstack dependencies: #134610, #134590, #134621 --- torch/_dynamo/variables/builder.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index cfc563307f89ed..69f7d4974057fc 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -15,6 +15,7 @@ import re import sys import types +import warnings import weakref from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union @@ -25,7 +26,7 @@ from torch._ops import HigherOrderOperator from torch._streambase import _EventBase, _StreamBase from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode -from torch._subclasses.meta_utils import is_sparse_any +from torch._subclasses.meta_utils import is_sparse_any, safe_grad from torch._utils_internal import justknobs_check from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.symbolic_shapes import ( @@ -220,6 +221,12 @@ DimList = List +def safe_has_grad(t): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + return hasattr(t, "grad") + + class _missing: pass @@ -1511,6 +1518,18 @@ def wrap_tensor(self, value: torch.Tensor): # SPARSE_TENSOR_GUARDS for guards to work propertly. unimplemented("torch.compile does not support sparse Tensors") + if ( + safe_has_grad(value) + and safe_grad(value) is not None + and value.dtype != safe_grad(value).dtype + ): + unimplemented( + "Inconsistent dtype between tensor and its gradient. " + "This can happen in FSDP and crashes meta tensor creation. " + "This is potentially a workaround. Fixing it correctly " + "requires some design around FSDP + torch.compile." + ) + tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=tensor_proxy, From 3e130eedf4a56ef1877e316089cf6b58d7271b4d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 27 Aug 2024 15:05:21 -0700 Subject: [PATCH 0141/1018] [PGNCCL] Improve logic to infer device for barrier (#134617) Fixes #134391, #124714 The above issues reported that `dist.barrier()` could hang in some cases. The culprit is that ProcessGroupNCCL inferred a wrong device to perform the dummy all-reduce. After the PR, the following will be the order of device selection: - 1st choice: `opts.device_ids`, if provided by user via `barrier(opts)`. - 2nd choice: bound device id, if provided to `init_process_group` via `device_id` arg. - 3rd choice: `usedDeviceIdxs_` recorded in current PG. Will have a value from previous collectives. - 4th choice: `globalRank() % localDeviceCount_`. This can only happen when `dist.barrier()` is the first call of the PG. What's new: - Added the 2nd choice. - In the 4th choice, we use `globalRank()` instead of group-local rank, because the group-local rank can be offset wrt the device id if intra-node GPUs are sharded into multiple dimensions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134617 Approved by: https://github.com/yifuwang, https://github.com/shuqiangzhang --- .../distributed/c10d/ProcessGroupNCCL.cpp | 66 +++++++++---------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 7b877de3d6de35..8f741185a2c584 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -422,22 +422,6 @@ std::future launchAsyncGilCheck() { return resultFuture; } -// Return CUDA device with ordinal given by input rank. If we aren't -// bound to a specific device, there is no strict guarantee that this -// heuristic is the correct assignment of ranks to GPUs that Python -// layers use, but in practice it tends to be. Fortunately we don't -// rely on this for correctness of any tensor operations, just for -// ancillary uses like barriers. -at::Device ProcessGroupNCCL::guessDeviceForRank() const { - TORCH_CHECK_WITH(ValueError, rank_ >= 0, "Invalid rank ", rank_); - if (getBoundDeviceId()) { - return *getBoundDeviceId(); - } else { - int16_t deviceIdx = static_cast(rank_ % localDeviceCount_); - return at::Device(at::DeviceType::CUDA, deviceIdx); - } -} - const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100; constexpr int64_t kSynchronizeBusyWaitMillis = 10; thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; @@ -4036,40 +4020,52 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { globalRankStride, // globalRankStride this->getSize()); // worldSize - std::vector devices; + // Device to use for barrier + int barDevIdx = -1; - // Use user defined GPU device ids if provided + // Select device to use for barrier + // 1st choice: Use user defined GPU device ids if provided if (!opts.device_ids.empty()) { - for (auto device : opts.device_ids) { - devices.emplace_back(at::DeviceType::CUDA, device); - } - } else if (usedDeviceIdxs_.empty()) { + // Use the first device id because PG NCCL is single-device now + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + // 2nd choice: Use the bound GPU device id if available. + // Bounded device id can be passed to `init_process_group`. + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + // 3rd choice: infer the device id from the used device ids. + barDevIdx = *usedDeviceIdxs_.begin(); + } else { // This means there is not yet a NCCL collective being called // Here we have to use the best guesses and will use a single GPU to call // allreduce to achieve barrier. // In case the multiple processes fall into the same node, we use rank to // ensure that each process is on a different GPU - auto numGPUs = at::cuda::getNumGPUs(); - int16_t deviceIdx = static_cast(rank_ % numGPUs); - LOG(INFO) + // Note: it is better to use global rank because the group-local rank can be + // offset wrt the device id if intra-node GPUs are sharded into multiple + // dimensions. + barDevIdx = static_cast(globalRank() % localDeviceCount_); + LOG(WARNING) << logPrefix() << c10::str( " using GPU ", - deviceIdx, + barDevIdx, " to perform barrier as devices used by this process are currently unknown. ", "This can potentially cause a hang if this rank to GPU mapping is incorrect.", - "Specify device_ids in barrier() to force use of a particular device."); - devices.emplace_back(guessDeviceForRank()); - } else { - for (auto usedDeviceIdx : usedDeviceIdxs_) { - devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); - } + "Specify device_ids in barrier() to force use of a particular device,", + "or call init_process_group() with a device_id."); } - // Use one device only - auto device = devices.back(); + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx); + + // Create a dummy tensor on the device at::Tensor barrierTensor = - at::empty({1}, at::TensorOptions().device(device).dtype(at::kFloat)); + at::empty({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + // All reduce to achieve the barrier auto work = allreduce_impl(barrierTensor); From 28a88880957102dee763bc0b7ba6e5ee17a6f1ff Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 27 Aug 2024 21:46:38 -0700 Subject: [PATCH 0142/1018] [Dynamo][autograd.Function] Support mark_non_differentiable (#134087) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134087 Approved by: https://github.com/zou3519 --- test/dynamo/test_autograd_function.py | 129 +++++++++++++++++--- test/test_autograd.py | 4 + torch/_dynamo/variables/higher_order_ops.py | 19 ++- torch/_dynamo/variables/misc.py | 9 +- torch/_functorch/autograd_function.py | 10 ++ 5 files changed, 150 insertions(+), 21 deletions(-) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 4785a7f96cd041..e9b913ed9559db 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -542,7 +542,7 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: " function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 - autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None + autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None return (autograd_function_apply,) class fwd_body_0(torch.nn.Module): @@ -699,20 +699,6 @@ def forward(self, x): # In the future, we should make the Dynamo test suite actually # run on test_autograd.py (it's disabled right now) and delete these. def test_smoke_from_test_autograd(self): - class Func(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - out0 = x.clone() - out1 = x.clone() - ctx.mark_non_differentiable(out1) - ctx._materialize_non_diff_grads = False - return out0, out1 - - @staticmethod - def backward(ctx, g0, g1): - assert g1 is None - return g0 - def mult1(x): return x.prod(dim=-1).prod(dim=-1) @@ -838,10 +824,6 @@ def backward(ctx, grad): return grad, None def test(): - a = torch.tensor(1.0, requires_grad=True) - out = Func.apply(a)[0] - out.backward() - x = torch.ones(2, 4, 4).requires_grad_() mult2(x) @@ -1076,6 +1058,115 @@ def foo(x): foo(torch.randn(2, requires_grad=True)) self.assertEqual(cnts.frame_count, 1) + def test_mark_non_differentiable(self): + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + from torch.autograd import Function + + class MyFunction(Function): + @staticmethod + def forward(ctx, x, y): + out1 = x.sin() + out2 = y * 2 + ctx.mark_non_differentiable(out2) + return out1, out2 + + @staticmethod + def backward(ctx, grad1, grad2): + return grad1.cos(), grad2 * 0.0 + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y): + return MyFunction.apply(x, y) + + x = torch.tensor(10.0, requires_grad=True) + y = torch.tensor(20.0, requires_grad=True) + ref1, ref2 = MyFunction.apply(x, y) + res1, res2 = fn(x, y) + self.assertEqual(ref1, res1) + self.assertEqual(ref2, res2) + # Ensure out1 requires gradients, out2 does not. + self.assertTrue(ref1.requires_grad) + self.assertTrue(res1.requires_grad) + self.assertFalse(ref2.requires_grad) + self.assertFalse(res2.requires_grad) + res1.sum().backward() + + # check Dynamo captured graph is correct! + actual_graph = torch._dynamo.testing.normalize_gm( + cnt.graphs[0].print_readable(print_output=False) + ) + self.assertExpectedInline( + actual_graph, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[]", L_y_: "f32[]"): + l_x_ = L_x_ + l_y_ = L_y_ + + function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None + fwd_body_0 = self.fwd_body_0 + bwd_body_0 = self.bwd_body_0 + autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None + getitem: "f32[]" = autograd_function_apply[0] + getitem_1: "f32[]" = autograd_function_apply[1]; autograd_function_apply = None + return (getitem, getitem_1) + + class fwd_body_0(torch.nn.Module): + def forward(self, ctx, x: "f32[]", y: "f32[]"): + out1: "f32[]" = x.sin(); x = None + + out2: "f32[]" = y * 2; y = None + return ((out1, out2), []) + + class bwd_body_0(torch.nn.Module): + def forward(self, ctx, grad1: "f32[]", grad2: "f32[]"): + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None + + cos: "f32[]" = grad1.cos(); grad1 = None + mul: "f32[]" = grad2 * 0.0; grad2 = None + + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None + return (cos, mul) +""", + ) + + def test_mark_multi_output_non_differentiable(self): + from torch.autograd import Function + + class MyFunction(Function): + @staticmethod + def forward(ctx, x, y, z): + out1 = x.sin() + out2 = y * 2 + out3 = z + 3 + ctx.mark_non_differentiable(out2, out3) + return out1, out2, out3 + + @staticmethod + def backward(ctx, grad1, grad2, grad3): + return grad1.cos(), grad2, grad3 + + @torch.compile(backend="aot_eager", fullgraph=True) + def fn(x, y, z): + return MyFunction.apply(x, y, z) + + x = torch.tensor(10.0, requires_grad=True) + y = torch.tensor(20.0, requires_grad=True) + z = torch.tensor(30.0, requires_grad=True) + ref1, ref2, ref3 = MyFunction.apply(x, y, z) + res1, res2, res3 = fn(x, y, z) + self.assertEqual(ref1, res1) + self.assertEqual(ref2, res2) + self.assertEqual(ref3, res3) + # Ensure out1 requires gradients, out2 does not. + self.assertTrue(ref1.requires_grad) + self.assertTrue(res1.requires_grad) + self.assertFalse(ref2.requires_grad) + self.assertFalse(res2.requires_grad) + self.assertFalse(ref3.requires_grad) + self.assertFalse(res3.requires_grad) + res1.sum().backward() + def test_default_values(self): from torch.autograd import Function diff --git a/test/test_autograd.py b/test/test_autograd.py index a3b5badd69ab76..6c3ac3723042e2 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -367,6 +367,7 @@ def backward(ctx, grad): x = torch.ones(1, requires_grad=True) torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward() + @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") def test_set_materialize_non_diff_grads(self): class Func(torch.autograd.Function): @staticmethod @@ -3307,6 +3308,7 @@ def backward(ctx, grad_output): y = x.masked_fill(mask, 0) y.sum().backward() + @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") def test_mark_non_differentiable_mixed(self): class MyFunction(Function): @staticmethod @@ -4097,6 +4099,7 @@ def assert_strict_equal(var1, var2): assert_strict_equal(xc, x) assert_strict_equal(yc, y) + @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") def test_dep_nograd(self): class F1(Function): @staticmethod @@ -8785,6 +8788,7 @@ def vjp(ctx, grad_out): gradcheck(Func.apply, (a,), check_forward_ad=True) + @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") def test_custom_function_forward_mode_non_differentiable(self): # returns differentiable type, marked non-differentiable class Func(torch.autograd.Function): diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 851c2a05af1ef7..37de70a2658f2a 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -2056,6 +2056,20 @@ def is_strict_for(v: VariableTracker): # graph.output will append the output at the very end # This might be a behavior difference + # If users call ctx.mark_non_differentiable, we should capture these output tensors who + # are marked as non-differentiable and pass them to ApplyTemplate + # at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction. + non_differentiable_idx = [] + if ctx.non_differentiable is not None: + non_differentiable_set = set(ctx.non_differentiable) + assert isinstance(fwd_out, variables.BaseListVariable) + for i, x in enumerate(fwd_out.items): + if ( + isinstance(x, variables.TensorVariable) + and x.as_proxy() in non_differentiable_set + ): + non_differentiable_idx.append(i) + # Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd) for node in fwd_graph.find_nodes(op="output"): fwd_graph.erase_node(node) @@ -2175,7 +2189,10 @@ def is_strict_for(v: VariableTracker): "call_function", autograd_function_apply, args=p_args, - kwargs={"args_tensor_mask": args_tensor_mask}, + kwargs={ + "args_tensor_mask": args_tensor_mask, + "non_differentiable_idx": non_differentiable_idx, + }, ), example_value=example_value, ) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 5d81b1730ecf5c..f2264667043c46 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -791,6 +791,7 @@ def __init__( proxy=None, saved_tensors=None, needs_input_grad=None, + non_differentiable=None, **kwargs, ) -> None: super().__init__(value=value, value_type=value_type, **kwargs) @@ -798,6 +799,7 @@ def __init__( self.proxy = proxy self.saved_tensors = saved_tensors self.needs_input_grad = needs_input_grad + self.non_differentiable = non_differentiable @staticmethod def create(tx: "InstructionTranslator", args=None, kwargs=None): @@ -840,6 +842,11 @@ def call_method( ) -> "VariableTracker": if name == "__setattr__": return super().call_method(tx, name, args, kwargs) + elif name == "mark_non_differentiable": + assert len(kwargs) == 0 + self.non_differentiable = proxy_args_kwargs(args, {})[0] + return variables.ConstantVariable.create(None) + if name != "save_for_backward": unimplemented(f"autograd.Function context method: {name}") if self.saved_tensors is None: @@ -859,7 +866,7 @@ def call_method( return variables.ConstantVariable.create(None) def var_getattr(self, tx: "InstructionTranslator", name): - if name == "save_for_backward": + if name in ["save_for_backward", "mark_non_differentiable"]: return LambdaVariable( lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) ) diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 270c1895f6fd54..cb501e2c924213 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -719,6 +719,7 @@ def __init__(self) -> None: def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs): saved_values = None args_tensor_mask = fwd_kwargs["args_tensor_mask"] + non_differentiable_idx = fwd_kwargs["non_differentiable_idx"] length_of_tensor_args = sum(args_tensor_mask) # Filter out the original tensor args from fwd_args, # lifted freevars should not be args of ApplyTemplate.apply @@ -730,6 +731,15 @@ class ApplyTemplate(torch.autograd.Function): def forward(ctx, *args): nonlocal saved_values output, saved_values = fwd(None, *fwd_args) + + # If users call ctx.mark_non_differentiable() in the original fwd function. + if len(non_differentiable_idx) > 0: + non_differentiable_output = [] + for i, x in enumerate(output): + if i in non_differentiable_idx: + non_differentiable_output.append(x) + ctx.mark_non_differentiable(*non_differentiable_output) + return output @staticmethod From b79d456fb8d363a8e587e535de035d90cf91cc4f Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 27 Aug 2024 22:15:56 -0700 Subject: [PATCH 0143/1018] [2nd try][Traceable FSDP2] Allow tracing through FSDP2 impl in trace_rules.py (#134539) The previous PR https://github.com/pytorch/pytorch/pull/133532 caused stuck compilation issue on internal models. In this 2nd attempt PR, we gate the trace_rules.py changes with `if not torch._dynamo.config.skip_fsdp_hooks:`, so that they don't take effect for current graph-break FSDP2 (which relies on the default config value `skip_fsdp_hooks=True`), and will only take effect when we are using Traceable FSDP2 (in which case the user needs to proactively set `skip_fsdp_hooks=False`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/134539 Approved by: https://github.com/ckluk2, https://github.com/yanboliang --- torch/_dynamo/trace_rules.py | 4 ++++ torch/distributed/_composable/fsdp/_fsdp_state.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 87db00ad1f9bab..966a54e672ee39 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3213,6 +3213,8 @@ def _module_dir(m: types.ModuleType): # the forward_hook won't be ignored. "torch.distributed._composable.replicate", } + if not torch._dynamo.config.skip_fsdp_hooks: + LEGACY_MOD_INLINELIST.add("torch.distributed._composable.fsdp") # Force inline functions under these modules, even they are in *_SKIPLIST. @@ -3262,6 +3264,8 @@ def _module_dir(m: types.ModuleType): if torch.distributed.is_available(): MOD_INLINELIST.add("torch.distributed") + if not torch._dynamo.config.skip_fsdp_hooks: + MOD_INLINELIST.add("torch.distributed._composable.fsdp") @functools.lru_cache(None) diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index 4b949c64d9f7f8..ceb480fd23939f 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -351,12 +351,14 @@ def _register_group_forward_hooks( """ modules_set = set(modules) + @disable_if_config_true @functools.wraps(pre_hook) def wrapped_pre_hook(*args: Any, **kwargs: Any): if len(modules_to_run) == 0: # first to run modules_to_run.update(modules_set) return pre_hook(*args, **kwargs) + @disable_if_config_true def get_wrapped_post_hook(module: nn.Module): @functools.wraps(post_hook) def wrapped_post_hook(*args: Any, **kwargs: Any): From d2c60ba3e9658413186d69cfb529242b2df1ddea Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 27 Aug 2024 22:22:20 -0700 Subject: [PATCH 0144/1018] [Dynamo] Support inspect.signature.Parameter getattr (#134636) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134636 Approved by: https://github.com/Chillee, https://github.com/anijain2305 --- test/dynamo/test_misc.py | 17 +++++++++++ test/inductor/test_flex_attention.py | 10 +++++++ torch/_dynamo/variables/misc.py | 44 ++++++++++++++++++++++++---- torch/nn/attention/flex_attention.py | 1 - 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 2bb11b7109799f..da258c94308837 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -11115,6 +11115,23 @@ def gn(x): self.assertEqual(bound0, bound1) + def test_inspect_signature_parameters(self): + import inspect + + def fn(x, gn): + d = inspect.signature(gn).parameters + if d["a"].default is inspect.Parameter.empty: + return torch.sin(x + 1) + else: + return torch.cos(x + 1) + + def gn(a: torch.Tensor, b: int) -> torch.Tensor: + return a + b + + x = torch.randn(2, 3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + self.assertEqual(fn(x, gn), opt_fn(x, gn)) + def test_grad_none(self): def fn(x, y): x.grad = torch.abs(y) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a5ee0525a9818a..83ec40ac7036a2 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1858,6 +1858,16 @@ def causal_mask(b, h, q, kv): assert block_mask.q_indices.is_cuda assert block_mask.q_num_blocks.is_cuda + @supported_platform + def test_compiling_create_block_mask(self): + def mask_mod(b, h, q, kv): + return q >= kv + + block_mask = create_block_mask(mask_mod, 1, 1, 512, 512, _compile=True) + self.assertIsInstance(block_mask, BlockMask) + self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((1, 1, 4))) + self.assertEqual(block_mask.kv_indices.shape, torch.Size((1, 1, 4, 4))) + @supported_platform def test_block_mask_viz(self): def causal_mask(b, h, q, kv): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index f2264667043c46..2979bd58e312fe 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -35,7 +35,12 @@ set_example_value, ) from .base import VariableTracker -from .functions import NestedUserFunctionVariable, UserFunctionVariable, wrap_bound_arg +from .functions import ( + NestedUserFunctionVariable, + UserFunctionVariable, + UserMethodVariable, + wrap_bound_arg, +) from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable @@ -366,6 +371,7 @@ class InspectSignatureVariable(VariableTracker): _nonvar_fields = { "signature", + "parameters", *VariableTracker._nonvar_fields, } @@ -381,18 +387,27 @@ def __init__(self, inspected: VariableTracker, **kwargs) -> None: super().__init__(**kwargs) self.inspected = inspected - if isinstance(self.inspected, UserFunctionVariable): + if isinstance(self.inspected, UserMethodVariable): + self.fn = self.inspected.get_function() + self.signature = inspect.signature(self.fn) + self.parameters = list(self.signature.parameters.items())[1:] + elif isinstance(self.inspected, UserFunctionVariable): self.fn = self.inspected.get_function() + self.signature = inspect.signature(self.fn) + self.parameters = list(self.signature.parameters.items()) else: self.fn = self.inspected.as_python_constant() - self.signature = inspect.signature(self.fn) + self.signature = inspect.signature(self.fn) + self.parameters = list(self.signature.parameters.items()) def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if name == "parameters": return variables.ConstDictVariable( { - variables.ConstantVariable.create(name): InspectParameterVariable() - for name in self.inspected.inspect_parameter_names() + variables.ConstantVariable.create( + param[0] + ): InspectParameterVariable(param[1]) + for param in self.parameters }, user_cls=dict, ) @@ -448,7 +463,24 @@ def reconstruct(self, codegen): class InspectParameterVariable(VariableTracker): - """This is not implemented, if used will graph break.""" + """represents inspect.Parameter(...)""" + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from .builder import SourcelessBuilder, VariableBuilder + + try: + attr_value = getattr(self.value, name) + if self.source: + attr_source = AttrSource(self.source, name) + return VariableBuilder(tx, attr_source)(attr_value) + else: + return SourcelessBuilder.create(tx, attr_value) + except AttributeError: + unimplemented(f"getattr({self.value}, {name})") class InspectBoundArgumentsVariable(VariableTracker): diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 556d85179cb170..4f99e3d903caca 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -51,7 +51,6 @@ class _ModificationType(Enum): UNKNOWN = 3 -@torch._dynamo.assume_constant_result def _get_mod_type(fn: Callable) -> _ModificationType: """Get the type of modification function. This function inspects the number of positional arguments of the function to determine From 55c14d2f82a77200f0f38d223c2cbb444d7d60f2 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 27 Aug 2024 12:04:34 -0700 Subject: [PATCH 0145/1018] [aotd] Make effects op registry WeakKeyDictionary (#134470) Op is used as a Dictionary Key, while op can be deregistered as a result this Key will be holding this op from deallocation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134470 Approved by: https://github.com/zou3519 --- torch/_higher_order_ops/effects.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index e6ba7f65cf182f..bb3d93bc69eb6b 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -2,6 +2,7 @@ # mypy: allow-untyped-defs from enum import Enum from typing import Any, Dict, Optional, Tuple, Union +from weakref import WeakKeyDictionary import torch import torch.utils._pytree as pytree @@ -23,11 +24,12 @@ class _EffectType(Enum): OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload] -# TODO(ivankobzarev): Make SIDE_EFFECTS dictionary WeakKeyDictionary as operator can go out of scope -SIDE_EFFECTS: Dict[OpType, _EffectType] = { - torch.ops.aten._print.default: _EffectType.ORDERED, - call_torchbind: _EffectType.ORDERED, -} +SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary( + { + torch.ops.aten._print.default: _EffectType.ORDERED, + call_torchbind: _EffectType.ORDERED, + } +) def _register_effectful_op(op: OpType, effect: _EffectType): From 5ac92d6fa6323cbb9eebb43a3c424d42ff5f8c88 Mon Sep 17 00:00:00 2001 From: Spencer Gibson Date: Wed, 28 Aug 2024 13:35:41 +0000 Subject: [PATCH 0146/1018] Bucketize fix to include number and tensor inputs (#133652) Fixes #132222 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133652 Approved by: https://github.com/ezyang --- torch/_refs/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 160a3da6a48909..2da7ff0669771c 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -5917,7 +5917,7 @@ def triu_indices( @register_decomposition(aten.bucketize) @out_wrapper(exact_dtype=True) def bucketize( - a: TensorLikeType, + a: TensorOrNumberLikeType, boundaries: TensorLikeType, *, out_int32: bool = False, @@ -5928,6 +5928,7 @@ def bucketize( lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", ) + a = a if isinstance(a, torch.Tensor) else torch.tensor(a) out_dtype = torch.int32 if out_int32 else torch.int64 n_boundaries = boundaries.shape[-1] if n_boundaries == 0: From 78879ec292a9d69feb238b919e59c35f96b03752 Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Wed, 28 Aug 2024 14:35:40 +0000 Subject: [PATCH 0147/1018] hang dim hint constants off Dim (#134484) Summary: Recently https://github.com/pytorch/pytorch/pull/133620 added support for automatic dynamic shapes, where a new enum, `DIM`, was introduced to provide hints like `AUTO` and `STATIC`. This PR is a nominal change where we expose the hints via the existing public `Dim` API, and remove `DIM` from the public API. The main motivation is to avoid having users need to import too many things. Test Plan: existing Differential Revision: D61807361 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134484 Approved by: https://github.com/angelayi --- docs/source/export.rst | 1 - test/export/test_export.py | 32 +++++----------- torch/_export/non_strict_utils.py | 4 +- torch/export/dynamic_shapes.py | 62 +++++++++++++++++-------------- 4 files changed, 46 insertions(+), 53 deletions(-) diff --git a/docs/source/export.rst b/docs/source/export.rst index deb84548a19cdc..3908beeb9b345e 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -676,7 +676,6 @@ API Reference .. autofunction:: save .. autofunction:: load .. autofunction:: register_dataclass -.. autoclass:: torch.export.dynamic_shapes.DIM .. autofunction:: torch.export.dynamic_shapes.Dim .. autofunction:: dims .. autoclass:: torch.export.dynamic_shapes.ShapesCollection diff --git a/test/export/test_export.py b/test/export/test_export.py index 2393443b5f0b0f..86364e061e2888 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1773,8 +1773,6 @@ def forward(self, x, y, y1, z): ) def test_static_dim_constraints(self): - from torch.export.dynamic_shapes import DIM - class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1803,7 +1801,7 @@ def forward(self, x, y, z): ((dx, None), (dy, 4), (dz, 3)), ((None, 6), (5, None), (None, None)), ((4, 6), {0: None, 1: 4}, {0: None, 1: 3}), - (None, None, (DIM.STATIC, DIM.STATIC)), + (None, None, (Dim.STATIC, Dim.STATIC)), ]: ep = export(foo, inputs, dynamic_shapes=dynamic_shapes) self.assertEqual(foo(*inputs), ep.module()(*inputs)) @@ -1950,9 +1948,7 @@ def forward(self, inp: Inp): self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") def test_mismatched_dynamic_shapes(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC class M(torch.nn.Module): def forward(self, x): @@ -1984,7 +1980,7 @@ def forward(self, x): + re.escape( "specified at `dynamic_shapes[0]['k']['k'][0]` " "(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - " where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)" + " where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)" ), ): export(M(), inputs, dynamic_shapes=dynamic_shapes) @@ -2061,7 +2057,7 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, re.escape( - "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " "and can easily lead to constraint violation errors or obscure errors in torch.export." ), ): @@ -2465,9 +2461,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_mark_and_auto_dynamic(self): # for this use case, mark_dynamic() and AUTO should have same effect. # check that same symbol gets allocated to both dims without raising constraint violation. - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC class Foo(torch.nn.Module): def forward(self, x, y): @@ -2495,9 +2489,7 @@ def forward(self, x, y): def test_dont_duck_size_for_auto_dynamic(self): # for this use case, mark_dynamic() and AUTO should have same effect. # check that same symbol gets allocated to both dims without raising constraint violation. - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC class Foo(torch.nn.Module): def forward(self, x, y): @@ -6876,9 +6868,7 @@ def test_automatic_dynamic_shapes_simple_equality(self): # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism # leads to replacement symbols being set for equalities, and inferred relationships being checked # with runtime asserts. Check that we specialize to static values when the program says so. - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 1: direct equality between symbols class SimpleEquality(torch.nn.Module): @@ -6943,9 +6933,7 @@ def forward(self, x, y, z): ) def test_automatic_dynamic_shapes_constant_relation(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 2: related by constant: s0 + 4 = s1 class OffBy4(torch.nn.Module): @@ -6988,9 +6976,7 @@ def forward(self, x, y): ) def test_automatic_dynamic_shapes_linear_relation(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 3: linear relation class LinearRel(torch.nn.Module): diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 8b5ff29b6c9e02..ddb51064f9ad78 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -24,10 +24,10 @@ from torch.export.dynamic_shapes import ( _check_dynamic_shapes, _combine_args, + _DimHint, _process_dynamic_shapes, _transform_shapes_for_default_dynamic, _tree_map_with_path, - DIM, ) from torch.export.graph_signature import CustomObjArgument from torch.fx.experimental import _config as config @@ -351,7 +351,7 @@ def make_constraints( # we want the symbol, not its replacement, which could be an expression. Maybe # there's a better way to do this, e.g., by (re)computing value ranges for expressions? dim = shape_spec[i] if shape_spec else None - if dim is None or isinstance(dim, DIM): + if dim is None or isinstance(dim, _DimHint): range_constraints[d.node.expr] = shape_env.var_to_range[ d.node._expr ] diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 48438ee8cb2a55..53266b59e04a33 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -30,20 +30,21 @@ __all__ = [ "Constraint", - "DIM", "Dim", "dims", "refine_dynamic_shapes_from_suggested_fixes", ] -class DIM(Enum): +class _DimHint(Enum): """ - Enum for automatic/static dynamic shapes. + Enum for dynamic shape hints. + - AUTO means automatic inference of shape (static or dynamic). + - STATIC means static shape (always specialized). """ - STATIC = auto() AUTO = auto() + STATIC = auto() class _Dim(type): @@ -214,6 +215,7 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): Returns: A type that can be used in dynamic shape specifications for tensors. """ + from torch.utils._sympy.numbers import int_oo _min = 0 if min is None else min @@ -227,6 +229,10 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): return dim +Dim.AUTO = _DimHint.AUTO # hint for automatic inference of shape (static or dynamic) +Dim.STATIC = _DimHint.STATIC # hint for static shape (always specialized) + + def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): """ Util to create multiple :func:`Dim` types. @@ -668,24 +674,24 @@ def check_symbols(path, tensor, shape): for i, dim in shape.items(): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, DIM)) or dim is None): + elif not (isinstance(dim, (int, _DimHint)) or dim is None): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " - f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", + f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", case_name="dynamic_shapes_validation", ) elif isinstance(shape, (tuple, list)): for i, dim in enumerate(shape): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, DIM)) or dim is None): + elif not (isinstance(dim, (int, _DimHint)) or dim is None): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension #{i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " - f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", + f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", case_name="dynamic_shapes_validation", ) elif shape is not None: @@ -693,7 +699,7 @@ def check_symbols(path, tensor, shape): UserErrorType.INVALID_INPUT, f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - f" where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)", + f" where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)", case_name="dynamic_shapes_validation", ) @@ -740,18 +746,18 @@ def check_shape(path, t, dynamic_shape): _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") - # raise user warning if both DIM.AUTO & Dims are specified in dynamic_shapes + # raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes) if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any( - s == DIM.AUTO for s in flatter_dynamic_shapes + s == _DimHint.AUTO for s in flatter_dynamic_shapes ): raise UserError( UserErrorType.INVALID_INPUT, - "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " "and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims " - "expect all equal or related dimensions to be specified, and does not yet compose well with `DIM.AUTO`. " - "We suggest using `DIM.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " + "expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. " + "We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " "torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` " "if you want to assert on the exact specification of your program's dynamic shapes behavior.", case_name="dynamic_shapes_validation", @@ -773,8 +779,8 @@ def _transform_shapes_for_default_dynamic( for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.) For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are: - - DIM.AUTO: dynamic, allocated a symbol - - None/unspecified/DIM.STATIC: static + - Dim.AUTO: dynamic, allocated a symbol + - None/unspecified/Dim.STATIC: static - Dim/DerivedDims: also a strict assertion To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes @@ -784,8 +790,8 @@ def _transform_shapes_for_default_dynamic( An example conversion might look like, for a 3-d input tensor: input spec: { - 0: DIM.AUTO, - 1: None, # or DIM.STATIC + 0: Dim.AUTO, + 1: None, # or Dim.STATIC 2: Dim("dx"), } output spec: { @@ -832,10 +838,10 @@ def _marked_dynamic(tensor, i): out = {} for i, val in enumerate(tensor.shape): dim = shape.get(i, None) - if _marked_dynamic(tensor, i) or dim == DIM.AUTO: + if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: # don't have to specify anything if dynamic # None also works, since assume_static_by_default=False - if dim == DIM.AUTO: + if dim == _DimHint.AUTO: torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing continue elif isinstance(dim, _Dim): @@ -846,14 +852,14 @@ def _marked_dynamic(tensor, i): out[i] = dim else: # make explicitly static - assert dim is None or dim == DIM.STATIC + assert dim is None or dim == _DimHint.STATIC out[i] = val elif isinstance(shape, (tuple, list)): out = [] for i, val in enumerate(tensor.shape): dim = shape[i] - if _marked_dynamic(tensor, i) or dim == DIM.AUTO: - if dim == DIM.AUTO: + if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: + if dim == _DimHint.AUTO: torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing out.append(None) elif isinstance(dim, _Dim): @@ -861,7 +867,7 @@ def _marked_dynamic(tensor, i): elif isinstance(dim, int): out.append(dim) else: - assert dim is None or dim == DIM.STATIC + assert dim is None or dim == _DimHint.STATIC out.append(val) out = type(shape)(out) # type: ignore[assignment] else: @@ -1040,7 +1046,7 @@ def _get_dim_name_mapping( dynamic_shapes, is_leaf=lambda x: isinstance(x, _Dim), )[0]: - if isinstance(dim, (int, DIM)) or dim is None: + if isinstance(dim, (int, _DimHint)) or dim is None: continue name_to_dim[dim.__name__] = dim if isinstance(dim, _DerivedDim): @@ -1122,9 +1128,11 @@ def refine_dynamic_shapes_from_suggested_fixes( name, expr = fix.split(" = ") expr = sympy.sympify(expr) if isinstance(expr, sympy.Number): - shape_fixes[name] = int(expr) # static, integer + # static, integer + shape_fixes[name] = int(expr) # type: ignore[assignment] else: - shape_fixes[name] = expr # relation or derived dim + # relation or derived dim + shape_fixes[name] = expr name_to_dim = _get_dim_name_mapping(dynamic_shapes) From c0aa2f595e3f4e96f99118fafa99587bbe5fdcb0 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 28 Aug 2024 04:46:50 -0700 Subject: [PATCH 0148/1018] [Inductor][Refactor] Rename CPU benchmark test configs (#134639) Summary: benchmarks/dynamo/ci_expected_accuracy/update_expected.py expects a benchmark run config is named as {config}_{benchmark}, and CPU tests should follow the same naming convention. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134639 Approved by: https://github.com/huydhn --- .ci/pytorch/test.sh | 4 +- .github/workflows/inductor.yml | 52 +++++++++---------- ...tor_amp_freezing_torchbench_inference.csv} | 0 ...ductor_freezing_huggingface_inference.csv} | 0 ..._aot_inductor_freezing_timm_inference.csv} | 0 ...nductor_freezing_torchbench_inference.csv} | 0 ...or_amp_freezing_huggingface_inference.csv} | 0 ..._inductor_amp_freezing_timm_inference.csv} | 0 ...tor_amp_freezing_torchbench_inference.csv} | 0 ...ductor_freezing_huggingface_inference.csv} | 0 ... cpu_inductor_freezing_timm_inference.csv} | 0 ...nductor_freezing_torchbench_inference.csv} | 0 ...tor_amp_freezing_torchbench_inference.csv} | 0 ...nductor_freezing_torchbench_inference.csv} | 0 14 files changed, 28 insertions(+), 28 deletions(-) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_aot_inductor_torchbench_amp_freezing_inference.csv => cpu_aot_inductor_amp_freezing_torchbench_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_aot_inductor_huggingface_freezing_inference.csv => cpu_aot_inductor_freezing_huggingface_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_aot_inductor_timm_freezing_inference.csv => cpu_aot_inductor_freezing_timm_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_aot_inductor_torchbench_freezing_inference.csv => cpu_aot_inductor_freezing_torchbench_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_inductor_huggingface_amp_freezing_inference.csv => cpu_inductor_amp_freezing_huggingface_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_inductor_timm_amp_freezing_inference.csv => cpu_inductor_amp_freezing_timm_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_inductor_torchbench_amp_freezing_inference.csv => cpu_inductor_amp_freezing_torchbench_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_inductor_huggingface_freezing_inference.csv => cpu_inductor_freezing_huggingface_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_inductor_timm_freezing_inference.csv => cpu_inductor_freezing_timm_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{cpu_inductor_torchbench_freezing_inference.csv => cpu_inductor_freezing_torchbench_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{dynamic_cpu_aot_inductor_torchbench_amp_freezing_inference.csv => dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv} (100%) rename benchmarks/dynamo/ci_expected_accuracy/{dynamic_cpu_aot_inductor_torchbench_freezing_inference.csv => dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv} (100%) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index daf71c306849ed..dd2ed255203da6 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -575,10 +575,10 @@ test_single_dynamo_benchmark() { fi if [[ "${TEST_CONFIG}" == *_avx2* ]]; then - TEST_CONFIG=${TEST_CONFIG::-5} + TEST_CONFIG=${TEST_CONFIG//_avx2/} fi if [[ "${TEST_CONFIG}" == *_avx512* ]]; then - TEST_CONFIG=${TEST_CONFIG::-7} + TEST_CONFIG=${TEST_CONFIG//_avx512/} fi python "benchmarks/dynamo/$suite.py" \ --ci --accuracy --timing --explain \ diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index d4c129529675e6..083821e94ad311 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -166,40 +166,40 @@ jobs: { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_huggingface_amp_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "cpu_inductor_timm_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "cpu_inductor_timm_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "cpu_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "cpu_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_amp_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_huggingface_freezing_avx2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_torchbench_freezing_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_torchbench_freezing_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_timm_freezing_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_timm_freezing_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_huggingface_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_huggingface_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_timm_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_timm_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_timm_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_timm_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_timm_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_timm_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_timm_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_timm_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv From fae557180d1318519a09d6167e8813c52f35472d Mon Sep 17 00:00:00 2001 From: "Mao, Yunfei" Date: Wed, 28 Aug 2024 15:06:35 +0000 Subject: [PATCH 0149/1018] [Intel GPU] Remove special dispatch logic for xpu in adaptive_avg_pooling (#132217) We now align the dispatch logic for XPU with CUDA in the adaptive average pooling operation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132217 Approved by: https://github.com/EikanWang, https://github.com/atalman, https://github.com/albanD, https://github.com/malfet --- aten/src/ATen/native/AdaptiveAveragePooling.cpp | 2 +- aten/src/ATen/native/AdaptiveAveragePooling3d.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp index d252a464cd4ac1..8b1182982002fa 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -117,7 +117,7 @@ namespace { return at::mkldnn_adaptive_avg_pool2d(input, C10_AS_INTARRAYREF_SLOW(output_size)); } - if (!input.is_quantized() && output_size[0] == 1 && output_size[1] == 1 && !input.is_xpu()) { + if (!input.is_quantized() && output_size[0] == 1 && output_size[1] == 1) { // in this case, adaptive pooling is just computing mean over hw // dimensions, which can be done more efficiently #if defined(C10_MOBILE) && defined(USE_XNNPACK) diff --git a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp index e5327185844fe4..4897864a378b71 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp @@ -313,7 +313,7 @@ Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_siz "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ", "but received {", output_size[0], ", ", output_size[1], ",", output_size[2], "}"); - if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1 && !input.is_xpu()) { + if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1) { // in this case, adaptive pooling is just computing mean over hw // dimensions, which can be done more efficiently Tensor out = input.mean({-1, -2, -3}, /* keepdim = */ true); From 788e78c4b16c8ef91f96b7ca853403409db72e60 Mon Sep 17 00:00:00 2001 From: chuanqiw Date: Wed, 28 Aug 2024 15:16:12 +0000 Subject: [PATCH 0150/1018] [CI] change conda to miniforge for XPU images (#134455) The `.ci/docker` change with `ciflow/xpu` label will trigger docker images rebuild on xpu runner, but xpu runner can't use miniconda, change to miniforge. Works for https://github.com/pytorch/pytorch/issues/114850 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134455 Approved by: https://github.com/atalman --- .ci/docker/common/install_conda.sh | 22 ++++++---------------- .ci/docker/ubuntu-xpu/Dockerfile | 1 + 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 4d8a7b92e34ec6..f3c198044ffed0 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -5,32 +5,22 @@ set -ex # Optionally install conda if [ -n "$ANACONDA_PYTHON_VERSION" ]; then BASE_URL="https://repo.anaconda.com/miniconda" + CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" + if [[ $(uname -m) == "aarch64" ]] || [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then + BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" + CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" + fi MAJOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 1) MINOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 2) -if [[ $(uname -m) == "aarch64" ]]; then - BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" - case "$MAJOR_PYTHON_VERSION" in - 3) - CONDA_FILE="Miniforge3-Linux-aarch64.sh" - ;; - *) - echo "Unsupported ANACONDA_PYTHON_VERSION: $ANACONDA_PYTHON_VERSION" - exit 1 - ;; - esac -else case "$MAJOR_PYTHON_VERSION" in - 3) - CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" - ;; + 3);; *) echo "Unsupported ANACONDA_PYTHON_VERSION: $ANACONDA_PYTHON_VERSION" exit 1 ;; esac -fi mkdir -p /opt/conda chown jenkins:jenkins /opt/conda diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile index 02cd1133a050c3..41f690c4ab386c 100644 --- a/.ci/docker/ubuntu-xpu/Dockerfile +++ b/.ci/docker/ubuntu-xpu/Dockerfile @@ -30,6 +30,7 @@ RUN bash ./install_docs_reqs.sh && rm install_docs_reqs.sh ARG ANACONDA_PYTHON_VERSION ARG CONDA_CMAKE ARG DOCS +ARG BUILD_ENVIRONMENT ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION ENV PATH /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/bin:/opt/conda/bin:$PATH ENV DOCS=$DOCS From 5495a1598c266db5ad383eb6f58f953b8fde6baf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 28 Aug 2024 15:49:18 +0000 Subject: [PATCH 0151/1018] Revert "Adding entry-point based support for out-of-tree rendezvous plugins (#132633)" This reverts commit 136b19b062f62c81ea3ed8fb306debe9d7720e93. Reverted https://github.com/pytorch/pytorch/pull/132633 on behalf of https://github.com/ZainRizvi due to Sorry but this is causing internal tests to fail with the error `ImportError: cannot import name '_register_out_of_tree_handlers' from 'torch.distributed.elastic.rendezvous.registry'` ([comment](https://github.com/pytorch/pytorch/pull/132633#issuecomment-2315716201)) --- .../rendezvous/out_of_tree_rendezvous_test.py | 38 ------------------- .../out_of_tree_test_package/pyproject.toml | 6 --- .../src/testbackend/__init__.py | 2 - .../elastic/rendezvous/__init__.py | 3 +- .../elastic/rendezvous/registry.py | 25 ------------ 5 files changed, 1 insertion(+), 73 deletions(-) delete mode 100644 test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py delete mode 100644 test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml delete mode 100644 test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py diff --git a/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py b/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py deleted file mode 100644 index 4d304bef1bc234..00000000000000 --- a/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py +++ /dev/null @@ -1,38 +0,0 @@ -# Owner(s): ["oncall: r2p"] - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import pathlib -import sys -import unittest - -import torch.distributed.elastic.rendezvous as rdvz - - -BACKEND_NAME = "testbackend" -TEST_PACKAGE_PATH = "/out_of_tree_test_package/src" - - -class OutOfTreeRendezvousTest(unittest.TestCase): - def test_out_of_tree_handler_loading(self): - current_path = str(pathlib.Path(__file__).parent.resolve()) - rdvz._register_out_of_tree_handlers() - registry_dict = rdvz.rendezvous_handler_registry._registry - - # test backend should not be registered as a backend - self.assertFalse(BACKEND_NAME in registry_dict) - - # Including testbackend in python path - sys.path.append(current_path + TEST_PACKAGE_PATH) - - # Registering the out of tree handlers again - rdvz._register_out_of_tree_handlers() - - # test backend should be registered as a backend - self.assertTrue(BACKEND_NAME in registry_dict) - - # Removing testbackend from python path - sys.path.remove(current_path + TEST_PACKAGE_PATH) diff --git a/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml b/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml deleted file mode 100644 index 32e177bf73fe6b..00000000000000 --- a/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml +++ /dev/null @@ -1,6 +0,0 @@ -[project] -name = "testbackend" -version = "0.0.1" - -[project.entry-points.'torchrun.handlers'] -testbackend = 'testbackend:test_handler' \ No newline at end of file diff --git a/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py b/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py deleted file mode 100644 index 1fbf5b4c2dfb10..00000000000000 --- a/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_handler(): - return "" diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index 22ec0c9a0f6758..62a31adab27b01 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -143,11 +143,10 @@ class that implements the rendezvous mechanism described above. It is a backend- RendezvousStoreInfo, RendezvousTimeoutError, ) -from .registry import _register_default_handlers, _register_out_of_tree_handlers +from .registry import _register_default_handlers _register_default_handlers() -_register_out_of_tree_handlers() __all__ = [ diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index d038ab95eabae6..1a91d0a8ff7946 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -4,9 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging -import sys - from .api import ( rendezvous_handler_registry as handler_registry, RendezvousHandler, @@ -15,13 +12,6 @@ from .dynamic_rendezvous import create_handler -if sys.version_info < (3, 10): - from importlib_metadata import entry_points -else: - from importlib.metadata import entry_points - -log = logging.getLogger(__name__) - __all__ = ["get_rendezvous_handler"] @@ -60,21 +50,6 @@ def _register_default_handlers() -> None: handler_registry.register("static", _create_static_handler) -def _register_out_of_tree_handlers() -> None: - discovered_handler_generators = entry_points(group="torchrun.handlers") - - for handler_generator in discovered_handler_generators: - try: - get_handler = discovered_handler_generators[handler_generator.name].load() - handler_registry.register(handler_generator.name, get_handler()) - except Exception: - log.warning( - "Exception while registering out of tree plugin %s: ", - handler_generator.name, - exc_info=True, - ) - - def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: """ Obtain a reference to a :py:class`RendezvousHandler`. From 1b679ec6927df93cc56afd258b852bef2dfed528 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 28 Aug 2024 15:56:08 +0000 Subject: [PATCH 0152/1018] Revert "Refactor caching device allocator utils (#130923)" This reverts commit c45ca8092dddf718563a1a754de798ad25eae1ee. Reverted https://github.com/pytorch/pytorch/pull/130923 on behalf of https://github.com/ZainRizvi due to Sorry but this appears to be causing internal tests to fail with errors like `error: no type named 'DeviceStats' in namespace 'xxx::xxx:xxxAllocator'; did you mean 'DeviceStatus'?` ([comment](https://github.com/pytorch/pytorch/pull/130923#issuecomment-2315730155)) --- c10/core/CachingDeviceAllocator.h | 131 ----------------- c10/cuda/CUDACachingAllocator.cpp | 163 ++++++++++++++------- c10/cuda/CUDACachingAllocator.h | 78 +++++++++- c10/cuda/CUDAMallocAsyncAllocator.cpp | 2 - torch/csrc/cuda/CUDAPluggableAllocator.cpp | 4 +- torch/csrc/cuda/CUDAPluggableAllocator.h | 2 +- torch/csrc/cuda/Module.cpp | 8 +- 7 files changed, 190 insertions(+), 198 deletions(-) delete mode 100644 c10/core/CachingDeviceAllocator.h diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h deleted file mode 100644 index 8724ecf88ae0f9..00000000000000 --- a/c10/core/CachingDeviceAllocator.h +++ /dev/null @@ -1,131 +0,0 @@ -#pragma once - -#include -#include - -#include - -namespace c10::CachingDeviceAllocator { - -struct Stat { - void increase(size_t amount) { - current += static_cast(amount); - peak = std::max(current, peak); - allocated += static_cast(amount); - } - - void decrease(size_t amount) { - current -= static_cast(amount); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - current >= 0, - "Negative tracked stat in device allocator (likely logic error)."); - freed += static_cast(amount); - } - - void reset_accumulated() { - allocated = 0; - freed = 0; - } - - void reset_peak() { - peak = current; - } - - int64_t current = 0; - int64_t peak = 0; - int64_t allocated = 0; - int64_t freed = 0; -}; - -enum struct StatType : uint64_t { - AGGREGATE = 0, - SMALL_POOL = 1, - LARGE_POOL = 2, - NUM_TYPES = 3 // remember to update this whenever a new stat type is added -}; - -using StatArray = std::array(StatType::NUM_TYPES)>; -using StatTypes = std::array(StatType::NUM_TYPES)>; - -template -void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { - for (const auto stat_type : c10::irange(stat_types.size())) { - if (stat_types[stat_type]) { - f(stat_type); - } - } -} - -// Struct containing memory allocator summary statistics for a device. -struct DeviceStats { - // COUNT: allocations requested by client code - StatArray allocation; - // COUNT: number of allocated segments from device memory allocation. - StatArray segment; - // COUNT: number of active memory blocks (allocated or used by stream) - StatArray active; - // COUNT: number of inactive, split memory blocks (unallocated but can't be - // released via device memory deallocation) - StatArray inactive_split; - - // SUM: bytes allocated by this memory alocator - StatArray allocated_bytes; - // SUM: bytes reserved by this memory allocator (both free and used) - StatArray reserved_bytes; - // SUM: bytes within active memory blocks - StatArray active_bytes; - // SUM: bytes within inactive, split memory blocks - StatArray inactive_split_bytes; - // SUM: bytes requested by client code - StatArray requested_bytes; - - // COUNT: total number of failed calls to device malloc necessitating cache - // flushes. - int64_t num_alloc_retries = 0; - - // COUNT: total number of OOMs (i.e. failed calls to device memory allocation - // after cache flush) - int64_t num_ooms = 0; - - // COUNT: total number of oversize blocks allocated from pool - Stat oversize_allocations; - - // COUNT: total number of oversize blocks requiring malloc - Stat oversize_segments; - - // COUNT: total number of synchronize_and_free_events() calls - int64_t num_sync_all_streams = 0; - - // COUNT: total number of device memory allocation calls. This includes both - // mapped and malloced memory. - int64_t num_device_alloc = 0; - - // COUNT: total number of device memory deallocation calls. This includes both - // un-mapped and free memory. - int64_t num_device_free = 0; - - // SIZE: maximum block size that is allowed to be split. - int64_t max_split_size = 0; -}; - -// Size pretty-printer -inline std::string format_size(uint64_t size) { - std::ostringstream os; - os.precision(2); - os << std::fixed; - if (size <= 1024) { - os << size << " bytes"; - } else if (size <= 1048576) { - os << (static_cast(size) / 1024.0); - os << " KiB"; - } else if (size <= 1073741824ULL) { - os << static_cast(size) / 1048576.0; - os << " MiB"; - } else { - os << static_cast(size) / 1073741824.0; - os << " GiB"; - } - return os.str(); -} - -} // namespace c10::CachingDeviceAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index a38f202394e5a3..4a3a2c7545c54c 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -26,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -42,8 +44,6 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace cuda::CUDACachingAllocator { -using namespace c10::CachingDeviceAllocator; - // Included here as this is externally used in CUDAAllocatorConfig const size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks @@ -134,13 +134,47 @@ namespace { using stream_set = ska::flat_hash_set; +using StatTypes = std::array(StatType::NUM_TYPES)>; + +void increase_stat(Stat& stat, size_t amount) { + stat.current += static_cast(amount); + stat.peak = std::max(stat.current, stat.peak); + stat.allocated += static_cast(amount); +} + +void decrease_stat(Stat& stat, size_t amount) { + stat.current -= static_cast(amount); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stat.current >= 0, + "Negative tracked stat in CUDA allocator (likely logic error)."); + stat.freed += static_cast(amount); +} + +void reset_accumulated_stat(Stat& stat) { + stat.allocated = 0; + stat.freed = 0; +} + +void reset_peak_stat(Stat& stat) { + stat.peak = stat.current; +} + +template +void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { + for (const auto stat_type : c10::irange(stat_types.size())) { + if (stat_types[stat_type]) { + f(stat_type); + } + } +} + void decrease_stat_array( StatArray& stat_array, size_t amount, const StatTypes& stat_types) { for_each_selected_stat_type( stat_types, [&stat_array, amount](size_t stat_type) { - stat_array[stat_type].decrease(amount); + decrease_stat(stat_array[stat_type], amount); }); } @@ -1390,16 +1424,16 @@ class DeviceCachingAllocator { // A new split inactive block is being created from a previously unsplit // block, size remaining->size bytes. for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - stats.inactive_split_bytes[stat_type].increase(remaining->size); - stats.inactive_split[stat_type].increase(1); + increase_stat(stats.inactive_split_bytes[stat_type], remaining->size); + increase_stat(stats.inactive_split[stat_type], 1); }); } } else if (already_split && !block->expandable_segment_) { // An already-split block is becoming active for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - stats.inactive_split_bytes[stat_type].decrease(block->size); - stats.inactive_split[stat_type].decrease(1); + decrease_stat(stats.inactive_split_bytes[stat_type], block->size); + decrease_stat(stats.inactive_split[stat_type], 1); }); } @@ -1420,14 +1454,14 @@ class DeviceCachingAllocator { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - stats.allocation[stat_type].increase(1); - stats.allocated_bytes[stat_type].increase(block->size); - stats.active[stat_type].increase(1); - stats.active_bytes[stat_type].increase(block->size); - stats.requested_bytes[stat_type].increase(block->requested_size); + increase_stat(stats.allocation[stat_type], 1); + increase_stat(stats.allocated_bytes[stat_type], block->size); + increase_stat(stats.active[stat_type], 1); + increase_stat(stats.active_bytes[stat_type], block->size); + increase_stat(stats.requested_bytes[stat_type], block->requested_size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - stats.oversize_allocations.increase(1); + increase_stat(stats.oversize_allocations, 1); c10::reportMemoryUsageToProfiler( block->ptr, @@ -1453,8 +1487,8 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.allocation[stat_type].decrease(1); - stats.allocated_bytes[stat_type].decrease(block->size); + decrease_stat(stats.allocation[stat_type], 1); + decrease_stat(stats.allocated_bytes[stat_type], block->size); }); record_trace( @@ -1466,7 +1500,7 @@ class DeviceCachingAllocator { context ? context : block->context_when_allocated); if (block->size >= CUDAAllocatorConfig::max_split_size()) - stats.oversize_allocations.decrease(1); + decrease_stat(stats.oversize_allocations, 1); if (!block->stream_uses.empty()) { if (C10_UNLIKELY(!captures_underway.empty())) { @@ -1594,15 +1628,15 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - stats.allocation[statType].reset_accumulated(); - stats.segment[statType].reset_accumulated(); - stats.active[statType].reset_accumulated(); - stats.inactive_split[statType].reset_accumulated(); - stats.allocated_bytes[statType].reset_accumulated(); - stats.reserved_bytes[statType].reset_accumulated(); - stats.active_bytes[statType].reset_accumulated(); - stats.inactive_split_bytes[statType].reset_accumulated(); - stats.requested_bytes[statType].reset_accumulated(); + reset_accumulated_stat(stats.allocation[statType]); + reset_accumulated_stat(stats.segment[statType]); + reset_accumulated_stat(stats.active[statType]); + reset_accumulated_stat(stats.inactive_split[statType]); + reset_accumulated_stat(stats.allocated_bytes[statType]); + reset_accumulated_stat(stats.reserved_bytes[statType]); + reset_accumulated_stat(stats.active_bytes[statType]); + reset_accumulated_stat(stats.inactive_split_bytes[statType]); + reset_accumulated_stat(stats.requested_bytes[statType]); } stats.num_alloc_retries = 0; @@ -1610,8 +1644,8 @@ class DeviceCachingAllocator { stats.num_sync_all_streams = 0; stats.num_device_alloc = 0; stats.num_device_free = 0; - stats.oversize_allocations.reset_accumulated(); - stats.oversize_segments.reset_accumulated(); + reset_accumulated_stat(stats.oversize_allocations); + reset_accumulated_stat(stats.oversize_segments); } /** Resets the historical peak stats for the device **/ @@ -1620,18 +1654,18 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - stats.allocation[statType].reset_peak(); - stats.segment[statType].reset_peak(); - stats.active[statType].reset_peak(); - stats.inactive_split[statType].reset_peak(); - stats.allocated_bytes[statType].reset_peak(); - stats.reserved_bytes[statType].reset_peak(); - stats.active_bytes[statType].reset_peak(); - stats.inactive_split_bytes[statType].reset_peak(); - stats.requested_bytes[statType].reset_peak(); + reset_peak_stat(stats.allocation[statType]); + reset_peak_stat(stats.segment[statType]); + reset_peak_stat(stats.active[statType]); + reset_peak_stat(stats.inactive_split[statType]); + reset_peak_stat(stats.allocated_bytes[statType]); + reset_peak_stat(stats.reserved_bytes[statType]); + reset_peak_stat(stats.active_bytes[statType]); + reset_peak_stat(stats.inactive_split_bytes[statType]); + reset_peak_stat(stats.requested_bytes[statType]); } - stats.oversize_allocations.reset_peak(); - stats.oversize_segments.reset_peak(); + reset_peak_stat(stats.oversize_allocations); + reset_peak_stat(stats.oversize_segments); } /* Checkpoint the state of a private pool necessary to return it to its @@ -2243,7 +2277,7 @@ class DeviceCachingAllocator { total_allocated_memory += mapped_range.size; StatTypes stat_types = get_stat_types_for_pool(*to_map->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.reserved_bytes[stat_type].increase(mapped_range.size); + increase_stat(stats.reserved_bytes[stat_type], mapped_range.size); }); stats.num_device_alloc++; @@ -2350,23 +2384,27 @@ class DeviceCachingAllocator { // inactive_split if (!block->expandable_segment_) { if (net_change_inactive_split_blocks > 0) { - stats.inactive_split[stat_type].increase( + increase_stat( + stats.inactive_split[stat_type], static_cast(net_change_inactive_split_blocks)); } else if (net_change_inactive_split_blocks < 0) { - stats.inactive_split[stat_type].decrease( + decrease_stat( + stats.inactive_split[stat_type], static_cast(-net_change_inactive_split_blocks)); } if (net_change_inactive_split_size > 0) { - stats.inactive_split_bytes[stat_type].increase( + increase_stat( + stats.inactive_split_bytes[stat_type], static_cast(net_change_inactive_split_size)); } else if (net_change_inactive_split_size < 0) { - stats.inactive_split_bytes[stat_type].decrease( + decrease_stat( + stats.inactive_split_bytes[stat_type], static_cast(-net_change_inactive_split_size)); } } - stats.active[stat_type].decrease(1); - stats.active_bytes[stat_type].decrease(original_block_size); - stats.requested_bytes[stat_type].decrease(requested_size); + decrease_stat(stats.active[stat_type], 1); + decrease_stat(stats.active_bytes[stat_type], original_block_size); + decrease_stat(stats.requested_bytes[stat_type], requested_size); }); } @@ -2673,11 +2711,11 @@ class DeviceCachingAllocator { total_allocated_memory += size; p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { - stats.segment[stat_type].increase(1); - stats.reserved_bytes[stat_type].increase(size); + increase_stat(stats.segment[stat_type], 1); + increase_stat(stats.reserved_bytes[stat_type], size); }); if (size >= CUDAAllocatorConfig::max_split_size()) - stats.oversize_segments.increase(1); + increase_stat(stats.oversize_segments, 1); // p.block came from new, not cudaMalloc. It should not be nullptr here. TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); @@ -2808,12 +2846,12 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.segment[stat_type].decrease(1); - stats.reserved_bytes[stat_type].decrease(block->size); + decrease_stat(stats.segment[stat_type], 1); + decrease_stat(stats.reserved_bytes[stat_type], block->size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - stats.oversize_segments.decrease(1); + decrease_stat(stats.oversize_segments, 1); pool->blocks.erase(block); delete block; } @@ -2865,7 +2903,7 @@ class DeviceCachingAllocator { total_allocated_memory -= unmapped.size; StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.reserved_bytes[stat_type].decrease(unmapped.size); + decrease_stat(stats.reserved_bytes[stat_type], unmapped.size); }); if (block->pool->owner_PrivatePool) { @@ -3694,6 +3732,25 @@ void local_raw_delete(void* ptr) { } } // namespace Native +// Size pretty-printer +std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (static_cast(size) / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << static_cast(size) / 1048576.0; + os << " MiB"; + } else { + os << static_cast(size) / 1073741824.0; + os << " GiB"; + } + return os.str(); +} namespace CudaMallocAsync { // If this is put in its own header file, it gets incorrectly renamed in HIPify. diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 9d1a631f3750ba..72617bcaf3a944 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -50,6 +50,73 @@ namespace c10::cuda::CUDACachingAllocator { extern const size_t kLargeBuffer; +struct Stat { + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +typedef std::array(StatType::NUM_TYPES)> StatArray; + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from cudaMalloc(). + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via cudaFree) + StatArray inactive_split; + + // SUM: bytes allocated by this memory alocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to CUDA malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of CUDA allocation calls. This includes both cuMemMap + // and cudaMalloc. + int64_t num_device_alloc = 0; + + // COUNT: total number of CUDA free calls. This includes both cuMemUnmap + // and cudaFree. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + typedef std::shared_ptr (*CreateContextFn)(); // Struct containing info of an allocation block (i.e. a fractional part of a @@ -180,6 +247,9 @@ enum struct RecordContext { ALL = 3, // additionally record stacks for when something is freed }; +// Size pretty-printer +std::string format_size(uint64_t size); + using OutOfMemoryObserver = std::functionrecordStream(dataPtr, stream); } -inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device) { +inline DeviceStats getDeviceStats(c10::DeviceIndex device) { return get()->getDeviceStats(device); } diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 3a7b485ebce223..3fe414b55e30a2 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -11,8 +11,6 @@ namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync { -using namespace c10::CachingDeviceAllocator; - #if CUDA_VERSION >= 11040 // CUDA device allocator that uses cudaMallocAsync to implement // the same interface as CUDACachingAllocator.cpp. diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index 5220e86233bd67..c6af163481481e 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -210,8 +210,8 @@ void CUDAPluggableAllocator::recordStream( } } -c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats( - c10::DeviceIndex device) { +c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: + getDeviceStats(c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getDeviceStats. " diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 8652ef0f2bfde8..70014b2fa0b37b 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -115,7 +115,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator void recordStream(const c10::DataPtr&, streamType stream) override; - c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 0106eebb1cd6ea..1cc6f7378e80c4 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -554,10 +554,10 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const auto device_index = THPUtils_unpackDeviceIndex(arg); - using c10::CachingDeviceAllocator::DeviceStats; - using c10::CachingDeviceAllocator::Stat; - using c10::CachingDeviceAllocator::StatArray; - using c10::CachingDeviceAllocator::StatType; + using c10::cuda::CUDACachingAllocator::DeviceStats; + using c10::cuda::CUDACachingAllocator::Stat; + using c10::cuda::CUDACachingAllocator::StatArray; + using c10::cuda::CUDACachingAllocator::StatType; const auto statToDict = [](const Stat& stat) { py::dict dict; From 89f35ad9011039b5e5bc2daaaf229c6673e04732 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 28 Aug 2024 16:05:42 +0000 Subject: [PATCH 0153/1018] Revert "hang dim hint constants off Dim (#134484)" This reverts commit c142af7209a423a05504fdec50680333f5a37629. Reverted https://github.com/pytorch/pytorch/pull/134484 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/134484#issuecomment-2315749549)) --- docs/source/export.rst | 1 + test/export/test_export.py | 32 +++++++++++----- torch/_export/non_strict_utils.py | 4 +- torch/export/dynamic_shapes.py | 62 ++++++++++++++----------------- 4 files changed, 53 insertions(+), 46 deletions(-) diff --git a/docs/source/export.rst b/docs/source/export.rst index 3908beeb9b345e..deb84548a19cdc 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -676,6 +676,7 @@ API Reference .. autofunction:: save .. autofunction:: load .. autofunction:: register_dataclass +.. autoclass:: torch.export.dynamic_shapes.DIM .. autofunction:: torch.export.dynamic_shapes.Dim .. autofunction:: dims .. autoclass:: torch.export.dynamic_shapes.ShapesCollection diff --git a/test/export/test_export.py b/test/export/test_export.py index 86364e061e2888..2393443b5f0b0f 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1773,6 +1773,8 @@ def forward(self, x, y, y1, z): ) def test_static_dim_constraints(self): + from torch.export.dynamic_shapes import DIM + class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1801,7 +1803,7 @@ def forward(self, x, y, z): ((dx, None), (dy, 4), (dz, 3)), ((None, 6), (5, None), (None, None)), ((4, 6), {0: None, 1: 4}, {0: None, 1: 3}), - (None, None, (Dim.STATIC, Dim.STATIC)), + (None, None, (DIM.STATIC, DIM.STATIC)), ]: ep = export(foo, inputs, dynamic_shapes=dynamic_shapes) self.assertEqual(foo(*inputs), ep.module()(*inputs)) @@ -1948,7 +1950,9 @@ def forward(self, inp: Inp): self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") def test_mismatched_dynamic_shapes(self): - AUTO, STATIC = Dim.AUTO, Dim.STATIC + from torch.export.dynamic_shapes import DIM + + AUTO, STATIC = DIM.AUTO, DIM.STATIC class M(torch.nn.Module): def forward(self, x): @@ -1980,7 +1984,7 @@ def forward(self, x): + re.escape( "specified at `dynamic_shapes[0]['k']['k'][0]` " "(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - " where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)" + " where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)" ), ): export(M(), inputs, dynamic_shapes=dynamic_shapes) @@ -2057,7 +2061,7 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, re.escape( - "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " "and can easily lead to constraint violation errors or obscure errors in torch.export." ), ): @@ -2461,7 +2465,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_mark_and_auto_dynamic(self): # for this use case, mark_dynamic() and AUTO should have same effect. # check that same symbol gets allocated to both dims without raising constraint violation. - AUTO, STATIC = Dim.AUTO, Dim.STATIC + from torch.export.dynamic_shapes import DIM + + AUTO, STATIC = DIM.AUTO, DIM.STATIC class Foo(torch.nn.Module): def forward(self, x, y): @@ -2489,7 +2495,9 @@ def forward(self, x, y): def test_dont_duck_size_for_auto_dynamic(self): # for this use case, mark_dynamic() and AUTO should have same effect. # check that same symbol gets allocated to both dims without raising constraint violation. - AUTO, STATIC = Dim.AUTO, Dim.STATIC + from torch.export.dynamic_shapes import DIM + + AUTO, STATIC = DIM.AUTO, DIM.STATIC class Foo(torch.nn.Module): def forward(self, x, y): @@ -6868,7 +6876,9 @@ def test_automatic_dynamic_shapes_simple_equality(self): # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism # leads to replacement symbols being set for equalities, and inferred relationships being checked # with runtime asserts. Check that we specialize to static values when the program says so. - AUTO, STATIC = Dim.AUTO, Dim.STATIC + from torch.export.dynamic_shapes import DIM + + AUTO, STATIC = DIM.AUTO, DIM.STATIC # case 1: direct equality between symbols class SimpleEquality(torch.nn.Module): @@ -6933,7 +6943,9 @@ def forward(self, x, y, z): ) def test_automatic_dynamic_shapes_constant_relation(self): - AUTO, STATIC = Dim.AUTO, Dim.STATIC + from torch.export.dynamic_shapes import DIM + + AUTO, STATIC = DIM.AUTO, DIM.STATIC # case 2: related by constant: s0 + 4 = s1 class OffBy4(torch.nn.Module): @@ -6976,7 +6988,9 @@ def forward(self, x, y): ) def test_automatic_dynamic_shapes_linear_relation(self): - AUTO, STATIC = Dim.AUTO, Dim.STATIC + from torch.export.dynamic_shapes import DIM + + AUTO, STATIC = DIM.AUTO, DIM.STATIC # case 3: linear relation class LinearRel(torch.nn.Module): diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index ddb51064f9ad78..8b5ff29b6c9e02 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -24,10 +24,10 @@ from torch.export.dynamic_shapes import ( _check_dynamic_shapes, _combine_args, - _DimHint, _process_dynamic_shapes, _transform_shapes_for_default_dynamic, _tree_map_with_path, + DIM, ) from torch.export.graph_signature import CustomObjArgument from torch.fx.experimental import _config as config @@ -351,7 +351,7 @@ def make_constraints( # we want the symbol, not its replacement, which could be an expression. Maybe # there's a better way to do this, e.g., by (re)computing value ranges for expressions? dim = shape_spec[i] if shape_spec else None - if dim is None or isinstance(dim, _DimHint): + if dim is None or isinstance(dim, DIM): range_constraints[d.node.expr] = shape_env.var_to_range[ d.node._expr ] diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 53266b59e04a33..48438ee8cb2a55 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -30,21 +30,20 @@ __all__ = [ "Constraint", + "DIM", "Dim", "dims", "refine_dynamic_shapes_from_suggested_fixes", ] -class _DimHint(Enum): +class DIM(Enum): """ - Enum for dynamic shape hints. - - AUTO means automatic inference of shape (static or dynamic). - - STATIC means static shape (always specialized). + Enum for automatic/static dynamic shapes. """ - AUTO = auto() STATIC = auto() + AUTO = auto() class _Dim(type): @@ -215,7 +214,6 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): Returns: A type that can be used in dynamic shape specifications for tensors. """ - from torch.utils._sympy.numbers import int_oo _min = 0 if min is None else min @@ -229,10 +227,6 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): return dim -Dim.AUTO = _DimHint.AUTO # hint for automatic inference of shape (static or dynamic) -Dim.STATIC = _DimHint.STATIC # hint for static shape (always specialized) - - def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): """ Util to create multiple :func:`Dim` types. @@ -674,24 +668,24 @@ def check_symbols(path, tensor, shape): for i, dim in shape.items(): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, _DimHint)) or dim is None): + elif not (isinstance(dim, (int, DIM)) or dim is None): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " - f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", + f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", case_name="dynamic_shapes_validation", ) elif isinstance(shape, (tuple, list)): for i, dim in enumerate(shape): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, _DimHint)) or dim is None): + elif not (isinstance(dim, (int, DIM)) or dim is None): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension #{i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " - f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", + f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", case_name="dynamic_shapes_validation", ) elif shape is not None: @@ -699,7 +693,7 @@ def check_symbols(path, tensor, shape): UserErrorType.INVALID_INPUT, f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - f" where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)", + f" where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)", case_name="dynamic_shapes_validation", ) @@ -746,18 +740,18 @@ def check_shape(path, t, dynamic_shape): _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") - # raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes + # raise user warning if both DIM.AUTO & Dims are specified in dynamic_shapes flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes) if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any( - s == _DimHint.AUTO for s in flatter_dynamic_shapes + s == DIM.AUTO for s in flatter_dynamic_shapes ): raise UserError( UserErrorType.INVALID_INPUT, - "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " "and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims " - "expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. " - "We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " + "expect all equal or related dimensions to be specified, and does not yet compose well with `DIM.AUTO`. " + "We suggest using `DIM.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " "torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` " "if you want to assert on the exact specification of your program's dynamic shapes behavior.", case_name="dynamic_shapes_validation", @@ -779,8 +773,8 @@ def _transform_shapes_for_default_dynamic( for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.) For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are: - - Dim.AUTO: dynamic, allocated a symbol - - None/unspecified/Dim.STATIC: static + - DIM.AUTO: dynamic, allocated a symbol + - None/unspecified/DIM.STATIC: static - Dim/DerivedDims: also a strict assertion To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes @@ -790,8 +784,8 @@ def _transform_shapes_for_default_dynamic( An example conversion might look like, for a 3-d input tensor: input spec: { - 0: Dim.AUTO, - 1: None, # or Dim.STATIC + 0: DIM.AUTO, + 1: None, # or DIM.STATIC 2: Dim("dx"), } output spec: { @@ -838,10 +832,10 @@ def _marked_dynamic(tensor, i): out = {} for i, val in enumerate(tensor.shape): dim = shape.get(i, None) - if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: + if _marked_dynamic(tensor, i) or dim == DIM.AUTO: # don't have to specify anything if dynamic # None also works, since assume_static_by_default=False - if dim == _DimHint.AUTO: + if dim == DIM.AUTO: torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing continue elif isinstance(dim, _Dim): @@ -852,14 +846,14 @@ def _marked_dynamic(tensor, i): out[i] = dim else: # make explicitly static - assert dim is None or dim == _DimHint.STATIC + assert dim is None or dim == DIM.STATIC out[i] = val elif isinstance(shape, (tuple, list)): out = [] for i, val in enumerate(tensor.shape): dim = shape[i] - if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: - if dim == _DimHint.AUTO: + if _marked_dynamic(tensor, i) or dim == DIM.AUTO: + if dim == DIM.AUTO: torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing out.append(None) elif isinstance(dim, _Dim): @@ -867,7 +861,7 @@ def _marked_dynamic(tensor, i): elif isinstance(dim, int): out.append(dim) else: - assert dim is None or dim == _DimHint.STATIC + assert dim is None or dim == DIM.STATIC out.append(val) out = type(shape)(out) # type: ignore[assignment] else: @@ -1046,7 +1040,7 @@ def _get_dim_name_mapping( dynamic_shapes, is_leaf=lambda x: isinstance(x, _Dim), )[0]: - if isinstance(dim, (int, _DimHint)) or dim is None: + if isinstance(dim, (int, DIM)) or dim is None: continue name_to_dim[dim.__name__] = dim if isinstance(dim, _DerivedDim): @@ -1128,11 +1122,9 @@ def refine_dynamic_shapes_from_suggested_fixes( name, expr = fix.split(" = ") expr = sympy.sympify(expr) if isinstance(expr, sympy.Number): - # static, integer - shape_fixes[name] = int(expr) # type: ignore[assignment] + shape_fixes[name] = int(expr) # static, integer else: - # relation or derived dim - shape_fixes[name] = expr + shape_fixes[name] = expr # relation or derived dim name_to_dim = _get_dim_name_mapping(dynamic_shapes) From 2086101ac15f829daa3fafdba8964bdfac1c9283 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 26 Aug 2024 13:14:09 -0700 Subject: [PATCH 0154/1018] [DCP] Fixes the BC issue where the traversal doesn't support versions before 2.4 (#134158) The original DCP doesn't flattening all the containers, which can cause issues, https://github.com/pytorch/pytorch/pull/125335 intends to solve the issue by flattening all the dictionaries. Unfortunately, it breaks the checkpoints that are saved before 2.4. This also shows some issues of the DCP: 1. DCP should record version in the metadata. 2. DCP should have a nice way to load old state_dict. 3. DCP should unflatten all containers (map, list) not just map. This PR only addresses issue 2 to unblock users. Issue 1 and issue 3 need to be addressed in the future. @pradeepfn Please let me know if this summary matches our discussion. Fixes https://github.com/pytorch/pytorch/issues/133923 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134158 Approved by: https://github.com/wz337, https://github.com/pradeepfn --- .../checkpoint/test_compatibility.py | 25 +++++++++ torch/distributed/checkpoint/_nested_dict.py | 20 ++++++- torch/distributed/checkpoint/_traverse.py | 52 ++++++++++++++++--- torch/distributed/checkpoint/_version.py | 6 +++ .../distributed/checkpoint/default_planner.py | 35 +++++++++++++ 5 files changed, 130 insertions(+), 8 deletions(-) create mode 100644 torch/distributed/checkpoint/_version.py diff --git a/test/distributed/checkpoint/test_compatibility.py b/test/distributed/checkpoint/test_compatibility.py index e64e6228986db2..c7bbe1a006c552 100644 --- a/test/distributed/checkpoint/test_compatibility.py +++ b/test/distributed/checkpoint/test_compatibility.py @@ -70,6 +70,31 @@ def test_storage_meta(self) -> None: self.assertEqual(storage_meta.save_id, writer.save_id) self.assertEqual(storage_meta.load_id, reader.load_id) + @with_temp_dir + def test_with_v_2_3(self) -> None: + sd = { + "a": torch.zeros(4, 4), + "dict": { + "dict_a": {"dict_a_1": 1, "dict_a_2": 2}, + "dict_b": {"dict_b_1": 1, "dict_b_2": 2}, + }, + "list": [0, 1, 2, 3, 4, 5], + } + load_sd = { + "a": torch.ones(4, 4), + "dict": { + "dict_a": {"dict_a_1": 2, "dict_a_2": 4}, + "dict_b": {"dict_b_1": 2, "dict_b_2": 4}, + }, + "list": [10, 11, 12, 13, 14, 15], + } + + dcp._version._act_like_version = "2_3" + dcp.save(sd, checkpoint_id=self.temp_dir) + dcp._version._act_like_version = None + dcp.load(load_sd, checkpoint_id=self.temp_dir) + self.assertEqual(sd, load_sd) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/_nested_dict.py b/torch/distributed/checkpoint/_nested_dict.py index 3347ea8bc432ae..b846a1d47d8434 100644 --- a/torch/distributed/checkpoint/_nested_dict.py +++ b/torch/distributed/checkpoint/_nested_dict.py @@ -3,7 +3,14 @@ from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE -from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict +from . import _version +from ._traverse import ( + OBJ_PATH, + set_element, + STATE_DICT_ITEM, + traverse_state_dict, + traverse_state_dict_v_2_3, +) """ @@ -40,7 +47,16 @@ def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: flattened[new_fqn] = value mappings[new_fqn] = path - traverse_state_dict(state_dict, flat_copy) + # We started to flatten dictionary since v2.4. But in order to not break + # the checkpoints that were saved before v2.4, we need to keep the old + # traversal so that we can reconstruct those checkpoints. + use_v_2_3 = ( + _version._derived_version is not None and _version._derived_version == "2_3" + ) + if use_v_2_3: + traverse_state_dict_v_2_3(state_dict, flat_copy) + else: + traverse_state_dict(state_dict, flat_copy) return flattened, mappings diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index 8bcb832c71980f..db7957e814604c 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -40,14 +40,11 @@ def traverse_state_dict( ) -> None: """ Invoke ``visitor`` for each value recursively in ``state_dict``. - Mapping, list, and tuple will be flattened and other value types are treated - as the terminal values and will invoke ``visitor``. - Mapping is treated as non terminal node and will be flattened. - List and tuple, on the other hand, will not be flattened unless containing other - mapping containers or tensors. + Mapping will be traversed and ``visitor`` will be applied to the leaf elements. + ``visitor`` will only be applied to elements in a list or a tuple, if the + container contains tensors or mappings. """ - # a value is terminal if it has no other containers values inside it def _is_terminal(value: STATE_DICT_ITEM) -> bool: values: Collection[STATE_DICT_ITEM] if isinstance(value, Mapping): @@ -78,6 +75,49 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: _traverse_obj((str(key),), value) +def traverse_state_dict_v_2_3( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates + to false for all elements. + By default, all collections with at least one ``torch.Tensor`` element are traversed. + Visitor takes a path argument that is a tuple of the keys used to reach it. + """ + + # a value is terminal if it has no other containers values inside it + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + values = value.values() + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if _is_terminal(value): + visitor(path, value) + elif isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, list): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + def set_element( root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM ) -> None: diff --git a/torch/distributed/checkpoint/_version.py b/torch/distributed/checkpoint/_version.py new file mode 100644 index 00000000000000..b3065bdfd6a2c1 --- /dev/null +++ b/torch/distributed/checkpoint/_version.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from typing import Optional + + +_derived_version: Optional[str] = None diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index 462808217ede63..b0325b195c82d4 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -46,6 +46,8 @@ ) from torch.distributed.checkpoint.utils import find_state_dict_object +from . import _version + logger: logging.Logger = logging.getLogger(__name__) @@ -195,6 +197,39 @@ def set_up_planner( def create_local_plan(self) -> LoadPlan: assert self.metadata is not None + if self.flatten_state_dict: + # To support checkpoints that are saved before v2.4, we have to + # differentiate if the missing keys are due to old checkpoints. + # The contracts are: + # 1. There are 3 cases when we found a missing key. + # 1.1 Actual missing key, but allow_partial_load is False + # 1.2 Actual missing key, but allow_partial load is True + # 1.3 Old checkpoint, but allow_partial_load is False + # 1.4 Old checkpoint, but allow_partial_load is True + # 2. If we found a missing key, we first convert the keys back to + # the key format of v2.3 + # 3. If the previous missing keys are in the v2.3 keys, we assume + # this is a old checkpoint. + # 4. Pass the state_dict to `create_default_local_load_plan()`, + # which has the logic to check missing for allow_partial_load. + # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to + # `create_default_local_load_plan()`. The logic here is to determine + # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after). + current_keys = set(self.state_dict.keys()) + load_keys = set(self.metadata.state_dict_metadata.keys()) + missing_keys = load_keys - current_keys + if missing_keys: + _version._derived_version = "2_3" + old_state_dict, old_mappings = flatten_state_dict( + self.original_state_dict + ) + old_keys = set(old_state_dict.keys()) + if old_keys & missing_keys: + self.state_dict, self.mappings = old_state_dict, old_mappings + # _derived_version is only used by flatten_state_dict now. + # Set it back to None so that later we can save to a new version. + _version._derived_version = None + return create_default_local_load_plan( self.state_dict, self.metadata, not self.allow_partial_load ) From 8a4fb3746f0399e9d5e2ed1462fd8d3ad043b976 Mon Sep 17 00:00:00 2001 From: Chien-Lin Chen Date: Wed, 28 Aug 2024 16:34:03 +0000 Subject: [PATCH 0155/1018] parameterized test_graph_optims and test_graph_scaling_fused_optimizers (#133749) Fixes #123451 This is a rework of a reverted pull request, https://github.com/pytorch/pytorch/pull/125127. The test failure is fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133749 Approved by: https://github.com/janeyx99 --- test/inductor/test_compiled_optimizers.py | 2 + test/test_cuda.py | 382 +++++++++---------- torch/testing/_internal/common_optimizers.py | 63 ++- 3 files changed, 230 insertions(+), 217 deletions(-) diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index a1a38a95b9aca4..0fcba0d0db56e0 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -146,6 +146,8 @@ class KernelCounts(NamedTuple): "test_sgd_weight_decay_maximize_cuda": 4, "test_sgd_weight_decay_maximize_xpu": 4, "test_sgd_weight_decay_maximize_cpu": 4, + "test_sgd_weight_decay_cpu": 4, + "test_sgd_weight_decay_cuda": 4, "test_sgd_momentum_weight_decay_foreach_cuda": 2, "test_sgd_momentum_weight_decay_foreach_xpu": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, diff --git a/test/test_cuda.py b/test/test_cuda.py index e5e25a678e2352..bb52e8c9360ffc 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -40,7 +40,12 @@ onlyCUDA, onlyNativeDeviceTypes, ) -from torch.testing._internal.common_optimizers import optim_db, optims, TensorTracker +from torch.testing._internal.common_optimizers import ( + _get_optim_inputs_including_global_cliquey_kwargs, + optim_db, + optims, + TensorTracker, +) from torch.testing._internal.common_utils import ( EXPANDABLE_SEGMENTS, freeze_rng_state, @@ -3454,93 +3459,6 @@ def _test_graphed_optimizer( for p_control, p_graphed in zip(params_control, params_graphed): self.assertEqual(p_control, p_graphed) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_optims(self): - # Needs generalization if we want to extend this test to non-Adam-like optimizers. - cases = ( - [ - ( - optimizer_ctor, - { - "lr": 0.1, - "betas": (0.8, 0.7), - "foreach": foreach, - "decoupled_weight_decay": decoupled_weight_decay, - "weight_decay": weight_decay, - }, - ) - for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product( - (torch.optim.NAdam, torch.optim.RAdam), - (False, True), - (False, True), - (0.0, 0.1), - ) - ] - + [ - ( - torch.optim.Rprop, - {"lr": 0.1, "foreach": foreach, "maximize": maximize}, - ) - for foreach, maximize in product( - (False, True), - (False, True), - ) - ] - + [ - ( - optimizer_ctor, - { - "lr": 0.1, - "betas": (0.8, 0.7), - "foreach": foreach, - "amsgrad": amsgrad, - }, - ) - for optimizer_ctor, foreach, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), - (False, True), - (False, True), - ) - ] - + [ - ( - optimizer_ctor, - {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, - ) - for optimizer_ctor, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), (False, True) - ) - ] - + [ - ( - optimizer_ctor, - { - "lr": 0.1, - "foreach": foreach, - "maximize": maximize, - "weight_decay": weight_decay, - }, - ) - for optimizer_ctor, foreach, maximize, weight_decay in product( - ( - torch.optim.Adamax, - torch.optim.ASGD, - torch.optim.Adadelta, - torch.optim.RMSprop, - ), - (False, True), - (False, True), - (0, 0.1), - ) - ] - ) - - for optimizer_ctor, kwargs in cases: - with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs): - self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs) - @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -3612,125 +3530,6 @@ def test_graph_optims_with_explicitly_capturable_param_groups(self): self.assertEqual(ref_p1, param1) self.assertEqual(ref_p2, param2) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_scaling_fused_optimizers(self): - cases = [ - ( - optimizer_ctor, - {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, - ) - for optimizer_ctor, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), (False, True) - ) - ] + list( - product( - (torch.optim.SGD,), - [ - { - "lr": 0.1, - "momentum": 0.0, - "dampening": d, - "weight_decay": w, - "nesterov": n, - "fused": True, - } - for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) - ] - + [ - { - "lr": 0.1, - "momentum": 0.5, - "dampening": d, - "weight_decay": w, - "nesterov": n, - "fused": True, - } - for d, w, n in product((0.0,), (0.0, 0.5), (True, False)) - ], - ) - ) - - steps_warmup = 3 - steps_train = 2 - - for OptClass, kwargs in cases: - has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW) - for actually_do_graphs in (True, False) if has_capturable_arg else (True,): - params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] - - # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] - with torch.no_grad(): - grads_control = [[g.clone() for g in gs] for gs in grads] - grads_graphed = [[g.clone() for g in gs] for gs in grads] - - # Gradient Scaler - scaler_for_control = torch.amp.GradScaler( - device="cuda", init_scale=128.0 - ) - with torch.no_grad(): - scaler_for_control._lazy_init_scale_growth_tracker( - torch.device("cuda") - ) - - scaler_for_graphed = torch.amp.GradScaler(device="cuda") - scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) - with torch.no_grad(): - scaler_for_graphed._lazy_init_scale_growth_tracker( - torch.device("cuda") - ) - - # Control (capturable=False) - if has_capturable_arg: - kwargs["capturable"] = False - opt = OptClass(params_control, **kwargs) - - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads_control[i][j] - scaler_for_control.step(opt) - scaler_for_control.update() - - # capturable=True - if has_capturable_arg: - kwargs["capturable"] = True - opt = OptClass(params_graphed, **kwargs) - - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads_graphed[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i + steps_warmup][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) - @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -5082,10 +4881,179 @@ def create_mempool_and_make_active(): self.assertEqual(len(set(active_pool_ids)), 4) +@torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests # to apply the new OptimizerInfo structure. + @onlyCUDA + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs" + ) + @optims( + [optim for optim in optim_db if optim.has_capturable_arg], + dtypes=[torch.float32], + ) + def test_graph_optims(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( + device, dtype, optim_info, skip=("differentiable",) + ) + + steps_warmup = 3 + steps_train = 2 + + for optim_input in all_optim_inputs: + kwargs = optim_input.kwargs + + # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam + # and torch.optim.adamw + kwargs["lr"] = 0.1 + + for actually_do_graphs in (True, False): + params = [ + torch.randn((i + 5, i + 5), device=device) for i in range(2) + ] + [torch.randn((), device=device)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + + # Control (capturable=False) + kwargs["capturable"] = False + + opt = optim_cls(params_control, **kwargs) + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads[i][j] + opt.step() + + # capturable=True + kwargs["capturable"] = True + opt = optim_cls(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads[i][j] + opt.step() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + opt.step() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads[i + steps_warmup][j] + opt.step() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + + @onlyCUDA + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @optims( + [ + optim + for optim in optim_db + if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on + ], + dtypes=[torch.float32], + ) + def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + + steps_warmup = 3 + steps_train = 2 + + optim_inputs = optim_info.optim_inputs_func(device=device) + + for optim_input in optim_inputs: + kwargs = optim_input.kwargs + kwargs["fused"] = True + + for actually_do_graphs in ( + (True, False) if optim_info.has_capturable_arg else (True,) + ): + params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + with torch.no_grad(): + grads_control = [[g.clone() for g in gs] for gs in grads] + grads_graphed = [[g.clone() for g in gs] for gs in grads] + + # Gradient Scaler + scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) + with torch.no_grad(): + scaler_for_control._lazy_init_scale_growth_tracker(device) + + scaler_for_graphed = torch.cuda.amp.GradScaler() + scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) + with torch.no_grad(): + scaler_for_graphed._lazy_init_scale_growth_tracker(device) + + # Control (capturable=False) + if optim_info.has_capturable_arg: + kwargs["capturable"] = False + opt = optim_cls(params_control, **kwargs) + + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads_control[i][j] + scaler_for_control.step(opt) + scaler_for_control.update() + + # capturable=True + if optim_info.has_capturable_arg: + kwargs["capturable"] = True + opt = optim_cls(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads_graphed[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i + steps_warmup][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + @onlyNativeDeviceTypes @optims( [optim for optim in optim_db if "fused" in optim.supported_impls], diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 7d4d225c8e4012..bd8f1f2963f525 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -132,6 +132,8 @@ def __init__( ), # the optim supports passing in sparse gradients as well as dense grads supports_sparse: bool = False, + # the optimizer constructor supports passing in capturable as a kwarg + has_capturable_arg: bool = False, # the optim only supports one config: sparse grads w/ dense params, see SparseAdam only_supports_sparse_grads: bool = False, # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests, @@ -157,6 +159,7 @@ def __init__( self.supported_impls = supported_impls self.not_og_supported_flags = not_og_supported_flags self.supports_sparse = supports_sparse + self.has_capturable_arg = has_capturable_arg self.metadata_for_sparse = metadata_for_sparse self.only_supports_sparse_grads = only_supports_sparse_grads self.supports_complex = supports_complex @@ -330,10 +333,11 @@ def optim_inputs_func_adadelta(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", + desc="maximize, weight_decay", ), OptimizerInput( params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" @@ -631,9 +635,14 @@ def optim_inputs_func_adamax(device, dtype=None): ), OptimizerInput( params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, + kwargs={"maximize": True}, desc="maximize", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize, weight_decay", + ), ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) @@ -788,9 +797,16 @@ def optim_inputs_func_nadam(device, dtype=None): ), OptimizerInput( params=None, - kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + kwargs={ + "weight_decay": 0.1, + }, desc="weight_decay", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + desc="weight_decay, momentum_decay", + ), OptimizerInput( params=None, kwargs={ @@ -933,11 +949,26 @@ def optim_inputs_func_rmsprop(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + }, + desc="maximize", + ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True}, desc="centered", ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + "weight_decay": 0.1, + }, + desc="maximize, weight_decay", + ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1}, @@ -951,7 +982,7 @@ def optim_inputs_func_rmsprop(device, dtype=None): "momentum": 0.1, "maximize": True, }, - desc="maximize", + desc="maximize, centered, weight_decay, w/ momentum", ), ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) @@ -1022,7 +1053,15 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr" ), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay" + ), OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "dampening": 0.5}, @@ -1031,18 +1070,13 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"momentum": 0.9, "weight_decay": 0.1}, - desc="non-zero weight_decay", + desc="weight_decay w/ momentum", ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, desc="nesterov", ), - OptimizerInput( - params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", - ), ] @@ -1208,6 +1242,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adadelta, optim_error_inputs_func=optim_error_inputs_func_adadelta, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1493,6 +1528,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), + has_capturable_arg=True, not_og_supported_flags=( "foreach", "differentiable", @@ -1578,6 +1614,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamax, optim_error_inputs_func=optim_error_inputs_func_adamax, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1630,6 +1667,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "capturable", ), supports_fused_on=("cpu", "cuda", "mps"), + has_capturable_arg=True, decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( @@ -1710,6 +1748,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_asgd, optim_error_inputs_func=optim_error_inputs_func_asgd, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1822,6 +1861,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_nadam, optim_error_inputs_func=optim_error_inputs_func_nadam, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1870,6 +1910,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_radam, optim_error_inputs_func=optim_error_inputs_func_radam, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1915,6 +1956,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rmsprop, optim_error_inputs_func=optim_error_inputs_func_rmsprop, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1964,6 +2006,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rprop, optim_error_inputs_func=optim_error_inputs_func_rprop, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # Rprop doesn't update for non-contiguous, see #118117 From 9761f64e6e4d46afc254d672909fcbfff64cf213 Mon Sep 17 00:00:00 2001 From: Banit Agrawal Date: Wed, 28 Aug 2024 17:12:01 +0000 Subject: [PATCH 0156/1018] [PyTorch Pin Memory Allocator] Optimize the free list implementation and add lock sharding (#134154) Summary: This diff addresses the lock contention issue in free list implementation of CachingHost/Pinned allocator. We add a different data structure for free list and also add lock sharding based on allocation size. Differential Revision: D61623367 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134154 Approved by: https://github.com/guangyey, https://github.com/jgong5, https://github.com/zyan0, https://github.com/EikanWang, https://github.com/jiayisuse --- aten/src/ATen/core/CachingHostAllocator.h | 88 ++++++++++++----------- 1 file changed, 45 insertions(+), 43 deletions(-) diff --git a/aten/src/ATen/core/CachingHostAllocator.h b/aten/src/ATen/core/CachingHostAllocator.h index 5af7f5f564a7dc..2ca90c914722c1 100644 --- a/aten/src/ATen/core/CachingHostAllocator.h +++ b/aten/src/ATen/core/CachingHostAllocator.h @@ -30,20 +30,17 @@ struct HostBlock { ska::flat_hash_set streams_; // streams on which the block was used }; -/** - * ComparatorSize is used for lookup support in the set of host memory blocks - * using the block size. - */ template -struct ComparatorSize { - bool operator()(const B* a, const B* b) const { - if (a->size_ != b->size_) { - return a->size_ < b->size_; - } - return (uintptr_t)a->ptr_ < (uintptr_t)b->ptr_; - } +struct alignas(64) FreeBlockList { + std::mutex mutex_; + std::deque list_; }; +namespace { + // Max cached block sizes: (1 << MAX_SIZE_INDEX) bytes + constexpr size_t MAX_SIZE_INDEX = 64; +} + /** * Note [HostAllocator design] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -81,8 +78,7 @@ struct ComparatorSize { * to abstract the caching mechanism. Any backend needs to provide a customized * implementation by specializing its own public functions and the related * runtime functions. Its template parameter S represents runtime Stream, E - * denotes runtime Event, B indicates the fundamental memory block, and C - * signifies the sorting compartor algorithm for the memory blocks. + * denotes runtime Event, B indicates the fundamental memory block. * * For the interface, we provide a CachingHostAllocatorInterface struct as an * interface. Any backend needs to derive its own host allocator from this @@ -111,8 +107,7 @@ struct ComparatorSize { template < typename S, typename E, - typename B = HostBlock, - typename C = ComparatorSize> + typename B = HostBlock> struct CachingHostAllocatorImpl { virtual ~CachingHostAllocatorImpl() = default; @@ -132,6 +127,7 @@ struct CachingHostAllocatorImpl { } // Round up the allocation to the nearest power of two to improve reuse. + // These power of two sizes are also used to index into the free list. size_t roundSize = c10::llvm::PowerOf2Ceil(size); void* ptr = nullptr; allocate_host_memory(roundSize, &ptr); @@ -171,8 +167,9 @@ struct CachingHostAllocatorImpl { } if (!events) { - std::lock_guard g(free_list_mutex_); - free_list_.insert(block); + auto index = size_index(block->size_); + std::lock_guard g(free_list_[index].mutex_); + free_list_[index].list_.push_back(block); } else { // restore these events that record by used streams. std::lock_guard g(events_mutex_); @@ -218,22 +215,28 @@ struct CachingHostAllocatorImpl { // Remove all elements from the free list, remove them from the blocks // list, and free the associated pinned memory allocation. This requires - // concurrently holding both the free list mutex and the blocks mutex, and + // concurrently holding both the free list mutexes and the blocks mutex, and // is the only function that concurrently holds multiple mutexes. - std::lock(free_list_mutex_, blocks_mutex_); - std::lock_guard gf(free_list_mutex_, std::adopt_lock); - std::lock_guard gb(blocks_mutex_, std::adopt_lock); - - std::vector blocks_to_remove(free_list_.begin(), free_list_.end()); - free_list_.clear(); - for (auto* block : blocks_to_remove) { - blocks_.erase(block); - ptr_to_block_.erase(block->ptr_); - free_block(block); - delete block; + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + std::vector blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end()); + free_list_[i].list_.clear(); + for (auto* block : blocks_to_remove) { + blocks_.erase(block); + ptr_to_block_.erase(block->ptr_); + free_block(block); + delete block; + } } } + inline size_t size_index(size_t size) { + return c10::llvm::Log2_64_Ceil(size); + } + virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const { TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data"); } @@ -246,13 +249,12 @@ struct CachingHostAllocatorImpl { } virtual B* get_free_block(size_t size) { - std::lock_guard g(free_list_mutex_); - B key(size); - auto it = free_list_.lower_bound(&key); - if (it != free_list_.end()) { - B* block = *it; + auto index = size_index(size); + std::lock_guard g(free_list_[index].mutex_); + if (free_list_[index].list_.size() > 0) { + B* block = free_list_[index].list_.back(); + free_list_[index].list_.pop_back(); block->allocated_ = true; - free_list_.erase(it); return block; } return nullptr; @@ -304,8 +306,9 @@ struct CachingHostAllocatorImpl { } if (available) { - std::lock_guard g(free_list_mutex_); - free_list_.insert(block); + auto index = size_index(block->size_); + std::lock_guard g(free_list_[index].mutex_); + free_list_[index].list_.push_back(block); } } } @@ -337,12 +340,11 @@ struct CachingHostAllocatorImpl { ska::flat_hash_set blocks_; // block list ska::flat_hash_map ptr_to_block_; - // Note: sharding this mutex seems to be profitable in heavily multi-threaded - // scenarios. - alignas(64) std::mutex free_list_mutex_; - // Note: an alternative datastructure can yield significant wins here in - // microbenchmarks. - std::set free_list_; // free list + // We keep free list as a vector of free lists, one for each power of two + // size. This allows us to quickly find a free block of the right size. + // We use deque to store per size free list and guard the list with its own + // mutex. + alignas(64) std::vector> free_list_ = std::vector>(MAX_SIZE_INDEX); alignas(64) std::mutex events_mutex_; std::deque> events_; // event queue paired with block From 1e3bba17357399d8cfe511dc0114e38bc8635f2f Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 27 Aug 2024 21:35:35 -0700 Subject: [PATCH 0157/1018] add back sum_floordiv benchmark. (#134635) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134635 Approved by: https://github.com/avikchaudhuri, https://github.com/oulgen ghstack dependencies: #133834 --- .../benchmarks/sum_floordiv.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py new file mode 100644 index 00000000000000..16d2f8ed4eec16 --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py @@ -0,0 +1,39 @@ +import sys + +from benchmark_base import BenchmarkBase + +import torch + + +class Benchmark(BenchmarkBase): + N = 100 + + def name(self): + return "sum_floordiv_regression" + + def description(self): + return "information at https://github.com/pytorch/pytorch/issues/134133" + + def _prepare_once(self): + class M(torch.nn.Module): + def forward(self, x): + total = sum(t.item() for t in x) + return total // 2 + + self.m = M() + self.input = [torch.tensor(i + 2) for i in range(self.N)] + + def _prepare(self): + torch._dynamo.reset() + + def _work(self): + torch.export.export(self.m, (self.input,)) + + +def main(): + result_path = sys.argv[1] + Benchmark().enable_instruction_count().collect_all().append_results(result_path) + + +if __name__ == "__main__": + main() From 668bfba64e9da7de8177077d58c70cb963bac5e6 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 28 Aug 2024 17:42:19 +0000 Subject: [PATCH 0158/1018] [AOTI] Fix test_aoti_inference CPU build issue (#134675) Summary: Fixes https://github.com/pytorch/pytorch/issues/130311. We need to guard CUDA-only code in test_aoti_inference with macros so that it won't fail for CPU-only platform. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134675 Approved by: https://github.com/atalman, https://github.com/chunyuan-w --- .ci/pytorch/test.sh | 2 -- test/cpp/aoti_inference/aoti_custom_class.cpp | 10 +++--- test/cpp/aoti_inference/compile_model.py | 2 +- test/cpp/aoti_inference/test.cpp | 34 ++++++++++++------- test/cpp/aoti_inference/test.py | 5 ++- 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index dd2ed255203da6..169cb8ce5f8943 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1480,8 +1480,6 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then test_inductor_shard "${SHARD_NUMBER}" if [[ "${SHARD_NUMBER}" == 1 ]]; then if [[ "${BUILD_ENVIRONMENT}" != linux-jammy-py3.8-gcc11-build ]]; then - # Temporarily skip test_inductor_aoti due to https://github.com/pytorch/pytorch/issues/130311 - test_inductor_aoti test_inductor_distributed fi fi diff --git a/test/cpp/aoti_inference/aoti_custom_class.cpp b/test/cpp/aoti_inference/aoti_custom_class.cpp index 5f3ea58620dd22..d3960d9f5074eb 100644 --- a/test/cpp/aoti_inference/aoti_custom_class.cpp +++ b/test/cpp/aoti_inference/aoti_custom_class.cpp @@ -29,12 +29,14 @@ MyAOTIClass::MyAOTIClass( const std::string& model_path, const std::string& device) : lib_path_(model_path), device_(device) { - if (device_ == "cuda") { - runner_ = std::make_unique( - model_path.c_str()); - } else if (device_ == "cpu") { + if (device_ == "cpu") { runner_ = std::make_unique( model_path.c_str()); +#if defined(USE_CUDA) || defined(USE_ROCM) + } else if (device_ == "cuda") { + runner_ = std::make_unique( + model_path.c_str()); +#endif } else { throw std::runtime_error("invalid device: " + device); } diff --git a/test/cpp/aoti_inference/compile_model.py b/test/cpp/aoti_inference/compile_model.py index 4542a9c8f9f80a..0668f9b4b91250 100644 --- a/test/cpp/aoti_inference/compile_model.py +++ b/test/cpp/aoti_inference/compile_model.py @@ -86,7 +86,7 @@ def compile_model(device, data): def main(): data = {} - for device in ["cuda", "cpu"]: + for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: compile_model(device, data) torch.jit.script(TensorSerializer(data)).save("script_data.pt") diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index a74109f1a9b9a9..8a97959172c97c 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -35,12 +35,14 @@ void test_aoti(const std::string& device, bool use_runtime_constant_folding) { data_loader.attr(outputs_attr.c_str()).toTensorList().vec(); std::unique_ptr runner; - if (device == "cuda") { - runner = std::make_unique( - model_so_path); - } else if (device == "cpu") { + if (device == "cpu") { runner = std::make_unique( model_so_path); +#if defined(USE_CUDA) || defined(USE_ROCM) + } else if (device == "cuda") { + runner = std::make_unique( + model_so_path); +#endif } else { testing::AssertionFailure() << "unsupported device: " << device; } @@ -111,12 +113,14 @@ void test_aoti_constants_update( real_map.emplace("L__self___w_add", new at::Tensor(add_tensors)); std::unique_ptr runner; - if (device == "cuda") { - runner = std::make_unique( - model_so_path); - } else if (device == "cpu") { + if (device == "cpu") { runner = std::make_unique( model_so_path); +#if defined(USE_CUDA) || defined(USE_ROCM) + } else if (device == "cuda") { + runner = std::make_unique( + model_so_path); +#endif } else { testing::AssertionFailure() << "unsupported device: " << device; } @@ -197,12 +201,14 @@ void test_aoti_double_buffering( real_map.emplace("L__self___w_add", new at::Tensor(add_tensors)); std::unique_ptr runner; - if (device == "cuda") { - runner = std::make_unique( - model_so_path.c_str()); - } else if (device == "cpu") { + if (device == "cpu") { runner = std::make_unique( - model_so_path.c_str()); + model_so_path); +#if defined(USE_CUDA) || defined(USE_ROCM) + } else if (device == "cuda") { + runner = std::make_unique( + model_so_path); +#endif } else { testing::AssertionFailure() << "unsupported device: " << device; } @@ -241,6 +247,7 @@ void test_aoti_double_buffering( ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); } +#if defined(USE_CUDA) || defined(USE_ROCM) void test_aoti_double_buffering_with_tensor_constants() { torch::NoGradGuard no_grad; @@ -279,6 +286,7 @@ void test_aoti_double_buffering_with_tensor_constants() { actual_output_tensors = runner->run(input_tensors); ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); } +#endif } // namespace diff --git a/test/cpp/aoti_inference/test.py b/test/cpp/aoti_inference/test.py index ea3f6f042d3c17..f5e730158ccc00 100644 --- a/test/cpp/aoti_inference/test.py +++ b/test/cpp/aoti_inference/test.py @@ -35,7 +35,7 @@ def forward(self, x, y): # Basice AOTI model test generation. def generate_basic_tests(): - for device in ["cpu", "cuda"]: + for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: for use_runtime_constant_folding in [True, False]: if device == "cpu" and use_runtime_constant_folding: # We do not test runtime const folding for cpu mode. @@ -74,6 +74,9 @@ def generate_basic_tests(): # AOTI model which will create additional tensors during autograd. def generate_test_with_additional_tensors(): + if not torch.cuda.is_available(): + return + model = NetWithTensorConstants() x = torch.randn((30, 1), device="cuda") y = torch.randn((30, 1), device="cuda") From 461a195db0bded5e3ef925fff66fd2655bbb9f61 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 26 Aug 2024 21:59:11 +0000 Subject: [PATCH 0159/1018] [reland][dtensor][MTPG] make sharding prop lru cache not shared among threads (#134509) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary** reland of https://github.com/pytorch/pytorch/pull/134294 Fixes #131446 Fixes #126852 Fixes #126868 Fixes #126493 The PR was reverted due to CI red signal in https://github.com/pytorch/pytorch/actions/runs/10537099590/job/29201744658. It seems that the `gaussian_nll_loss` test had been flaky before my original PR #134294 . Therefore this PR also removes the `xfail` mark on this specific test to make CI signal green. See the error message below: ``` 2024-08-24T13:42:01.3228990Z ==================================== RERUNS ==================================== 2024-08-24T13:42:01.3229530Z _ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _ 2024-08-24T13:42:01.3229710Z Unexpected success 2024-08-24T13:42:01.3230235Z _ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _ 2024-08-24T13:42:01.3230407Z Unexpected success 2024-08-24T13:42:01.3230594Z =================================== FAILURES =================================== 2024-08-24T13:42:01.3231128Z _ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _ 2024-08-24T13:42:01.3231296Z Unexpected success ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134509 Approved by: https://github.com/tianyu-l, https://github.com/wz337 --- test/distributed/_tensor/test_dtensor_ops.py | 2 -- torch/distributed/_tensor/_sharding_prop.py | 20 ++++++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index d13ec1724983d6..480207e17d29ca 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -307,7 +307,6 @@ def wrapped(fn): xfail("nn.functional.elu"), xfail("nn.functional.fractional_max_pool2d"), xfail("nn.functional.fractional_max_pool3d"), - xfail("nn.functional.gaussian_nll_loss"), xfail("nn.functional.glu"), xfail("nn.functional.grid_sample"), xfail("nn.functional.group_norm"), @@ -353,7 +352,6 @@ def wrapped(fn): xfail("nn.functional.pdist"), xfail("nn.functional.pixel_shuffle"), xfail("nn.functional.pixel_unshuffle"), - xfail("nn.functional.poisson_nll_loss"), xfail("nn.functional.prelu"), xfail("nn.functional.relu6"), xfail("nn.functional.rrelu"), diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index aefb4e41c94479..c8c8ddb9817c1c 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import threading from functools import lru_cache from itertools import chain from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union @@ -37,6 +38,17 @@ def _length(obj) -> int: return len(obj) +class LocalLRUCache(threading.local): + def __init__(self, user_function: Callable) -> None: + self.cache = lru_cache(None)(user_function) + + def __call__(self, *args, **kwargs) -> object: + return self.cache(*args, **kwargs) + + def cache_info(self): + return self.cache.cache_info() + + class ShardingPropagator: def __init__(self) -> None: self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} @@ -46,7 +58,9 @@ def __init__(self) -> None: ] = {} # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} - self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign] + self.propagate_op_sharding = LocalLRUCache( + self.propagate_op_sharding_non_cached + ) # op map to save indices of shape (and stride) args which may need to be modified in sharding prop self.op_to_shape_and_stride_idx: Dict[ OpOverload, Union[int, Tuple[int, int]] @@ -183,7 +197,9 @@ def propagate(self, op_info: OpInfo) -> None: if op_info.schema.has_symints: output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: - output_sharding = self.propagate_op_sharding(op_info.schema) + output_sharding = cast( + OutputSharding, self.propagate_op_sharding(op_info.schema) + ) op_info.output_sharding = output_sharding def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: From 8e83322ee01e03ae00f398aec895174db882e0cb Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 26 Aug 2024 17:42:49 -0700 Subject: [PATCH 0160/1018] [compiled autograd][cpp node] point c++ custom autograd functions tracing error to google doc (#134514) `RuntimeError: Attempting to trace a potentially unsafe C++ autograd function: torch::autograd::CppNode. It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/.` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134514 Approved by: https://github.com/yf225 --- test/inductor/test_compiled_autograd.py | 48 +++++++++++++++++++++++++ torch/csrc/autograd/custom_function.h | 5 +-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index b31c5bea6a2435..4f0bb76d5cea7c 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1411,6 +1411,54 @@ def _compiler_fn(gm): f, compiler_fn=compiler_fn_with_op_check, compile_fn=False ) + def test_non_traceable_autograd_cpp_node(self): + cpp_source = """ +struct CustomOpAutogradFunction : public torch::autograd::Function { + static constexpr bool is_traceable = false; + + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& x) { + return x; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext *ctx, + torch::autograd::variable_list grad_output) { + return grad_output; + } +}; + +torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { + return CustomOpAutogradFunction::apply(x); +} + +TORCH_LIBRARY(test_non_traceable_autograd_cpp_node, m) { + m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); +} + """ + + module = torch.utils.cpp_extension.load_inline( + name="test_non_traceable_autograd_cpp_node", + cpp_sources=cpp_source, + functions="custom_op_backed_by_autograd_fn", + verbose=True, + ) + + def fn(): + x = torch.ones(10, 10, requires_grad=True) + out = torch.ops.test_non_traceable_autograd_cpp_node.custom_op_backed_by_autograd_fn( + x + ) + loss = out.sum() + loss.backward() + + with self.assertRaisesRegex( + RuntimeError, + "https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/", + ), compiled_autograd.enable(compiler_fn): + fn() + def test_autograd_cpp_node(self): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 2f53924fd9111d..a79ee3d1608777 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -194,8 +194,9 @@ struct CppNode : public Node { if (!T::is_traceable) { throw std::runtime_error( std::string( - "compiled_args not implemented for non-traceable node: ") + - name()); + "Attempting to trace a potentially unsafe C++ autograd function: ") + + name() + + ". It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/."); } // although neither of the 2 methods below have uniqueness guarantees From 15999c83220d7d6589d7b1df061e2c78644e4ca8 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 27 Aug 2024 00:35:02 -0700 Subject: [PATCH 0161/1018] [compiled autograd] error instead of deadlock on reentrant autograd (#134530) reentrant calls autograd multiple times using the same thread, so it passes all the thread checks and hangs waiting for the lock it holds in another scope Pull Request resolved: https://github.com/pytorch/pytorch/pull/134530 Approved by: https://github.com/jansel ghstack dependencies: #134514 --- test/inductor/test_compiled_autograd.py | 14 ++++++++++++ .../csrc/dynamo/python_compiled_autograd.cpp | 22 +++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 4f0bb76d5cea7c..1a30d141bbf344 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2518,6 +2518,20 @@ def tensor_hook(_): self.assertEqual(pack_count, 1) self.assertEqual(unpack_count, 1) + def test_reentrant_checkpointing(self): + def fn(x): + y = x.sin() + z = y.cos() + return (y * z).sum() + + inp = torch.rand(10, 10, requires_grad=True) + out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True) + with self.assertRaisesRegex( + RuntimeError, + r"\(e.g. reentrant checkpointing\), this is not supported yet\.", + ), torch._dynamo.compiled_autograd.enable(torch.compile): + out.backward() + def load_test_module(name): testdir = Path(__file__).absolute().parent.parent diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index ba008144d68c5d..da8b2888f25281 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -685,6 +685,24 @@ CacheNode* _compiled_autograd_impl( return cache; } +struct LockGuardWithErrorLogs { + LockGuardWithErrorLogs(std::mutex& mtx) : mtx_(mtx) { + // Note: the standard allows try_lock to fail spuriously during races for + // performance reasons, but it shouldn't happen here since we: + // 1. disable multithreaded autograd + // 2. plenty of latency between backward calls + TORCH_INTERNAL_ASSERT( + mtx_.try_lock(), + "Trying to run compiled autograd within another compiled autograd call (e.g. reentrant checkpointing), this is not supported yet."); + } + + ~LockGuardWithErrorLogs() { + mtx_.unlock(); + } + + std::mutex& mtx_; +}; + variable_list compiled_autograd( const std::shared_ptr& graph_root, GraphTask& graph_task, @@ -693,8 +711,8 @@ variable_list compiled_autograd( TORCH_CHECK( c10::impl::TorchDispatchModeTLS::stack_len() == 0, "TorchDispatchMode not yet implemented for compiled autograd") - static std::mutex lock; - std::lock_guard lock_guard(lock); + static std::mutex mtx; + LockGuardWithErrorLogs lock_guard(mtx); pybind11::gil_scoped_acquire gil; at::ThreadLocalStateGuard tls_guard(graph_task.thread_locals_); From ec1af9e47006628c3fd9d47b9b71f421ff821e65 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 28 Aug 2024 05:59:20 -0700 Subject: [PATCH 0162/1018] [FSDP] Made `clip_grad_norm_` norm compute order deterministic (#134673) Fixes https://github.com/pytorch/pytorch/issues/134393 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134673 Approved by: https://github.com/weifengpy ghstack dependencies: #134152 --- .../fsdp/fully_sharded_data_parallel.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index bd8049b50ebf24..27f375da6f0669 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -1130,28 +1130,40 @@ def clip_grad_norm_( # where sharded and non-sharded parameters must be handled separately max_norm = float(max_norm) norm_type = float(norm_type) - sharded_params = set() - nonsharded_params = set() # `NO_SHARD` or not FSDP-managed + sharded_params_set = set() + nonsharded_params_set = set() # `NO_SHARD` or not FSDP-managed + # Make sure to compute the local norm using lists for deterministic + # iteration order and hence deterministic total norm computation + sharded_params = [] + nonsharded_params = [] grads: List[torch.Tensor] = [] for handle in self._all_handles: - target_set = ( - sharded_params if handle.uses_sharded_strategy else nonsharded_params - ) + if handle.uses_sharded_strategy: + target_set = sharded_params_set + target_list = sharded_params + else: + target_set = nonsharded_params_set + target_list = nonsharded_params if handle._use_orig_params: for param in handle.flat_param._params: - target_set.add(param) - if param.grad is not None: - grads.append(param.grad) + if param not in target_set: + target_set.add(param) + target_list.append(param) + if param.grad is not None: + grads.append(param.grad) else: - target_set.add(handle.flat_param) - if handle.flat_param.grad is not None: - grads.append(handle.flat_param.grad) + if handle.flat_param not in target_set: + target_set.add(handle.flat_param) + target_list.append(handle.flat_param) + if handle.flat_param.grad is not None: + grads.append(handle.flat_param.grad) for param in self.parameters(): not_fsdp_managed = ( - param not in sharded_params and param not in nonsharded_params + param not in sharded_params_set and param not in nonsharded_params_set ) if not_fsdp_managed: - nonsharded_params.add(param) + nonsharded_params_set.add(param) + nonsharded_params.append(param) if param.grad is not None: grads.append(param.grad) # Compute local norms (forced to be in FP32) From 36871d5a6209c9fd857239a6027c724a0595613a Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 27 Aug 2024 23:12:36 -0700 Subject: [PATCH 0163/1018] [dynamo][freezing] Set is_static_type to false after marking an input static (#134653) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134653 Approved by: https://github.com/mlazos --- torch/_dynamo/variables/builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 69f7d4974057fc..b4f35dbbd404b8 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1415,7 +1415,8 @@ def wrap_tensor(self, value: torch.Tensor): or (source and source.guard_source().is_unspecialized_nn_module()) ) ): - self.mark_static_input(value, guard=False) + self.mark_static_input(value, guard=is_parameter_freezing()) + is_static_input = True make_graph_attribute = is_static_input and ( not config.inline_inbuilt_nn_modules or is_parameter_freezing() From 2791f4a538868e5596ddfa7a5d43751307788fdf Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 28 Aug 2024 19:24:02 +0000 Subject: [PATCH 0164/1018] Update PyTorch for XNNPACK 87ee0b4 (#134518) Summary: Update XNNPACK library version. Test Plan: Combined diff CI is clean: D61586079 (all changes, has to be split out for export). Differential Revision: D61822610 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134518 Approved by: https://github.com/mcr229 --- cmake/Dependencies.cmake | 5 + third_party/XNNPACK | 2 +- third_party/generate-xnnpack-wrappers.py | 17 +- third_party/xnnpack.buck.bzl | 56 +-- third_party/xnnpack_src_defs.bzl | 413 +++++++++++------------ third_party/xnnpack_wrapper_defs.bzl | 14 - 6 files changed, 237 insertions(+), 270 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 21ff5f197ea836..71bfcc40da9a4d 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -544,6 +544,11 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK) # Disable I8MM For CI since clang 9 does not support neon i8mm. set(XNNPACK_ENABLE_ARM_I8MM OFF CACHE BOOL "") + # Older MSVC versions don't support AVX512FP. TODO Minimum version support? + IF(CMAKE_C_COMPILER_ID STREQUAL "MSVC") + set(XNNPACK_ENABLE_AVX512FP16 OFF CACHE BOOL "") + ENDIF() + # Conditionally disable AVX512AMX, as it requires Clang 11 or later. Note that # XNNPACK does conditionally compile this based on GCC version. Once it also does # so based on Clang version, this logic can be removed. diff --git a/third_party/XNNPACK b/third_party/XNNPACK index fcbf55af6cf28a..87ee0b46b834f6 160000 --- a/third_party/XNNPACK +++ b/third_party/XNNPACK @@ -1 +1 @@ -Subproject commit fcbf55af6cf28a4627bcd1f703ab7ad843f0f3a2 +Subproject commit 87ee0b46b834f67bad9025d4a82ed5654f3403d3 diff --git a/third_party/generate-xnnpack-wrappers.py b/third_party/generate-xnnpack-wrappers.py index 34171f676506cd..e9b23e4a784d5e 100755 --- a/third_party/generate-xnnpack-wrappers.py +++ b/third_party/generate-xnnpack-wrappers.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from __future__ import print_function +from pathlib import Path import collections import os import sys @@ -99,12 +100,24 @@ def handle_singleline_parse(line): return key_val[0], [x[4:] for x in key_val[1:]] def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"): + print(f"Updating sources from {cmakefile}") sources = collections.defaultdict(list) with open(os.path.join(xnnpack_path, cmakefile)) as cmake: lines = cmake.readlines() i = 0 while i < len(lines): line = lines[i] + + if lines[i].startswith("INCLUDE"): + file, _ = handle_singleline_parse(line) + if file.startswith("cmake/gen/"): + path = Path(xnnpack_path) / "XNNPACK" / file + local_sources = update_sources(xnnpack_path, path.absolute().as_posix()) + for k,v in local_sources.items(): + if k in sources: + sources[k] = sources[k] + local_sources[k] + else: + sources[k] = local_sources[k] if lines[i].startswith("SET") and "src/" in lines[i]: name, val = handle_singleline_parse(line) @@ -132,7 +145,7 @@ def gen_wrappers(xnnpack_path): xnnpack_sources = collections.defaultdict(list) sources = update_sources(xnnpack_path) - microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake") + microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/gen/microkernels.cmake") for key in microkernels_sources: sources[key] = microkernels_sources[key] @@ -186,6 +199,8 @@ def gen_wrappers(xnnpack_path): def main(argv): + print("Generating wrappers...") + if argv is None or len(argv) == 0: gen_wrappers(".") else: diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index cb351261d4039c..144dc8513ec622 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -4,7 +4,6 @@ load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") load( ":xnnpack_src_defs.bzl", - "JIT_SRCS", "LOGGING_SRCS", "OPERATOR_SRCS", "SUBGRAPH_SRCS", @@ -108,7 +107,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", - "-DXNN_ENABLE_JIT=0", "-DXNN_ENABLE_SPARSE=0", "-DXNN_ENABLE_MEMOPT", ], @@ -154,37 +152,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ) - fb_xplat_cxx_library( - name = "jit_memory", - # srcs have to include HOT_SRCS to be able to build on ARVR - srcs = JIT_SRCS, - headers = subdir_glob([ - ("XNNPACK/src", "**/*.h"), - ]), - header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), - compiler_flags = [ - "-Oz", - ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], - labels = labels, - platforms = (APPLE, ANDROID, CXX, WINDOWS), - preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], - visibility = ["PUBLIC"], - windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, - windows_compiler_flags_override = WINDOWS_FLAGS, - deps = [ - ":interface", - third_party("clog"), - ], - ) - fb_xplat_cxx_library( name = "ukernels_scalar", srcs = PROD_SCALAR_MICROKERNEL_SRCS, @@ -792,6 +759,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", ], + exported_preprocessor_flags = [ + "-DXNN_ENABLE_AVX512VNNI" + ], visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -833,6 +803,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", ], + exported_preprocessor_flags = [ + "-DXNN_ENABLE_AVX512VNNI" + ], visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -1328,6 +1301,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", ], windows_compiler_flags_override = WINDOWS_FLAGS + [ + "/D__AVX2__", "-mavx2", "-mfma", "-mf16c", @@ -1576,6 +1550,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512bw", "-mavx512dq", "-mavx512vl", + ], deps = [ ":interface", @@ -1633,6 +1608,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512bw", "-mavx512dq", "-mavx512vl", + "/D__AVX512BW__", ], windows_srcs = PROD_AVX512SKX_MICROKERNEL_SRCS, deps = [ @@ -2463,7 +2439,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F windows_compiler_flags_override = WINDOWS_FLAGS, deps = [ ":interface", - ":jit_memory", third_party("FP16"), ], ) @@ -2507,7 +2482,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F windows_compiler_flags_override = WINDOWS_FLAGS, deps = [ ":interface", - ":jit_memory", third_party("FP16"), ], ) @@ -2519,7 +2493,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ - ":jit_memory", ":ukernels_asm_aarch64", ":ukernels_neon", ":ukernels_neon_aarch64", @@ -2581,10 +2554,13 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ":ukernels_ssse3_ovr_win32", ":ukernels_xop_ovr_win32", ":ukernels_avx512vbmi", - ":ukernels_avx512vnni_ovr_win32", - ":ukernels_avx512vnnigfni_ovr_win32", + # ":ukernels_avx512vnni_ovr_win32", # Build crashes on Windows Clang 17.0.3, re-enable when fixed (T199959765) + # ":ukernels_avx512vnnigfni_ovr_win32", # ":ukernels_avxvnni_ovr_win32" Excluding avxvnni microkernels because they fail on older compilers ], + exported_preprocessor_flags = [ + "-DXNN_ENABLE_AVX512VNNIGFNI=0" + ] ) fb_xplat_cxx_library( @@ -2594,7 +2570,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ - ":jit_memory", ":ukernels_armsimd32", ":ukernels_asm_aarch32", ":ukernels_asm_aarch64", @@ -2622,7 +2597,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ - ":jit_memory", ":ukernels_asm_aarch32", ":ukernels_neon", ":ukernels_neon_dot", @@ -2690,7 +2664,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-DXNN_NO_X8_OPERATORS", "-DXNN_ENABLE_MEMOPT", "-DXNN_ENABLE_SPARSE=0", - "-DXNN_ENABLE_JIT=0", "-DXNN_ENABLE_ASSEMBLY", "-DXNN_ENABLE_GEMM_M_SPECIALIZATION", "-DXNN_ENABLE_ARM_DOTPROD", @@ -2712,7 +2685,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "XNNPACK/src/memory.c", "XNNPACK/src/mutex.c", "XNNPACK/src/microparams-init.c", - "XNNPACK/src/operators/post-operation.c", ], visibility = ["PUBLIC"], windows_clang_compiler_flags_override = (WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS) if XNNPACK_WINDOWS_AVX512F_ENABLED else WINDOWS_FLAGS, diff --git a/third_party/xnnpack_src_defs.bzl b/third_party/xnnpack_src_defs.bzl index 296dacb58ec4c1..e9b3b6f9a9cef8 100644 --- a/third_party/xnnpack_src_defs.bzl +++ b/third_party/xnnpack_src_defs.bzl @@ -2,16 +2,12 @@ Auto-generated by generate-wrappers.py script. Do not modify """ -PROD_SCALAR_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/scalar.c", -] - -PROD_AVX512VNNI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512vnni.c", +PROD_ARMSIMD32_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/armsimd32.c", ] -PROD_AVX512F_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512f.c", +PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfp16arith-aarch64.c", ] AARCH64_ASM_MICROKERNEL_SRCS = [ @@ -240,152 +236,22 @@ AARCH64_ASM_MICROKERNEL_SRCS = [ "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", ] -PROD_NEONV8_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonv8.c", -] - -PROD_AVX_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx.c", -] - -LOGGING_SRCS = [ - "XNNPACK/src/enums/datatype-strings.c", - "XNNPACK/src/enums/microkernel-type.c", - "XNNPACK/src/enums/node-type.c", - "XNNPACK/src/enums/operator-type.c", - "XNNPACK/src/log.c", -] - -PROD_NEONI8MM_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neoni8mm.c", -] - -AARCH32_ASM_MICROKERNEL_SRCS = [ - "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x1.S", - "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x2.S", - "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x4.S", - "XNNPACK/src/cs16-fftr/cs16-fftr-asm-aarch32-neon-x1.S", - "XNNPACK/src/cs16-fftr/cs16-fftr-asm-aarch32-neon-x4.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch32-neon-cortex-a53.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x4-asm-aarch32-vfp-ld64.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x4-minmax-asm-aarch32-vfp-ld64.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a7.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a53.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a55.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a75-prfm.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-ld64.S", - "XNNPACK/src/f32-igemm/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a55.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-1x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-1x8-minmax-asm-aarch32-neon-cortex-a53.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a7.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a53.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75-prfm.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-ld64.S", - "XNNPACK/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", - "XNNPACK/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", - "XNNPACK/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p8c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p16c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-asm-aarch32-neondot-ld64.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8c4-minmax-fp32-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8c4-minmax-fp32-asm-aarch32-neondot-ld64.S", - "XNNPACK/src/qs16-qs8-vcvt/qs16-qs8-vcvt-asm-aarch32-neon-u16.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-arm-x1.S", - "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x1.S", - "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x2.S", -] - -PROD_F16C_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/f16c.c", -] - -PROD_XOP_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/xop.c", -] - -PROD_RVV_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/rvv.c", +PROD_AVXVNNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avxvnni.c", ] SUBGRAPH_SRCS = [ @@ -404,25 +270,30 @@ SUBGRAPH_SRCS = [ "XNNPACK/src/subgraph/convert.c", "XNNPACK/src/subgraph/convolution-2d.c", "XNNPACK/src/subgraph/copy.c", + "XNNPACK/src/subgraph/copysign.c", "XNNPACK/src/subgraph/deconvolution-2d.c", "XNNPACK/src/subgraph/depth-to-space-2d.c", "XNNPACK/src/subgraph/depthwise-convolution-2d.c", "XNNPACK/src/subgraph/divide.c", "XNNPACK/src/subgraph/elu.c", "XNNPACK/src/subgraph/even-split.c", + "XNNPACK/src/subgraph/exp.c", "XNNPACK/src/subgraph/floor.c", "XNNPACK/src/subgraph/fully-connected-sparse.c", "XNNPACK/src/subgraph/fully-connected.c", + "XNNPACK/src/subgraph/gelu.c", "XNNPACK/src/subgraph/global-average-pooling.c", "XNNPACK/src/subgraph/global-sum-pooling.c", "XNNPACK/src/subgraph/hardswish.c", "XNNPACK/src/subgraph/leaky-relu.c", + "XNNPACK/src/subgraph/log.c", "XNNPACK/src/subgraph/max-pooling-2d.c", "XNNPACK/src/subgraph/maximum2.c", "XNNPACK/src/subgraph/minimum2.c", "XNNPACK/src/subgraph/multiply2.c", "XNNPACK/src/subgraph/negate.c", "XNNPACK/src/subgraph/prelu.c", + "XNNPACK/src/subgraph/reciprocal-square-root.c", "XNNPACK/src/subgraph/reshape-helpers.c", "XNNPACK/src/subgraph/scaled-dot-product-attention.c", "XNNPACK/src/subgraph/sigmoid.c", @@ -444,26 +315,40 @@ SUBGRAPH_SRCS = [ "XNNPACK/src/tensor.c", ] -PROD_FMA3_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fma3.c", +PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vnnigfni.c", ] -PROD_AVX512SKX_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512skx.c", +PROD_AVX512VNNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vnni.c", ] -JIT_SRCS = [ - "XNNPACK/src/jit/aarch32-assembler.cc", - "XNNPACK/src/jit/aarch64-assembler.cc", - "XNNPACK/src/jit/assembler.cc", +PROD_SSE2_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse2.c", ] -PROD_NEONFP16_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfp16.c", +PROD_NEONDOT_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neondot.c", ] -PROD_SSSE3_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/ssse3.c", +PROD_SSE41_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse41.c", +] + +PROD_SSE_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse.c", +] + +PROD_NEONFP16ARITH_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfp16arith.c", +] + +PROD_NEONV8_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonv8.c", +] + +PROD_NEONFP16_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfp16.c", ] XNNPACK_SRCS = [ @@ -500,69 +385,23 @@ XNNPACK_SRCS = [ "XNNPACK/src/params.c", ] -PROD_FP16ARITH_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fp16arith.c", -] - -TABLE_SRCS = [ - "XNNPACK/src/tables/exp2-k-over-64.c", - "XNNPACK/src/tables/exp2-k-over-2048.c", - "XNNPACK/src/tables/exp2minus-k-over-4.c", - "XNNPACK/src/tables/exp2minus-k-over-8.c", - "XNNPACK/src/tables/exp2minus-k-over-16.c", - "XNNPACK/src/tables/exp2minus-k-over-32.c", - "XNNPACK/src/tables/exp2minus-k-over-64.c", - "XNNPACK/src/tables/exp2minus-k-over-2048.c", - "XNNPACK/src/tables/vlog.c", -] - -PROD_NEON_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neon.c", -] - -PROD_AVXVNNI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avxvnni.c", -] - -PROD_NEONFP16ARITH_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfp16arith.c", -] - -PROD_SSE_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse.c", +PROD_AVX_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx.c", ] -PROD_NEON_AARCH64_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neon-aarch64.c", - "XNNPACK/src/amalgam/gen/neonfma-aarch64.c", +PROD_AVX512SKX_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512skx.c", ] PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/neondotfp16-aarch64.c", ] -PROD_NEONFMA_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfma.c", +PROD_FP16ARITH_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/fp16arith.c", ] PROD_FMA_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fma.c", -] - -PROD_SSE2_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse2.c", -] - -PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512vnnigfni.c", -] - -PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfp16arith-aarch64.c", -] - -PROD_AVX2_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx2.c", ] OPERATOR_SRCS = [ @@ -595,26 +434,176 @@ OPERATOR_SRCS = [ "XNNPACK/src/operators/unpooling-nhwc.c", ] -PROD_AVX512VBMI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512vbmi.c", +PROD_NEONI8MM_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neoni8mm.c", ] -PROD_NEONDOT_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neondot.c", +PROD_AVX512F_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512f.c", +] + +JIT_SRCS = [ +] + +PROD_F16C_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/f16c.c", +] + +PROD_NEON_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neon.c", +] + +PROD_SCALAR_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/scalar.c", ] PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/neondot-aarch64.c", ] -PROD_SSE41_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse41.c", +PROD_FMA3_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/fma3.c", ] -PROD_ARMSIMD32_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/armsimd32.c", +LOGGING_SRCS = [ + "XNNPACK/src/enums/allocation-type.c", + "XNNPACK/src/enums/datatype-strings.c", + "XNNPACK/src/enums/microkernel-type.c", + "XNNPACK/src/enums/node-type.c", + "XNNPACK/src/enums/operator-type.c", + "XNNPACK/src/log.c", +] + +PROD_NEONFMA_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfma.c", +] + +PROD_AVX2_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx2.c", +] + +PROD_AVX512VBMI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vbmi.c", +] + +PROD_RVV_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/rvv.c", ] PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/neondotfp16arith.c", ] + +PROD_XOP_MICROKERNEL_SRCS = [ +] + +AARCH32_ASM_MICROKERNEL_SRCS = [ + "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x1.S", + "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x2.S", + "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x4.S", + "XNNPACK/src/cs16-fftr/cs16-fftr-asm-aarch32-neon-x1.S", + "XNNPACK/src/cs16-fftr/cs16-fftr-asm-aarch32-neon-x4.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch32-neon-cortex-a53.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x4-asm-aarch32-vfp-ld64.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x4-minmax-asm-aarch32-vfp-ld64.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a7.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a53.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a55.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a75-prfm.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-ld64.S", + "XNNPACK/src/f32-igemm/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a55.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-1x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-1x8-minmax-asm-aarch32-neon-cortex-a53.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a7.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a53.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75-prfm.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-ld64.S", + "XNNPACK/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", + "XNNPACK/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", + "XNNPACK/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", + "XNNPACK/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", + "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p8c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p16c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-asm-aarch32-neondot-cortex-a55.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-asm-aarch32-neondot-ld64.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8c4-minmax-fp32-asm-aarch32-neondot-cortex-a55.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8c4-minmax-fp32-asm-aarch32-neondot-ld64.S", + "XNNPACK/src/qs16-qs8-vcvt/qs16-qs8-vcvt-asm-aarch32-neon-u16.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", + "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-arm-x1.S", + "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x1.S", + "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x2.S", +] + +PROD_SSSE3_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/ssse3.c", +] + +TABLE_SRCS = [ + "XNNPACK/src/tables/exp2-k-over-64.c", + "XNNPACK/src/tables/exp2-k-over-2048.c", + "XNNPACK/src/tables/exp2minus-k-over-4.c", + "XNNPACK/src/tables/exp2minus-k-over-8.c", + "XNNPACK/src/tables/exp2minus-k-over-16.c", + "XNNPACK/src/tables/exp2minus-k-over-32.c", + "XNNPACK/src/tables/exp2minus-k-over-64.c", + "XNNPACK/src/tables/exp2minus-k-over-2048.c", + "XNNPACK/src/tables/vlog.c", +] + +PROD_NEON_AARCH64_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neon-aarch64.c", + "XNNPACK/src/amalgam/gen/neonfma-aarch64.c", +] diff --git a/third_party/xnnpack_wrapper_defs.bzl b/third_party/xnnpack_wrapper_defs.bzl index b92ebb88d74efc..b05cdcd5cdec15 100644 --- a/third_party/xnnpack_wrapper_defs.bzl +++ b/third_party/xnnpack_wrapper_defs.bzl @@ -7,7 +7,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ ] PROD_FMA_MICROKERNEL_SRCS = [ - "xnnpack_wrappers/amalgam/gen/fma.c", ] PROD_ARMSIMD32_MICROKERNEL_SRCS = [ @@ -92,7 +91,6 @@ PROD_F16C_MICROKERNEL_SRCS = [ ] PROD_XOP_MICROKERNEL_SRCS = [ - "xnnpack_wrappers/amalgam/gen/xop.c", ] PROD_FMA3_MICROKERNEL_SRCS = [ @@ -447,28 +445,16 @@ AARCH64_ASM_MICROKERNEL_SRCS = [ "xnnpack_wrappers/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", "xnnpack_wrappers/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", "xnnpack_wrappers/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", ] From 8881fb2b4c69816dd385e85fa14524025034021a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 28 Aug 2024 19:44:55 +0000 Subject: [PATCH 0165/1018] Revert "[3/N] Set correct device to CUDA guards (#134357)" This reverts commit 13114da4ef9d14978ea1dfc0fefb236cb4000435. Reverted https://github.com/pytorch/pytorch/pull/134357 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/134357#issuecomment-2316121423)) --- test/distributed/test_c10d_nccl.py | 19 ------------------- .../distributed/c10d/ProcessGroupNCCL.cpp | 18 ++++++++---------- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 3a575729cdff78..55b785a3552b5a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -350,7 +350,6 @@ def test_close_pg(self): @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16]) @skip_if_rocm def test_nan_assert(self, type): - # Expecting a device-side error when NaN is detected os.environ["TORCH_NCCL_NAN_CHECK"] = "1" store = c10d.FileStore(self.file_name, self.world_size) pg = self._create_process_group_nccl(store, self.opts()) @@ -389,24 +388,6 @@ def test_nan_p2p(self): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" - @requires_nccl() - @skip_if_lt_x_gpu(2) - def test_nan_check(self): - # Not expecting an error, NaN check should not make legit code fail - os.environ["TORCH_NCCL_NAN_CHECK"] = "1" - store = c10d.FileStore(self.file_name, self.world_size) - device = torch.device("cuda:%d" % self.rank) - c10d.init_process_group( - backend="nccl", store=store, rank=self.rank, world_size=self.world_size - ) - x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank - t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) - c10d.broadcast(x, src=0) - c10d.all_reduce(t) - c10d.destroy_process_group() - # reset env - os.environ["TORCH_NCCL_NAN_CHECK"] = "0" - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 8f741185a2c584..291eb7993ef2fb 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1148,10 +1148,6 @@ void ProcessGroupNCCL::abortCommsFromMap( at::cuda::OptionalCUDAGuard gpuGuard; at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); if (deviceIndex >= 0) { - // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in - // the map could be non deviceIndex, but rank to rank numbers. So we - // indeed need to check if deviceIndex >= 0 - // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` gpuGuard.set_index(deviceIndex); } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " @@ -2145,9 +2141,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( bool batchP2P = ncclActiveGroupCounter_ > 0; bool singleP2POp = isP2POp(opType, batchP2P); - // Get the device index - auto deviceIndex = device.index(); - at::cuda::OptionalCUDAGuard gpuGuard(device); + at::cuda::OptionalCUDAGuard gpuGuard; // [Group Start/End Note] This is used to ensure that nccl communicator will // be created before communication primitives are called. Let's look at this @@ -2187,6 +2181,10 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( rank = p2pRank; } + // Get the device index + auto deviceIndex = device.index(); + gpuGuard.set_index(deviceIndex); + #ifdef NCCL_HAS_COMM_SPLIT if (options_->split_from) { TORCH_CHECK( @@ -2696,7 +2694,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( work->stashed_for_allocator_safety_->push_back(input); } - at::cuda::OptionalCUDAGuard gpuGuard(device); + at::cuda::OptionalCUDAGuard gpuGuard; if (nanCheck) { checkForNan(input, ncclStream); @@ -2861,7 +2859,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard(device); + at::cuda::OptionalCUDAGuard gpuGuard; // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) @@ -3129,7 +3127,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard(device); + at::cuda::OptionalCUDAGuard gpuGuard; // Only check for NaN for send ops, for recv ops `tensor` can be a random // placeholder From 1d956e2711000a50418dc3cc8885196b3540be09 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 28 Aug 2024 19:51:59 +0000 Subject: [PATCH 0166/1018] Revert "[2/N] Add flag to control which rank should perform NaN check (#134345)" This reverts commit be7752ead3824e79f5ede6a2f59715b415a2f776. Reverted https://github.com/pytorch/pytorch/pull/134345 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/134345#issuecomment-2316133024)) --- test/distributed/test_c10d_nccl.py | 22 ----------- .../distributed/c10d/ProcessGroupNCCL.cpp | 39 ++++++------------- .../distributed/c10d/ProcessGroupNCCL.hpp | 6 +-- 3 files changed, 13 insertions(+), 54 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 55b785a3552b5a..64c51f791fddd0 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -366,28 +366,6 @@ def test_nan_assert(self, type): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" - @requires_nccl() - @skip_if_lt_x_gpu(2) - def test_nan_p2p(self): - # Putting NaN at recv buffer, program should not fail as NaN checker - # should not check on receive buffer - os.environ["TORCH_NCCL_NAN_CHECK"] = "1" - store = c10d.FileStore(self.file_name, self.world_size) - device = torch.device("cuda:%d" % self.rank) - c10d.init_process_group( - backend="nccl", store=store, rank=self.rank, world_size=self.world_size - ) - t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) - if self.rank == 0: - c10d.send(t, 1) - elif self.rank == 1: - # Putting NaN at recv buffer - t[1, 1] = float("nan") - c10d.recv(t, 0) - c10d.destroy_process_group() - # reset env - os.environ["TORCH_NCCL_NAN_CHECK"] = "0" - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 291eb7993ef2fb..abcf493e62fb8e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2637,12 +2637,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PostProcess post, OpType opType, const char* profilingTitle, - bool avoidRecordStreams, - bool nanCheck) { + bool avoidRecordStreams) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; - nanCheck &= enableNanCheck_; - c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2696,7 +2693,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; - if (nanCheck) { + if (enableNanCheck_) { checkForNan(input, ncclStream); } @@ -3129,9 +3126,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; - // Only check for NaN for send ops, for recv ops `tensor` can be a random - // placeholder - if (enableNanCheck_ && opType == OpType::SEND) { + if (enableNanCheck_) { checkForNan(tensor, ncclStream); } @@ -3228,8 +3223,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( Fn fn, OpType opType, const char* profilingTitle, - bool avoidRecordStreams, - bool nanCheck) { + bool avoidRecordStreams) { return collective( input, output, @@ -3240,8 +3234,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( c10::intrusive_ptr& work) {}, opType, profilingTitle, - avoidRecordStreams, - nanCheck); + avoidRecordStreams); } template @@ -3491,9 +3484,6 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( // avoidRecordStreams_ note: collective() will stash tensors. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - const auto root = opts.rootRank + opts.rootTensor; - bool nanCheck = (root == rank_); - return collective( tensor, tensor, @@ -3501,6 +3491,7 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; return ncclBcast( input.data_ptr(), input.numel(), @@ -3511,8 +3502,7 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( }, OpType::BROADCAST, "nccl:broadcast", - avoidRecordStreams, - nanCheck); + avoidRecordStreams); } // _broadcast_oop adds an out-of-place broadcast in PGNCCL @@ -3532,9 +3522,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( "Tensor input and output of _broadcast_oop must have the same number of elements "); } - const auto root = opts.rootRank + opts.rootTensor; - bool nanCheck = (root == rank_); - return collective( inputTensor, outputTensor, @@ -3542,6 +3529,7 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; return ncclBroadcast( input.data_ptr(), output.data_ptr(), @@ -3552,9 +3540,7 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, - "nccl:_broadcast_oop", - /*avoidRecordStreams=*/false, - nanCheck); + "nccl:_broadcast_oop"); } c10::intrusive_ptr ProcessGroupNCCL::reduce( @@ -4505,9 +4491,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // inputs, which == inputTensors[0] on the root rank where it matters. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - const auto root = opts.rootRank; - bool nanCheck = (rank_ == root); - return collective( outputTensor, inputs[0], // just to fit the collective interface @@ -4515,6 +4498,7 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank; if (getRank() == root) { if (!avoidRecordStreams) { for (auto input : inputs) { @@ -4528,8 +4512,7 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( }, OpType::SCATTER, "nccl:scatter", - avoidRecordStreams, - nanCheck); + avoidRecordStreams); } c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index b6635157d75e1b..2ba68cda9cc3ea 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -762,8 +762,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, - bool nanCheck = true); + bool avoidRecordStreams = false); template c10::intrusive_ptr collective( @@ -774,8 +773,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { PostProcess post, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, - bool nanCheck = true); + bool avoidRecordStreams = false); template c10::intrusive_ptr collectiveCoalesced( From 4303cd8c50b753be3dced6d6af0bfddfb1970219 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Tue, 27 Aug 2024 16:24:42 -0700 Subject: [PATCH 0167/1018] [hop] require hops to override __call__. (#134352) Fixes https://github.com/pytorch/pytorch/issues/133719 by making `__call__` of hops an abstractmethod. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134352 Approved by: https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 3 +++ test/functorch/test_control_flow.py | 7 +++++++ test/functorch/test_eager_transforms.py | 3 +++ torch/_dynamo/_trace_wrapped_higher_order_op.py | 3 +++ torch/_export/wrappers.py | 3 +++ .../executorch_call_delegate.py | 3 +++ torch/_higher_order_ops/map.py | 8 +++++++- torch/_higher_order_ops/run_const_graph.py | 3 +++ torch/_higher_order_ops/strict_mode.py | 3 +++ torch/_higher_order_ops/torchbind.py | 3 +++ torch/_higher_order_ops/triton_kernel_wrap.py | 17 +++++++++++++++++ torch/_ops.py | 7 +++---- torch/_prims/rng_prims.py | 6 ++++++ 13 files changed, 64 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 4c6cf74421be87..bb109448d28482 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -6441,6 +6441,9 @@ class _FallthroughTestOnly(torch._ops.HigherOrderOperator): def __init__(self): super().__init__("_fallthrough_test_only") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + test_op = _FallthroughTestOnly() default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS self.assertTrue( diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 2798391c4c7f49..e26c8a6f8220e2 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -3710,6 +3710,13 @@ def cond_fn(x): torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),)) self.assertEqual(cnt.frame_count, 3) + def test_hop_raises_if_not_overriding_call(self): + class WrongHop(torch._ops.HigherOrderOperator): + pass + + with self.assertRaisesRegex(TypeError, "WrongHop"): + wrong_hop = WrongHop("wrong_hop") + _hop_schema_test_schema_types = [ "bool", diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index ce2f781c08b889..ba046be3d99324 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -4993,6 +4993,9 @@ class MySum(HigherOrderOperator): def __init__(self): super().__init__("mysum") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + mysum = MySum() @mysum.py_impl(torch._C._functorch.TransformType.Vmap) diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 38f763f07898a1..c698ded100943a 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -52,6 +52,9 @@ class TraceWrapped(HigherOrderOperator): def __init__(self): super().__init__("trace_wrapped") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + # TODO(jansel): need to ensure this does not get DCEed _trace_wrapped_op = TraceWrapped() diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index 5aee2f16a0ddcf..d57ff46de41c8f 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -16,6 +16,9 @@ class ExportTracepoint(HigherOrderOperator): def __init__(self): super().__init__("_export_tracepoint") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + _export_tracepoint = ExportTracepoint() diff --git a/torch/_higher_order_ops/executorch_call_delegate.py b/torch/_higher_order_ops/executorch_call_delegate.py index 4ef87f36080462..a6ee5205ff4ebb 100644 --- a/torch/_higher_order_ops/executorch_call_delegate.py +++ b/torch/_higher_order_ops/executorch_call_delegate.py @@ -29,6 +29,9 @@ class ExecutorchCallDelegate(HigherOrderOperator): def __init__(self): super().__init__("executorch_call_delegate") + def __call__(self, lowered_module, *args): + return super().__call__(lowered_module, *args) + executorch_call_delegate = ExecutorchCallDelegate() executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index b91a7a3701d848..a8dc9d94e45f0e 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -32,6 +32,9 @@ # TODO: We add this to prevent dymamo from tracing into map_wrapper, # remove the wrapper call when it's ready. class MapWrapper(HigherOrderOperator): + def __init__(self): + super().__init__("map") + def __call__(self, xs, *args): return map_wrapper(xs, *args) @@ -40,8 +43,11 @@ class MapImpl(HigherOrderOperator): def __init__(self): super().__init__("map_impl") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + -map = MapWrapper("map") +map = MapWrapper() map_impl = MapImpl() diff --git a/torch/_higher_order_ops/run_const_graph.py b/torch/_higher_order_ops/run_const_graph.py index 774f3ebc0e2a6f..1f49ee28394a1c 100644 --- a/torch/_higher_order_ops/run_const_graph.py +++ b/torch/_higher_order_ops/run_const_graph.py @@ -12,6 +12,9 @@ class RunConstGraph(HigherOrderOperator): def __init__(self): super().__init__("run_const_graph") + def __call__(self, *args): + return super().__call__(*args) + run_const_graph = RunConstGraph() diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index c58eea5750681d..7324e20dcd4cd7 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -32,6 +32,9 @@ class StrictMode(HigherOrderOperator): def __init__(self): super().__init__("strict_mode") + def __call__(self, callable, operands): + return super().__call__(callable, operands) + strict_mode_op = StrictMode() diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index ef2f228bc4b3fe..b35b8d5b296d11 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -26,6 +26,9 @@ class CallTorchBind(HigherOrderOperator): def __init__(self): super().__init__("call_torchbind") + def __call__(self, obj, method, *args, **kwargs): + return super().__call__(obj, method, *args, **kwargs) + call_torchbind = CallTorchBind() diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 65de706a529142..874dcdcb50fa28 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -522,6 +522,14 @@ class TritonKernelWrapperMutation(HigherOrderOperator): def __init__(self) -> None: super().__init__("triton_kernel_wrapper_mutation") + def __call__(self, kernel_idx, constant_args_idx, grid, kwargs): + return super().__call__( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=kwargs, + ) + triton_kernel_wrapper_mutation = TritonKernelWrapperMutation() @@ -531,6 +539,15 @@ class TritonKernelWrapperFunctional(HigherOrderOperator): def __init__(self) -> None: super().__init__("triton_kernel_wrapper_functional") + def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone): + return super().__call__( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=kwargs, + tensors_to_clone=tensors_to_clone, + ) + triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() diff --git a/torch/_ops.py b/torch/_ops.py index 08d0ffab5e7d70..03ed9ca9260998 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import abc import contextlib import ctypes import importlib @@ -238,7 +239,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] ] -class HigherOrderOperator(OperatorBase): +class HigherOrderOperator(OperatorBase, abc.ABC): # The HigherOrderOperator will appear as torch.ops.higher_order.{name} # # If you're creating a new HigherOrderOperator, please do not change the @@ -410,6 +411,7 @@ def check_overloaded(arg): assert not isinstance(kernel, DispatchKey) return kernel(*args, **kwargs) + @abc.abstractmethod def __call__(self, /, *args, **kwargs): # Dynamo already traces the body of HigherOrderOp beforehand when it # so no need to trace into it. @@ -433,9 +435,6 @@ def wrapper(): def __str__(self): return f"{self.name()}" - # def __repr__(self): - # return f"torch.ops._higher_order_ops.{self._name}" - def name(self): return self._name diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index a10106cacc9845..bbbdb8958f9adb 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -156,6 +156,9 @@ class RunAndSaveRngState(HigherOrderOperator): def __init__(self): super().__init__("run_and_save_rng_state") + def __call__(self, op, *args, **kwargs): + return super().__call__(op, *args, **kwargs) + run_and_save_rng_state = RunAndSaveRngState() run_and_save_rng_state.py_impl(DispatchKey.Autograd)( @@ -217,6 +220,9 @@ class RunWithRngState(HigherOrderOperator): def __init__(self): super().__init__("run_with_rng_state") + def __call__(self, rng_state, op, *args, **kwargs): + return super().__call__(rng_state, op, *args, **kwargs) + run_with_rng_state = RunWithRngState() run_with_rng_state.py_impl(DispatchKey.Autograd)( From 13786525aee65b646d7596692d976ad25b7620b8 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 28 Aug 2024 19:58:37 +0000 Subject: [PATCH 0168/1018] [ET][CodeGen] Remove TORCH_API from NativeFunctions.h declarations (#134245) Summary: Remove TORCH_API from the generated executorch/kernels/portable/NativeFunctions.h declarations These generated declarations are using ET tensors. They don't need to have the TORCH_API macro prefixed to them, since in this case TORCH_API is just empty. See [codegen/macros.h](https://www.internalfb.com/code/fbsource/[d12d7d3accfb12932368e0216124f2d735c51d73]/fbcode/executorch/codegen/macros.h) Test Plan: CI Differential Revision: D61490943 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134245 Approved by: https://github.com/larryliu0820 --- torchgen/gen_executorch.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index 0e8d79cf679c2d..353302c7cd4a16 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -337,12 +337,11 @@ def compute_native_function_declaration( metadata_list = kernel_index.get_kernels(g).values() if metadata_list is None: return [] - prefix = "TORCH_API" # for kernels in lean mode, we declare two versions, one with context and one without. # In the end we will cleanup the unused one. def gen_decl(metadata: BackendMetadata, include_context: bool) -> str: - return f"{prefix} {sig.decl(name=metadata.kernel, include_context=include_context)};" + return f"{sig.decl(name=metadata.kernel, include_context=include_context)};" return [ gen_decl(metadata, include_context) @@ -499,11 +498,11 @@ def gen_headers( headers = { "headers": [ "#include // at::Tensor etc.", - "#include // TORCH_API", "#include ", ], } if use_aten_lib: + headers["headers"].append("#include // TORCH_API") cpu_fm.write( "NativeFunctions.h", lambda: dict( From 128bc99ca96d49bf0c8f8ecb5b8b7763ebb86d1b Mon Sep 17 00:00:00 2001 From: Sanket Purandare Date: Wed, 28 Aug 2024 20:06:52 +0000 Subject: [PATCH 0169/1018] Runtime Estimator for estimating GPU compute time (#134243) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a basic Runtime Estimator for single-device models. It estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. It provides a ``TorchDispatchMode`` based context manager that can estimate the eager runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and roofline cost modeling (`operator-level-cost-model`). For modules executed under this context manager, it agggregates the forward and backward operation runtimes and records their execution orders. ``` import torch from torch import nn, optim from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._tools.runtime_estimator import RuntimeEstimator from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, ) if __name__ == "__main__": def _train_step( model: nn.Module, optimizer: optim.Optimizer, inp: torch.Tensor, ): out = model(inp) loss = out.sum() loss.backward() optimizer.step() optimizer.zero_grad() dev = torch.cuda.current_device() vocab_size = 8192 bsz, seq_len = 32, 1024 model_args = ModelArgs( n_layers=4, n_heads=12, vocab_size=vocab_size, max_seq_len=seq_len, dim=768, dropout_p=0.1, ) runtime_estimator = RuntimeEstimator() with FakeTensorMode(): with torch.device(dev): model = Transformer(model_args) optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True) inp = torch.randint(0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev) with runtime_estimator("operator-level-benchmark"): _train_step(model, optimizer, inp) with runtime_estimator("operator-level-cost-model"): _train_step(model, optimizer, inp) # Actual model runtime with torch.device(dev): model = Transformer(model_args) optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True) inp = torch.randint(0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev) warmup_iters, actual_iters = 2, 5 start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) for _ in range(warmup_iters): _train_step(model, optimizer, inp) start_event.record() for _ in range(actual_iters): _train_step(model, optimizer, inp) end_event.record() torch.cuda.synchronize() measured_time = start_event.elapsed_time(end_event) / actual_iters print(f"Actual total_time: {measured_time:.3f} ms") ``` Screenshot 2024-08-26 at 11 27 15 PM @weifengpy @xuanzhang816 @gnadathur Pull Request resolved: https://github.com/pytorch/pytorch/pull/134243 Approved by: https://github.com/weifengpy --- .../_tools/test_runtime_estimator.py | 197 +++++++ torch/distributed/_tools/__init__.py | 1 + torch/distributed/_tools/runtime_estimator.py | 527 ++++++++++++++++++ 3 files changed, 725 insertions(+) create mode 100644 test/distributed/_tools/test_runtime_estimator.py create mode 100644 torch/distributed/_tools/runtime_estimator.py diff --git a/test/distributed/_tools/test_runtime_estimator.py b/test/distributed/_tools/test_runtime_estimator.py new file mode 100644 index 00000000000000..400903f17673f5 --- /dev/null +++ b/test/distributed/_tools/test_runtime_estimator.py @@ -0,0 +1,197 @@ +# Owner(s): ["module: unknown"] +import unittest +from dataclasses import dataclass +from typing import Any, Callable, cast, Tuple, Union + +import torch +from torch import nn, optim +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, +) + + +@dataclass +class ConvArgs: + image_size: int + num_classes: int + + +class SimpleCNN(nn.Module): + def __init__(self, conv_args: ConvArgs): + super().__init__() + image_size = conv_args.image_size + num_classes = conv_args.num_classes + self.image_size = image_size + self.conv1 = nn.Conv2d(3, 32, kernel_size=5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(32, 64, kernel_size=5) + self.conv3 = nn.Conv2d(64, 128, kernel_size=3) + self.conv4 = nn.Conv2d(128, 256, kernel_size=3) + self.fc1_size = self._calculate_fc1_size() + self.fc1 = nn.Linear(self.fc1_size, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, num_classes) + + def _calculate_fc1_size(self): + size = self.image_size + size = (size - 5 + 1) // 2 # conv1 and pool + size = (size - 5 + 1) // 2 # conv2 and pool + size = size - 3 + 1 # conv3 + size = (size - 3 + 1) // 2 # conv4 and pool + return 512 * size * size + + def forward(self, x): + x = self.pool(nn.functional.relu(self.conv1(x))) + x = self.pool(nn.functional.relu(self.conv2(x))) + x = nn.functional.relu(self.conv3(x)) + x = self.pool(nn.functional.relu(self.conv4(x))) + x = x.view(-1, self.fc1_size) + x = nn.functional.relu(self.fc1(x)) + x = nn.functional.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class TestRuntimeEstimator(TestCase): + def _train_step( + self, + model: nn.Module, + optimizer: optim.Optimizer, + inp: torch.Tensor, + ): + out = model(inp) + loss = out.sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + def _measure_actual_cuda_time( + self, + func: Callable, + args: Tuple[Any, ...], + ) -> float: + warmup_iters, actual_iters = 2, 5 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + for _ in range(warmup_iters): + func(*args) + start_event.record() + for _ in range(actual_iters): + func(*args) + end_event.record() + torch.cuda.synchronize() + measured_time = start_event.elapsed_time(end_event) / actual_iters + return measured_time + + def _runtime_estimate( + self, + estimate_mode: str, + func: Callable, + args: Tuple[Any, ...], + ) -> float: + # Optimizer init step + func(*args) + runtime_estimator = RuntimeEstimator() + with runtime_estimator(estimate_mode_type=estimate_mode): + func(*args) + return runtime_estimator.total_runtime + + def _init_model_and_args( + self, + model_type: str, + model_args: Union[ConvArgs, ModelArgs], + bsz: int, + ) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]: + dev = torch.cuda.current_device() + if model_type == "Transformer": + model_args = cast(ModelArgs, model_args) + with torch.device(dev): + model = Transformer(model_args) + optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True) + inp = torch.randint( + 0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev + ) + elif model_type == "CNN": + model_args = cast(ConvArgs, model_args) + with torch.device(dev): + model = SimpleCNN(model_args) + optimizer = optim.SGD(model.parameters(), lr=1e-2, foreach=True) + inp = torch.randn( + bsz, 3, model_args.image_size, model_args.image_size, device=dev + ) + else: + raise NotImplementedError("Only Transformer and CNN is supported") + return (model, optimizer, inp) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_transformer_runtime( + self, + ): + """Runs a basic GPT-2 model""" + vocab_size = 8192 + bsz, seq_len = 8, 1024 + model_args = ModelArgs( + n_layers=4, + n_heads=12, + vocab_size=vocab_size, + max_seq_len=seq_len, + dim=768, + dropout_p=0.1, + ) + + args = self._init_model_and_args("Transformer", model_args, bsz) + actual_runtime = self._measure_actual_cuda_time(self._train_step, args) + with FakeTensorMode(): + fake_args = self._init_model_and_args("Transformer", model_args, bsz) + benchmark_estimate = self._runtime_estimate( + "operator-level-benchmark", self._train_step, fake_args + ) + roofline_estimate = self._runtime_estimate( + "operator-level-cost-model", self._train_step, fake_args + ) + benchmark_accuracy = actual_runtime / benchmark_estimate + roofline_accuracy = actual_runtime / roofline_estimate + print( + f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}" + f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" + ) + self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) + self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_conv_model_runtime( + self, + ): + """Runs a simple CNN model""" + num_classes = 100 + bsz, img_sz = 256, 128 + model_args = ConvArgs(img_sz, num_classes) + args = self._init_model_and_args("CNN", model_args, bsz) + actual_runtime = self._measure_actual_cuda_time(self._train_step, args) + with FakeTensorMode(): + fake_args = self._init_model_and_args("CNN", model_args, bsz) + benchmark_estimate = self._runtime_estimate( + "operator-level-benchmark", self._train_step, fake_args + ) + roofline_estimate = self._runtime_estimate( + "operator-level-cost-model", self._train_step, fake_args + ) + benchmark_accuracy = actual_runtime / benchmark_estimate + roofline_accuracy = actual_runtime / roofline_estimate + print( + f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n" + f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" + ) + self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) + self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/_tools/__init__.py b/torch/distributed/_tools/__init__.py index 8acd9cce2f9033..cd57eedba37517 100644 --- a/torch/distributed/_tools/__init__.py +++ b/torch/distributed/_tools/__init__.py @@ -2,3 +2,4 @@ from .mem_tracker import MemTracker from .memory_tracker import MemoryTracker from .mod_tracker import ModTracker +from .runtime_estimator import RuntimeEstimator diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py new file mode 100644 index 00000000000000..87f4d3f36b60e0 --- /dev/null +++ b/torch/distributed/_tools/runtime_estimator.py @@ -0,0 +1,527 @@ +# Owner(s): ["module: unknown"] +import math +import os +from collections import defaultdict +from typing import Any, Callable, Dict, List, Set, Tuple +from typing_extensions import Self + +import torch +import torch.utils._pytree as pytree +from torch._guards import active_fake_mode +from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.mod_tracker import ModTracker +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.flop_counter import flop_registry + + +aten = torch.ops.aten + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + +# No fall-back kernel needed/exists for view ops +_VIEW_OPS = { + aten.lift_fresh, + aten.t, + aten.transpose, + aten.view, + aten.detach, + aten._unsafe_view, + aten.split, + aten.adjoint, + aten.as_strided, + aten.diagonal, + aten.expand, + aten.expand_as, + aten.movedim, + aten.permute, + aten.select, + aten.squeeze, + aten.mT, + aten.mH, + aten.real, + aten.imag, + aten.view_as, + aten.unflatten, + aten.unfold, + aten.unbind, + aten.unsqueeze, + aten.vsplit, + aten.hsplit, + aten.split_with_sizes, + aten.swapaxes, + aten.swapdims, + aten.chunk, +} +# We can ignore benchmarking tensor create ops +_CREATE_OPS = { + aten.randint, + aten.randn, + aten.rand, + aten.randn_like, + aten.rand_like, + aten.randint_like, + aten.arange, + aten.ones_like, + aten.zeros_like, +} + +_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS + +__all__ = ["RuntimeEstimator"] + + +class RuntimeEstimator(TorchDispatchMode): + """ + Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager + runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and + roofline cost modeling (`operator-level-cost-model`). + For modules executed under this context manager, it agggregates the forward and backward operation runtimes + and also records their execution orders. + + Attributes: + mod_runtimes (Dict[str, Dict[str, float]]): A dictionary of module runtimes. The key to the outer dictionary + is the fully qualified name (FQN) of the module. For each module the forward and backward runtimes of the + operations are aggregated in the inner dictionary keyed by 'fw' and 'bw'. + mod_fw_pre_order (List[str]): List of module FQNs in pre-forward execution order. + mod_bw_pre_order (List[str]): List of module FQNs in pre-backward execution order. + mod_fw_post_order (List[str]): List of module FQNs in post-forward execution order. + mod_bw_post_order (List[str]): List of module FQNs in post-backward execution order. + total_runtime (float): The total estimated runtime in milliseconds. + + Note: + 1) The benchmarking estimate mode will execute kernels on GPU and assumes that every operation can run in + isolation without causing an OOM error. It is also designed to be used only under ``FakeTensorMode``. + 2) Currently wrapper tensor sub-classes such as ``DTensor`` won't produce correct estimates. We plan to support + them in future PRs. + 3) We only estimate the compute time, if your code has communication, it will not be considered. Again, we will + support this in future PRs. + + Example usage: + + .. code-block:: python + + runtime_estimator = RuntimeEstimator() + with FakeTensorMode(): + module = ... + optimizer = ... + inp = ... + with runtime_estimator(estimate_mode_type="operator-level-cost-model"): + loss = module(inp) + loss.backward() + optimizer.step() + optimizer.zero_grad() + runtime_estimator.display_modulewise_stats() + """ + + _float_types: Set[torch.dtype] = { + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + } + _no_fallback_kernel: Set[torch._ops._OpNamespace] = set() + fake_mode: FakeTensorMode + + def __init__(self) -> None: + super().__init__() + self._estimate: Callable + self._estimate_mode_type: str + self._mod_tracker = ModTracker() + self.mod_runtimes: Dict[str, Dict[str, float]] = defaultdict( + lambda: defaultdict(lambda: 0.0) + ) + self.mod_fw_pre_order: List[str] = [] + self.mod_bw_pre_order: List[str] = [] + self.mod_fw_post_order: List[str] = [] + self.mod_bw_post_order: List[str] = [] + self.total_runtime: float = 0.0 + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 + # NB: returns fake tensors + @classmethod + def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] + cls, + func, + args, + kwargs, + orig_not_implemented_exception, + ): + """ + Runs and benchmarks a fallback kernel for a given function. + + Args: + func (Callable): The function to benchmark. + args (Tuple): The arguments to pass to the function. + kwargs (Dict[str, Any]): The keyword arguments to pass to the function. + orig_not_implemented_exception (Exception): The original exception to raise if the fallback kernel + is not implemented. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] + raise orig_not_implemented_exception + + inp_impls = {} + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e): # type: ignore[no-untyped-def] + if cls.fake_mode.is_our_fake(e): + if e.dtype in cls._float_types: + out = torch.rand_like(e, device=e.fake_device) + else: + out = torch.ones_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + r = func(*args, **kwargs) + warmup_iters, actual_iters = 2, 3 + for _ in range(warmup_iters): + func(*args, **kwargs) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record(torch.cuda.current_stream()) + for _ in range(actual_iters): + func(*args, **kwargs) + end_event.record(torch.cuda.current_stream()) + torch.cuda.synchronize() + cuda_time = start_event.elapsed_time(end_event) + mean_op_time = cuda_time / actual_iters + + storages = set() + + for e in flat_args: + if isinstance(e, torch.Tensor): + if not e.is_sparse: + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e): # type: ignore[no-untyped-def] + if id(e) not in inp_impls and ( + isinstance(e, torch.Tensor) + and not e.is_sparse + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, torch.Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return cls.fake_mode.fake_tensor_converter.from_real_tensor( + cls.fake_mode, e + ) + else: + return e + + return (pytree.tree_map(map_out, r), mean_op_time) + + @classmethod + def _benchmark_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using benchmarking. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + res: The result of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert isinstance( + cls.fake_mode, FakeTensorMode + ), "Initialize/Assign FakeTensorMode before using this function" + mean_op_time = 0.0 + if func._overloadpacket not in _VIEW_OPS: + try: + res, mean_op_time = cls._maybe_run_and_benchmark_fallback_kernel( + func, + args, + kwargs, + NotImplementedError, + ) + return (res, mean_op_time) + except NotImplementedError: + cls._no_fallback_kernel.add(func._overloadpacket) + res = func(*args, **kwargs or {}) + return (res, mean_op_time) + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + @classmethod + def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a roofline cost model. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + out: The output of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert ( + torch.cuda.is_available() + ), "Roofline estimation needs to access CUDA capabilities to make estimations" + + def get_num_bytes(t: torch.Tensor) -> int: + """ + Calculates the memory consumption of a tensor. + + Args: + t (torch.Tensor): The input tensor. + + Returns: + int: The memory consumption of the tensor in bytes. + """ + num_bytes = t.untyped_storage().nbytes() + mem_consumed = ( + math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + ) + return mem_consumed + + def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert ( + len(out_dtypes) == 1 + ), f"Only support single out dtype got {out_dtypes} for {func_packet}" + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 + + def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] + """ + Estimates the memory transfer time of input and output tensors. + + Args: + flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. + flat_outs (List[torch.Tensor]): The flat list of outputs. + + Returns: + float: The estimated memory transfer time in nanoseconds. + """ + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) + for t in flat_args_kwargs + if isinstance(t, torch.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time + + # Roofline Cost Model Explanation + + # The roofline cost model estimates the execution time of an operator based on + # the device's empirical maximum FLOPs/sec (pi) and device DRAM bandwidth (beta). + + # Variables: + # - pi: Maximum empirical FLOPs/sec of the device + # - beta: Maximum empirical device DRAM bandwidth (bytes/sec) of the device + # - I: Arithmetic intensity of the operator (FLOPs/bytes) + # - op_flops: FLOPs required by the operator + # - op_bytes: Bytes transferred to and from DRAM for the operator + + # Calculation Steps: + # 1. Calculate arithmetic intensity: I = op_flops / op_bytes + # 2. Calculate estimated FLOPs/sec: est_flops_sec = min(pi, beta * I) + # 3. Calculate estimated operator time: estimated_op_time = op_flops / est_flops_sec + # This simplifies to: estimated_op_time = max(op_flops / pi, op_flops / (beta * I)) + # Further simplifying: estimated_op_time = max(op_flops / pi, op_bytes / beta) + + # Simplified Formulas: + # - compute_time = op_flops / pi + # - transfer_time = op_bytes / beta + # - estimated_op_time = max(compute_time, transfer_time) + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + op_time = 0.0 + func_packet = func._overloadpacket + if func_packet not in _IGNORE_OPS: + flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) + flat_outs, out_spec = pytree.tree_flatten(out) + transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) + + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in cls._float_types + } + + args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) + out = pytree.tree_unflatten(flat_outs, out_spec) + + compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 + + return (out, op_time) + + def display_modulewise_stats(self, depth: int = 2) -> None: + """ + Displays module-wise statistics collected by ``RuntimeEstimator``. + + Prints the pre-forward and pre-backward execution orders. + Displays the module-wise forward and backward runtimes in milliseconds. + + Args: + depth (int): The maximum depth of module hierarchy to display (default to 2). + """ + print("Pre-Forward Execution Order: ") + for mod_fqn in self.mod_fw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + print("Pre-Backward Execution Order: ") + for mod_fqn in self.mod_bw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + for mod_fqn, runtimes in self.mod_runtimes.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print( + f"{mod_fqn} fw: {runtimes.get('fw', 0.0):.3f}ms bw: {runtimes.get('bw', 0.0):.3f}ms" + ) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + # TODO: @sanketpurandare: Flatten tensors by desugaring the tensor subclasses + # TODO: @sanketpurandare: Add logic for incorporating communication time + res, op_time = self._estimate(func, args, kwargs) + for par in self._mod_tracker.parents: + if self._mod_tracker.is_bw: + self.mod_runtimes[par]["bw"] += op_time + else: + self.mod_runtimes[par]["fw"] += op_time + self.total_runtime += op_time + return res + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + RuntimeEstimator: The runtime estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + self._estimate_mode_type = estimate_mode_type + return self + + def __enter__(self) -> Self: + fake_mode = active_fake_mode() + assert isinstance( + fake_mode, FakeTensorMode + ), "No FakeTensorMode found, designed to used under FakeTensorMode" + RuntimeEstimator.fake_mode = fake_mode + self.total_runtime = 0.0 + self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) + self.mod_fw_pre_order.clear() + self.mod_bw_pre_order.clear() + self.mod_fw_post_order.clear() + self.mod_bw_post_order.clear() + self._mod_tracker.register_user_hooks( + pre_fw_hook=lambda mod, inp: self.mod_fw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + pre_bw_hook=lambda mod, g_out: self.mod_bw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_fw_hook=lambda mod, inp, out: self.mod_fw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_bw_hook=lambda mod, g_inp: self.mod_bw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + ) + self._mod_tracker.__enter__() + super().__enter__() + return self + + def __exit__(self, *args: Any) -> None: + print( + f"Estimated ({self._estimate_mode_type})" + f"total_time: {self.total_runtime:.3f} ms" + ) + if len(self._no_fallback_kernel) > 0: + print("no_fallback_kernel: ", list(self._no_fallback_kernel)) + super().__exit__(*args) + self._mod_tracker.clear_user_hooks() + self._mod_tracker.__exit__() From ac22e8bc7ca3828732a9d57583077890ef381334 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Wed, 28 Aug 2024 11:26:39 -0700 Subject: [PATCH 0170/1018] [c10d FR analyzer] Output a meaningful debug report for users (#134528) - This PR generates a more useful output log for users: P1552399180. - It also fixes the logic when we check the all-gather size mismatch. - Add dtype check for collective input/output - We store more context information for error match_state so that we can report them in the file. - Disable the size match for alltoall because we don't log the size for all inputs/outputs. - Correct some types for func args specification. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134528 Approved by: https://github.com/c-p-i-o --- .../flight_recorder/test_fr_analysis.py | 16 +++ tools/flight_recorder/components/builder.py | 55 +++++++--- tools/flight_recorder/components/types.py | 102 ++++++++++-------- tools/flight_recorder/components/utils.py | 65 +++++++++-- torch/csrc/distributed/c10d/NCCLUtils.cpp | 1 + 5 files changed, 172 insertions(+), 67 deletions(-) diff --git a/test/distributed/flight_recorder/test_fr_analysis.py b/test/distributed/flight_recorder/test_fr_analysis.py index 2e389de7694ab9..43a296a3966c61 100644 --- a/test/distributed/flight_recorder/test_fr_analysis.py +++ b/test/distributed/flight_recorder/test_fr_analysis.py @@ -25,6 +25,7 @@ def create_one_event( state="scheduled", collective_seq_id=0, p2p_seq_id=0, + output_dtypes="float32", ): return { "profiling_name": f"nccl:{collectcive_name}", @@ -32,6 +33,8 @@ def create_one_event( "process_group": pg_info, "input_sizes": input_sizes, "output_sizes": output_sizes, + "input_dtypes": "float32", + "output_dtypes": output_dtypes, "collective_seq_id": str(collective_seq_id), "p2p_seq_id": str(p2p_seq_id), } @@ -88,6 +91,19 @@ def test_match_one_event(self): match_one_event(e1, e9, membership), MatchState.COLLECTIVE_STATE_MISMATCH ) + e10 = create_one_event( + "all_reduce", + ("0", "default"), + [[4, 4]], + [[4, 4]], + "completed", + 1, + output_dtypes="float16", + ) + self.assertEqual( + match_one_event(e10, e9, membership), MatchState.COLLECTIVE_DTYPE_MISMATCH + ) + if __name__ == "__main__": run_tests() diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index b3cb9f792469a3..b2f7b13101800d 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -22,12 +22,13 @@ from tools.flight_recorder.components.utils import ( check_no_missing_dump_files, check_size_alltoall, - check_trace_from_beginning, check_version, find_coalesced_group, + format_frames, just_print_entries, match_coalesced_groups, match_one_event, + sort_trace_from_beginning, ) @@ -122,7 +123,7 @@ def build_nccl_call( def build_collectives( - all_entries: Dict[str, List[Dict[str, Any]]], + all_entries: Dict[int, List[Dict[str, Any]]], _groups: Dict[str, Group], _memberships: Dict[str, Set[Any]], ) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]: @@ -186,6 +187,11 @@ def build_collectives( pg_name, desc = entries[0]["process_group"] profiling_name = entries[0]["profiling_name"] collective_seq_id = entries[0]["collective_seq_id"] + record_id = entries[0]["record_id"] + input_sizes = entries[0]["input_sizes"] + output_sizes = entries[0]["output_sizes"] + collective_state = entries[0]["state"] + collective_frames = format_frames(entries[0]["frames"]) expected_ranks = set(_memberships[pg_name]) candidate_ranks = {first_rank} candidate_idx = {} @@ -265,15 +271,25 @@ def build_collectives( else: candidate_ranks.add(o) candidate_idx[o] = i - errors.add(match_state) + if match_state not in [ + MatchState.FULLY_MATCHED, + MatchState.UNDECIDED, + ]: + # Here we assume the current rank is not the source of the error. + # But it's possible that the current rank is the culprit, then users will + # see lots of normal ranks reported as culprit. + # TODO: we need to figure out a better way to handle the case mentioned above. + errors.add((o, match_state)) break # case one: not every rank join the collective or in the flight recorder. if (candidate_ranks | found_ranks) != expected_ranks: mismatch[pg_name] += 1 print( - f"Not all ranks joining collective for group {pg_name}:{desc} collective {profiling_name}", - f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)}", + f"Not all ranks joining collective for group {pg_name}:{desc} collective {profiling_name} ", + f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)} ", + f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", + f"\nCollective stack traces: \n{collective_frames}", ) elif len(candidate_ranks) == 1: # case two: alltoall or alltoall_base case. @@ -281,14 +297,20 @@ def build_collectives( alltoall_cases = [entries[0]] + [ all_entries[o][found_idx[o]] for o in found_ranks ] - check_result, input_numel, output_numel = check_size_alltoall( + fail_check, input_numel, output_numel = check_size_alltoall( alltoall_cases ) - if not check_result: + # We don't log the input/output sizes for alltoall so we don't consider the size mismatch as an error for now. + fail_check = False + if fail_check: + # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. mismatch[pg_name] += 1 print( - f"Input/output mismatch in the collective for group {pg_name}:{desc} collective {profiling_name}", - f"input_numel {input_numel} output_numel{output_numel}", + f"Input/output mismatch in the collective {record_id} ", + f"for group {pg_name}:{desc} collective {profiling_name} ", + f"input_numel {input_numel} output_numel {output_numel} ", + f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", + f"\nCollective stack traces: \n{collective_frames}", ) candidate_ranks.update(found_ranks) candidate_idx.update(found_idx) @@ -306,11 +328,16 @@ def build_collectives( candidate_idx.clear() candidate_ranks.clear() # case four: mismatch cases due to not same type, size mismatch or state mismatch. - else: - error_msg = ", ".join(error.name for error in errors) + elif len(errors) > 0: + mismatch[pg_name] += 1 + error_msg = ", ".join( + f"Error rank {error[0]}, {str(error[1])}" for error in errors + ) print( - f"Collective errors for group {pg_name}:{desc} collective {profiling_name}", - f"Found errors: {error_msg}", + f"Collective {record_id} errors for group {pg_name}:{desc} collective {profiling_name} ", + f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", + f"\nFound errors: {error_msg}\n", + f"\nCollective stack traces: \n{collective_frames} ", ) candidate_ranks.update(found_ranks) candidate_idx.update(found_idx) @@ -373,7 +400,7 @@ def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Da pg_config[rank] = dump["pg_config"] check_version(version) - check_trace_from_beginning(entries) + entries = sort_trace_from_beginning(entries) # flattened database groups, _groups, memberships, _memberships = build_groups_memberships(pg_config) diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index 4a33f3580e4c16..2d02954577d330 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +import math +from enum import auto, Enum from typing import ( # type: ignore[attr-defined] _eval_type, Any, @@ -12,6 +13,7 @@ Generic, List, NamedTuple, + Optional, Set, Tuple, Type, @@ -156,34 +158,25 @@ class MatchState(Enum): - SIZE_OR_SYNTAX_MISMATCH: There is a mismatch in input/output sizes or violation of collective syntax. - COLLECTIVE_STATE_MISMATCH: The states of the collective not same, such as one finished while another just started or scheduled. + - COLLECTIVE_DTYPE_MISMATCH: The data types of the collective input/output differ. - UNDECIDED: - The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for all_to_all. + The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base. """ - FULLY_MATCHED = 1 - COLLECTIVE_TYPE_MISMATCH = 2 - SIZE_OR_SYNTAX_MISMATCH = 3 - COLLECTIVE_STATE_MISMATCH = 4 - UNDECIDED = 5 - - -def check_size_evenly_broadcasting( - list1: List[Any], list2: List[Any], size: int -) -> bool: - if len(list1) != len(list2): - return False - ratio = None - for a, b in zip(list1, list2): - current_ratio = int(a) / int(b) - if current_ratio == 1: - continue - if current_ratio != size: - return False - elif ratio is None: - ratio = current_ratio - else: - return False - return True + FULLY_MATCHED = auto() + COLLECTIVE_TYPE_MISMATCH = auto() + SIZE_OR_SYNTAX_MISMATCH = auto() + COLLECTIVE_STATE_MISMATCH = auto() + COLLECTIVE_DTYPE_MISMATCH = auto() + UNDECIDED = auto() + + def __call__(self, culprit: Optional[str] = None) -> "MatchState": + # Make the enum instance callable to add culprit. + self.culprit = culprit + return self + + def __str__(self) -> str: + return f"Error type: {self.name}, Detail finding {self.culprit if self.culprit else ''}" class Op: @@ -203,9 +196,7 @@ def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): type = parts[0] meta = parts[1] if len(parts) == 2 else None self.state = event["state"] - self.pg_name, _ = event["process_group"] - assert type in COLLECTIVES | P2P | { "coalesced" }, f"{type} is not a supported operation" @@ -221,7 +212,6 @@ def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): pg_name, pg_desc = event["process_group"] self._init_global_src_dst(memberships[pg_name]) self.pg_size = len(memberships[pg_name]) - if type in P2P | COLLECTIVES: self.input_sizes = event["input_sizes"] self.output_sizes = event["output_sizes"] @@ -229,6 +219,8 @@ def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): self.input_sizes, self.output_sizes = None, None self.collective_seq_id = event["collective_seq_id"] self.p2p_seq_id = event["p2p_seq_id"] + self.input_dtypes = event["input_dtypes"] + self.output_dtypes = event["output_dtypes"] def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None: pg_ranks = sorted(pg_ranks) @@ -283,34 +275,60 @@ def match(self, other: "Op") -> MatchState: ) elif self.type in COLLECTIVES: if self.type != other.type: - return MatchState.COLLECTIVE_TYPE_MISMATCH + return MatchState.COLLECTIVE_TYPE_MISMATCH( + f"Type '{self.type}' and '{other.type}' do not match" + ) + if self.state != other.state: + # MatchState() + return MatchState.COLLECTIVE_STATE_MISMATCH( + f"States '{self.state}' '{other.state}' do not match" + ) + if ( + other.input_dtypes != other.output_dtypes + or self.input_dtypes != other.input_dtypes + or self.output_dtypes != other.output_dtypes + ): + return MatchState.COLLECTIVE_DTYPE_MISMATCH( + f"Dtypes '{self.input_dtypes}/{other.input_dtypes}' '{self.output_dtypes}/{other.output_dtypes}' do not match" + ) if self.type == "all_to_all": return MatchState.UNDECIDED if self.type != "scatter" and self.input_sizes != other.input_sizes: - return MatchState.SIZE_OR_SYNTAX_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Input sizes '{self.input_sizes}' '{other.input_sizes}' do not match" + ) if self.type != "gather" and self.output_sizes != other.output_sizes: - return MatchState.SIZE_OR_SYNTAX_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Output sizes '{self.output_sizes}' '{other.output_sizes}' do not match" + ) if self.type == "all_reduce" and self.input_sizes != other.output_sizes: - return MatchState.SIZE_OR_SYNTAX_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Input sizes '{self.input_sizes}' do not match output sizes '{other.output_sizes}'" + ) # TODO: need to consider uneven sharding for all-gather. # TODO: need to consider all_gather_into_tensor_coalesced (coalesced related) if self.type in [ "all_gather", "all_gather_base", - ] and not check_size_evenly_broadcasting( - other.output_sizes[0], self.input_sizes[0], self.pg_size + ] and not ( + math.prod(other.output_sizes[0]) + == math.prod(self.input_sizes[0]) * self.pg_size ): - return MatchState.SIZE_OR_SYNTAX_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' " + f"do not match output numel '{math.prod(other.output_sizes[0])}'", + ) if self.type in [ "reduce_scatter", "_reduce_scatter_base", - ] and not check_size_evenly_broadcasting( - other.input_sizes[0], self.output_sizes[0], self.pg_size + ] and not ( + math.prod(other.input_sizes[0]) + == math.prod(self.output_sizes[0]) * self.pg_size ): - return MatchState.SIZE_OR_SYNTAX_MISMATCH - # TODO: need to add more checks for gather and scatter. - if self.state != other.state: - return MatchState.COLLECTIVE_STATE_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Input numel '{math.prod(other.input_sizes[0])}' do not match output numel " + f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'", + ) elif self.type == "coalesced": return ( MatchState.FULLY_MATCHED diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 2f9d382261b913..efb1799260f055 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -22,6 +22,20 @@ print("tabulate is not installed. Proceeding without it.") +def format_frame(frame: Dict[str, str]) -> str: + name = frame["name"] + filename = frame["filename"] + line = frame["line"] + return f"{name} at {filename}:{line}" + + +def format_frames(frames: List[Dict[str, str]]) -> str: + formatted_frames = [] + for frame in frames: + formatted_frames.append(format_frame(frame)) + return "\n".join(formatted_frames) + + def match_one_event( event_a: Dict[Any, Any], event_b: Dict[Any, Any], @@ -198,7 +212,7 @@ def just_print_entries( def check_no_missing_dump_files( - entries: Dict[str, Any], memberships: List[Membership] + entries: Dict[int, Any], memberships: List[Membership] ) -> None: all_ranks = set() for membership in memberships: @@ -216,16 +230,45 @@ def check_version(versions: Dict[str, Any]) -> None: # assert minor >= 0, f"Rank {rank} unsupported version {version}" -# TODO: We need to revisit this function to see if we still need it. -def check_trace_from_beginning(entries: Dict[str, Any]) -> bool: +def sort_trace_from_beginning( + entries: Dict[int, List[Dict[str, Any]]] +) -> Dict[int, List[Dict[str, Any]]]: + """ + Sorts the trace entries by record ID for entries. + This function takes a dictionary of rank names to lists of trace entries as input. + Each trace entry is a dictionary containing information about a collective operation, + including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer). + The function first sorts the entries in each rank by their `record_id` values. + Then, it finds the largest starting point across all ranks by taking the maximum + `record_id` value of the first entry in each rank. Finally, it filters out any + entries with `record_id` values less than the maximum starting point. + The function returns the updated dictionary of sorted and filtered trace entries. + + Args: + entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries. + + Returns: + entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point. + """ + + maximum_starting_record_id = 0 for rank in entries: + # Although this is a ring buffer, we already sort the entries by `record_id` when dumping, we just + # need to find the largest starting point. For example, if the buffer has the following entries: + # Rank 0: [0, 1, 2, 3, 4, 5, 6] + # Rank 1: [1, 2, 3, 4, 5, 6, 7] + # Rank 2: [2, 3, 4, 5, 6, 7, 8] + # Rank 3: [0, 1, 2, 3, 4, 5, None] + # Then we should start from collective 2 not 0 because any collective before, + # we don't have complete records from all ranks so we need to ignore them. first_record_id = entries[rank][0]["record_id"] - # TODO add more sequence information such that analysis can proceed even without complete buffer + maximum_starting_record_id = max(maximum_starting_record_id, first_record_id) - # assert first_record_id == 0, f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}." - if first_record_id != 0: - print( - f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}." - ) - return False - return True + for rank in entries: + entries[rank] = [ + entry + for entry in entries[rank] + if entry["record_id"] >= maximum_starting_record_id + ] + + return entries diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 46a2426fd5f4c0..648c87881c6110 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -534,6 +534,7 @@ const c10::List NCCLTraceBuffer::getCollectiveTrace( bool includeStacktraces, bool onlyActive) { auto entries = new_list(); + // Entries are returned in the order they were recorded auto result = dump_entries(); std::vector tracebacks; torch::SymbolizedTracebacks stracebacks; From 79347d0c3bf08d4e31565d4e5e33574b1e8447f5 Mon Sep 17 00:00:00 2001 From: Nowtryz <44437613+nowtryz@users.noreply.github.com> Date: Wed, 28 Aug 2024 21:28:21 +0000 Subject: [PATCH 0171/1018] Add MaskedTensor support to *_like API (#128637) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128637 Approved by: https://github.com/cpuhrsch --- docs/source/masked.rst | 14 +++++++ test/test_maskedtensor.py | 53 ++++++++++++++++++++++++++ torch/masked/maskedtensor/__init__.py | 1 + torch/masked/maskedtensor/_ops_refs.py | 13 +++---- torch/masked/maskedtensor/like.py | 38 ++++++++++++++++++ 5 files changed, 112 insertions(+), 7 deletions(-) create mode 100644 torch/masked/maskedtensor/like.py diff --git a/docs/source/masked.rst b/docs/source/masked.rst index 60dd67f643b890..00dadd2abf6885 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -296,11 +296,25 @@ The following ops are currently supported: Tensor.reshape_as Tensor.view +Other functions +--------------- + +The following :mod:`torch` functions has also a specific support for :class:`~torch.masked.MaskedTensor`: + +:func:`~torch.empty_like`, +:func:`~torch.full_like`, +:func:`~torch.ones_like`, +:func:`~torch.rand_like`, +:func:`~torch.randint_like`, +:func:`~torch.randn_like`, +:func:`~torch.zeros_like`. + .. This module needs to be documented. Adding here in the meantime .. for tracking purposes .. py:module:: torch.masked.maskedtensor.binary .. py:module:: torch.masked.maskedtensor.core .. py:module:: torch.masked.maskedtensor.creation +.. py:module:: torch.masked.maskedtensor.like .. py:module:: torch.masked.maskedtensor.passthrough .. py:module:: torch.masked.maskedtensor.reductions .. py:module:: torch.masked.maskedtensor.unary \ No newline at end of file diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index bf42cf086376ec..0817822be457b2 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -19,6 +19,7 @@ binary_ufuncs, reduction_ops, unary_ufuncs, + ops_and_refs, ) from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask @@ -26,6 +27,7 @@ from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES from torch.masked.maskedtensor.reductions import REDUCE_NAMES +from torch.masked.maskedtensor.like import LIKE_NAMES def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05): @@ -812,9 +814,13 @@ def is_binary(op): def is_reduction(op): return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"} +def is_like(op): + return op.name in LIKE_NAMES + mt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)] mt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)] mt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)] +mt_like_funcs = [op for op in ops_and_refs if is_like(op)] MASKEDTENSOR_FLOAT_TYPES = { torch.float16, @@ -822,6 +828,16 @@ def is_reduction(op): torch.float64, } +MASKEDTENSOR_SUPPORTED_TYPES = [ + *MASKEDTENSOR_FLOAT_TYPES, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + torch.int64, +] + + class TestOperators(TestCase): def _convert_mt_args(self, args, mask, layout): return [ @@ -967,6 +983,43 @@ def test_reduction_all(self, device, dtype, op, layout): self._test_reduction_equality(device, dtype, op, layout) + @ops(mt_like_funcs, allowed_dtypes=MASKEDTENSOR_SUPPORTED_TYPES) # type: ignore[arg-type] + @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) + def test_like(self, device, dtype, op, layout): + samples = op.sample_inputs(device, dtype, requires_grad=dtype in MASKEDTENSOR_FLOAT_TYPES) + + for sample in samples: + input = sample.input + sample_args, sample_kwargs = sample.args, sample.kwargs + mask = _create_random_mask(input.shape, device) + + if layout == torch.sparse_coo: + mask = mask.to_sparse_coo().coalesce() + input = input.sparse_mask(mask) + elif layout == torch.sparse_csr: + if input.ndim != 2 or mask.ndim != 2: + continue + mask = mask.to_sparse_csr() + input = input.sparse_mask(mask) + + mt = masked_tensor(input, mask) + + try: + t_result = op(input, *sample_args, **sample_kwargs) + except NotImplementedError: + self.skipTest(f"{op.name} is not supported for {layout}") + + mt_result = op(mt, *sample_args, **sample_kwargs) + + if 'rand' not in op.name and op.name != 'empty_like': + _compare_mt_t(mt_result, t_result.to_dense()) + + self.assertEqual(mt_result.shape, t_result.shape) + self.assertEqual(mt_result.dtype, t_result.dtype) + self.assertEqual(mt_result.layout, t_result.layout) + self.assertEqual(mt_result.device, t_result.device) + self.assertEqual(mt_result.requires_grad, t_result.requires_grad) + only_for = ("cpu", "cuda") instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) diff --git a/torch/masked/maskedtensor/__init__.py b/torch/masked/maskedtensor/__init__.py index e38e03c87086cf..9c908e9a5cba6b 100644 --- a/torch/masked/maskedtensor/__init__.py +++ b/torch/masked/maskedtensor/__init__.py @@ -3,6 +3,7 @@ from .binary import _apply_native_binary, _is_native_binary from .core import is_masked_tensor, MaskedTensor +from .like import _apply_like_fn, _is_like_fn from .passthrough import _apply_pass_through_fn, _is_pass_through_fn from .reductions import _apply_reduction, _is_reduction from .unary import _apply_native_unary, _is_native_unary diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 7344a6e801aae2..febb0a95359b32 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -14,6 +14,7 @@ is_masked_tensor, MaskedTensor, ) +from .like import _apply_like_fn, LIKE_FNS from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS from .reductions import ( _apply_reduction, @@ -270,6 +271,11 @@ def _general_binary(func, *args, **kwargs): return _apply_native_binary(func, *args, **kwargs) +@register_dispatch_func(LIKE_FNS) +def _general_like(func, *args, **kwargs): + return _apply_like_fn(func, *args, **kwargs) + + @register_dispatch_func([torch.ops.aten.stride]) def stride(func, *args, **kwargs): return None @@ -365,13 +371,6 @@ def _softmax(func, *args, **kwargs): return MaskedTensor(result_data, mask) -@register_dispatch_func([torch.ops.aten.ones_like]) -def ones_like(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1) - result_data = func(_get_data(args[0]), **kwargs) - return MaskedTensor(result_data, _maybe_get_mask(args[0])) - - @register_dispatch_func([torch.ops.aten._softmax_backward_data]) def _softmax_backward_data(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4) diff --git a/torch/masked/maskedtensor/like.py b/torch/masked/maskedtensor/like.py new file mode 100644 index 00000000000000..8503670e1ba660 --- /dev/null +++ b/torch/masked/maskedtensor/like.py @@ -0,0 +1,38 @@ +# mypy: allow-untyped-defs +""" +These are the creation ops using the ``*_like`` API that need to support the +`MaskedTensor`. The wrapper just applies the function to the masked data then convert +it to a masked tensor using the mask from the given tensor. +""" + +import torch + +from .core import _get_data, _maybe_get_mask, MaskedTensor + + +__all__ = [ + "LIKE_NAMES", + "LIKE_FNS", +] + + +LIKE_NAMES = [ + "empty_like", + "full_like", + "ones_like", + "rand_like", + "randint_like", + "randn_like", + "zeros_like", +] + +LIKE_FNS = [getattr(torch.ops.aten, name) for name in LIKE_NAMES] + + +def _is_like_fn(fn): + return fn in LIKE_FNS + + +def _apply_like_fn(func, *args, **kwargs): + result_data = func(_get_data(args[0]), *args[1:], **kwargs) + return MaskedTensor(result_data, _maybe_get_mask(args[0])) From c0efe2dc5e5f3223c48a03bc11af87e279a7139c Mon Sep 17 00:00:00 2001 From: Nowtryz <44437613+nowtryz@users.noreply.github.com> Date: Wed, 28 Aug 2024 21:30:37 +0000 Subject: [PATCH 0172/1018] Add MaskedTensor passthrough: unfold, F.Unfold, F.Fold, stack (#125262) Hi, I noticed the `unfold` operator was missing on MaskedTensor. I tested that my change works when calling unfold and backward on a `MaskedTensor` but I didn't find the tests for the dispatch of such operation. Where is it? Pull Request resolved: https://github.com/pytorch/pytorch/pull/125262 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/Col2Im.cpp | 2 +- aten/src/ATen/native/Im2Col.cpp | 2 +- aten/src/ATen/native/cuda/Col2Im.cu | 2 +- aten/src/ATen/native/cuda/Im2Col.cu | 2 +- docs/source/masked.rst | 3 ++ test/test_maskedtensor.py | 50 ++++++++++++++++--- torch/masked/maskedtensor/passthrough.py | 5 ++ .../_internal/common_methods_invocations.py | 4 +- 8 files changed, 56 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/Col2Im.cpp b/aten/src/ATen/native/Col2Im.cpp index 0a98ee5c46ca77..51e005c2901b93 100644 --- a/aten/src/ATen/native/Col2Im.cpp +++ b/aten/src/ATen/native/Col2Im.cpp @@ -144,7 +144,7 @@ static void col2im_out_cpu_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, input.scalar_type(), "col2im_out_cpu", [&] { Tensor input_n = Tensor(); Tensor output_n = Tensor(); diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp index dac2ee6e3f103a..25eb4d6787240a 100644 --- a/aten/src/ATen/native/Im2Col.cpp +++ b/aten/src/ATen/native/Im2Col.cpp @@ -94,7 +94,7 @@ static void im2col_out_cpu_template( output.resize_({batch_size, n_output_plane, output_length}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, input.scalar_type(), "im2col_out_cpu", [&] { Tensor input_n; Tensor output_n; diff --git a/aten/src/ATen/native/cuda/Col2Im.cu b/aten/src/ATen/native/cuda/Col2Im.cu index bb6d4748deb1f5..e390666b307aa1 100644 --- a/aten/src/ATen/native/cuda/Col2Im.cu +++ b/aten/src/ATen/native/cuda/Col2Im.cu @@ -102,7 +102,7 @@ void col2im_out_cuda_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); int64_t output_batch_stride = output.stride(0); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, input.scalar_type(), "col2im_out_cuda", [&] { int64_t height_col = (output_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / diff --git a/aten/src/ATen/native/cuda/Im2Col.cu b/aten/src/ATen/native/cuda/Im2Col.cu index 312ad893c0d81e..796eda97b37331 100644 --- a/aten/src/ATen/native/cuda/Im2Col.cu +++ b/aten/src/ATen/native/cuda/Im2Col.cu @@ -103,7 +103,7 @@ static void im2col_out_cuda_template( output.resize_({batch_size, n_output_plane, output_length}); // Launch kernel - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, input.scalar_type(), "im2col_out_cuda", [&] { Tensor input_n; Tensor output_n; diff --git a/docs/source/masked.rst b/docs/source/masked.rst index 00dadd2abf6885..086e1f6c8b25c2 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -283,9 +283,11 @@ The following ops are currently supported: kron meshgrid narrow + nn.functional.unfold ravel select split + stack t transpose vsplit @@ -294,6 +296,7 @@ The following ops are currently supported: Tensor.expand_as Tensor.reshape Tensor.reshape_as + Tensor.unfold Tensor.view Other functions diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index 0817822be457b2..50d03b306bb0cc 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -68,6 +68,18 @@ def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08): if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match") +def _compare_forward_backward(data, mask, fn): + mt = masked_tensor(data, mask, requires_grad=True) + masked_res = fn(mt) + masked_res.sum().backward() + + t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() + tensor_res = fn(t) + tensor_res.sum().backward() + + _compare_mt_t(masked_res, tensor_res) + _compare_mt_t(mt.grad, t.grad, atol=1e-06) + def _create_random_mask(shape, device): return make_tensor(shape, device=device, dtype=torch.bool) @@ -166,15 +178,8 @@ def test_softmax(self, device): ], device=device ) - mt = masked_tensor(data, mask, requires_grad=True) - masked_res = torch.softmax(mt, -1) - masked_res.sum().backward() - xinf = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() - tensor_res = torch.softmax(xinf, -1) - tensor_res.sum().backward() - _compare_mt_t(masked_res, tensor_res) - _compare_mt_t(mt.grad, xinf.grad, atol=1e-06) + _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1)) def test_where(self, device): data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device) @@ -194,6 +199,35 @@ def test_where(self, device): _compare_mt_t(mx.grad, x.grad) _compare_mt_t(my.grad, y.grad) + def test_unfold(self, device): + data = torch.rand(5, 5, device=device) + mask = torch.rand(5, 5, device=device) > 0.5 + _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2)) + + def test_nn_unfold(self, device): + data = torch.rand(2, 5, 3, 4, device=device) + mask = torch.rand(2, 5, 3, 4, device=device) > 0.5 + _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3))) + + def test_stack(self, device): + masked_tensors = [ + masked_tensor( + torch.rand(2, 5, 3, 4, device=device), + torch.rand(2, 5, 3, 4, device=device) > 0.5, + requires_grad=True, + ) for _ in range(3) + ] + + data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors] + masked_res = torch.stack(masked_tensors) + tensor_res = torch.stack(data_tensors) + + masked_res.sum().backward() + tensor_res.sum().backward() + _compare_mt_t(masked_res, tensor_res) + for mt, t in zip(masked_tensors, data_tensors): + _compare_mt_t(mt.grad, t.grad, atol=1e-06) + def test_to_sparse(self, device): for sample in _generate_sample_data(device=device): data = sample.input diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py index 4a2e79456c86db..ba13f50c1fee9c 100644 --- a/torch/masked/maskedtensor/passthrough.py +++ b/torch/masked/maskedtensor/passthrough.py @@ -30,6 +30,11 @@ torch.ops.aten._reshape_alias, torch.ops.aten.cat, torch.ops.aten.unsqueeze, + torch.ops.aten.unfold, + torch.ops.aten.unfold_backward, + torch.ops.aten.im2col, + torch.ops.aten.col2im, + torch.ops.aten.stack, ] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3049cd8d8f5753..fecab41a741f33 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15266,8 +15266,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): autodiff_nonfusible_nodes=["aten::hardswish"]), OpInfo('nn.functional.unfold', aten_name='im2col', - dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), sample_inputs_func=sample_inputs_nn_unfold, # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, From bedc339760e11546a705d0e02afdd83655f9dbaa Mon Sep 17 00:00:00 2001 From: Jia Li Date: Wed, 28 Aug 2024 21:34:32 +0000 Subject: [PATCH 0173/1018] Create processes in parallel in mp.start_processes for forkserver (#134629) Summary: This is to fix the pytorch issue filed https://github.com/pytorch/pytorch/issues/133010 one way to fix this problem is to enable parallel start processes in mp.start_processes. What else in the diff: refactored a test case api_test which was repeating a lot of tests due to the inheritance. added unit test for forkserver when parallel start is on. Test Plan: Added unit tests Differential Revision: D61878552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134629 Approved by: https://github.com/d4l3k --- .../elastic/multiprocessing/api_test.py | 285 ++++++++++-------- test/test_multiprocessing_spawn.py | 74 ++++- torch/multiprocessing/__init__.py | 1 + torch/multiprocessing/spawn.py | 59 +++- 4 files changed, 280 insertions(+), 139 deletions(-) diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 75e903807ff9b2..9f55785810ca86 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -226,36 +226,65 @@ def start_processes_zombie_test( pc.close(e.sigval) -# tests incompatible with tsan or asan -if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): +class _StartProcessesTest(TestCase): + def setUp(self): + super().setUp() + self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") + self._start_methods = ["spawn"] - class StartProcessesTest(TestCase): - def setUp(self): - super().setUp() - self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") - self._start_methods = ["spawn"] + def tearDown(self): + super().tearDown() + shutil.rmtree(self.test_dir) - def tearDown(self): - super().tearDown() - shutil.rmtree(self.test_dir) + def log_dir(self): + return tempfile.mkdtemp(dir=self.test_dir) - def log_dir(self): - return tempfile.mkdtemp(dir=self.test_dir) + def assert_in_file(self, expected: List[str], filename: str) -> None: + expected = [f"{line.rstrip()}\n" for line in expected] + with open(filename) as fp: + actual = fp.readlines() + for line in expected: + self.assertIn(line, actual) + + def assert_pids_noexist(self, pids: Dict[int, int]): + for local_rank, pid in pids.items(): + with self.assertRaises( + OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist" + ): + os.kill(pid, 0) + + def _test_zombie_workflow( + self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals + ) -> None: + mp_queue = mp.get_context("spawn").Queue() + child_nproc = 2 + ctx = mp.spawn( + start_processes_zombie_test, + nprocs=1, + args=(entrypoint, mp_queue, self.log_dir(), child_nproc), + join=False, + ) + total_processes = child_nproc + 1 + pids = [] + for _ in range(total_processes): + pids.append(mp_queue.get(timeout=120)) + parent_pid = pids[0] + child_pids = pids[1:] + + os.kill(parent_pid, signal.SIGTERM) + # Wait to give time for signal handlers to finish work + time.sleep(5) + for child_pid in child_pids: + # Killing parent should kill all children, we expect that each call to + # os.kill would raise OSError + with self.assertRaises(OSError): + os.kill(child_pid, 0) - def assert_in_file(self, expected: List[str], filename: str) -> None: - expected = [f"{line.rstrip()}\n" for line in expected] - with open(filename) as fp: - actual = fp.readlines() - for line in expected: - self.assertIn(line, actual) - def assert_pids_noexist(self, pids: Dict[int, int]): - for local_rank, pid in pids.items(): - with self.assertRaises( - OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist" - ): - os.kill(pid, 0) +# tests incompatible with tsan or asan +if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): + class StartProcessesAsFuncTest(_StartProcessesTest): def test_to_map(self): local_world_size = 2 self.assertEqual( @@ -349,26 +378,13 @@ def test_multiprocess_context_close(self): self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped()) - def test_subprocess_context_close(self): - pc = start_processes( - name="sleep", - entrypoint=bin("zombie_test.py"), - args={0: (1,)}, - envs={0: {}}, - logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), - ) - - pids = pc.pids() - pc.close() - self.assert_pids_noexist(pids) - def test_function_with_tensor(self): for start_method in self._start_methods: pc = start_processes( name="dummy_compute", entrypoint=dummy_compute, - args={}, - envs={}, + args={0: ()}, + envs={0: {}}, logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), start_method=start_method, ) @@ -489,10 +505,55 @@ def test_wait_for_all_child_procs_to_exit(self): mpc._poll() self.assertEqual(4, mock_join.call_count) + @skip_but_pass_in_sandcastle_if( + NO_MULTIPROCESSING_SPAWN, + "Disabled for environments that \ + don't support multiprocessing with spawn start method", + ) + def test_multiprocessing_context_poll_raises_exception(self): + mp_context = MultiprocessContext( + name="test_mp", + entrypoint=echo0, + args={0: (0, 1)}, + envs={0: {}}, + logs_specs=DefaultLogsSpecs( + log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL + ), + start_method="spawn", + ) + mp_context._pc = mock.Mock() + # Using mock since we cannot just set exitcode on process + mock_process = mock.Mock() + mock_process.exitcode = -1 + mp_context._pc.processes = [mock_process] + e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123) + mp_context._pc.join.side_effect = e + with mock.patch.object(mp_context, "close"): + run_result = mp_context._poll() + self.assertEqual(1, len(run_result.failures)) + failure = run_result.failures[0] + self.assertEqual( + "Signal 1 (SIGHUP) received by PID 123", failure.message + ) + + class StartProcessesAsBinaryTest(_StartProcessesTest): ######################################## # start_processes as binary tests ######################################## + def test_subprocess_context_close(self): + pc = start_processes( + name="sleep", + entrypoint=bin("zombie_test.py"), + args={0: (1,)}, + envs={0: {}}, + logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), + ) + + pids = pc.pids() + pc.close() + self.assert_pids_noexist(pids) + def test_binary_exit(self): FAIL = 138 pc = start_processes( @@ -556,45 +617,11 @@ def test_validate_full_rank(self): with self.assertRaises(RuntimeError): _validate_full_rank({}, 10, "") - @skip_but_pass_in_sandcastle_if( - NO_MULTIPROCESSING_SPAWN, - "Disabled for environments that \ - don't support multiprocessing with spawn start method", - ) - def test_multiprocessing_context_poll_raises_exception(self): - mp_context = MultiprocessContext( - name="test_mp", - entrypoint=echo0, - args={0: (0, 1)}, - envs={0: {}}, - logs_specs=DefaultLogsSpecs( - log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL - ), - start_method="spawn", - ) - mp_context._pc = mock.Mock() - # Using mock since we cannot just set exitcode on process - mock_process = mock.Mock() - mock_process.exitcode = -1 - mp_context._pc.processes = [mock_process] - e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123) - mp_context._pc.join.side_effect = e - with mock.patch.object(mp_context, "close"): - run_result = mp_context._poll() - self.assertEqual(1, len(run_result.failures)) - failure = run_result.failures[0] - self.assertEqual( - "Signal 1 (SIGHUP) received by PID 123", failure.message - ) - # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): - class StartProcessesListTest(StartProcessesTest): - ######################################## - # start_processes as binary tests - ######################################## + class StartProcessesListAsFuncTest(_StartProcessesTest): def test_function(self): for start_method, redirs in product( self._start_methods, redirects_oss_test() @@ -634,6 +661,10 @@ def test_function(self): [f"hello stderr from {i}"], results.stderrs[i] ) + class StartProcessesListAsBinaryTest(_StartProcessesTest): + ######################################## + # start_processes as binary tests + ######################################## def test_binary(self): for redirs in redirects_oss_test(): with self.subTest(redirs=redirs): @@ -700,7 +731,7 @@ def test_binary_redirect_and_tee(self): # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI): - class StartProcessesNotCITest(StartProcessesTest): + class StartProcessesNotCIAsFuncTest(_StartProcessesTest): @skip_if_pytest def test_wrap_bad(self): none = "" @@ -733,32 +764,6 @@ def test_wrap_bad(self): self.assert_in_file(["hello stderr from 0"], stderr_log) worker_finished_event_mock.wait.assert_called_once() - def test_binary_signal(self): - pc = start_processes( - name="echo", - entrypoint=bin("echo3.py"), - args={0: ("--segfault", "true", "foo"), 1: ("bar",)}, - envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, - logs_specs=DefaultLogsSpecs( - log_dir=self.log_dir(), - ), - ) - - results = pc.wait(period=0.1) - - self.assert_pids_noexist(pc.pids()) - self.assertTrue(results.is_failed()) - self.assertEqual(1, len(results.failures)) - - failure = results.failures[0] - self.assertNotEqual(signal.SIGSEGV, failure.exitcode) - if TEST_WITH_ASAN or TEST_WITH_TSAN: - # ASAN/TSAN exit code is 1. - self.assertEqual("", failure.signal_name()) - else: - self.assertEqual("SIGSEGV", failure.signal_name()) - self.assertEqual("", failure.error_file_data["message"]) - def test_function_redirect_and_tee(self): for start_method in self._start_methods: with self.subTest(start_method=start_method): @@ -871,42 +876,60 @@ def test_function_exit(self): self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped()) - def test_no_zombie_process_binary(self): - signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] - for s in signals: - self._test_zombie_workflow(bin("zombie_test.py"), s) - def test_no_zombie_process_function(self): signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] for s in signals: self._test_zombie_workflow(wait_fn, s) - def _test_zombie_workflow( - self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals - ) -> None: - mp_queue = mp.get_context("spawn").Queue() - child_nproc = 2 - ctx = mp.spawn( - start_processes_zombie_test, - nprocs=1, - args=(entrypoint, mp_queue, self.log_dir(), child_nproc), - join=False, + class StartProcessesNotCIAsBinaryTest(_StartProcessesTest): + def test_binary_signal(self): + pc = start_processes( + name="echo", + entrypoint=bin("echo3.py"), + args={0: ("--segfault", "true", "foo"), 1: ("bar",)}, + envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, + logs_specs=DefaultLogsSpecs( + log_dir=self.log_dir(), + ), ) - total_processes = child_nproc + 1 - pids = [] - for _ in range(total_processes): - pids.append(mp_queue.get(timeout=120)) - parent_pid = pids[0] - child_pids = pids[1:] - - os.kill(parent_pid, signal.SIGTERM) - # Wait to give time for signal handlers to finish work - time.sleep(5) - for child_pid in child_pids: - # Killing parent should kill all children, we expect that each call to - # os.kill would raise OSError - with self.assertRaises(OSError): - os.kill(child_pid, 0) + + results = pc.wait(period=0.1) + + self.assert_pids_noexist(pc.pids()) + self.assertTrue(results.is_failed()) + self.assertEqual(1, len(results.failures)) + + failure = results.failures[0] + self.assertNotEqual(signal.SIGSEGV, failure.exitcode) + if TEST_WITH_ASAN or TEST_WITH_TSAN: + # ASAN/TSAN exit code is 1. + self.assertEqual("", failure.signal_name()) + else: + self.assertEqual("SIGSEGV", failure.signal_name()) + self.assertEqual("", failure.error_file_data["message"]) + + def test_no_zombie_process_binary(self): + signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] + for s in signals: + self._test_zombie_workflow(bin("zombie_test.py"), s) + + class ForkServerTest( + StartProcessesAsFuncTest, + StartProcessesListAsFuncTest, + StartProcessesNotCIAsFuncTest, + ): + def setUp(self): + super().setUp() + self._start_methods = ["forkserver"] + self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) + os.environ[mp.ENV_VAR_PARALLEL_START] = "1" + + def tearDown(self): + super().tearDown() + if self.orig_paralell_env_val is None: + del os.environ[mp.ENV_VAR_PARALLEL_START] + else: + os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val if __name__ == "__main__": diff --git a/test/test_multiprocessing_spawn.py b/test/test_multiprocessing_spawn.py index 5160056a87f7d0..acad97827ec423 100644 --- a/test/test_multiprocessing_spawn.py +++ b/test/test_multiprocessing_spawn.py @@ -8,9 +8,14 @@ import time import unittest -from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN) import torch.multiprocessing as mp +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + NO_MULTIPROCESSING_SPAWN, + run_tests, + TestCase, +) def _test_success_func(i): pass @@ -220,6 +225,73 @@ class ForkTest(TestCase, _TestMultiProcessing): start_method = 'fork' +@unittest.skipIf( + IS_WINDOWS, + "Fork is only available on Unix", +) +class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing): + orig_paralell_env_val = None + + def setUp(self): + super().setUp() + self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) + os.environ[mp.ENV_VAR_PARALLEL_START] = "1" + + def tearDown(self): + super().tearDown() + if self.orig_paralell_env_val is None: + del os.environ[mp.ENV_VAR_PARALLEL_START] + else: + os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val + + +@unittest.skipIf( + IS_WINDOWS, + "Fork is only available on Unix", +) +class ParallelForkServerPerfTest(TestCase): + + def test_forkserver_perf(self): + + start_method = 'forkserver' + expensive = Expensive() + nprocs = 4 + orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) + + # test the non parallel case + os.environ[mp.ENV_VAR_PARALLEL_START] = "0" + start = time.perf_counter() + mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) + elapsed = time.perf_counter() - start + # the elapsed time should be at least {nprocs}x the sleep time + self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs) + + # test the parallel case + os.environ[mp.ENV_VAR_PARALLEL_START] = "1" + start = time.perf_counter() + mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) + elapsed = time.perf_counter() - start + # the elapsed time should be less than {nprocs}x the sleep time + self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs) + + if orig_paralell_env_val is None: + del os.environ[mp.ENV_VAR_PARALLEL_START] + else: + os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val + + +class Expensive: + SLEEP_SECS = 5 + # Simulate startup overhead such as large imports + time.sleep(SLEEP_SECS) + + def __init__(self): + self.config: str = "*" * 1000000 + + def my_call(self, *args): + pass + + class ErrorTest(TestCase): def test_errors_pickleable(self): for error in ( diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index f9ea0ad13bdafa..745c180d8c415c 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -39,6 +39,7 @@ """Add helper function to spawn N processes and wait for completion of any of them. This depends `mp.get_context` which was added in Python 3.4.""" from .spawn import ( + ENV_VAR_PARALLEL_START, ProcessContext, ProcessExitedException, ProcessRaisedException, diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index aae2e338f7b0a0..74bdde0fd97b20 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -9,13 +9,26 @@ import tempfile import time import warnings +from concurrent.futures import as_completed, ThreadPoolExecutor from typing import Optional from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] +ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START" + log = logging.getLogger(__name__) +__all__ = [ + "ProcessContext", + "ProcessException", + "ProcessExitedException", + "ProcessRaisedException", + "spawn", + "SpawnContext", + "start_processes", +] + class ProcessException(Exception): __slots__ = ["error_index", "error_pid"] @@ -205,12 +218,32 @@ def __init__(self, processes, error_files): # Currently we only add this API first, we can consider adding it to documentation as # needed in the future. def start_processes( - fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn" + fn, + args=(), + nprocs=1, + join=True, + daemon=False, + start_method="spawn", ): + # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), + # this func will start processes in parallel if start_method is 'forkserver'. + # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1. + # todo: investigate why spawn does not work with threadpool and raises SIGINT + if ( + start_method == "forkserver" + and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1" + ): + log.info("Starting processes in parallel.") + start_parallel = True + else: + # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start + start_parallel = False + mp = multiprocessing.get_context(start_method) - error_files = [] - processes = [] - for i in range(nprocs): + error_files = [None] * nprocs + processes = [None] * nprocs + + def start_process(i): # Each process is assigned a file to write tracebacks to. We # use the file being non-empty to indicate an exception # occurred (vs an expected shutdown). Note: this previously @@ -228,9 +261,21 @@ def start_processes( daemon=daemon, ) process.start() - error_files.append(tf.name) - processes.append(process) - + return i, process, tf.name + + if not start_parallel: + for i in range(nprocs): + idx, process, tf_name = start_process(i) + error_files[idx] = tf_name + processes[idx] = process + else: + with ThreadPoolExecutor(max_workers=nprocs) as executor: + futures = [executor.submit(start_process, i) for i in range(nprocs)] + for fut in as_completed(futures): + idx, process, tf_name = fut.result() + # idx and process rank needs to be the same. + error_files[idx] = tf_name + processes[idx] = process context = ProcessContext(processes, error_files) if not join: return context From 344e940958c00de06b6dfc41d7c676a02df1455d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 28 Aug 2024 21:51:43 +0000 Subject: [PATCH 0174/1018] Revert "[dynamo] Cache _dynamo.disable results (#134272)" This reverts commit dbef2b05b4d81e891f7497f92f730a22bebe445d. Reverted https://github.com/pytorch/pytorch/pull/134272 on behalf of https://github.com/anijain2305 due to Peak mem increase detected internally ([comment](https://github.com/pytorch/pytorch/pull/134272#issuecomment-2316308170)) --- test/dynamo/test_decorators.py | 14 -------------- torch/_dynamo/__init__.py | 1 - torch/_dynamo/convert_frame.py | 1 - torch/_dynamo/decorators.py | 15 +-------------- 4 files changed, 1 insertion(+), 30 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 463634e3216358..b915e0633f7c11 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -184,20 +184,6 @@ def hook(module, args): all(node.target is not torch.sigmoid for node in gm1.graph.nodes) ) - def test_disable_no_recompile(self): - def gn(x): - return torch.cos(x) - - @torch.compile(backend="eager") - def fn(x): - x = torch.sin(x) - x = torch._dynamo.disable(gn, recursive=True)(x) - return torch.sin(x) - - with torch._dynamo.config.patch(error_on_recompile=True): - for _ in range(5): - fn(torch.randn(4)) - def test_allow_in_graph(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 00286a32750d1e..7f58ba7f7bf7f1 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -107,4 +107,3 @@ def reset_code_caches() -> None: if code: reset_code(code) code_context.clear() - convert_frame.disabled_codes.clear() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index c1fadedd614405..cddc89695d4fee 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -161,7 +161,6 @@ def clear(self) -> None: input_codes = Tracker() output_codes = Tracker() -disabled_codes: Dict[int, Callable[..., Any]] = {} initial_global_state: Optional[GlobalStateGuard] = None diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 52ffa76cf4a4e7..c83d8d718a62a4 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -10,7 +10,6 @@ from . import trace_rules, variables from .comptime import comptime -from .convert_frame import disabled_codes from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage from .external_utils import is_compiling @@ -56,21 +55,9 @@ def disable(fn=None, recursive=True): """ if recursive: if fn is not None: - id_fn = id(fn) - if cached_fn := disabled_codes.get(id_fn): - return cached_fn - fn = innermost_fn(fn) assert callable(fn) - out = DisableContext()(fn) - - # Do not cache Fx graph methods. These come from Dynamo itself. Caching such graphs can increase lifetime of - # held objects. Since these are coming from Dynamo itself, there is no benefit of caching. - if not ( - inspect.ismethod(fn) and isinstance(fn.__self__, torch.fx.GraphModule) - ): - disabled_codes[id_fn] = out - return out + return DisableContext()(fn) return DisableContext() else: return skip(fn) From eefe7a4916650859537a923a57126bd101f71638 Mon Sep 17 00:00:00 2001 From: Janet Yang Date: Wed, 28 Aug 2024 21:53:24 +0000 Subject: [PATCH 0175/1018] [amd][lowering] hipify shim v2 headers (#134689) Summary: The default c_shim version was switched to 2 for HIP in D60674018. This results in some linking errors where shim function symbols are missing from the compiled .so file (eg. P1551186492) when building lowering benchmark scripts since the required files aren't included. Hipify the shim v2 generated header files as well since they're needed during codegen when the buck binaries are executed. Reviewed By: frank-wei, zoranzhao, henryoier Differential Revision: D61865202 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134689 Approved by: https://github.com/zoranzhao --- build.bzl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/build.bzl b/build.bzl index 9ada64dc302ad1..dbb1866ac54823 100644 --- a/build.bzl +++ b/build.bzl @@ -76,7 +76,7 @@ def define_targets(rules): ] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else [])) gen_aten_outs_cuda = ( - GENERATED_H_CUDA + GENERATED_CPP_CUDA + + GENERATED_H_CUDA + GENERATED_CPP_CUDA + GENERATED_AOTI_CUDA_CPP + aten_ufunc_generated_cuda_sources() ) @@ -320,5 +320,8 @@ GENERATED_AUTOGRAD_CPP = [ GENERATED_AOTI_CPP = [ "torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp", +] + +GENERATED_AOTI_CUDA_CPP = [ "torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp", ] From c5c3b32fa4063497477200396e36fba9c9431715 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 26 Aug 2024 15:03:49 -0700 Subject: [PATCH 0176/1018] [2/N] Add flag to control which rank should perform NaN check (#134345) Fixes https://github.com/pytorch/pytorch/issues/134062. For example, in case of broadcast / scatter, only the root rank should perform the NaN check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134345 Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab ghstack dependencies: #134300 --- test/distributed/test_c10d_nccl.py | 22 +++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 39 +++++++++++++------ .../distributed/c10d/ProcessGroupNCCL.hpp | 6 ++- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 64c51f791fddd0..55b785a3552b5a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -366,6 +366,28 @@ def test_nan_assert(self, type): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_nan_p2p(self): + # Putting NaN at recv buffer, program should not fail as NaN checker + # should not check on receive buffer + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + if self.rank == 0: + c10d.send(t, 1) + elif self.rank == 1: + # Putting NaN at recv buffer + t[1, 1] = float("nan") + c10d.recv(t, 0) + c10d.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index abcf493e62fb8e..291eb7993ef2fb 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2637,9 +2637,12 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PostProcess post, OpType opType, const char* profilingTitle, - bool avoidRecordStreams) { + bool avoidRecordStreams, + bool nanCheck) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; + nanCheck &= enableNanCheck_; + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2693,7 +2696,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { + if (nanCheck) { checkForNan(input, ncclStream); } @@ -3126,7 +3129,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { + // Only check for NaN for send ops, for recv ops `tensor` can be a random + // placeholder + if (enableNanCheck_ && opType == OpType::SEND) { checkForNan(tensor, ncclStream); } @@ -3223,7 +3228,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( Fn fn, OpType opType, const char* profilingTitle, - bool avoidRecordStreams) { + bool avoidRecordStreams, + bool nanCheck) { return collective( input, output, @@ -3234,7 +3240,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( c10::intrusive_ptr& work) {}, opType, profilingTitle, - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } template @@ -3484,6 +3491,9 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( // avoidRecordStreams_ note: collective() will stash tensors. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( tensor, tensor, @@ -3491,7 +3501,6 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; return ncclBcast( input.data_ptr(), input.numel(), @@ -3502,7 +3511,8 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( }, OpType::BROADCAST, "nccl:broadcast", - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } // _broadcast_oop adds an out-of-place broadcast in PGNCCL @@ -3522,6 +3532,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( "Tensor input and output of _broadcast_oop must have the same number of elements "); } + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( inputTensor, outputTensor, @@ -3529,7 +3542,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; return ncclBroadcast( input.data_ptr(), output.data_ptr(), @@ -3540,7 +3552,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, - "nccl:_broadcast_oop"); + "nccl:_broadcast_oop", + /*avoidRecordStreams=*/false, + nanCheck); } c10::intrusive_ptr ProcessGroupNCCL::reduce( @@ -4491,6 +4505,9 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // inputs, which == inputTensors[0] on the root rank where it matters. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank; + bool nanCheck = (rank_ == root); + return collective( outputTensor, inputs[0], // just to fit the collective interface @@ -4498,7 +4515,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank; if (getRank() == root) { if (!avoidRecordStreams) { for (auto input : inputs) { @@ -4512,7 +4528,8 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( }, OpType::SCATTER, "nccl:scatter", - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 2ba68cda9cc3ea..b6635157d75e1b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -762,7 +762,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool avoidRecordStreams = false, + bool nanCheck = true); template c10::intrusive_ptr collective( @@ -773,7 +774,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PostProcess post, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool avoidRecordStreams = false, + bool nanCheck = true); template c10::intrusive_ptr collectiveCoalesced( From fa1e876bc25f0616fb2dbc1698fca228194ee554 Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 28 Aug 2024 10:17:00 -0700 Subject: [PATCH 0177/1018] Skip test_mutable_custom_op_fixed_layout2 on ROCM (#134690) ROCM doesn't trigger the layout optimization that makes the test case valid so we're going to skip the checks. Should fix the following (I'll close them later) - https://github.com/pytorch/pytorch/issues/134481 - https://github.com/pytorch/pytorch/issues/134519 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134690 Approved by: https://github.com/FindHao ghstack dependencies: #134466, #134490, #134491 --- test/inductor/test_torchinductor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 512d606920df2d..2019d378063404 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10613,8 +10613,8 @@ def fn(x): with torch.no_grad(): self.common(fn, (inp,), check_lowp=False) - # Dynamic shapes invalidate this test case - if torch._dynamo.config.assume_static_by_default: + # Dynamic shapes and rocm invalidate this test case + if torch._dynamo.config.assume_static_by_default and not TEST_WITH_ROCM: # For this test to be valid, Inductor must have changed the conv # to be channels-last. If this assertion ever fails then we need # a new test case. From 355ee51b7a5e882ce555f56e024d08c05006decb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 26 Aug 2024 15:03:49 -0700 Subject: [PATCH 0178/1018] [3/N] Set correct device to CUDA guards (#134357) In `collective()`, `pointToPoint()` and `collectiveCoalesced()`, CUDA guards were created with an unset (default) CUDA device. This is the reason for the IMA facing the NaN checker in issue https://github.com/pytorch/pytorch/issues/134062. With this fix, `torch.cuda.set_device(device)` is not needed to work around the IMA. Also refactored a couple places where the guard is created -- preferably we create the guard with a known device, rather than setting the device later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134357 Approved by: https://github.com/wconstab, https://github.com/shuqiangzhang ghstack dependencies: #134300, #134345 --- test/distributed/test_c10d_nccl.py | 19 +++++++++++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 18 ++++++++++-------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 55b785a3552b5a..3a575729cdff78 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -350,6 +350,7 @@ def test_close_pg(self): @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16]) @skip_if_rocm def test_nan_assert(self, type): + # Expecting a device-side error when NaN is detected os.environ["TORCH_NCCL_NAN_CHECK"] = "1" store = c10d.FileStore(self.file_name, self.world_size) pg = self._create_process_group_nccl(store, self.opts()) @@ -388,6 +389,24 @@ def test_nan_p2p(self): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_nan_check(self): + # Not expecting an error, NaN check should not make legit code fail + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank + t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + c10d.broadcast(x, src=0) + c10d.all_reduce(t) + c10d.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 291eb7993ef2fb..8f741185a2c584 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1148,6 +1148,10 @@ void ProcessGroupNCCL::abortCommsFromMap( at::cuda::OptionalCUDAGuard gpuGuard; at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); if (deviceIndex >= 0) { + // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in + // the map could be non deviceIndex, but rank to rank numbers. So we + // indeed need to check if deviceIndex >= 0 + // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` gpuGuard.set_index(deviceIndex); } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " @@ -2141,7 +2145,9 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( bool batchP2P = ncclActiveGroupCounter_ > 0; bool singleP2POp = isP2POp(opType, batchP2P); - at::cuda::OptionalCUDAGuard gpuGuard; + // Get the device index + auto deviceIndex = device.index(); + at::cuda::OptionalCUDAGuard gpuGuard(device); // [Group Start/End Note] This is used to ensure that nccl communicator will // be created before communication primitives are called. Let's look at this @@ -2181,10 +2187,6 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( rank = p2pRank; } - // Get the device index - auto deviceIndex = device.index(); - gpuGuard.set_index(deviceIndex); - #ifdef NCCL_HAS_COMM_SPLIT if (options_->split_from) { TORCH_CHECK( @@ -2694,7 +2696,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( work->stashed_for_allocator_safety_->push_back(input); } - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); if (nanCheck) { checkForNan(input, ncclStream); @@ -2859,7 +2861,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) @@ -3127,7 +3129,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); // Only check for NaN for send ops, for recv ops `tensor` can be a random // placeholder From fdfe7f5b436fc2d4d8136cbb538991970f35b89d Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 28 Aug 2024 10:17:01 -0700 Subject: [PATCH 0179/1018] Improve opcheck docs. (#134692) Fixes https://github.com/pytorch/pytorch/issues/134119 From user feedback, it's difficult to understand what the tests do. We clarify the docs more. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134692 Approved by: https://github.com/albanD ghstack dependencies: #134466, #134490, #134491, #134690 --- torch/library.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/torch/library.py b/torch/library.py index 3fc44f191ba383..4ac89c5259b071 100644 --- a/torch/library.py +++ b/torch/library.py @@ -1228,17 +1228,35 @@ def opcheck( ``opcheck`` tests these metadata and properties. Concretely, we test the following: - - test_schema: if the operator's schema is correct. - - test_autograd_registration: if autograd was registered correctly. + + - test_schema: If the schema matches the implementation of + the operator. For example: if the schema specifies a Tensor is mutated, + then we check the implementation mutates the Tensor. If the schema + specifies that we return a new Tensor, then we check that the + implementation returns a new Tensor (instead of an existing one or + a view of an existing one). + - test_autograd_registration: If the operator supports training + (autograd): we check that its autograd formula is registered via + torch.library.register_autograd or a manual registration to one + or more DispatchKey::Autograd keys. Any other DispatchKey-based + registrations may lead to undefined behavior. - test_faketensor: If the operator has a FakeTensor kernel - (and if it is correct). The FakeTensor kernel is necessary ( - but not sufficient) for the operator to work with PyTorch compilation - APIs (torch.compile/export/FX). + (and if it is correct). The FakeTensor kernel is necessary ( + but not sufficient) for the operator to work with PyTorch compilation + APIs (torch.compile/export/FX). We check that a FakeTensor kernel + (also sometimes known as a meta kernel) was registered for the + operator and that it is correct. This test takes the result of + running the operator on real tensors and the result of running + the operator on FakeTensors and checks that they have the same + Tensor metadata (sizes/strides/dtype/device/etc). - test_aot_dispatch_dynamic: If the operator has correct behavior - with PyTorch compilation APIs (torch.compile/export/FX). - This checks that the outputs (and gradients, if applicable) are the - same under eager-mode PyTorch and torch.compile. - This test is a superset of ``test_faketensor``. + with PyTorch compilation APIs (torch.compile/export/FX). + This checks that the outputs (and gradients, if applicable) are the + same under eager-mode PyTorch and torch.compile. + This test is a superset of ``test_faketensor`` and is an e2e test; + other things it tests are that the operator supports + functionalization and that the backward pass (if it exists) also + supports FakeTensor and functionalization. For best results, please call ``opcheck`` multiple times with a representative set of inputs. If your operator supports From fd7d5f8d1ba2ea257d5148e64a3a45396a64eafc Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 28 Aug 2024 10:17:01 -0700 Subject: [PATCH 0180/1018] Improve custom ops aliasing error message (#134688) Fixes https://github.com/pytorch/pytorch/issues/134278 Test Plan: - tested locally Pull Request resolved: https://github.com/pytorch/pytorch/pull/134688 Approved by: https://github.com/yushangdi ghstack dependencies: #134466, #134490, #134491, #134690, #134692 --- torch/_library/custom_ops.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index ced50a7892a010..2eb45a78a5edce 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -332,12 +332,16 @@ def backend_impl(*args, **kwargs): fn = self._backend_fns[device_type] module = inspect.getmodule(fn) raise RuntimeError( - f"Tensors returned from custom ops (1) must not " - f"be inputs to the custom op and (2) may not alias " - f"any inputs or other returns. Please clone the " - f"the offending output tensors (e.g. output.clone()) " - f"or refactor your code. " - f"Offending op: {self._name} (with implementation in {module})" + f"{self._name} (with implementation in {module}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." ) storages.add(key) return result From 4cd392292665b795689186684f051b77295e0ad5 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Wed, 28 Aug 2024 07:01:24 -0700 Subject: [PATCH 0181/1018] Add export_for_training as public API (#134677) Differential Revision: [D61912084](https://our.internmc.facebook.com/intern/diff/D61912084) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134677 Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17 --- test/export/test_export.py | 12 +-- .../test_export_training_ir_to_run_decomp.py | 11 +-- torch/_export/__init__.py | 8 +- torch/export/__init__.py | 86 +++++++++++++++++++ 4 files changed, 102 insertions(+), 15 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 2393443b5f0b0f..3997de36eb4b5b 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1437,7 +1437,7 @@ def forward(self, x, y): x_linear = self.linear(x_conv) return x_linear.cos() + y_conv_1d.sum() - ep = torch.export._trace._export_for_training( + ep = torch.export.export_for_training( Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) ) ep_has_linear_convd = ep.run_decompositions( @@ -1570,7 +1570,7 @@ def forward(self, x): return self.linear(x) eager_model = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model, (torch.ones(2, 2),) ) self.assertExpectedInline( @@ -1609,7 +1609,7 @@ def forward(self, x): eager_model_for_export = Foo() eager_model_for_testing = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model_for_export, (torch.ones(4, 4),) ) self.assertExpectedInline( @@ -1654,7 +1654,7 @@ def forward(self, x): eager_model_for_export_training = Foo() eager_model_for_export_inference = Foo() eager_model_for_testing = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model_for_export_training, (torch.ones(4, 4),), dynamic_shapes=({0: Dim("x")},), @@ -1691,7 +1691,7 @@ def forward(self, container): return x + y + self.buffer.sum() eager_model = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model, ([torch.ones(4, 4), torch.ones(4, 4)],), ) @@ -1717,7 +1717,7 @@ def forward(self, x): return self.linear(x) + self.buffer.sum() eager_model = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model, (torch.ones(2, 2),), ) diff --git a/test/export/test_export_training_ir_to_run_decomp.py b/test/export/test_export_training_ir_to_run_decomp.py index 7ff8beaa47094c..6f780c1fec85bf 100644 --- a/test/export/test_export_training_ir_to_run_decomp.py +++ b/test/export/test_export_training_ir_to_run_decomp.py @@ -1,19 +1,20 @@ # Owner(s): ["oncall: export"] +import torch + try: from . import test_export, testing except ImportError: import test_export - import testing -from torch.export._trace import _export_for_training + import testing test_classes = {} def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs): - ep = _export_for_training(*args, **kwargs) + ep = torch.export.export_for_training(*args, **kwargs) return ep.run_decompositions( {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY ) @@ -21,9 +22,9 @@ def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs): def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs): if "strict" in kwargs: - ep = _export_for_training(*args, **kwargs) + ep = torch.export.export_for_training(*args, **kwargs) else: - ep = _export_for_training(*args, **kwargs, strict=False) + ep = torch.export.export_for_training(*args, **kwargs, strict=False) return ep.run_decompositions( {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY ) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 4e9f7433f8be2c..fea70b061a73d3 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -65,9 +65,9 @@ def capture_pre_autograd_graph_warning(): log.warning("| !!! WARNING !!! |") log.warning("+============================+") log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.") - log.warning("Please switch to use torch.export._trace._export_for_training instead.") + log.warning("Please switch to use torch.export.export_for_training instead.") if config.is_fbcode(): - log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export._trace._export_for_training.") # noqa: B950 + log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950 @compatibility(is_backward_compatible=False) @@ -128,9 +128,9 @@ def capture_pre_autograd_graph( if capture_pre_autograd_graph_using_training_ir(): @lru_cache def print_export_warning(): - log.warning("Using torch.export._trace._export_for_training(...,strict=True)") + log.warning("Using torch.export.export_for_training(...,strict=True)") print_export_warning() - module = torch.export._trace._export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module() + module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module() else: log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 1e365a990c8290..2cbbc472968a83 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -51,6 +51,7 @@ "ModuleCallSignature", "dims", "export", + "export_for_training", "load", "register_dataclass", "save", @@ -69,6 +70,91 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] +def export_for_training( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), +) -> ExportedProgram: + """ + :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the all ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. This API is intended for PT2 quantization training use cases + and will soon be the default IR of torch.export.export in the near future. + + **Soundness Guarantee** + + See :func:`export()` docstring for more details. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export_for_training + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + return _export_for_training( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + ) + + def export( mod: torch.nn.Module, args: Tuple[Any, ...], From afe6c543a922305271574c6afa47102b2a56017b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 28 Aug 2024 23:10:07 +0000 Subject: [PATCH 0182/1018] Revert "Add MaskedTensor passthrough: unfold, F.Unfold, F.Fold, stack (#125262)" This reverts commit f685018ea9d08f98cbd7106028db134f967f74d3. Reverted https://github.com/pytorch/pytorch/pull/125262 on behalf of https://github.com/ZainRizvi due to Hi, this PR appears to be calling maskedtensor tests to fail on main. Please rebase your changes onto the latest trunk build to repro the failure. test_maskedtensor.py::TestOperatorsCUDA::test_like_empty_like_layout1_cuda_bool [GH job link](https://github.com/pytorch/pytorch/actions/runs/10604716811/job/29393256312) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/f685018ea9d08f98cbd7106028db134f967f74d3) ([comment](https://github.com/pytorch/pytorch/pull/125262#issuecomment-2316387447)) --- aten/src/ATen/native/Col2Im.cpp | 2 +- aten/src/ATen/native/Im2Col.cpp | 2 +- aten/src/ATen/native/cuda/Col2Im.cu | 2 +- aten/src/ATen/native/cuda/Im2Col.cu | 2 +- docs/source/masked.rst | 3 -- test/test_maskedtensor.py | 50 +++---------------- torch/masked/maskedtensor/passthrough.py | 5 -- .../_internal/common_methods_invocations.py | 4 +- 8 files changed, 14 insertions(+), 56 deletions(-) diff --git a/aten/src/ATen/native/Col2Im.cpp b/aten/src/ATen/native/Col2Im.cpp index 51e005c2901b93..0a98ee5c46ca77 100644 --- a/aten/src/ATen/native/Col2Im.cpp +++ b/aten/src/ATen/native/Col2Im.cpp @@ -144,7 +144,7 @@ static void col2im_out_cpu_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "col2im_out_cpu", [&] { Tensor input_n = Tensor(); Tensor output_n = Tensor(); diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp index 25eb4d6787240a..dac2ee6e3f103a 100644 --- a/aten/src/ATen/native/Im2Col.cpp +++ b/aten/src/ATen/native/Im2Col.cpp @@ -94,7 +94,7 @@ static void im2col_out_cpu_template( output.resize_({batch_size, n_output_plane, output_length}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "im2col_out_cpu", [&] { Tensor input_n; Tensor output_n; diff --git a/aten/src/ATen/native/cuda/Col2Im.cu b/aten/src/ATen/native/cuda/Col2Im.cu index e390666b307aa1..bb6d4748deb1f5 100644 --- a/aten/src/ATen/native/cuda/Col2Im.cu +++ b/aten/src/ATen/native/cuda/Col2Im.cu @@ -102,7 +102,7 @@ void col2im_out_cuda_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); int64_t output_batch_stride = output.stride(0); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "col2im_out_cuda", [&] { int64_t height_col = (output_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / diff --git a/aten/src/ATen/native/cuda/Im2Col.cu b/aten/src/ATen/native/cuda/Im2Col.cu index 796eda97b37331..312ad893c0d81e 100644 --- a/aten/src/ATen/native/cuda/Im2Col.cu +++ b/aten/src/ATen/native/cuda/Im2Col.cu @@ -103,7 +103,7 @@ static void im2col_out_cuda_template( output.resize_({batch_size, n_output_plane, output_length}); // Launch kernel - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "im2col_out_cuda", [&] { Tensor input_n; Tensor output_n; diff --git a/docs/source/masked.rst b/docs/source/masked.rst index 086e1f6c8b25c2..00dadd2abf6885 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -283,11 +283,9 @@ The following ops are currently supported: kron meshgrid narrow - nn.functional.unfold ravel select split - stack t transpose vsplit @@ -296,7 +294,6 @@ The following ops are currently supported: Tensor.expand_as Tensor.reshape Tensor.reshape_as - Tensor.unfold Tensor.view Other functions diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index 50d03b306bb0cc..0817822be457b2 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -68,18 +68,6 @@ def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08): if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match") -def _compare_forward_backward(data, mask, fn): - mt = masked_tensor(data, mask, requires_grad=True) - masked_res = fn(mt) - masked_res.sum().backward() - - t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() - tensor_res = fn(t) - tensor_res.sum().backward() - - _compare_mt_t(masked_res, tensor_res) - _compare_mt_t(mt.grad, t.grad, atol=1e-06) - def _create_random_mask(shape, device): return make_tensor(shape, device=device, dtype=torch.bool) @@ -178,8 +166,15 @@ def test_softmax(self, device): ], device=device ) + mt = masked_tensor(data, mask, requires_grad=True) + masked_res = torch.softmax(mt, -1) + masked_res.sum().backward() + xinf = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() + tensor_res = torch.softmax(xinf, -1) + tensor_res.sum().backward() - _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1)) + _compare_mt_t(masked_res, tensor_res) + _compare_mt_t(mt.grad, xinf.grad, atol=1e-06) def test_where(self, device): data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device) @@ -199,35 +194,6 @@ def test_where(self, device): _compare_mt_t(mx.grad, x.grad) _compare_mt_t(my.grad, y.grad) - def test_unfold(self, device): - data = torch.rand(5, 5, device=device) - mask = torch.rand(5, 5, device=device) > 0.5 - _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2)) - - def test_nn_unfold(self, device): - data = torch.rand(2, 5, 3, 4, device=device) - mask = torch.rand(2, 5, 3, 4, device=device) > 0.5 - _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3))) - - def test_stack(self, device): - masked_tensors = [ - masked_tensor( - torch.rand(2, 5, 3, 4, device=device), - torch.rand(2, 5, 3, 4, device=device) > 0.5, - requires_grad=True, - ) for _ in range(3) - ] - - data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors] - masked_res = torch.stack(masked_tensors) - tensor_res = torch.stack(data_tensors) - - masked_res.sum().backward() - tensor_res.sum().backward() - _compare_mt_t(masked_res, tensor_res) - for mt, t in zip(masked_tensors, data_tensors): - _compare_mt_t(mt.grad, t.grad, atol=1e-06) - def test_to_sparse(self, device): for sample in _generate_sample_data(device=device): data = sample.input diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py index ba13f50c1fee9c..4a2e79456c86db 100644 --- a/torch/masked/maskedtensor/passthrough.py +++ b/torch/masked/maskedtensor/passthrough.py @@ -30,11 +30,6 @@ torch.ops.aten._reshape_alias, torch.ops.aten.cat, torch.ops.aten.unsqueeze, - torch.ops.aten.unfold, - torch.ops.aten.unfold_backward, - torch.ops.aten.im2col, - torch.ops.aten.col2im, - torch.ops.aten.stack, ] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index fecab41a741f33..3049cd8d8f5753 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15266,8 +15266,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): autodiff_nonfusible_nodes=["aten::hardswish"]), OpInfo('nn.functional.unfold', aten_name='im2col', - dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_nn_unfold, # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, From e2b32fb5f0aeb36b0e58f383b17c6ea64d78411d Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 28 Aug 2024 13:49:35 -0700 Subject: [PATCH 0183/1018] Add torch.serialization.skip_data context manager (#134504) ## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504 Approved by: https://github.com/albanD --- docs/source/notes/serialization.rst | 1 + ...cpp_extensions_open_device_registration.py | 12 +- test/test_serialization.py | 88 ++++++++++++++ torch/_tensor.py | 66 +++++++++-- torch/_utils.py | 6 +- torch/serialization.py | 109 ++++++++++++++++-- torch/storage.py | 4 + 7 files changed, 260 insertions(+), 26 deletions(-) diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 5541d28bdcafa0..c05dc028a471c9 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -398,3 +398,4 @@ The following utility functions are related to serialization: .. autofunction:: clear_safe_globals .. autofunction:: get_safe_globals .. autoclass:: safe_globals +.. autoclass:: skip_data diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 616a6e0f4b5518..4e86ed458b0788 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - def test_open_device_numpy_serialization_map_location(self): + def test_open_device_numpy_serialization(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL @@ -553,6 +553,7 @@ def test_open_device_numpy_serialization_map_location(self): self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) + # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( @@ -569,6 +570,15 @@ def test_open_device_numpy_serialization_map_location(self): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) + # Test metadata_only + with TemporaryFileName() as f: + with self.assertRaisesRegex( + RuntimeError, + "Cannot serialize tensors on backends with no storage under skip_data context manager", + ): + with torch.serialization.skip_data(): + torch.save(sd, f) + if __name__ == "__main__": common.run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index a041473d195b19..3ba96b80541d88 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,5 +1,6 @@ # Owner(s): ["module: serialization"] +import contextlib import copy import gc import gzip @@ -19,6 +20,7 @@ from pathlib import Path import torch +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter from torch._utils import _rebuild_tensor from torch._utils_internal import get_file_path_2 from torch.serialization import ( @@ -27,6 +29,7 @@ LoadEndianness, safe_globals, set_default_load_endianness, + skip_data, SourceChangeWarning, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -4212,6 +4215,91 @@ def test_filewriter_metadata_writing(self, filename): sd_loaded_ref = torch.load(f) self.assertEqual(sd_loaded, sd_loaded_ref) + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization(self, materialize_fake): + # Create one tensor that uses each of the paths in __reduce_ex__ that should work + t_device = "cuda" if torch.cuda.is_available() else "cpu" + t_v2 = torch.randn(2, 3, device=t_device) + t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device) + i = torch.tensor([[0, 1, 1], + [2, 0, 2]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + if not materialize_fake: + # FakeTensorConverter messes up sizes of i and v for the sparse tensor + st = torch.sparse_coo_tensor(i, v, (2, 4)) + tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device)) + + mode, converter = FakeTensorMode(), FakeTensorConverter() + + def fn(t): + return converter.from_real_tensor(mode, t) if materialize_fake else t + + sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)} + sd_expected = { + 't_v2': torch.zeros(2, 3, device=t_device), + 't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device), + 'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)), + } + + if not materialize_fake: + sd['st'] = st + sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4)) + + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + with safe_globals([TwoTensor]): + sd_loaded = torch.load(f, weights_only=True) + self.assertEqual(sd_loaded, sd_expected, exact_device=True) + self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False)) + self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False)) + + # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx + if not materialize_fake: + ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) + with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__..remove'"): + with skip_data(), BytesIOContext() as f: + torch.save(ft, f) + + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization_preserves_views(self, materialize_fake): + ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext + with ctx(): + t = torch.randn(2, 3) + t_view = t.view(-1) + t_slice = t[1] + sd = {'t': t, 't_view': t_view, 't_slice': t_slice} + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + sd_loaded = torch.load(f, weights_only=True) + self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + + def test_skip_data_serialization_error_cases(self): + def _save_load(t): + with BytesIOContext() as f: + with skip_data(): + torch.save(t, f) + f.seek(0) + torch.load(f, weights_only=True) + + nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)]) + t = torch.randn(2, 3, device="meta") + with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"): + _save_load(nt) + + with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"): + _save_load(t) + + with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"): + with skip_data(), BytesIOContext() as f: + torch.save(torch.randn(2, 3), f) + f.seek(0) + torch.load(f, weights_only=True) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 98563aebae9aa7..61d5e3891b9e56 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -209,8 +209,16 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) state = torch._utils._get_obj_state(self) - if type(self) is Tensor and not state: + # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has + # some state that cannot be pickled + if ( + type(self) is torch._subclasses.fake_tensor.FakeTensor + and materialize_fake_tensors + ) or (type(self) is Tensor and not state): # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): @@ -251,6 +259,12 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() + + skip_data = torch.serialization._serialization_tls.skip_data + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -268,6 +282,10 @@ def _reduce_ex_internal(self, proto): # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) numpy_tensor = ( self.cpu().numpy() if self.dtype != torch.bfloat16 @@ -280,6 +298,10 @@ def _reduce_ex_internal(self, proto): if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. + if skip_data: + warnings.warn( + "Serializing tensors on the meta device under skip_data context manager is a no-op" + ) arg_meta = ( self.dtype, tuple(self.size()), @@ -288,6 +310,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: + if skip_data: + raise RuntimeError( + "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" + ) # quantizer_params can be different type based on torch attribute quantizer_params: Union[ Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] @@ -369,6 +395,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif self.is_nested: + if skip_data: + raise RuntimeError( + "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" + ) args_nested = ( # NB: values() currently returns the storage as a buffer in an unsafe way. # Ideally, we'd use a private API for this instead. TODO: Switch to this if @@ -383,14 +413,30 @@ def _reduce_ex_internal(self, proto): type(self) is not torch.Tensor and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ and ( - isinstance( - self, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), + isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) + or ( + not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and self.data_ptr() == 0 ) - or self.data_ptr() == 0 + ) + ): + arg_wrapper_subclass = ( + type(self), + self.dtype, + tuple(self.size()), + self.stride(), + self.storage_offset(), + self.layout, + self.device, + self.requires_grad, + ) + return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) + elif ( + type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + and ( + isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and not (skip_data and materialize_fake_tensors) ) ): arg_wrapper_subclass = ( @@ -418,6 +464,10 @@ def _reduce_ex_internal(self, proto): dtype=self.dtype, _internal=True, ) # type: ignore[assignment] + + if isinstance(self, torch._subclasses.fake_tensor.FakeTensor) and skip_data: + storage._fake_device = self.device + args = ( storage, self.storage_offset(), diff --git a/torch/_utils.py b/torch/_utils.py index 938392fa971590..f0d38daa811490 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -3,7 +3,6 @@ import functools import logging import sys -import threading import traceback import warnings from collections import defaultdict @@ -109,16 +108,13 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] -_thread_local_state = threading.local() - - def _get_restore_location(device): """Return the map_location location. Used for rebuild functions where the tensor device is distinct from the storage """ - map_location = getattr(_thread_local_state, "map_location", None) + map_location = torch.serialization._serialization_tls.map_location if map_location is None: return device else: diff --git a/torch/serialization.py b/torch/serialization.py index 2ac36ec371fe57..954596ade61f5e 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,6 +11,7 @@ import sys import tarfile import tempfile +import threading import warnings from contextlib import closing, contextmanager from enum import Enum @@ -60,6 +61,7 @@ "get_safe_globals", "add_safe_globals", "safe_globals", + "skip_data", ] @@ -87,6 +89,22 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] +# _serialization_tls is used to store thread local state specific to serialization +# that needs to be propagated to other files, in particular we use this for +# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) +# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +class _SerializationLocal(threading.local): + def __init__(self): + super().__init__() + self.map_location: Optional[MAP_LOCATION] = None + self.skip_data: bool = False + self.materialize_fake_tensors: bool = False + + +_serialization_tls = _SerializationLocal() + + class SourceChangeWarning(Warning): pass @@ -268,6 +286,46 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ +class skip_data: + """ + Context-manager that skips writing storage bytes for ``torch.save`` calls. + + Storages will still be saved, but the space that their bytes would usually be written to + will be empty space. The storage bytes can then be populated in a separate pass. + + .. warning:: + The ``skip_data`` context manager is an early prototype and is subject to change. + + Args: + materialize_fake_tensors: Whether to materialize FakeTensors. + + Example: + >>> import tempfile + >>> t = torch.randn(2, 3) + >>> with tempfile.NamedTemporaryFile() as f: + ... with torch.serialization.skip_data(): + ... torch.save(t, f.name) + ... torch.load(f.name, weights_only=True) + tensor([[0., 0., 0.], + [0., 0., 0.]]) + """ + + def __init__(self, materialize_fake_tensors: bool = False): + self.materialize_fake_tensors = materialize_fake_tensors + + def __enter__(self): + global _serialization_tls + self._old_skip_data = _serialization_tls.skip_data + self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors + _serialization_tls.skip_data = True + _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors + + def __exit__(self, type, value, tb): + global _serialization_tls + _serialization_tls.skip_data = self._old_skip_data + _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors + + def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -797,6 +855,11 @@ def save( ) return else: + global _serialization_tls + if _serialization_tls.skip_data: + raise RuntimeError( + "Cannot use skip_data=True with _use_new_zipfile_serialization=False" + ) with _open_file_like(f, "wb") as opened_file: _legacy_save(obj, opened_file, pickle_module, pickle_protocol) @@ -955,7 +1018,13 @@ def persistent_id(obj: Any) -> Optional[Tuple]: ) -def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): serialized_storages = {} id_map: Dict[int, str] = {} @@ -990,7 +1059,7 @@ def persistent_id(obj): # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check - if storage.data_ptr() != 0: + if str(storage.device) != "meta" and storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( @@ -1001,7 +1070,10 @@ def persistent_id(obj): storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) - location = location_tag(storage) + if hasattr(obj, "_fake_device") and obj._fake_device is not None: + location = str(obj._fake_device) + else: + location = location_tag(storage) serialized_storages[storage_key] = storage return ("storage", storage_type, storage_key, location, storage_numel) @@ -1027,14 +1099,18 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f"data/{key}" storage = serialized_storages[key] - # given that we copy things around anyway, we might use storage.cpu() - # this means to that to get tensors serialized, you need to implement - # .cpu() on the underlying Storage - if storage.device.type != "cpu": - storage = storage.cpu() - # Now that it is on the CPU we can directly copy it into the zip file num_bytes = storage.nbytes() - zip_file.write_record(name, storage, num_bytes) + global _serialization_tls + if _serialization_tls.skip_data: + zip_file.write_record_metadata(name, num_bytes) + else: + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + zip_file.write_record(name, storage, num_bytes) def load( @@ -1184,6 +1260,14 @@ def _get_wo_message(message: str) -> str: updated_message += message return updated_message + DOCS_MESSAGE + global _serialization_tls + skip_data = _serialization_tls.skip_data + if skip_data: + raise RuntimeError( + "`torch.load` called within a torch.serialization.skip_data context manager " + "is not supported yet. Please call torch.load outside the skip_data context manager." + ) + if weights_only is None: weights_only, warn_weights_only = False, True else: @@ -1758,9 +1842,10 @@ def find_class(self, mod_name, name): unpickler.persistent_load = persistent_load # Needed for tensors where storage device and rebuild tensor device are # not connected (wrapper subclasses and tensors rebuilt using numpy) - torch._utils._thread_local_state.map_location = map_location + global _serialization_tls + _serialization_tls.map_location = map_location result = unpickler.load() - del torch._utils._thread_local_state.map_location + _serialization_tls.map_location = None torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata( diff --git a/torch/storage.py b/torch/storage.py index b6ba608c16e5ca..8848649905f936 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,6 +39,8 @@ class _StorageBase: is_sparse: _bool = False is_sparse_csr: _bool = False device: torch.device + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None def __init__(self, *args, **kwargs): pass @@ -649,6 +651,8 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None dtype: torch.dtype From d1963295200e4edd8b912b192001f04cb72a993f Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Wed, 28 Aug 2024 23:58:37 +0000 Subject: [PATCH 0184/1018] [AOTI][Tooling] Follow up to print location of saved file path for `torch.pickle_save()` (#134651) Summary: - Follow up to add torch.pickle_save() log for saved file path - Minor debug printer code refine Test Plan: CI Differential Revision: D61883239 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134651 Approved by: https://github.com/muchulee8 --- torch/_inductor/codegen/debug_utils.py | 30 ++++++++++++------- .../csrc/inductor/aoti_torch/shim_common.cpp | 11 +++++-- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index 11fe548155d514..ebe46fc8cde01e 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations import functools +import logging from enum import Enum from typing import List, Optional @@ -11,6 +12,9 @@ from .multi_kernel import MultiKernel +log = logging.getLogger(__name__) + + # AOTI debug printing related configs class IntermediateValueDebuggingLevel(Enum): # OFF: No intermediate tensor value debug info will be printed or saved. @@ -37,7 +41,7 @@ def __init__( self.kernel_name = kernel_name self.arg_signatures: Optional[List[type]] = None self.kernel = kernel - self.filtered_kernel_names_to_print = self.get_debug_filtered_kernel_names() + self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names() def __enter__(self): self._perform_debug_print_or_save_helper( @@ -81,6 +85,15 @@ def _perform_debug_print_or_save_helper( arg_signatures=self.arg_signatures, ) + @functools.lru_cache # noqa: B019 + def _get_debug_filtered_kernel_names(self) -> List[str]: + if config.aot_inductor.filtered_kernel_names is None: + return [] + return [ + x.strip() + for x in config.aot_inductor.filtered_kernel_names.lower().split(",") + ] + def set_printer_args( self, args_to_print_or_save: List[str], @@ -90,21 +103,15 @@ def set_printer_args( ): # Note: MultiKernel debug printing is not supported for now if isinstance(kernel, MultiKernel): + log.info( + "MultiKernel type is not supported in AOTI debug printer tool yet." + ) self.debug_printer_level = IntermediateValueDebuggingLevel.OFF self.args_to_print_or_save = args_to_print_or_save self.kernel_name = kernel_name self.arg_signatures = arg_signatures self.kernel = kernel - @functools.lru_cache # noqa: B019 - def get_debug_filtered_kernel_names(self) -> List[str]: - if config.aot_inductor.filtered_kernel_names is None: - return [] - return [ - x.strip() - for x in config.aot_inductor.filtered_kernel_names.lower().split(",") - ] - def codegen_intermediate_tensor_value_save( self, args_to_save, @@ -143,7 +150,8 @@ def codegen_intermediate_tensor_value_print( ): continue if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: - # when debug printing is enabled i.e. DEFAULT_PRINT level, check if filtered kernel name list is provided + # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY, + # check if filtered kernel name list is provided if ( len(self.filtered_kernel_names_to_print) > 0 and kernel_name not in self.filtered_kernel_names_to_print diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 77bba40c782cf1..5763470449e664 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1007,11 +1007,13 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( std::string cwd = fs::current_path().string(); std::string tmp_folder = cwd + "/tmp/aoti_torch/"; if (!fs::exists(tmp_folder)) { - std::cout << "Path does not exist, creating it..." << tmp_folder - << std::endl; + std::cout + << "aoti_torch_save_tensor_handle: Path does not exist, creating it..." + << tmp_folder << std::endl; if (!fs::create_directories(tmp_folder)) { - std::cout << "Error creating directory: " << tmp_folder << std::endl; + std::cout << "aoti_torch_save_tensor_handle: Error creating directory: " + << tmp_folder << std::endl; return; } } @@ -1022,6 +1024,9 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( std::ofstream fout(tensor_filepath_to_save, std::ios::out | std::ios::binary); fout.write(bytes.data(), bytes.size()); fout.close(); + + std::cout << "aoti_torch_save_tensor_handle: Saved tensor to: " + << tensor_filepath_to_save << std::endl; #endif // !defined(C10_MOBILE) } From 53bfa2906f386ca6e79dee13964220b1b7cbc70c Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Thu, 29 Aug 2024 01:02:01 +0000 Subject: [PATCH 0185/1018] hang dim hint constants off Dim (#134702) Summary: Retry landing https://github.com/pytorch/pytorch/pull/134484 Test Plan: (see original) Differential Revision: D61925860 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134702 Approved by: https://github.com/pianpwk --- docs/source/export.rst | 1 - test/export/test_export.py | 32 +++++----------- torch/_export/non_strict_utils.py | 4 +- torch/export/dynamic_shapes.py | 62 +++++++++++++++++-------------- 4 files changed, 46 insertions(+), 53 deletions(-) diff --git a/docs/source/export.rst b/docs/source/export.rst index deb84548a19cdc..3908beeb9b345e 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -676,7 +676,6 @@ API Reference .. autofunction:: save .. autofunction:: load .. autofunction:: register_dataclass -.. autoclass:: torch.export.dynamic_shapes.DIM .. autofunction:: torch.export.dynamic_shapes.Dim .. autofunction:: dims .. autoclass:: torch.export.dynamic_shapes.ShapesCollection diff --git a/test/export/test_export.py b/test/export/test_export.py index 3997de36eb4b5b..ec6bbdc93d6c44 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1773,8 +1773,6 @@ def forward(self, x, y, y1, z): ) def test_static_dim_constraints(self): - from torch.export.dynamic_shapes import DIM - class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1803,7 +1801,7 @@ def forward(self, x, y, z): ((dx, None), (dy, 4), (dz, 3)), ((None, 6), (5, None), (None, None)), ((4, 6), {0: None, 1: 4}, {0: None, 1: 3}), - (None, None, (DIM.STATIC, DIM.STATIC)), + (None, None, (Dim.STATIC, Dim.STATIC)), ]: ep = export(foo, inputs, dynamic_shapes=dynamic_shapes) self.assertEqual(foo(*inputs), ep.module()(*inputs)) @@ -1950,9 +1948,7 @@ def forward(self, inp: Inp): self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") def test_mismatched_dynamic_shapes(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC class M(torch.nn.Module): def forward(self, x): @@ -1984,7 +1980,7 @@ def forward(self, x): + re.escape( "specified at `dynamic_shapes[0]['k']['k'][0]` " "(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - " where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)" + " where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)" ), ): export(M(), inputs, dynamic_shapes=dynamic_shapes) @@ -2061,7 +2057,7 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, re.escape( - "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " "and can easily lead to constraint violation errors or obscure errors in torch.export." ), ): @@ -2465,9 +2461,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_mark_and_auto_dynamic(self): # for this use case, mark_dynamic() and AUTO should have same effect. # check that same symbol gets allocated to both dims without raising constraint violation. - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC class Foo(torch.nn.Module): def forward(self, x, y): @@ -2495,9 +2489,7 @@ def forward(self, x, y): def test_dont_duck_size_for_auto_dynamic(self): # for this use case, mark_dynamic() and AUTO should have same effect. # check that same symbol gets allocated to both dims without raising constraint violation. - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC class Foo(torch.nn.Module): def forward(self, x, y): @@ -6876,9 +6868,7 @@ def test_automatic_dynamic_shapes_simple_equality(self): # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism # leads to replacement symbols being set for equalities, and inferred relationships being checked # with runtime asserts. Check that we specialize to static values when the program says so. - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 1: direct equality between symbols class SimpleEquality(torch.nn.Module): @@ -6943,9 +6933,7 @@ def forward(self, x, y, z): ) def test_automatic_dynamic_shapes_constant_relation(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 2: related by constant: s0 + 4 = s1 class OffBy4(torch.nn.Module): @@ -6988,9 +6976,7 @@ def forward(self, x, y): ) def test_automatic_dynamic_shapes_linear_relation(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 3: linear relation class LinearRel(torch.nn.Module): diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 8b5ff29b6c9e02..ddb51064f9ad78 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -24,10 +24,10 @@ from torch.export.dynamic_shapes import ( _check_dynamic_shapes, _combine_args, + _DimHint, _process_dynamic_shapes, _transform_shapes_for_default_dynamic, _tree_map_with_path, - DIM, ) from torch.export.graph_signature import CustomObjArgument from torch.fx.experimental import _config as config @@ -351,7 +351,7 @@ def make_constraints( # we want the symbol, not its replacement, which could be an expression. Maybe # there's a better way to do this, e.g., by (re)computing value ranges for expressions? dim = shape_spec[i] if shape_spec else None - if dim is None or isinstance(dim, DIM): + if dim is None or isinstance(dim, _DimHint): range_constraints[d.node.expr] = shape_env.var_to_range[ d.node._expr ] diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 48438ee8cb2a55..7232114619b4f6 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -30,20 +30,21 @@ __all__ = [ "Constraint", - "DIM", "Dim", "dims", "refine_dynamic_shapes_from_suggested_fixes", ] -class DIM(Enum): +class _DimHint(Enum): """ - Enum for automatic/static dynamic shapes. + Enum for dynamic shape hints. + - AUTO means automatic inference of shape (static or dynamic). + - STATIC means static shape (always specialized). """ - STATIC = auto() AUTO = auto() + STATIC = auto() class _Dim(type): @@ -214,6 +215,7 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): Returns: A type that can be used in dynamic shape specifications for tensors. """ + from torch.utils._sympy.numbers import int_oo _min = 0 if min is None else min @@ -227,6 +229,10 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): return dim +Dim.AUTO = _DimHint.AUTO # type: ignore[attr-defined] +Dim.STATIC = _DimHint.STATIC # type: ignore[attr-defined] + + def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): """ Util to create multiple :func:`Dim` types. @@ -668,24 +674,24 @@ def check_symbols(path, tensor, shape): for i, dim in shape.items(): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, DIM)) or dim is None): + elif not (isinstance(dim, (int, _DimHint)) or dim is None): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " - f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", + f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", case_name="dynamic_shapes_validation", ) elif isinstance(shape, (tuple, list)): for i, dim in enumerate(shape): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, DIM)) or dim is None): + elif not (isinstance(dim, (int, _DimHint)) or dim is None): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension #{i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " - f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", + f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", case_name="dynamic_shapes_validation", ) elif shape is not None: @@ -693,7 +699,7 @@ def check_symbols(path, tensor, shape): UserErrorType.INVALID_INPUT, f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - f" where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)", + f" where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)", case_name="dynamic_shapes_validation", ) @@ -740,18 +746,18 @@ def check_shape(path, t, dynamic_shape): _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") - # raise user warning if both DIM.AUTO & Dims are specified in dynamic_shapes + # raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes) if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any( - s == DIM.AUTO for s in flatter_dynamic_shapes + s == _DimHint.AUTO for s in flatter_dynamic_shapes ): raise UserError( UserErrorType.INVALID_INPUT, - "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " "and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims " - "expect all equal or related dimensions to be specified, and does not yet compose well with `DIM.AUTO`. " - "We suggest using `DIM.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " + "expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. " + "We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " "torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` " "if you want to assert on the exact specification of your program's dynamic shapes behavior.", case_name="dynamic_shapes_validation", @@ -773,8 +779,8 @@ def _transform_shapes_for_default_dynamic( for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.) For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are: - - DIM.AUTO: dynamic, allocated a symbol - - None/unspecified/DIM.STATIC: static + - Dim.AUTO: dynamic, allocated a symbol + - None/unspecified/Dim.STATIC: static - Dim/DerivedDims: also a strict assertion To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes @@ -784,8 +790,8 @@ def _transform_shapes_for_default_dynamic( An example conversion might look like, for a 3-d input tensor: input spec: { - 0: DIM.AUTO, - 1: None, # or DIM.STATIC + 0: Dim.AUTO, + 1: None, # or Dim.STATIC 2: Dim("dx"), } output spec: { @@ -832,10 +838,10 @@ def _marked_dynamic(tensor, i): out = {} for i, val in enumerate(tensor.shape): dim = shape.get(i, None) - if _marked_dynamic(tensor, i) or dim == DIM.AUTO: + if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: # don't have to specify anything if dynamic # None also works, since assume_static_by_default=False - if dim == DIM.AUTO: + if dim == _DimHint.AUTO: torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing continue elif isinstance(dim, _Dim): @@ -846,14 +852,14 @@ def _marked_dynamic(tensor, i): out[i] = dim else: # make explicitly static - assert dim is None or dim == DIM.STATIC + assert dim is None or dim == _DimHint.STATIC out[i] = val elif isinstance(shape, (tuple, list)): out = [] for i, val in enumerate(tensor.shape): dim = shape[i] - if _marked_dynamic(tensor, i) or dim == DIM.AUTO: - if dim == DIM.AUTO: + if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: + if dim == _DimHint.AUTO: torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing out.append(None) elif isinstance(dim, _Dim): @@ -861,7 +867,7 @@ def _marked_dynamic(tensor, i): elif isinstance(dim, int): out.append(dim) else: - assert dim is None or dim == DIM.STATIC + assert dim is None or dim == _DimHint.STATIC out.append(val) out = type(shape)(out) # type: ignore[assignment] else: @@ -1040,7 +1046,7 @@ def _get_dim_name_mapping( dynamic_shapes, is_leaf=lambda x: isinstance(x, _Dim), )[0]: - if isinstance(dim, (int, DIM)) or dim is None: + if isinstance(dim, (int, _DimHint)) or dim is None: continue name_to_dim[dim.__name__] = dim if isinstance(dim, _DerivedDim): @@ -1122,9 +1128,11 @@ def refine_dynamic_shapes_from_suggested_fixes( name, expr = fix.split(" = ") expr = sympy.sympify(expr) if isinstance(expr, sympy.Number): - shape_fixes[name] = int(expr) # static, integer + # static, integer + shape_fixes[name] = int(expr) # type: ignore[assignment] else: - shape_fixes[name] = expr # relation or derived dim + # relation or derived dim + shape_fixes[name] = expr name_to_dim = _get_dim_name_mapping(dynamic_shapes) From b375e67225a732ee263ffe9024de3a198980c398 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Wed, 28 Aug 2024 17:44:58 -0700 Subject: [PATCH 0186/1018] [MPS] Fix SDP training (#134719) Check whether the input tensors require grad. If required, then we don't get into the fast path and fall back to composite implicit. Fixes #134678 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134719 Approved by: https://github.com/malfet --- .../ATen/native/transformers/attention.cpp | 3 +- test/test_mps.py | 33 ++++++++++++++++--- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 5369e87d58bec3..08a6eb04cee568 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -737,7 +737,8 @@ Tensor scaled_dot_product_attention( return std::get<0>(out_lse_softmax); } case sdp::SDPBackend::math: - if (query_.device().type() == DeviceType::MPS && dropout_p == 0.0 + if ((!GradMode::is_enabled() || (!query_.requires_grad() && !key.requires_grad() && !value.requires_grad())) + && query_.device().type() == DeviceType::MPS && dropout_p == 0.0 && query_.is_contiguous() && key.is_contiguous() && value.is_contiguous() && !query_.is_nested() && !key.is_nested() && !value.is_nested()) { return std::get<0>(at::_scaled_dot_product_attention_math_for_mps( diff --git a/test/test_mps.py b/test/test_mps.py index cd28de9dcae3c1..281c68b42cae14 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9360,18 +9360,37 @@ def _compare_tensors(self, y, ref): err = ((y - ref).abs() / denom).mean().item() self.assertLess(err, 0.01) - def _test_sdpa_no_mask(self, is_causal: bool, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128): + def _test_sdpa_no_mask( + self, + is_causal: bool, + dtype: torch.dtype, + L: int = 1, + S: int = 72, + NH: int = 32, + HS: int = 128, + requires_grad: bool = False + ): + torch.manual_seed(1729) with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): - q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps") + q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad) k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") + q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad) + k_cpu = k.cpu() + v_cpu = v.cpu() y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal) - y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=is_causal) + y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal) self._compare_tensors(y.cpu(), y_ref) + if requires_grad and torch.is_grad_enabled(): + y.sum().backward() + y_ref.sum().backward() + + self._compare_tensors(q.grad.cpu(), q_cpu.grad) + def test_sdpa_no_mask_no_causal_fp32(self): self._test_sdpa_no_mask(False, torch.float32) @@ -9393,6 +9412,12 @@ def test_sdpa_no_mask_causal_fp16_L7_S17(self): def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self): self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121) + def test_sdpa_no_mask_no_causal_fp32_grad(self): + self._test_sdpa_no_mask(False, torch.float32, requires_grad=True) + + with torch.no_grad(): + self._test_sdpa_no_mask(False, torch.float32, requires_grad=True) + def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128): torch.manual_seed(1729) causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps')) @@ -9421,7 +9446,7 @@ def test_sdpa_mask_fp16_L6(self): self._test_sdpa_mask(torch.float16, 6) def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self): - self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121) + self._test_sdpa_mask(torch.float16, 7, 17, 23, 121) class TestGatherScatter(TestCaseMPS): From 86adf0c6fbde6511a98f7e8597a1644a6683b400 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 01:30:49 +0000 Subject: [PATCH 0187/1018] Revert "Add torch.serialization.skip_data context manager (#134504)" This reverts commit 202600bc2384cb19a29b8fca503bafc289158c32. Reverted https://github.com/pytorch/pytorch/pull/134504 on behalf of https://github.com/mikaylagawarecki due to This is breaking Windows docs tests due to NamedTemporaryFile on Windows not working well ([comment](https://github.com/pytorch/pytorch/pull/134504#issuecomment-2316543901)) --- docs/source/notes/serialization.rst | 1 - ...cpp_extensions_open_device_registration.py | 12 +- test/test_serialization.py | 88 -------------- torch/_tensor.py | 66 ++--------- torch/_utils.py | 6 +- torch/serialization.py | 109 ++---------------- torch/storage.py | 4 - 7 files changed, 26 insertions(+), 260 deletions(-) diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index c05dc028a471c9..5541d28bdcafa0 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -398,4 +398,3 @@ The following utility functions are related to serialization: .. autofunction:: clear_safe_globals .. autofunction:: get_safe_globals .. autoclass:: safe_globals -.. autoclass:: skip_data diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 4e86ed458b0788..616a6e0f4b5518 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - def test_open_device_numpy_serialization(self): + def test_open_device_numpy_serialization_map_location(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL @@ -553,7 +553,6 @@ def test_open_device_numpy_serialization(self): self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) - # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( @@ -570,15 +569,6 @@ def test_open_device_numpy_serialization(self): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) - # Test metadata_only - with TemporaryFileName() as f: - with self.assertRaisesRegex( - RuntimeError, - "Cannot serialize tensors on backends with no storage under skip_data context manager", - ): - with torch.serialization.skip_data(): - torch.save(sd, f) - if __name__ == "__main__": common.run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index 3ba96b80541d88..a041473d195b19 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,6 +1,5 @@ # Owner(s): ["module: serialization"] -import contextlib import copy import gc import gzip @@ -20,7 +19,6 @@ from pathlib import Path import torch -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter from torch._utils import _rebuild_tensor from torch._utils_internal import get_file_path_2 from torch.serialization import ( @@ -29,7 +27,6 @@ LoadEndianness, safe_globals, set_default_load_endianness, - skip_data, SourceChangeWarning, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -4215,91 +4212,6 @@ def test_filewriter_metadata_writing(self, filename): sd_loaded_ref = torch.load(f) self.assertEqual(sd_loaded, sd_loaded_ref) - @parametrize("materialize_fake", (True, False)) - def test_skip_data_serialization(self, materialize_fake): - # Create one tensor that uses each of the paths in __reduce_ex__ that should work - t_device = "cuda" if torch.cuda.is_available() else "cpu" - t_v2 = torch.randn(2, 3, device=t_device) - t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device) - i = torch.tensor([[0, 1, 1], - [2, 0, 2]]) - v = torch.tensor([3, 4, 5], dtype=torch.float32) - if not materialize_fake: - # FakeTensorConverter messes up sizes of i and v for the sparse tensor - st = torch.sparse_coo_tensor(i, v, (2, 4)) - tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device)) - - mode, converter = FakeTensorMode(), FakeTensorConverter() - - def fn(t): - return converter.from_real_tensor(mode, t) if materialize_fake else t - - sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)} - sd_expected = { - 't_v2': torch.zeros(2, 3, device=t_device), - 't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device), - 'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)), - } - - if not materialize_fake: - sd['st'] = st - sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4)) - - with BytesIOContext() as f: - with skip_data(materialize_fake_tensors=materialize_fake): - torch.save(sd, f) - f.seek(0) - with safe_globals([TwoTensor]): - sd_loaded = torch.load(f, weights_only=True) - self.assertEqual(sd_loaded, sd_expected, exact_device=True) - self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False)) - self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False)) - - # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx - if not materialize_fake: - ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) - with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__..remove'"): - with skip_data(), BytesIOContext() as f: - torch.save(ft, f) - - @parametrize("materialize_fake", (True, False)) - def test_skip_data_serialization_preserves_views(self, materialize_fake): - ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext - with ctx(): - t = torch.randn(2, 3) - t_view = t.view(-1) - t_slice = t[1] - sd = {'t': t, 't_view': t_view, 't_slice': t_slice} - with BytesIOContext() as f: - with skip_data(materialize_fake_tensors=materialize_fake): - torch.save(sd, f) - f.seek(0) - sd_loaded = torch.load(f, weights_only=True) - self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) - self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) - - def test_skip_data_serialization_error_cases(self): - def _save_load(t): - with BytesIOContext() as f: - with skip_data(): - torch.save(t, f) - f.seek(0) - torch.load(f, weights_only=True) - - nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)]) - t = torch.randn(2, 3, device="meta") - with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"): - _save_load(nt) - - with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"): - _save_load(t) - - with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"): - with skip_data(), BytesIOContext() as f: - torch.save(torch.randn(2, 3), f) - f.seek(0) - torch.load(f, weights_only=True) - def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 61d5e3891b9e56..98563aebae9aa7 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -209,16 +209,8 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): - materialize_fake_tensors = ( - torch.serialization._serialization_tls.materialize_fake_tensors - ) state = torch._utils._get_obj_state(self) - # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has - # some state that cannot be pickled - if ( - type(self) is torch._subclasses.fake_tensor.FakeTensor - and materialize_fake_tensors - ) or (type(self) is Tensor and not state): + if type(self) is Tensor and not state: # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): @@ -259,12 +251,6 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() - - skip_data = torch.serialization._serialization_tls.skip_data - materialize_fake_tensors = ( - torch.serialization._serialization_tls.materialize_fake_tensors - ) - # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -282,10 +268,6 @@ def _reduce_ex_internal(self, proto): # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. - if skip_data: - raise RuntimeError( - "Cannot serialize tensors on backends with no storage under skip_data context manager" - ) numpy_tensor = ( self.cpu().numpy() if self.dtype != torch.bfloat16 @@ -298,10 +280,6 @@ def _reduce_ex_internal(self, proto): if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. - if skip_data: - warnings.warn( - "Serializing tensors on the meta device under skip_data context manager is a no-op" - ) arg_meta = ( self.dtype, tuple(self.size()), @@ -310,10 +288,6 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: - if skip_data: - raise RuntimeError( - "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" - ) # quantizer_params can be different type based on torch attribute quantizer_params: Union[ Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] @@ -395,10 +369,6 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif self.is_nested: - if skip_data: - raise RuntimeError( - "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" - ) args_nested = ( # NB: values() currently returns the storage as a buffer in an unsafe way. # Ideally, we'd use a private API for this instead. TODO: Switch to this if @@ -413,30 +383,14 @@ def _reduce_ex_internal(self, proto): type(self) is not torch.Tensor and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ and ( - isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) - or ( - not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) - and self.data_ptr() == 0 + isinstance( + self, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), ) - ) - ): - arg_wrapper_subclass = ( - type(self), - self.dtype, - tuple(self.size()), - self.stride(), - self.storage_offset(), - self.layout, - self.device, - self.requires_grad, - ) - return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) - elif ( - type(self) is not torch.Tensor - and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ - and ( - isinstance(self, torch._subclasses.fake_tensor.FakeTensor) - and not (skip_data and materialize_fake_tensors) + or self.data_ptr() == 0 ) ): arg_wrapper_subclass = ( @@ -464,10 +418,6 @@ def _reduce_ex_internal(self, proto): dtype=self.dtype, _internal=True, ) # type: ignore[assignment] - - if isinstance(self, torch._subclasses.fake_tensor.FakeTensor) and skip_data: - storage._fake_device = self.device - args = ( storage, self.storage_offset(), diff --git a/torch/_utils.py b/torch/_utils.py index f0d38daa811490..938392fa971590 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -3,6 +3,7 @@ import functools import logging import sys +import threading import traceback import warnings from collections import defaultdict @@ -108,13 +109,16 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] +_thread_local_state = threading.local() + + def _get_restore_location(device): """Return the map_location location. Used for rebuild functions where the tensor device is distinct from the storage """ - map_location = torch.serialization._serialization_tls.map_location + map_location = getattr(_thread_local_state, "map_location", None) if map_location is None: return device else: diff --git a/torch/serialization.py b/torch/serialization.py index 954596ade61f5e..2ac36ec371fe57 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,7 +11,6 @@ import sys import tarfile import tempfile -import threading import warnings from contextlib import closing, contextmanager from enum import Enum @@ -61,7 +60,6 @@ "get_safe_globals", "add_safe_globals", "safe_globals", - "skip_data", ] @@ -89,22 +87,6 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] -# _serialization_tls is used to store thread local state specific to serialization -# that needs to be propagated to other files, in particular we use this for -# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) -# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) -# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) -class _SerializationLocal(threading.local): - def __init__(self): - super().__init__() - self.map_location: Optional[MAP_LOCATION] = None - self.skip_data: bool = False - self.materialize_fake_tensors: bool = False - - -_serialization_tls = _SerializationLocal() - - class SourceChangeWarning(Warning): pass @@ -286,46 +268,6 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ -class skip_data: - """ - Context-manager that skips writing storage bytes for ``torch.save`` calls. - - Storages will still be saved, but the space that their bytes would usually be written to - will be empty space. The storage bytes can then be populated in a separate pass. - - .. warning:: - The ``skip_data`` context manager is an early prototype and is subject to change. - - Args: - materialize_fake_tensors: Whether to materialize FakeTensors. - - Example: - >>> import tempfile - >>> t = torch.randn(2, 3) - >>> with tempfile.NamedTemporaryFile() as f: - ... with torch.serialization.skip_data(): - ... torch.save(t, f.name) - ... torch.load(f.name, weights_only=True) - tensor([[0., 0., 0.], - [0., 0., 0.]]) - """ - - def __init__(self, materialize_fake_tensors: bool = False): - self.materialize_fake_tensors = materialize_fake_tensors - - def __enter__(self): - global _serialization_tls - self._old_skip_data = _serialization_tls.skip_data - self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors - _serialization_tls.skip_data = True - _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors - - def __exit__(self, type, value, tb): - global _serialization_tls - _serialization_tls.skip_data = self._old_skip_data - _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors - - def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -855,11 +797,6 @@ def save( ) return else: - global _serialization_tls - if _serialization_tls.skip_data: - raise RuntimeError( - "Cannot use skip_data=True with _use_new_zipfile_serialization=False" - ) with _open_file_like(f, "wb") as opened_file: _legacy_save(obj, opened_file, pickle_module, pickle_protocol) @@ -1018,13 +955,7 @@ def persistent_id(obj: Any) -> Optional[Tuple]: ) -def _save( - obj, - zip_file, - pickle_module, - pickle_protocol, - _disable_byteorder_record, -): +def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): serialized_storages = {} id_map: Dict[int, str] = {} @@ -1059,7 +990,7 @@ def persistent_id(obj): # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check - if str(storage.device) != "meta" and storage.data_ptr() != 0: + if storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( @@ -1070,10 +1001,7 @@ def persistent_id(obj): storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) - if hasattr(obj, "_fake_device") and obj._fake_device is not None: - location = str(obj._fake_device) - else: - location = location_tag(storage) + location = location_tag(storage) serialized_storages[storage_key] = storage return ("storage", storage_type, storage_key, location, storage_numel) @@ -1099,18 +1027,14 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f"data/{key}" storage = serialized_storages[key] + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file num_bytes = storage.nbytes() - global _serialization_tls - if _serialization_tls.skip_data: - zip_file.write_record_metadata(name, num_bytes) - else: - # given that we copy things around anyway, we might use storage.cpu() - # this means to that to get tensors serialized, you need to implement - # .cpu() on the underlying Storage - if storage.device.type != "cpu": - storage = storage.cpu() - # Now that it is on the CPU we can directly copy it into the zip file - zip_file.write_record(name, storage, num_bytes) + zip_file.write_record(name, storage, num_bytes) def load( @@ -1260,14 +1184,6 @@ def _get_wo_message(message: str) -> str: updated_message += message return updated_message + DOCS_MESSAGE - global _serialization_tls - skip_data = _serialization_tls.skip_data - if skip_data: - raise RuntimeError( - "`torch.load` called within a torch.serialization.skip_data context manager " - "is not supported yet. Please call torch.load outside the skip_data context manager." - ) - if weights_only is None: weights_only, warn_weights_only = False, True else: @@ -1842,10 +1758,9 @@ def find_class(self, mod_name, name): unpickler.persistent_load = persistent_load # Needed for tensors where storage device and rebuild tensor device are # not connected (wrapper subclasses and tensors rebuilt using numpy) - global _serialization_tls - _serialization_tls.map_location = map_location + torch._utils._thread_local_state.map_location = map_location result = unpickler.load() - _serialization_tls.map_location = None + del torch._utils._thread_local_state.map_location torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata( diff --git a/torch/storage.py b/torch/storage.py index 8848649905f936..b6ba608c16e5ca 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,8 +39,6 @@ class _StorageBase: is_sparse: _bool = False is_sparse_csr: _bool = False device: torch.device - # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: _Optional[torch.device] = None def __init__(self, *args, **kwargs): pass @@ -651,8 +649,6 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False - # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: _Optional[torch.device] = None dtype: torch.dtype From 18877905a451c1a8a685614d9c1f327b474dc791 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 01:42:52 +0000 Subject: [PATCH 0188/1018] Revert "Add MaskedTensor support to *_like API (#128637)" This reverts commit b6e51711a0ea6174806e75ab6e208d2d910b45f5. Reverted https://github.com/pytorch/pytorch/pull/128637 on behalf of https://github.com/ZainRizvi due to Actually, seems like it was this commit that introduced the failure: test_maskedtensor.py::TestOperatorsCUDA::test_like_empty_like_layout1_cuda_bool [GH job link](https://github.com/pytorch/pytorch/actions/runs/10604690725/job/29392898277) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/b6e51711a0ea6174806e75ab6e208d2d910b45f5) ([comment](https://github.com/pytorch/pytorch/pull/128637#issuecomment-2316554188)) --- docs/source/masked.rst | 14 ------- test/test_maskedtensor.py | 53 -------------------------- torch/masked/maskedtensor/__init__.py | 1 - torch/masked/maskedtensor/_ops_refs.py | 13 ++++--- torch/masked/maskedtensor/like.py | 38 ------------------ 5 files changed, 7 insertions(+), 112 deletions(-) delete mode 100644 torch/masked/maskedtensor/like.py diff --git a/docs/source/masked.rst b/docs/source/masked.rst index 00dadd2abf6885..60dd67f643b890 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -296,25 +296,11 @@ The following ops are currently supported: Tensor.reshape_as Tensor.view -Other functions ---------------- - -The following :mod:`torch` functions has also a specific support for :class:`~torch.masked.MaskedTensor`: - -:func:`~torch.empty_like`, -:func:`~torch.full_like`, -:func:`~torch.ones_like`, -:func:`~torch.rand_like`, -:func:`~torch.randint_like`, -:func:`~torch.randn_like`, -:func:`~torch.zeros_like`. - .. This module needs to be documented. Adding here in the meantime .. for tracking purposes .. py:module:: torch.masked.maskedtensor.binary .. py:module:: torch.masked.maskedtensor.core .. py:module:: torch.masked.maskedtensor.creation -.. py:module:: torch.masked.maskedtensor.like .. py:module:: torch.masked.maskedtensor.passthrough .. py:module:: torch.masked.maskedtensor.reductions .. py:module:: torch.masked.maskedtensor.unary \ No newline at end of file diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index 0817822be457b2..bf42cf086376ec 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -19,7 +19,6 @@ binary_ufuncs, reduction_ops, unary_ufuncs, - ops_and_refs, ) from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask @@ -27,7 +26,6 @@ from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES from torch.masked.maskedtensor.reductions import REDUCE_NAMES -from torch.masked.maskedtensor.like import LIKE_NAMES def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05): @@ -814,13 +812,9 @@ def is_binary(op): def is_reduction(op): return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"} -def is_like(op): - return op.name in LIKE_NAMES - mt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)] mt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)] mt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)] -mt_like_funcs = [op for op in ops_and_refs if is_like(op)] MASKEDTENSOR_FLOAT_TYPES = { torch.float16, @@ -828,16 +822,6 @@ def is_like(op): torch.float64, } -MASKEDTENSOR_SUPPORTED_TYPES = [ - *MASKEDTENSOR_FLOAT_TYPES, - torch.bool, - torch.int8, - torch.int16, - torch.int32, - torch.int64, -] - - class TestOperators(TestCase): def _convert_mt_args(self, args, mask, layout): return [ @@ -983,43 +967,6 @@ def test_reduction_all(self, device, dtype, op, layout): self._test_reduction_equality(device, dtype, op, layout) - @ops(mt_like_funcs, allowed_dtypes=MASKEDTENSOR_SUPPORTED_TYPES) # type: ignore[arg-type] - @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) - def test_like(self, device, dtype, op, layout): - samples = op.sample_inputs(device, dtype, requires_grad=dtype in MASKEDTENSOR_FLOAT_TYPES) - - for sample in samples: - input = sample.input - sample_args, sample_kwargs = sample.args, sample.kwargs - mask = _create_random_mask(input.shape, device) - - if layout == torch.sparse_coo: - mask = mask.to_sparse_coo().coalesce() - input = input.sparse_mask(mask) - elif layout == torch.sparse_csr: - if input.ndim != 2 or mask.ndim != 2: - continue - mask = mask.to_sparse_csr() - input = input.sparse_mask(mask) - - mt = masked_tensor(input, mask) - - try: - t_result = op(input, *sample_args, **sample_kwargs) - except NotImplementedError: - self.skipTest(f"{op.name} is not supported for {layout}") - - mt_result = op(mt, *sample_args, **sample_kwargs) - - if 'rand' not in op.name and op.name != 'empty_like': - _compare_mt_t(mt_result, t_result.to_dense()) - - self.assertEqual(mt_result.shape, t_result.shape) - self.assertEqual(mt_result.dtype, t_result.dtype) - self.assertEqual(mt_result.layout, t_result.layout) - self.assertEqual(mt_result.device, t_result.device) - self.assertEqual(mt_result.requires_grad, t_result.requires_grad) - only_for = ("cpu", "cuda") instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) diff --git a/torch/masked/maskedtensor/__init__.py b/torch/masked/maskedtensor/__init__.py index 9c908e9a5cba6b..e38e03c87086cf 100644 --- a/torch/masked/maskedtensor/__init__.py +++ b/torch/masked/maskedtensor/__init__.py @@ -3,7 +3,6 @@ from .binary import _apply_native_binary, _is_native_binary from .core import is_masked_tensor, MaskedTensor -from .like import _apply_like_fn, _is_like_fn from .passthrough import _apply_pass_through_fn, _is_pass_through_fn from .reductions import _apply_reduction, _is_reduction from .unary import _apply_native_unary, _is_native_unary diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index febb0a95359b32..7344a6e801aae2 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -14,7 +14,6 @@ is_masked_tensor, MaskedTensor, ) -from .like import _apply_like_fn, LIKE_FNS from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS from .reductions import ( _apply_reduction, @@ -271,11 +270,6 @@ def _general_binary(func, *args, **kwargs): return _apply_native_binary(func, *args, **kwargs) -@register_dispatch_func(LIKE_FNS) -def _general_like(func, *args, **kwargs): - return _apply_like_fn(func, *args, **kwargs) - - @register_dispatch_func([torch.ops.aten.stride]) def stride(func, *args, **kwargs): return None @@ -371,6 +365,13 @@ def _softmax(func, *args, **kwargs): return MaskedTensor(result_data, mask) +@register_dispatch_func([torch.ops.aten.ones_like]) +def ones_like(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1) + result_data = func(_get_data(args[0]), **kwargs) + return MaskedTensor(result_data, _maybe_get_mask(args[0])) + + @register_dispatch_func([torch.ops.aten._softmax_backward_data]) def _softmax_backward_data(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4) diff --git a/torch/masked/maskedtensor/like.py b/torch/masked/maskedtensor/like.py deleted file mode 100644 index 8503670e1ba660..00000000000000 --- a/torch/masked/maskedtensor/like.py +++ /dev/null @@ -1,38 +0,0 @@ -# mypy: allow-untyped-defs -""" -These are the creation ops using the ``*_like`` API that need to support the -`MaskedTensor`. The wrapper just applies the function to the masked data then convert -it to a masked tensor using the mask from the given tensor. -""" - -import torch - -from .core import _get_data, _maybe_get_mask, MaskedTensor - - -__all__ = [ - "LIKE_NAMES", - "LIKE_FNS", -] - - -LIKE_NAMES = [ - "empty_like", - "full_like", - "ones_like", - "rand_like", - "randint_like", - "randn_like", - "zeros_like", -] - -LIKE_FNS = [getattr(torch.ops.aten, name) for name in LIKE_NAMES] - - -def _is_like_fn(fn): - return fn in LIKE_FNS - - -def _apply_like_fn(func, *args, **kwargs): - result_data = func(_get_data(args[0]), *args[1:], **kwargs) - return MaskedTensor(result_data, _maybe_get_mask(args[0])) From b3aeaeaa3839367cbb25d524faef76d7a0bc9b4a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 01:50:22 +0000 Subject: [PATCH 0189/1018] Revert "[3/N] Set correct device to CUDA guards (#134357)" This reverts commit afc76c6f2d46d7726012507ec5c67b4c04e21723. Reverted https://github.com/pytorch/pytorch/pull/134357 on behalf of https://github.com/kwen2501 due to This is breaking builds of MTIA ([comment](https://github.com/pytorch/pytorch/pull/134300#issuecomment-2316559704)) --- test/distributed/test_c10d_nccl.py | 19 ------------------- .../distributed/c10d/ProcessGroupNCCL.cpp | 18 ++++++++---------- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 3a575729cdff78..55b785a3552b5a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -350,7 +350,6 @@ def test_close_pg(self): @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16]) @skip_if_rocm def test_nan_assert(self, type): - # Expecting a device-side error when NaN is detected os.environ["TORCH_NCCL_NAN_CHECK"] = "1" store = c10d.FileStore(self.file_name, self.world_size) pg = self._create_process_group_nccl(store, self.opts()) @@ -389,24 +388,6 @@ def test_nan_p2p(self): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" - @requires_nccl() - @skip_if_lt_x_gpu(2) - def test_nan_check(self): - # Not expecting an error, NaN check should not make legit code fail - os.environ["TORCH_NCCL_NAN_CHECK"] = "1" - store = c10d.FileStore(self.file_name, self.world_size) - device = torch.device("cuda:%d" % self.rank) - c10d.init_process_group( - backend="nccl", store=store, rank=self.rank, world_size=self.world_size - ) - x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank - t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) - c10d.broadcast(x, src=0) - c10d.all_reduce(t) - c10d.destroy_process_group() - # reset env - os.environ["TORCH_NCCL_NAN_CHECK"] = "0" - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 8f741185a2c584..291eb7993ef2fb 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1148,10 +1148,6 @@ void ProcessGroupNCCL::abortCommsFromMap( at::cuda::OptionalCUDAGuard gpuGuard; at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); if (deviceIndex >= 0) { - // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in - // the map could be non deviceIndex, but rank to rank numbers. So we - // indeed need to check if deviceIndex >= 0 - // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` gpuGuard.set_index(deviceIndex); } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " @@ -2145,9 +2141,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( bool batchP2P = ncclActiveGroupCounter_ > 0; bool singleP2POp = isP2POp(opType, batchP2P); - // Get the device index - auto deviceIndex = device.index(); - at::cuda::OptionalCUDAGuard gpuGuard(device); + at::cuda::OptionalCUDAGuard gpuGuard; // [Group Start/End Note] This is used to ensure that nccl communicator will // be created before communication primitives are called. Let's look at this @@ -2187,6 +2181,10 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( rank = p2pRank; } + // Get the device index + auto deviceIndex = device.index(); + gpuGuard.set_index(deviceIndex); + #ifdef NCCL_HAS_COMM_SPLIT if (options_->split_from) { TORCH_CHECK( @@ -2696,7 +2694,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( work->stashed_for_allocator_safety_->push_back(input); } - at::cuda::OptionalCUDAGuard gpuGuard(device); + at::cuda::OptionalCUDAGuard gpuGuard; if (nanCheck) { checkForNan(input, ncclStream); @@ -2861,7 +2859,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard(device); + at::cuda::OptionalCUDAGuard gpuGuard; // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) @@ -3129,7 +3127,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard(device); + at::cuda::OptionalCUDAGuard gpuGuard; // Only check for NaN for send ops, for recv ops `tensor` can be a random // placeholder From 5b79f86d83a1c1a2ae69402ee2281392c820243d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 01:50:22 +0000 Subject: [PATCH 0190/1018] Revert "[2/N] Add flag to control which rank should perform NaN check (#134345)" This reverts commit 2fe7e332c7a61f025ccbcdbbb4875c6bf0b9afdf. Reverted https://github.com/pytorch/pytorch/pull/134345 on behalf of https://github.com/kwen2501 due to This is breaking builds of MTIA ([comment](https://github.com/pytorch/pytorch/pull/134300#issuecomment-2316559704)) --- test/distributed/test_c10d_nccl.py | 22 ----------- .../distributed/c10d/ProcessGroupNCCL.cpp | 39 ++++++------------- .../distributed/c10d/ProcessGroupNCCL.hpp | 6 +-- 3 files changed, 13 insertions(+), 54 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 55b785a3552b5a..64c51f791fddd0 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -366,28 +366,6 @@ def test_nan_assert(self, type): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" - @requires_nccl() - @skip_if_lt_x_gpu(2) - def test_nan_p2p(self): - # Putting NaN at recv buffer, program should not fail as NaN checker - # should not check on receive buffer - os.environ["TORCH_NCCL_NAN_CHECK"] = "1" - store = c10d.FileStore(self.file_name, self.world_size) - device = torch.device("cuda:%d" % self.rank) - c10d.init_process_group( - backend="nccl", store=store, rank=self.rank, world_size=self.world_size - ) - t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) - if self.rank == 0: - c10d.send(t, 1) - elif self.rank == 1: - # Putting NaN at recv buffer - t[1, 1] = float("nan") - c10d.recv(t, 0) - c10d.destroy_process_group() - # reset env - os.environ["TORCH_NCCL_NAN_CHECK"] = "0" - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 291eb7993ef2fb..abcf493e62fb8e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2637,12 +2637,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PostProcess post, OpType opType, const char* profilingTitle, - bool avoidRecordStreams, - bool nanCheck) { + bool avoidRecordStreams) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; - nanCheck &= enableNanCheck_; - c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2696,7 +2693,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; - if (nanCheck) { + if (enableNanCheck_) { checkForNan(input, ncclStream); } @@ -3129,9 +3126,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; - // Only check for NaN for send ops, for recv ops `tensor` can be a random - // placeholder - if (enableNanCheck_ && opType == OpType::SEND) { + if (enableNanCheck_) { checkForNan(tensor, ncclStream); } @@ -3228,8 +3223,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( Fn fn, OpType opType, const char* profilingTitle, - bool avoidRecordStreams, - bool nanCheck) { + bool avoidRecordStreams) { return collective( input, output, @@ -3240,8 +3234,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( c10::intrusive_ptr& work) {}, opType, profilingTitle, - avoidRecordStreams, - nanCheck); + avoidRecordStreams); } template @@ -3491,9 +3484,6 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( // avoidRecordStreams_ note: collective() will stash tensors. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - const auto root = opts.rootRank + opts.rootTensor; - bool nanCheck = (root == rank_); - return collective( tensor, tensor, @@ -3501,6 +3491,7 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; return ncclBcast( input.data_ptr(), input.numel(), @@ -3511,8 +3502,7 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( }, OpType::BROADCAST, "nccl:broadcast", - avoidRecordStreams, - nanCheck); + avoidRecordStreams); } // _broadcast_oop adds an out-of-place broadcast in PGNCCL @@ -3532,9 +3522,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( "Tensor input and output of _broadcast_oop must have the same number of elements "); } - const auto root = opts.rootRank + opts.rootTensor; - bool nanCheck = (root == rank_); - return collective( inputTensor, outputTensor, @@ -3542,6 +3529,7 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; return ncclBroadcast( input.data_ptr(), output.data_ptr(), @@ -3552,9 +3540,7 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, - "nccl:_broadcast_oop", - /*avoidRecordStreams=*/false, - nanCheck); + "nccl:_broadcast_oop"); } c10::intrusive_ptr ProcessGroupNCCL::reduce( @@ -4505,9 +4491,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // inputs, which == inputTensors[0] on the root rank where it matters. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - const auto root = opts.rootRank; - bool nanCheck = (rank_ == root); - return collective( outputTensor, inputs[0], // just to fit the collective interface @@ -4515,6 +4498,7 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank; if (getRank() == root) { if (!avoidRecordStreams) { for (auto input : inputs) { @@ -4528,8 +4512,7 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( }, OpType::SCATTER, "nccl:scatter", - avoidRecordStreams, - nanCheck); + avoidRecordStreams); } c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index b6635157d75e1b..2ba68cda9cc3ea 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -762,8 +762,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, - bool nanCheck = true); + bool avoidRecordStreams = false); template c10::intrusive_ptr collective( @@ -774,8 +773,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { PostProcess post, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, - bool nanCheck = true); + bool avoidRecordStreams = false); template c10::intrusive_ptr collectiveCoalesced( From 1770ef208e127c486f8a0fd417197da6b2da5e36 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 01:50:22 +0000 Subject: [PATCH 0191/1018] Revert "[1/N] Move NaN check onto NCCL stream (#134300)" This reverts commit 94caba4899096f160eca9628acddba6032755b3b. Reverted https://github.com/pytorch/pytorch/pull/134300 on behalf of https://github.com/kwen2501 due to This is breaking builds of MTIA ([comment](https://github.com/pytorch/pytorch/pull/134300#issuecomment-2316559704)) --- BUILD.bazel | 4 ++-- build_variables.bzl | 2 +- torch/csrc/distributed/c10d/NCCLUtils.hpp | 6 ------ torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 14 ++++++-------- .../distributed/c10d/{NCCLUtils.cu => Utils.cu} | 6 +++--- torch/csrc/distributed/c10d/Utils.hpp | 2 ++ 6 files changed, 14 insertions(+), 20 deletions(-) rename torch/csrc/distributed/c10d/{NCCLUtils.cu => Utils.cu} (86%) diff --git a/BUILD.bazel b/BUILD.bazel index 523f2ad2ce60a4..f079c76a7f7204 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -574,7 +574,7 @@ cu_library( name = "torch_cuda", srcs = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/NCCLUtils.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], copts = torch_cuda_half_options, @@ -722,7 +722,7 @@ cc_library( "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/NCCLUtils.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], )) + torch_sources, diff --git a/build_variables.bzl b/build_variables.bzl index bf7ffda0abda63..c3c06c01ada6e8 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -691,7 +691,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/NCCLUtils.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index e45d9d09b25379..b34ff33333daf0 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -715,11 +714,6 @@ struct NCCLTraceBuffer { bool includeStackTraces, bool onlyActive); }; - -// Check for NaNs in a tensor on a given stream. If any are found, throw a -// device-side error. -void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream); - } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index abcf493e62fb8e..71a081b5e495d5 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2638,6 +2638,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( OpType opType, const char* profilingTitle, bool avoidRecordStreams) { + if (enableNanCheck_) { + checkForNan(input); + } // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; c10::cuda::CaptureStatus capture_status = @@ -2693,10 +2696,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { - checkForNan(input, ncclStream); - } - // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { work->ncclStartEvent_->record(ncclStream); @@ -2998,6 +2997,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( PreProcess pre, PostProcess post, const char* profilingTitle) { + if (enableNanCheck_) { + checkForNan(tensor); + } // avoidRecordStreams_ note: // send, recv, and irecv should be ok with avoidRecordStreams, // However, for isend, I don't think the API requires the user @@ -3126,10 +3128,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { - checkForNan(tensor, ncclStream); - } - if (!coalescing_state_) { // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cu b/torch/csrc/distributed/c10d/Utils.cu similarity index 86% rename from torch/csrc/distributed/c10d/NCCLUtils.cu rename to torch/csrc/distributed/c10d/Utils.cu index 5dcf01fb98b153..ae2017efdf8d97 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cu +++ b/torch/csrc/distributed/c10d/Utils.cu @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include @@ -20,7 +20,7 @@ __global__ void checkForNaN(T* data, size_t size) { } // CHECK if a Tensor contains NAN in any of its element -void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { +void checkForNan(const at::Tensor& tensor) { // skip check for non float types if (!torch::is_floating_point(tensor)) { return; @@ -40,7 +40,7 @@ void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { tensor.scalar_type(), "checkForNaN", [&] { - checkForNaN<<>>( + checkForNaN<<>>( tensor.data_ptr(), tensor.numel()); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index ea4a4653bc35fc..5c7365393404fc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -611,6 +611,8 @@ using SizeType = uint64_t; // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) +void checkForNan(const at::Tensor& tensor); + namespace tcputil { // Send and receive From 9b3a5443fd313cb1d2081a2be6e3d85de67f590f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 01:59:02 +0000 Subject: [PATCH 0192/1018] Revert "[2nd try][Traceable FSDP2] Allow tracing through FSDP2 impl in trace_rules.py (#134539)" This reverts commit 26e392132d3039345de6aaf8643e7330f7fc3cbc. Reverted https://github.com/pytorch/pytorch/pull/134539 on behalf of https://github.com/ZainRizvi due to Sorry, I had to revert this in order to revert another PR ([comment](https://github.com/pytorch/pytorch/pull/134539#issuecomment-2316568257)) --- torch/_dynamo/trace_rules.py | 4 ---- torch/distributed/_composable/fsdp/_fsdp_state.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 966a54e672ee39..87db00ad1f9bab 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3213,8 +3213,6 @@ def _module_dir(m: types.ModuleType): # the forward_hook won't be ignored. "torch.distributed._composable.replicate", } - if not torch._dynamo.config.skip_fsdp_hooks: - LEGACY_MOD_INLINELIST.add("torch.distributed._composable.fsdp") # Force inline functions under these modules, even they are in *_SKIPLIST. @@ -3264,8 +3262,6 @@ def _module_dir(m: types.ModuleType): if torch.distributed.is_available(): MOD_INLINELIST.add("torch.distributed") - if not torch._dynamo.config.skip_fsdp_hooks: - MOD_INLINELIST.add("torch.distributed._composable.fsdp") @functools.lru_cache(None) diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index ceb480fd23939f..4b949c64d9f7f8 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -351,14 +351,12 @@ def _register_group_forward_hooks( """ modules_set = set(modules) - @disable_if_config_true @functools.wraps(pre_hook) def wrapped_pre_hook(*args: Any, **kwargs: Any): if len(modules_to_run) == 0: # first to run modules_to_run.update(modules_set) return pre_hook(*args, **kwargs) - @disable_if_config_true def get_wrapped_post_hook(module: nn.Module): @functools.wraps(post_hook) def wrapped_post_hook(*args: Any, **kwargs: Any): From 201cfb75e2c6a5add5f2a6c96e703b4cf268b3d3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 02:02:10 +0000 Subject: [PATCH 0193/1018] Revert "[dynamo] Graph break on FSDP flat_param inconsistent tensor and grad dtype (#134614)" This reverts commit e4a5958ab58e2f9b5b9c336a1d2a6449784d88d3. Reverted https://github.com/pytorch/pytorch/pull/134614 on behalf of https://github.com/ZainRizvi due to Sorry, I had to revert this in order to revert another PR ([comment](https://github.com/pytorch/pytorch/pull/134610#issuecomment-2316568553)) --- torch/_dynamo/variables/builder.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index b4f35dbbd404b8..9cd48c9c04c85b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -15,7 +15,6 @@ import re import sys import types -import warnings import weakref from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union @@ -26,7 +25,7 @@ from torch._ops import HigherOrderOperator from torch._streambase import _EventBase, _StreamBase from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode -from torch._subclasses.meta_utils import is_sparse_any, safe_grad +from torch._subclasses.meta_utils import is_sparse_any from torch._utils_internal import justknobs_check from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.symbolic_shapes import ( @@ -221,12 +220,6 @@ DimList = List -def safe_has_grad(t): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") - return hasattr(t, "grad") - - class _missing: pass @@ -1519,18 +1512,6 @@ def wrap_tensor(self, value: torch.Tensor): # SPARSE_TENSOR_GUARDS for guards to work propertly. unimplemented("torch.compile does not support sparse Tensors") - if ( - safe_has_grad(value) - and safe_grad(value) is not None - and value.dtype != safe_grad(value).dtype - ): - unimplemented( - "Inconsistent dtype between tensor and its gradient. " - "This can happen in FSDP and crashes meta tensor creation. " - "This is potentially a workaround. Fixing it correctly " - "requires some design around FSDP + torch.compile." - ) - tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=tensor_proxy, From 9a855f6f2414b2a69ca664f372005eaafbe0e626 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 02:02:10 +0000 Subject: [PATCH 0194/1018] Revert "[raland][dynamo][exceptions] Support raise from None (#134621)" This reverts commit e96dc3665a1d48434c02e17f7faed41f779cee2c. Reverted https://github.com/pytorch/pytorch/pull/134621 on behalf of https://github.com/ZainRizvi due to Sorry, I had to revert this in order to revert another PR ([comment](https://github.com/pytorch/pytorch/pull/134610#issuecomment-2316568553)) --- test/dynamo/test_exceptions.py | 33 --------------- torch/_dynamo/symbolic_convert.py | 69 +++++++++++++------------------ 2 files changed, 29 insertions(+), 73 deletions(-) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 5efa344c7ea2ed..33c6e3d0f6121b 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -353,39 +353,6 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) - def test_raise_from_None(self): - # Inspired from os.environ - class MyMapping: - def __init__(self, d): - self._d = d - - def __getitem__(self, key): - try: - value = self._d[key] - except KeyError: - raise KeyError(key) from None - return value - - d = MyMapping({"a": 10, "b": 20}) - - def mapping_get(obj, key, value=None): - try: - return obj.__getitem__(key) - except KeyError: - return value - - def fn(x, d, key): - x = torch.sin(x + 1) - return x, mapping_get(d, key) - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - - x = torch.rand(2, 3) - ref = fn(x, d, "m") - res = opt_fn(x, d, "m") - self.assertEqual(ref[0], res[0]) - self.assertEqual(ref[1], res[1]) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index e605e57b6aa49c..7e8040b09a03e4 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1359,47 +1359,35 @@ def FOR_ITER(self, inst): self.push(ConstantVariable.create(None)) self.jump(inst) - def _raise_exception_variable(self, inst): - val = self.pop() - # User can raise exception in 2 ways - # 1) raise exception type - raise NotImplementedError - # 2) raise execption instance - raise NotImplemetedError("foo") - - # 1) when user raises exception type - if isinstance(val, variables.BuiltinVariable): - # Create the instance of the exception type - # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 - val = val.call_function(self, [], {}) # type: ignore[arg-type] - - # Save the exception in a global data structure - self.exn_vt_stack.append(val) - - # 2) when user raises exception instance - if isinstance(val, variables.ExceptionVariable): - if observed_exception_type := exc.observed_exception_map.get(val.exc_type): - raise observed_exception_type(f"raised exception {val}") - raise exc.ObservedException(f"raised exception {val}") - unimplemented(f"raise {exc}") - def RAISE_VARARGS(self, inst): if inst.arg == 0: unimplemented("re-raise") elif inst.arg == 1: - self._raise_exception_variable(inst) + val = self.pop() + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + if isinstance(val, variables.BuiltinVariable): + # Create the instance of the exception type + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 + val = val.call_function(self, [], {}) # type: ignore[arg-type] + + # Save the exception in a global data structure + self.exn_vt_stack.append(val) + + # 2) when user raises exception instance + if isinstance(val, variables.ExceptionVariable): + if observed_exception_type := exc.observed_exception_map.get( + val.exc_type + ): + raise observed_exception_type(f"raised exception {val}") + raise exc.ObservedException(f"raised exception {val}") + unimplemented(f"raise {exc}") else: - # Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we - # ignore `from None` part. - from_vt = self.pop() - if isinstance(from_vt, ConstantVariable) and from_vt.value is None: - self._raise_exception_variable(inst) unimplemented("raise ... from ...") - def RERAISE(self, inst): - if sys.version_info >= (3, 11): - # RERAISE is currently supported in a narrow case of `raise ... from None` - self._raise_exception_variable(inst) - unimplemented("RERAISE") - def exception_handler(self, raised_exception): if sys.version_info >= (3, 11): exn_tab_entry = self.current_instruction.exn_tab_entry @@ -1413,6 +1401,10 @@ def exception_handler(self, raised_exception): # 2) if 'lasti' is true, then push the offset that the exception was raised at if exn_tab_entry.lasti: + # This is untested. Any test that tests this end-to-end + # requires supporting more bytecodes. Therefore graph + # breaking for now. + unimplemented("lasti=True while exception handling") self.push( variables.ConstantVariable(self.current_instruction.offset) ) @@ -1444,12 +1436,9 @@ def exception_handler(self, raised_exception): # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 self.popn(3) if len(self.block_stack) == 0: - # No handler found in this frame. Bubble the exception to the parent - # instruction translater. - self.stack.clear() - if type(self) is InstructionTranslator: - raise Unsupported("Observed exception") - raise raised_exception + unimplemented( + "exception is raised when block stack " "is empty" + ) block_stack_entry = self.block_stack.pop() if block_stack_entry.inst.opname != "SETUP_FINALLY": From 5f99397a4d6d50188e1600d3b7ecfcd7808e4988 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 02:02:10 +0000 Subject: [PATCH 0195/1018] Revert "[dynamo][dicts] Support hasattr on dicts (#134590)" This reverts commit c566f2465f41b8081caed205fcf5fe973fd970b3. Reverted https://github.com/pytorch/pytorch/pull/134590 on behalf of https://github.com/ZainRizvi due to Sorry, I had to revert this in order to revert another PR ([comment](https://github.com/pytorch/pytorch/pull/134610#issuecomment-2316568553)) --- test/dynamo/test_functions.py | 16 ---------------- torch/_dynamo/variables/dicts.py | 9 --------- 2 files changed, 25 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index f5cbfb8760c307..67ab47dfac9093 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1683,22 +1683,6 @@ def test_dict_sorted(x): tmp = {1: "D", 10: "B", 3: "E", 0: "F"} return x + 1, sorted(tmp), sorted(tmp, reverse=True) - def test_dict_hasattr(self): - def fn(x): - if hasattr(x, "to"): - return x.to("cpu") - if hasattr(x, "items"): - return torch.cos(x["a"]) - return x - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - - x = dict(a=torch.randn(3)) - self.assertEqual(fn(x), opt_fn(x)) - - x = torch.randn(4) - self.assertEqual(fn(x), opt_fn(x)) - @make_test def test_list_clear(a, b): tmp = [a + 1, a + 2] diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index f9b74bb2f202f4..2d4c015bf3ac1b 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -355,15 +355,6 @@ def call_method( def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] - def call_hasattr(self, tx, name): - # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict. - # OrderedDict though requires side effects tracking because it supports arbitrary setattr. - if self.user_cls is dict: - if name in self.user_cls.__dict__: - return ConstantVariable.create(True) - return ConstantVariable.create(False) - unimplemented(f"hasattr on {self.user_cls} is not supported") - class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: From 4f5a5683f587cf4233a96734f5d39c94ab3e5642 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 02:02:10 +0000 Subject: [PATCH 0196/1018] Revert "[dynamo][exceptions] Use exception subclass whenever possible (#134610)" This reverts commit 880e3d18a406777dbea6aeaf14443b0e3a8b441c. Reverted https://github.com/pytorch/pytorch/pull/134610 on behalf of https://github.com/ZainRizvi due to Sorry, I had to revert this in order to revert another PR ([comment](https://github.com/pytorch/pytorch/pull/134610#issuecomment-2316568553)) --- test/dynamo/test_exceptions.py | 23 ----------------------- torch/_dynamo/symbolic_convert.py | 7 +++---- 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 33c6e3d0f6121b..d9b364443eb132 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -267,29 +267,6 @@ def forward(self, x): x = torch.ones(4) self.assertEqual(mod(x), opt_mod(x)) - def test_attribute_error_from_getattr(self): - class Mock: - def __init__(self): - self.a = 5 - - def __getattr__(self, name): - if name != "a": - raise AttributeError("missing") - return self.__dict__["a"] - - mock = Mock() - - def fn(x): - if hasattr(mock, "b"): - return torch.cos(x) - return torch.sin(x) - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - x = torch.randn(4) - ref = fn(x) - res = opt_fn(x) - self.assertEqual(ref, res) - def test_stop_iteration(self): def zip_longest(*iterables, fillvalue=None): # Get the iterators for each iterable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 7e8040b09a03e4..0b7d68be08b9e5 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1379,10 +1379,9 @@ def RAISE_VARARGS(self, inst): # 2) when user raises exception instance if isinstance(val, variables.ExceptionVariable): - if observed_exception_type := exc.observed_exception_map.get( - val.exc_type - ): - raise observed_exception_type(f"raised exception {val}") + if val.exc_type is StopIteration: + # StopIteration is used to find the end of iteration while tracing __next__ + raise exc.ObservedUserStopIteration(f"raised exception {val}") raise exc.ObservedException(f"raised exception {val}") unimplemented(f"raise {exc}") else: From fe345ddd8120ef1d047328e7ce8f6309050648cc Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 28 Aug 2024 11:12:42 -0700 Subject: [PATCH 0197/1018] fix flakiness in update_hint_benchmark.py (#134649) ``` compile time instruction count for iteration 1 is 10732129038 compile time instruction count for iteration 2 is 10719776783 compile time instruction count for iteration 3 is 10729546868 compile time instruction count for iteration 4 is 10737655132 compile time instruction count for iteration 5 is 10732564252 compile time instruction count for iteration 6 is 10728721234 compile time instruction count for iteration 7 is 10733354271 compile time instruction count for iteration 8 is 10719588972 compile time instruction count for iteration 9 is 10706311856 ``` 1. add torch.manual_seed(0), inputs was not the same across iterations 2. disable gc. 3. remove loop (not needed since compilation happen once only) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134649 Approved by: https://github.com/aorenste ghstack dependencies: #133834, #134635 --- benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py | 4 ++++ .../benchmarks/update_hint_benchmark.py | 10 ++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py index 225f51c5b314d7..f15a87b07c10a9 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py @@ -109,6 +109,10 @@ def _count_compile_time_instructions(self): CompileTimeInstructionCounter.clear() self._work() count = CompileTimeInstructionCounter.value() + if count == 0: + raise RuntimeError( + "compile time instruction count is 0, please check your benchmarks" + ) print(f"compile time instruction count for iteration {i} is {count}") results.append(count) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py index 694bcea5985771..d7dfa0615d25b7 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py @@ -1,4 +1,4 @@ -import random +import gc import sys from benchmark_base import BenchmarkBase @@ -17,13 +17,16 @@ def description(self): def _prepare_once(self): torch._dynamo.config.capture_scalar_outputs = True - random.seed(42) + torch.manual_seed(0) + self.splits = torch.randint(10, (self.N,)) sz = self.splits.sum().item() self.input = torch.randn(sz) def _prepare(self): torch._dynamo.reset() + gc.collect() + gc.disable() def _work(self): @torch.compile(fullgraph=True) @@ -34,8 +37,7 @@ def f(a, b): torch._check(x <= self.N) return a.split(xs) - for i in range(1000): - f(self.input, self.splits) + f(self.input, self.splits) def main(): From f59ad4db37dfb5bdea8715bb61584f0f5f393e1a Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Tue, 27 Aug 2024 19:54:39 -0700 Subject: [PATCH 0198/1018] [Inductor] add inductor config: masked_vec (#134566) This PR adds inductor config: masked_vec to control enable/disable masked vectorization for the tail_loop, and enable by default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134566 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/codegen/cpp.py | 8 ++++++-- torch/_inductor/config.py | 3 +++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 08d7f3cc4b6fc7..1791d23da690e5 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3778,7 +3778,11 @@ def run(kernel): ) main_loop.set_kernel(vec_kernel) main_loop.simd_vec = True - if could_masked_vec and (tail_loop.size - tail_loop.offset) >= 4: + if ( + config.cpp.enable_loop_tail_vec + and could_masked_vec + and (tail_loop.size - tail_loop.offset) >= 4 + ): tail_loop.steps = tail_loop.size - tail_loop.offset masked_vec_kernel = codegen_kernel( CppVecKernel, @@ -3818,7 +3822,7 @@ def run(kernel): ) inner_main_loop.set_kernel(tile2d_kernel) - if could_masked_vec: + if config.cpp.enable_loop_tail_vec and could_masked_vec: ( inner_main_loop_of_outer_tail_loop, inner_tail_loop_of_outer_tail_loop, diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c9757f62a95dc3..443bb7cacc0775 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -768,6 +768,9 @@ class cpp: # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM. gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None) + # Whether to enable masked vectorization for the tail_loop. + enable_loop_tail_vec = True + # config specific to codegen/triton.py class triton: From 4c565f1565c0994a06e1f41db9c252eb1a2ee939 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 28 Aug 2024 08:19:46 -0700 Subject: [PATCH 0199/1018] [inductor] Cleanup generate_node_schedule (#134306) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134306 Approved by: https://github.com/shunting314 --- torch/_inductor/codegen/simd.py | 45 ++++++++++----------------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 7f50ad1a98f0d1..4958104b7105a2 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1064,13 +1064,10 @@ def can_fuse(self, node1, node2): def generate_node_schedule(self, nodes, numel, rnumel): node_schedule: List[Any] = [] - current_loop_writes: OrderedSet[str] = OrderedSet() - + done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() # Writes with a reduced shape, meaning they are only present once the # reduction loop has ended - current_loop_reduced_writes: OrderedSet[str] = OrderedSet() - current_loop_has_writes = False - done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() + not_ready_yet_nodes: OrderedSet[str] = OrderedSet() def fits_in_main_body(n): _, (node_numel, node_rnumel) = n.group @@ -1083,10 +1080,8 @@ def fits_outside_reduction(n): return node_numel == numel and node_rnumel == 1 and rnumel != 1 def schedule_node_in_loop(n): - nonlocal current_loop_has_writes done.add(n) node_schedule.append(n) - current_loop_has_writes = True # A scan is modelled as a reduction in the scheduler but has a # full sized output that can be used inside the loop body if ( @@ -1095,45 +1090,33 @@ def schedule_node_in_loop(n): and isinstance(n.node, ir.ComputedBuffer) and not isinstance(n.node.data, ir.Scan) ): - current_loop_reduced_writes.add(n.get_name()) + not_ready_yet_nodes.add(n.get_name()) @contextlib.contextmanager def end_current_reduction_loop(): - nonlocal current_loop_has_writes - if current_loop_has_writes: - # flush out any other runnable nodes to reduce number of loops - for other_node in nodes[index + 1 :]: - if ( - node not in done - and fits_in_main_body(other_node) - and not (current_loop_reduced_writes & other_node.ancestors) - ): - schedule_node_in_loop(node) - if node_schedule and node_schedule[-1] is EnableReduction: node_schedule.pop() else: node_schedule.append(DisableReduction) yield node_schedule.append(EnableReduction) - current_loop_reduced_writes.clear() - current_loop_has_writes = False + not_ready_yet_nodes.clear() + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not not_ready_yet_nodes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(not_ready_yet_nodes) for index, node in enumerate(nodes): if node in done: continue done.add(node) - def requires_closing_previous_reduction(node, node_schedule): - if rnumel == 1: - return False - if not current_loop_reduced_writes & node.ancestors: - return False - assert node_schedule and not isinstance( - node_schedule[-1], (EnableReduction, DisableReduction) - ) - return bool(current_loop_reduced_writes) - if fits_in_main_body(node): if requires_closing_previous_reduction(node, node_schedule): with end_current_reduction_loop(): From 320a917573270a4b3204ad6036bbdaafb368cce0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 02:49:30 +0000 Subject: [PATCH 0200/1018] Revert "[dynamo] simplify implementation for `os.fspath` (#133801)" This reverts commit 74341e1150f10b8aaddd33a165e686724424071f. Reverted https://github.com/pytorch/pytorch/pull/133801 on behalf of https://github.com/ZainRizvi due to Sorry, have to revert this in order to be able to revert https://github.com/pytorch/pytorch/pull/133769 ([comment](https://github.com/pytorch/pytorch/pull/133771#issuecomment-2316611158)) --- torch/_dynamo/polyfills/__init__.py | 12 ++++++++++ torch/_dynamo/polyfills/loader.py | 1 - torch/_dynamo/polyfills/os.py | 36 ----------------------------- torch/_dynamo/variables/misc.py | 14 ++++++++++- 4 files changed, 25 insertions(+), 38 deletions(-) delete mode 100644 torch/_dynamo/polyfills/os.py diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 9b465726754bbe..5e51b08ad40fd8 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -154,3 +154,15 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): if isinstance(obj, cls): obj.__init__(*args, **kwargs) return obj + + +def fspath(path): + # Python equivalent of os.fspath + if isinstance(path, (str, bytes)): + return path + elif hasattr(path, "__fspath__"): + return path.__fspath__() + else: + raise TypeError( + f"expected str, bytes or os.PathLike object, not {type(path).__name__}" + ) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 377029fbd63c86..c891490b68fb32 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -14,7 +14,6 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( "builtins", "itertools", - "os", ) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) diff --git a/torch/_dynamo/polyfills/os.py b/torch/_dynamo/polyfills/os.py deleted file mode 100644 index 013cc81ac30e6f..00000000000000 --- a/torch/_dynamo/polyfills/os.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Python polyfills for os -""" - -from __future__ import annotations - -import os -from typing import AnyStr - -from ..decorators import substitute_in_graph - - -__all__ = ["fspath"] - - -# Copied from os.py in the standard library -@substitute_in_graph(os.fspath) -def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: - if isinstance(path, (str, bytes)): - return path - - path_type = type(path) - try: - path_repr = path_type.__fspath__(path) # type: ignore[arg-type] - except AttributeError: - if hasattr(path_type, "__fspath__"): - raise - raise TypeError( - f"expected str, bytes or os.PathLike object, not {path_type.__name__}", - ) from None - if isinstance(path_repr, (str, bytes)): - return path_repr # type: ignore[return-value] - raise TypeError( - f"expected {path_type.__name__}.__fspath__() to return str or bytes, " - f"not {type(path_repr).__name__}", - ) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2979bd58e312fe..7a81bd37094a53 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -4,6 +4,7 @@ import functools import inspect import itertools +import os import random import re import sys @@ -14,7 +15,7 @@ import torch._numpy as tnp import torch.utils._pytree as pytree -from .. import config, variables +from .. import config, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented @@ -29,6 +30,7 @@ ) from ..utils import ( check_unspec_or_constant_args, + hashable, identity, is_tensor_base_attr_getter, proxy_args_kwargs, @@ -47,6 +49,10 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator +POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS = { + os.fspath: polyfills.fspath, +} + class NO_SUCH_SUBOBJ: pass @@ -1161,6 +1167,12 @@ def var_getattr(self, tx: "InstructionTranslator", name): else: attr_value = self.value.__dict__[name] + if hashable(attr_value): + if polyfill_fn := POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS.get( + attr_value, None + ): + return variables.UserFunctionVariable(polyfill_fn) + if self.source: new_source = AttrSource(self.source, name) return VariableBuilder(tx, new_source)(attr_value) From d4585ce0dc6abaef46b8e3c4775ae2fbf60c187a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 02:49:30 +0000 Subject: [PATCH 0201/1018] Revert "[dynamo][itertools] support `itertools.tee` (#133771)" This reverts commit 1dbd3476de07d7f07489e243cb7a43073e8c25c1. Reverted https://github.com/pytorch/pytorch/pull/133771 on behalf of https://github.com/ZainRizvi due to Sorry, have to revert this in order to be able to revert https://github.com/pytorch/pytorch/pull/133769 ([comment](https://github.com/pytorch/pytorch/pull/133771#issuecomment-2316611158)) --- test/dynamo/test_misc.py | 16 ------------- torch/_dynamo/polyfills/itertools.py | 34 ---------------------------- torch/_dynamo/polyfills/loader.py | 5 +--- 3 files changed, 1 insertion(+), 54 deletions(-) delete mode 100644 torch/_dynamo/polyfills/itertools.py diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index da258c94308837..74dcc2507f8888 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10061,22 +10061,6 @@ def fn(l): self.assertEqual(eager, compiled) self.assertEqual(len(counters["graph_break"]), 0) - def test_itertools_tee(self): - counters.clear() - - def fn(l): - a, b = itertools.tee(l) - return list(a), list(b) - - l = [1, 2, 2, 3, 4, 4, 4, 1, 2] - eager = fn(l) - - compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) - compiled = compiled_fn(l) - - self.assertEqual(eager, compiled) - self.assertEqual(len(counters["graph_break"]), 0) - def test_list_iterator_contains(self): def fn(x): it = iter(["my_weight", "not_my_weight"]) diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py deleted file mode 100644 index 7207acd6504b51..00000000000000 --- a/torch/_dynamo/polyfills/itertools.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Python polyfills for itertools -""" - -import itertools -from typing import Iterable, Iterator, Tuple, TypeVar - -from ..decorators import substitute_in_graph - - -__all__ = ["tee"] - - -_T = TypeVar("_T") - - -# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee -@substitute_in_graph(itertools.tee) -def tee(iterable: Iterable[_T], n: int = 2, /) -> Tuple[Iterator[_T], ...]: - iterator = iter(iterable) - shared_link = [None, None] - - def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def] - try: - while True: - if link[1] is None: - link[0] = next(iterator) - link[1] = [None, None] - value, link = link - yield value - except StopIteration: - return - - return tuple(_tee(shared_link) for _ in range(n)) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index c891490b68fb32..0a7e4aa628ada1 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -11,10 +11,7 @@ from types import ModuleType -POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( - "builtins", - "itertools", -) +POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ("builtins",) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) for submodule in POLYFILLED_MODULE_NAMES From 3e834223df62dfcb28bd0426f69d21d3c441268e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 03:00:47 +0000 Subject: [PATCH 0202/1018] Revert "[dynamo] simplify polyfill registration for `builtins.all` and `builtins.any` (#133769)" This reverts commit cc3a76edbac4a48381db6ccc44a83927f80c545b. Reverted https://github.com/pytorch/pytorch/pull/133769 on behalf of https://github.com/ZainRizvi due to Sorry but this has been discovered to be causing a performance regression internally ([comment](https://github.com/pytorch/pytorch/pull/133769#issuecomment-2316620213)) --- torch/_dynamo/polyfills/__init__.py | 14 ++++++++++++++ torch/_dynamo/polyfills/builtins.py | 30 ----------------------------- torch/_dynamo/polyfills/loader.py | 2 +- torch/_dynamo/trace_rules.py | 2 -- torch/_dynamo/variables/builtin.py | 18 ++++++++++++++++- 5 files changed, 32 insertions(+), 34 deletions(-) delete mode 100644 torch/_dynamo/polyfills/builtins.py diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 5e51b08ad40fd8..07ded8c4bd562f 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -13,6 +13,20 @@ import torch +def all(iterator): + for elem in iterator: + if not elem: + return False + return True + + +def any(iterator): + for elem in iterator: + if elem: + return True + return False + + def index(iterator, item, start=0, end=None): for i, elem in islice(enumerate(iterator), start, end): if item == elem: diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py deleted file mode 100644 index d04ea11fe3d7f5..00000000000000 --- a/torch/_dynamo/polyfills/builtins.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Python polyfills for builtins -""" - -import builtins -from typing import Iterable - -from ..decorators import substitute_in_graph - - -__all__ = [ - "all", - "any", -] - - -@substitute_in_graph(builtins.all) -def all(iterable: Iterable[object], /) -> bool: - for elem in iterable: - if not elem: - return False - return True - - -@substitute_in_graph(builtins.any) -def any(iterable: Iterable[object], /) -> bool: - for elem in iterable: - if elem: - return True - return False diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 0a7e4aa628ada1..04e30572d773cd 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -11,7 +11,7 @@ from types import ModuleType -POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ("builtins",) +POLYFILLED_MODULE_NAMES: Tuple[str, ...] = () POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) for submodule in POLYFILLED_MODULE_NAMES diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 87db00ad1f9bab..d924c383186253 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2980,7 +2980,6 @@ def _disallowed_callable_ids() -> Dict[int, str]: @FunctionIdSet def _builtin_function_ids() -> Dict[int, str]: - # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids rv = { id(v): f"builtins.{k}" for k, v in builtins.__dict__.items() @@ -3073,7 +3072,6 @@ def is_forbidden(obj) -> bool: def is_builtin_callable(obj) -> bool: - # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids return id(obj) in _builtin_function_ids diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 86900dee30bc56..84a2f6f05e4b1e 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -16,7 +16,7 @@ from torch import sym_float, sym_int from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from .. import config, variables +from .. import config, polyfills, variables from ..exc import ( AttributeMutationError, unimplemented, @@ -94,6 +94,19 @@ } +def _polyfill_call_impl(name): + """Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}""" + + def call_fn(self, tx: "InstructionTranslator", *args, **kwargs): + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn), args, kwargs + ) + + fn = getattr(polyfills, name) + call_fn.__name__ = f"call_{name}" + return call_fn + + class BuiltinVariable(VariableTracker): _SENTINEL = object() _nonvar_fields = { @@ -2111,6 +2124,9 @@ def call_contains( ): return a.call_method(tx, "__contains__", [b], {}) + call_all = _polyfill_call_impl("all") + call_any = _polyfill_call_impl("any") + @contextlib.contextmanager def dynamo_disable_grad(tx): From 6f1f2b46ca9c1be7af64ea06aeac10ea7f0d1f2a Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Thu, 29 Aug 2024 03:06:57 +0000 Subject: [PATCH 0203/1018] [inductor]Let output or input_as_strided match exact strides (#130956) Fixes #130394 TorchInductor doesn't respect original strides of outputs. It opens up optimization opportunities like changing up memory layout. But for some cases, such as the case in https://github.com/pytorch/pytorch/issues/130394, we do need the output match the exact stride as required. The correctness is the first priority goal. So, this PR adds a new API `ir.ExternKernel.require_exact_strides(x, exact_strides, allow_padding=False)` to fix the issue. This PR enables dense and non-dense outputs' strides follow the strides required by semantics. The comparison between the original and after this fix for the test is the below. ```python @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 128 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex % 8 x1 = (xindex // 8) - x2 = xindex tmp0 = tl.load(in_ptr0 + (x0 + (16*x1)), xmask) tmp1 = tmp0 + tmp0 - tl.store(out_ptr0 + (x2), tmp1, xmask) + tl.store(out_ptr0 + (x0 + (16*x1)), tmp1, xmask) def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (16, 8), (16, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) - buf1 = empty_strided_cuda((16, 8), (8, 1), torch.float32) + buf1 = empty_strided_cuda((16, 8), (16, 1), torch.float32) stream0 = get_raw_stream(0) triton_poi_fused_add_copy_0.run(arg0_1, buf1, 128, grid=grid(128), stream=stream0) del arg0_1 return (buf1, ) ``` The buf1 is created with exact stride required by users, and its values are written in same stride with the input. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130956 Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/desertfire --- test/inductor/test_torchinductor.py | 18 +++ torch/_inductor/codegen/wrapper.py | 2 + torch/_inductor/graph.py | 62 ++++++++--- torch/_inductor/ir.py | 164 ++++++++++++++++++++++------ 4 files changed, 196 insertions(+), 50 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 2019d378063404..b354ff2b1ac903 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -7257,6 +7257,24 @@ def fn_channels_last(x): [torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)], ) + def test_exact_stride(self): + full = torch.randn((16, 16), device=self.device) + view = torch.as_strided(full, (16, 8), full.stride()) + + def fn(x): + result = x + x + result_strided = torch.empty_strided( + x.size(), x.stride(), device=self.device + ) + result_strided[:] = result + return result_strided + + self.common(fn, [view]) + reference_out = fn(view) + compiled_fn = torch.compile(fn) + actual_out = compiled_fn(view) + self.assertEqual(reference_out.stride(), actual_out.stride()) + def test_like_channels_last(self): def foo(): randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 7626c39aec20b4..9ecc942e5d4031 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2014,6 +2014,8 @@ def statically_known_int_or_none(x): # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but # the actual codegen will still generate the full expression here. return None + if isinstance(x, int): + return x val = V.graph._shape_env._maybe_evaluate_static(x) return int(val) except Exception: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 47bcb068486532..987e639657cb0c 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -186,7 +186,9 @@ def getattr_recursive( return attr_itr -def mark_nodes_dislike_padding(g: Graph) -> None: +def mark_nodes_dislike_padding( + g: Graph, user_visible_outputs: Optional[Dict[str, None]] +) -> None: """ Nodes like convolution/convolution_backward want its input to be dense. If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction. @@ -244,6 +246,9 @@ def _get_overload_packet( continue if prior_op not in ops_like_padding: prior.meta["dislike_padding"] = True + # We only want to mark output nodes. So, move it after the above prior nodes process. + if user_visible_outputs and cur.name in user_visible_outputs: + cur.meta["dislike_padding"] = True class GraphLowering(torch.fx.Interpreter): @@ -415,11 +420,11 @@ def __init__( self.nodes_prefer_channels_last = ( self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet() ) - mark_nodes_dislike_padding(gm.graph) self._warned_fallback = {"aten.convolution_backward"} self.user_visible_outputs = ( user_visible_outputs if user_visible_outputs is not None else {} ) + mark_nodes_dislike_padding(gm.graph, user_visible_outputs) self.cache_key: str = "" # This is the cache key for the compiled artifact self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored self.cache_linemap: List[ @@ -1352,27 +1357,48 @@ def debug(msg: str) -> None: n.meta["val"], torch.Tensor ): strides = n.meta["val"].stride() - dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]) - unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0 - # requiring a stride order for a non-dense output wouldn't - # recreate the same strides, and would fail with view, defer for now. - if not unbacked_symbols_in_strides and dense and len(strides): - stride_order = ir.get_stride_order(strides) - if ( - len(result.get_size()) == 4 - and n in self.nodes_prefer_channels_last - and n.name not in self.user_visible_outputs - and not is_input_for_as_strided - ): - stride_order = ir.NHWC_STRIDE_ORDER - + if len(strides): allow_padding = ( n.name not in self.user_visible_outputs and not is_input_for_as_strided ) - result = ir.ExternKernel.require_stride_order( - result, stride_order, allow_padding=allow_padding + dense = torch._prims_common.is_non_overlapping_and_dense( + n.meta["val"] + ) + unbacked_symbols_in_strides = ( + len(free_unbacked_symbols(strides)) > 0 ) + if ( + not unbacked_symbols_in_strides + and dense + and len(result.get_size()) == 4 + and n in self.nodes_prefer_channels_last + and n.name not in self.user_visible_outputs + and not is_input_for_as_strided + ): + strides = ir.FlexibleLayout.stride_ordered_for_memory_format( + result.get_size(), torch.channels_last + ) + if not unbacked_symbols_in_strides and len(strides): + # To avoid converting possible view ops to a copy kernel, we use the previous + # require_exact_strides to handle views. But ultimately it's better to require + # the right strides at the tensor definition. + if n.meta["val"]._is_view() or isinstance( + result.data, ir.BaseView + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(strides), + allow_padding=allow_padding, + ) + else: + strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s + for s in strides + ] + result = ir.ExternKernel.require_exact_strides( + result, strides, allow_padding=allow_padding + ) # Realize if (1) any user need inputs realized, or (2) there is # already too many reads and rematerializing can be bad. diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index ec33f94e70d95a..0f4cc5d1b61f9d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -745,6 +745,22 @@ def welford_combine_fn( raise NotImplementedError(f"unknown reduction_type={reduction_type}") +def significant_strides_equal( + strides1: Sequence[_IntLike], strides2: Sequence[_IntLike], size: Sequence[_IntLike] +) -> bool: + """ + Returns true if the strides are equal, ignoring dimensions of size 1 . + """ + non_1_indices = [ + i + for i, dim in enumerate(size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + strides1 = [V.graph.sizevars.size_hint(strides1[i]) for i in non_1_indices] + strides2 = [V.graph.sizevars.size_hint(strides2[i]) for i in non_1_indices] + return strides1 == strides2 + + @dataclasses.dataclass class Reduction(Loops): reduction_ranges: List[Expr] @@ -2085,6 +2101,7 @@ def as_storage_and_layout( want_contiguous: bool = False, stride_order: Optional[Sequence[Union[int, Integer]]] = None, allow_padding: bool = False, + exact_strides: Optional[Sequence[Union[int, Integer]]] = None, ) -> Tuple[StorageBox, Layout]: """ Try to simplify x into a StorageBox and a Layout. @@ -2099,6 +2116,7 @@ def as_storage_and_layout( want_contiguous=want_contiguous, stride_order=stride_order, allow_padding=allow_padding, + exact_strides=exact_strides, ) if isinstance(x, StorageBox) and isinstance(x.data, Buffer): if freeze: @@ -2109,6 +2127,10 @@ def as_storage_and_layout( x.data.freeze_layout_with_stride_order( stride_order, allow_padding=allow_padding ) + elif exact_strides is not None: + x.data.freeze_layout_with_exact_strides( + exact_strides, allow_padding=allow_padding + ) else: x.data.decide_layout() return x, x.data.layout @@ -3202,6 +3224,19 @@ def as_stride_order(self, order, allow_padding=False): self.offset, ) + def as_exact_strides(self, exact_strides, allow_padding=False): + new_stride = exact_strides + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + def as_fill_order(self, order): new_stride = self.fill_ordered(self.size, order) if self.should_pad_strides(): @@ -3428,6 +3463,12 @@ def freeze_layout_with_same_order(self, stride): assert isinstance(self.layout, FlexibleLayout) self.layout = self.layout.as_same_order(stride) + def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_exact_strides( + exact_strides, allow_padding=allow_padding + ) + def is_zero_elements(self): return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] @@ -4680,53 +4721,93 @@ def require_stride1(cls, x): return cls.copy_input(x) @classmethod - def require_stride_order(cls, x, order, allow_padding=False): + def require_strides( + cls, + x, + order: Optional[Sequence[int]] = None, + exact_strides: Optional[Sequence[_IntLike]] = None, + allow_padding=False, + ): + assert order is not None or exact_strides is not None if x.get_numel() == 0: # Layout doesn't matter return x - - # require x to have the layout as strided_ordered as order + # require x to have the layout if is_storage_and_layout(x): while isinstance(x.get_layout(), NonOwningLayout): x = x.get_layout().view if isinstance(x.get_layout(), FlexibleLayout): - # If the the FlexibleLayout already has the size and stride in the required order, - # freeze it to a FixedLayout by using its current size and stride. - # The behavior of using its current size and stride or the given order can be different - # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: - # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), - # the current size and stride already satisfies this order. - # However by freezing it to the required order, the layout will be changed to: - # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. - - # fix flexiblelayout to be FixedLayout with stride_order - as_storage_and_layout( - x, - freeze=True, - want_contiguous=False, - stride_order=get_stride_order( - V.graph.sizevars.size_hints(x.get_layout().stride) + if order: + # If the the FlexibleLayout already has the size and stride in the required order, + # freeze it to a FixedLayout by using its current size and stride. + # The behavior of using its current size and stride or the given order can be different + # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: + # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), + # the current size and stride already satisfies this order. + # However by freezing it to the required order, the layout will be changed to: + # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. + + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=get_stride_order( + V.graph.sizevars.size_hints(x.get_layout().stride) + ) + if is_stride_order_storage_and_layout(x, order) + else order, + allow_padding=allow_padding, + ) + return x + else: + # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides. + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=None, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + return x + elif isinstance(x.get_layout(), FixedLayout) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() ) - if is_stride_order_storage_and_layout(x, order) - else order, - allow_padding=allow_padding, ) - return x - elif isinstance( - x.get_layout(), FixedLayout - ) and x.get_layout().is_stride_ordered(order): + ): return x elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE): if isinstance(x.get_layout().real_layout(), FlexibleLayout): raise AssertionError( "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" ) - elif isinstance( - x.get_layout().real_layout(), FixedLayout - ) and x.get_layout().real_layout().is_stride_ordered(order): + elif isinstance(x.get_layout().real_layout(), FixedLayout) and ( + (order and x.get_layout().real_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, + x.get_layout().real_layout().stride, + x.get_size(), + ) + ) + ): return x # TODO - Storage to InputBuffer - if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): + if isinstance(x, InputBuffer) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): return x if ( isinstance(x, TensorBox) @@ -4737,7 +4818,14 @@ def require_stride_order(cls, x, order, allow_padding=False): ): try: x.data = cls.convert_to_reinterpret_view(x.data) - return cls.require_stride_order(x, order, allow_padding=allow_padding) + if order: + return cls.require_stride_order( + x, order, allow_padding=allow_padding + ) + elif exact_strides: + return cls.require_exact_strides( + x, exact_strides, allow_padding=allow_padding + ) except NotImplementedError: pass # Although this is a clone, inductor is good about fusing clones into previous @@ -4749,10 +4837,22 @@ def require_stride_order(cls, x, order, allow_padding=False): want_contiguous=False, stride_order=order, allow_padding=allow_padding, + exact_strides=exact_strides, ) - assert is_stride_order_storage_and_layout(x, order) + if order: + assert is_stride_order_storage_and_layout(x, order) return x + @classmethod + def require_exact_strides(cls, x, exact_strides, allow_padding=False): + return cls.require_strides( + x, exact_strides=exact_strides, allow_padding=allow_padding + ) + + @classmethod + def require_stride_order(cls, x, order, allow_padding=False): + return cls.require_strides(x, order=order, allow_padding=allow_padding) + @classmethod def require_channels_last(cls, x): return cls.require_stride_order(x, NHWC_STRIDE_ORDER) From 0dc969f6015fa25ce228ca8db5a04e8e82581a8c Mon Sep 17 00:00:00 2001 From: Ivan Zaitsev Date: Thu, 29 Aug 2024 03:11:13 +0000 Subject: [PATCH 0204/1018] Reflect check_labels status as a signal (#134711) Fixes the workflow when meta-exported diff (co-dev) doesn't have the required labels, but the signal is suppressed due to job failure (e.g. [see this run](https://github.com/pytorch/pytorch/actions/runs/10590994706/job/29347663526?pr=134484)). With this change the workflow status correctly reflects the status of the check. # Testing * [illegal pr_num](https://github.com/pytorch/pytorch/actions/runs/10603163898/job/29386843591) * [successful run](https://github.com/pytorch/pytorch/actions/runs/10603279052/job/29387230110) (topic label present) * no labels: [check fails](https://github.com/pytorch/pytorch/actions/runs/10603310368/job/29387333864) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134711 Approved by: https://github.com/clee2000 --- .github/scripts/check_labels.py | 11 ++++++++++- .github/scripts/test_check_labels.py | 1 + .github/workflows/check-labels.yml | 8 ++++++-- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/.github/scripts/check_labels.py b/.github/scripts/check_labels.py index a29ef24a414c54..10be42c3fd5647 100755 --- a/.github/scripts/check_labels.py +++ b/.github/scripts/check_labels.py @@ -27,6 +27,12 @@ def parse_args() -> Any: parser = ArgumentParser("Check PR labels") parser.add_argument("pr_num", type=int) + # add a flag to return a non-zero exit code if the PR does not have the required labels + parser.add_argument( + "--exit-non-zero", + action="store_true", + help="Return a non-zero exit code if the PR does not have the required labels", + ) return parser.parse_args() @@ -41,10 +47,13 @@ def main() -> None: if not has_required_labels(pr): print(LABEL_ERR_MSG) add_label_err_comment(pr) + if args.exit_non_zero: + sys.exit(1) else: delete_all_label_err_comments(pr) except Exception as e: - pass + if args.exit_non_zero: + sys.exit(1) sys.exit(0) diff --git a/.github/scripts/test_check_labels.py b/.github/scripts/test_check_labels.py index 2b2cd7b6c5204b..1c921f2eafa9a3 100644 --- a/.github/scripts/test_check_labels.py +++ b/.github/scripts/test_check_labels.py @@ -18,6 +18,7 @@ def mock_parse_args() -> object: class Object: def __init__(self) -> None: self.pr_num = 76123 + self.exit_non_zero = False return Object() diff --git a/.github/workflows/check-labels.yml b/.github/workflows/check-labels.yml index d638d588504f2e..8ad611bd7cc917 100644 --- a/.github/workflows/check-labels.yml +++ b/.github/workflows/check-labels.yml @@ -19,6 +19,10 @@ on: branches: [gh/**/base] workflow_dispatch: + inputs: + pr_number: + description: 'PR number to check labels for' + required: true concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} @@ -54,7 +58,7 @@ jobs: - name: Check labels env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PR_NUM: ${{ github.event.number }} + PR_NUM: ${{ github.event.number || github.event.inputs.pr_number }} run: | set -ex - python3 .github/scripts/check_labels.py "${PR_NUM}" + python3 .github/scripts/check_labels.py --exit-non-zero "${PR_NUM}" From d94e09c34b974b6cd9ae3295ca7271a8ca0cdab5 Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 27 Aug 2024 20:15:41 -0700 Subject: [PATCH 0205/1018] Add oneDNN support for Half LSTM on CPU (#132607) Pull Request resolved: https://github.com/pytorch/pytorch/pull/132607 Approved by: https://github.com/jgong5, https://github.com/peterbell10 --- aten/src/ATen/native/RNN.cpp | 8 +++++- test/test_mkldnn.py | 47 +++++++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index d1bb95f833231f..9db7b4cb7da090 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -15,6 +15,9 @@ #include #include #include +#if AT_MKLDNN_ENABLED() +#include +#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -97,7 +100,10 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) { }; return input.options().backend() == at::Backend::CPU && is_cpu_backend(params) && is_cpu_backend(hx) && - (input.scalar_type() == kFloat || input.scalar_type() == kBFloat16) && + (input.scalar_type() == kFloat || + (input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) || + (input.scalar_type() == kHalf && !at::GradMode::is_enabled() && + mkldnn_fp16_device_check())) && input.numel() != 0; #endif return false; diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index b61e50fb4f6227..5f192d7c349dfe 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -1469,24 +1469,30 @@ def _lstm_params_list(self): params_list = list(params_dict.values()) return params_list - def _cast_dtype(self, input, bf16): - if bf16: + def _cast_dtype(self, input, dtype): + if dtype == torch.bfloat16: input = input.to(torch.bfloat16) + elif dtype == torch.half: + input = input.to(torch.half) return input - @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") def test_lstm(self): seed = 2023 torch.manual_seed(seed) params_list = self._lstm_params_list() for dtype in types: - bf16 = True if dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported() else False + bf16 = dtype == torch.bfloat16 + fp16 = dtype == torch.half rtol = 1.3e-6 atol = 1e-5 + if bf16: rtol = 0.02 atol = 0.02 + if fp16: + rtol = 1e-3 + atol = 1e-3 for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \ in itertools.product(*params_list): num_directions = 2 if bidirectional else 1 @@ -1496,7 +1502,9 @@ def test_lstm(self): input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32) h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) - + if fp16: + # TODO add traing support when oneDNN support lstm FP16 training + training = False model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first).float() model.train() if training else model.eval() @@ -1510,15 +1518,25 @@ def test_lstm(self): model1 = copy.deepcopy(model) model2 = copy.deepcopy(model) - with torch.cpu.amp.autocast(enabled=bf16, dtype=torch.bfloat16), torch.no_grad() if not training else nullcontext(): + with torch.no_grad() if not training else nullcontext(): with torch.backends.mkldnn.flags(enabled=False): torch.manual_seed(seed) - output1, (hn1, cn1) = self._cast_dtype(model1, bf16)(self._cast_dtype(input1, bf16), - (self._cast_dtype(h1, bf16), - self._cast_dtype(c1, bf16))) + output1, (hn1, cn1) = self._cast_dtype(model1, dtype)( + self._cast_dtype(input1, dtype), + ( + self._cast_dtype(h1, dtype), + self._cast_dtype(c1, dtype), + ), + ) torch.manual_seed(seed) - output2, (hn2, cn2) = model2(input2, (h2, c2)) + output2, (hn2, cn2) = self._cast_dtype(model2, dtype)( + self._cast_dtype(input2, dtype), + ( + self._cast_dtype(h2, dtype), + self._cast_dtype(c2, dtype), + ), + ) self.assertEqual(output1, output2, rtol=rtol, atol=atol) self.assertEqual(hn1, hn2, rtol=rtol, atol=atol) self.assertEqual(cn1, cn2, rtol=rtol, atol=atol) @@ -1533,8 +1551,13 @@ def test_lstm(self): self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol) for name, para in model1.named_parameters(): - self.assertEqual(para, self._cast_dtype(getattr(model2, name), bf16)) - self.assertEqual(para.grad, self._cast_dtype(getattr(model2, name).grad, bf16), rtol=rtol, atol=atol) + self.assertEqual(para, getattr(model2, name)) + self.assertEqual( + para.grad, + getattr(model2, name).grad, + rtol=rtol, + atol=atol, + ) with torch.backends.mkldnn.flags(enabled=False): torch.manual_seed(seed) From 89c92493f08f33f20cd911f7382991d25c69b73f Mon Sep 17 00:00:00 2001 From: David Berard Date: Mon, 26 Aug 2024 11:29:22 -0700 Subject: [PATCH 0206/1018] [NJT] Support NJT SDPA + meta-device flop counting (#134289) A user wants to use the flop counter with meta devices. This previously caused problems for SDPA+NJT: 1. autocast check: `torch.is_autocast_enabled("meta")` fails because `meta` is not valid for autocasting. If we skip this, we run into the next error 2. math backend: conversion to NST requires getting concrete offsets in a list of python integers, which doesn't work on a meta tensor https://github.com/pytorch/pytorch/blob/b2eb0e8c6a817f304277dfed39c25853ff301d90/torch/nested/_internal/sdpa.py#L809-L815 3. (fixed in the previous PR, #134288) - if we force using flash attention backend for flop counting, `_flash_attention_forward` previously didn't support meta tensors. In this PR, we check specifically for FlopCounterMode, and, if it's enabled and combined with meta tensors, (a) skip autocasting and (b) force it down the flash attention path. This isn't generally safe for tracing (e.g. if you actually care which kernels you are running), but in the absence of actual device information, we have to make some assumptions. By specifically checking for FlopCounterMode, this should reduce the chance of unintended side effects for other meta tensor users. Note: fake tensor would solve a bunch of these issues, but it's not a viable solution right now for the user. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134289 Approved by: https://github.com/soulitzer ghstack dependencies: #134288 --- test/test_nestedtensor.py | 32 ++++++++++++++++++++++++++++++++ torch/nested/_internal/sdpa.py | 24 +++++++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index bd4010d81083a1..1c8f23e8c41fba 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -6680,6 +6680,38 @@ def get_values(): self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad) self.assertEqual(v16_dense_eager.grad, v16_nt_compile.grad) + @unittest.skipIf( + not PLATFORM_SUPPORTS_FUSED_ATTENTION, + "Platform doesn't support flash or mem-efficient attention", + ) + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + @onlyCUDA + @skipIfTorchDynamo() + def test_sdpa_flop_counter(self, device): + from torch.utils.flop_counter import FlopCounterMode + + def get_flops(nt): + flop_counter = FlopCounterMode(display=False) + with flop_counter: + ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt) + ret.values().sum().backward() + return flop_counter.get_total_flops() + + values = torch.randn( + (8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16 + ) + offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) + nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16) + + values_meta = torch.randn( + (8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16 + ) + offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32) + nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16) + + self.assertEqual(get_flops(nt), get_flops(nt_meta)) + @skipIfTorchDynamo() def test_nested_tensor_activation_checkpoint(self, device): values = torch.randn( diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 73d078d04c8a17..578904af946971 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -615,6 +615,20 @@ def _post_process_flash_output(out: torch.Tensor, og_size): return out +def _is_computing_meta_flops(x): + # Note: there's a use case of using meta tensors & the dispatch-based flop counter. + # We can use this function to check for this scenario in order to handle it specially. + if not torch.jit.is_scripting() and x.device.type == "meta": + torch_dispatch_mode_stack = ( + torch.utils._python_dispatch._get_current_dispatch_mode_stack() + ) + return any( + type(x) == torch.utils.flop_counter.FlopCounterMode + for x in torch_dispatch_mode_stack + ) + return False + + def _autocast( query: torch.Tensor, key: torch.Tensor, @@ -642,7 +656,8 @@ def _autocast( actual dtype conversions. """ device_type = query.device.type - if not torch.is_autocast_enabled(device_type): + # meta device is not supported by autocast, so break early for it + if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type): return query, key, value, attn_mask def cvt(x): @@ -703,6 +718,13 @@ def jagged_scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, is_causal, enable_gqa ) + if _is_computing_meta_flops(query): + # Backend choice will probably not be correct if we have a meta device, + # because backend choice is device-aware. In this case, we mostly just + # want to avoid using math backend (which does a .item() call). + # Arbitrarily choose flash attention. + backend_choice = SDPBackend.FLASH_ATTENTION + if backend_choice == SDPBackend.FLASH_ATTENTION: og_size = query.size(-1) query_padded = _pad_last_dim(query, 8, False) From 64f5aba20ba4386b3b5115891075848ffa818f83 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Thu, 29 Aug 2024 03:56:29 +0000 Subject: [PATCH 0207/1018] Uses MemPoolContext to route allocations from CUDACachingAllocator (#134685) Re-open of https://github.com/pytorch/pytorch/pull/133599 that was mistakenly closed by issuing `ghstack land` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134685 Approved by: https://github.com/ezyang --- c10/cuda/CUDACachingAllocator.cpp | 7 +++++- docs/source/cuda.rst | 3 +++ test/test_cuda.py | 33 ++++++++++++++++++++++++-- torch/_C/__init__.pyi.in | 1 + torch/csrc/cuda/Module.cpp | 7 ++++++ torch/cuda/__init__.py | 1 + torch/cuda/memory.py | 39 ++++++++++++++++++++++++++++++- 7 files changed, 87 insertions(+), 4 deletions(-) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 4a3a2c7545c54c..bfab4507b62b2c 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -2674,7 +2674,12 @@ class DeviceCachingAllocator { // any potential exceptions in the cudaMallocMaybeCapturing function. auto sg = c10::make_scope_exit([&]() { lock.lock(); }); lock.unlock(); - p.err = cudaMallocMaybeCapturing(&ptr, size); + } + auto active_pool = MemPoolContext::getActiveMemPool(); + if (active_pool && active_pool->allocator() && + p.pool->owner_PrivatePool) { + ptr = active_pool->allocator()->raw_alloc(size); + p.err = ptr ? cudaSuccess : cudaErrorMemoryAllocation; } else { p.err = cudaMallocMaybeCapturing(&ptr, size); } diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 328c3cecead6e1..5bdc4e81d352cb 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -122,6 +122,9 @@ Memory management change_current_allocator MemPool MemPoolContext + +.. autoclass:: torch.cuda.use_mem_pool + .. FIXME The following doesn't seem to exist. Is it supposed to? https://github.com/pytorch/pytorch/issues/27785 .. autofunction:: reset_max_memory_reserved diff --git a/test/test_cuda.py b/test/test_cuda.py index bb52e8c9360ffc..b2f1ef3b53cf25 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2,6 +2,7 @@ import collections import contextlib +import ctypes import gc import json import os @@ -4806,10 +4807,25 @@ def test_mempool_with_allocator(self): dummy_allocator_source = """ #include + #include + #include + extern "C" { + C10_EXPORT int called_dummy_alloc = 0; + C10_EXPORT int called_dummy_free = 0; + // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 - C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { return nullptr; } - C10_EXPORT void dummy_free(void* ptr) { } + C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { + called_dummy_alloc = 123; + void* ptr; + C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); + return ptr; + } + + C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { + called_dummy_free = 321; + C10_CUDA_CHECK(cudaFree(ptr)); + } } """ dummy_allocator_libname = "dummy_allocator" @@ -4819,6 +4835,7 @@ def test_mempool_with_allocator(self): is_python_module=False, keep_intermediates=False, verbose=True, + with_cuda=True, ) allocator = torch.cuda.memory.CUDAPluggableAllocator( dummy_allocator, @@ -4830,6 +4847,18 @@ def test_mempool_with_allocator(self): # pool should point to the same allocator as the one passed into it self.assertEqual(allocator.allocator(), pool.allocator) + # no allocations happened yet, so called_dummy_alloc should be 0 + alloc_lib = ctypes.CDLL(dummy_allocator) + called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc") + self.assertEqual(called_dummy_alloc.value, 0) + + with torch.cuda.use_mem_pool(pool): + out = torch.randn(1, device="cuda") + + # called_dummy_alloc should be 123 if dummy_alloc was used to allocate + # out tensor + self.assertEqual(called_dummy_alloc.value, 123) + def test_mempool_context(self): active_pool = torch.cuda.MemPoolContext.active_pool() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1638d3891e4084..23f3bae774fec6 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1831,6 +1831,7 @@ def _cuda_cudaHostAllocator() -> _int: ... def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ... def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ... +def _cuda_beginAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 1cc6f7378e80c4..4bfcc17500b964 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1281,6 +1281,13 @@ static void registerCudaPluggableAllocator(PyObject* module) { }); }); + m.def( + "_cuda_beginAllocateToPool", + [](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) { + c10::cuda::CUDACachingAllocator::beginAllocateToPool( + device, mempool_id, [](cudaStream_t) { return true; }); + }); + m.def( "_cuda_endAllocateCurrentStreamToPool", [](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) { diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index eb30f8f057ed25..5fb0038c23e145 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1628,6 +1628,7 @@ def addmm_kernel_impl(*args, **kwargs): "memory_usage", "MemPool", "MemPoolContext", + "use_mem_pool", "temperature", "power_draw", "clock_rate", diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 1726cbe439dc60..af2c8d480c8345 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -52,6 +52,7 @@ "change_current_allocator", "MemPool", "MemPoolContext", + "use_mem_pool", ] @@ -64,8 +65,20 @@ # Define dummy base classes torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool") torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext") + torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( + "_cuda_beginAllocateToPool" + ) + torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type( + "_cuda_endAllocateCurrentStreamToPool" + ) -from torch._C import _cuda_CUDAAllocator, _MemPool, _MemPoolContext # noqa: F401 +from torch._C import ( # noqa: F401 + _cuda_beginAllocateToPool, + _cuda_CUDAAllocator, + _cuda_endAllocateCurrentStreamToPool, + _MemPool, + _MemPoolContext, +) def _host_allocator(): @@ -1002,3 +1015,27 @@ def __init__(self, pool: MemPool): def active_pool() -> Optional[_MemPool]: r"""Returns the active MemPool""" return _MemPoolContext.active_pool() + + +@contextlib.contextmanager +def use_mem_pool(pool: MemPool, device: Union[Device, int] = None): + r"""A context manager that routes allocations to a given pool. + + Args: + pool(torch.cuda.MemPool): a MemPool object to be made active so that + allocations route to this pool. + device (torch.device or int, optional): selected device. Uses MemPool on + the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + """ + ctx = MemPoolContext(pool) + device_index = ( + torch.cuda.current_device() if device is None else _get_device_index(device) + ) + _cuda_beginAllocateToPool(device_index, pool.id) + try: + yield + finally: + _cuda_endAllocateCurrentStreamToPool(device_index, pool.id) + del ctx From 97e801f9c2ba61d3c166ad0183a04852d598790f Mon Sep 17 00:00:00 2001 From: "Jennifer (Jiyue) Wang" Date: Thu, 29 Aug 2024 04:09:44 +0000 Subject: [PATCH 0208/1018] [caffe2][hipify] remove un-used flag from `pybind_utils.h` (#134404) Summary: Encountered issues related to AMD build when working on https://www.internalfb.com/diff/D60739324?dst_version_fbid=2203158110057105 (see stack trace P1545717562) Looking at the file history, seems that the flag is no longer used so I propose to remove it. Alternatively, I could change the `#ifdef` to check both `USE_C10D_NCCL` and `USE_ROCM` and include the corresponding AMD header files. Let me know what is more preferred way. Test Plan: Sandcastle Differential Revision: D61762129 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134404 Approved by: https://github.com/malfet --- torch/csrc/jit/python/pybind_utils.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 555e7f50c3dd4e..96a6f23d6dad54 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -31,10 +31,6 @@ #include #include -#ifdef USE_C10D_NCCL -#include -#include -#endif #include #include #include From 0f05b2a8652092ac087fbe240f7f2d206bea1833 Mon Sep 17 00:00:00 2001 From: Banit Agrawal Date: Thu, 29 Aug 2024 04:46:08 +0000 Subject: [PATCH 0209/1018] [PyTorch] Use pinned memory for zero_cuda_out (#134712) Summary: This diff creates a pinned tensor for copying from device for the zero_out op. Differential Revision: D61759262 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134712 Approved by: https://github.com/zyan0 --- aten/src/ATen/native/cuda/Nonzero.cu | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index e87f46cd844eab..e5fb9230de7637 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -1,6 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include +#include #include #include #include @@ -70,7 +71,16 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ auto temp_storage = allocator.allocate(temp_storage_bytes); cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); int num_nonzeros_h; - at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream); + auto pinned_num_nonzeros_h = at::detail::empty_cpu( + {1}, /* size */ + c10::CppTypeToScalarType(), /* dtype */ + std::nullopt, /* layout */ + std::nullopt, /* device */ + true, /* pin_memory */ + std::nullopt /* memory format */ + ); + at::cuda::memcpy_and_sync((void *)pinned_num_nonzeros_h.const_data_ptr(), num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream); + num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr()); //expected output size is num_nonzeros x ndim //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) //we are able to directly use passed output with this size and strides, and we can also (per contract) From a546e51eb67052ed8425833ec1e7a295de86316b Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 28 Aug 2024 18:31:38 -0700 Subject: [PATCH 0210/1018] Add torch.serialization.skip_data context manager (#134504) ## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504 Approved by: https://github.com/albanD --- docs/source/notes/serialization.rst | 1 + ...cpp_extensions_open_device_registration.py | 12 +- test/test_serialization.py | 88 ++++++++++++++ torch/_tensor.py | 66 +++++++++-- torch/_utils.py | 6 +- torch/serialization.py | 110 ++++++++++++++++-- torch/storage.py | 4 + 7 files changed, 261 insertions(+), 26 deletions(-) diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 5541d28bdcafa0..c05dc028a471c9 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -398,3 +398,4 @@ The following utility functions are related to serialization: .. autofunction:: clear_safe_globals .. autofunction:: get_safe_globals .. autoclass:: safe_globals +.. autoclass:: skip_data diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 616a6e0f4b5518..4e86ed458b0788 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - def test_open_device_numpy_serialization_map_location(self): + def test_open_device_numpy_serialization(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL @@ -553,6 +553,7 @@ def test_open_device_numpy_serialization_map_location(self): self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) + # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( @@ -569,6 +570,15 @@ def test_open_device_numpy_serialization_map_location(self): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) + # Test metadata_only + with TemporaryFileName() as f: + with self.assertRaisesRegex( + RuntimeError, + "Cannot serialize tensors on backends with no storage under skip_data context manager", + ): + with torch.serialization.skip_data(): + torch.save(sd, f) + if __name__ == "__main__": common.run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index a041473d195b19..3ba96b80541d88 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,5 +1,6 @@ # Owner(s): ["module: serialization"] +import contextlib import copy import gc import gzip @@ -19,6 +20,7 @@ from pathlib import Path import torch +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter from torch._utils import _rebuild_tensor from torch._utils_internal import get_file_path_2 from torch.serialization import ( @@ -27,6 +29,7 @@ LoadEndianness, safe_globals, set_default_load_endianness, + skip_data, SourceChangeWarning, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -4212,6 +4215,91 @@ def test_filewriter_metadata_writing(self, filename): sd_loaded_ref = torch.load(f) self.assertEqual(sd_loaded, sd_loaded_ref) + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization(self, materialize_fake): + # Create one tensor that uses each of the paths in __reduce_ex__ that should work + t_device = "cuda" if torch.cuda.is_available() else "cpu" + t_v2 = torch.randn(2, 3, device=t_device) + t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device) + i = torch.tensor([[0, 1, 1], + [2, 0, 2]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + if not materialize_fake: + # FakeTensorConverter messes up sizes of i and v for the sparse tensor + st = torch.sparse_coo_tensor(i, v, (2, 4)) + tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device)) + + mode, converter = FakeTensorMode(), FakeTensorConverter() + + def fn(t): + return converter.from_real_tensor(mode, t) if materialize_fake else t + + sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)} + sd_expected = { + 't_v2': torch.zeros(2, 3, device=t_device), + 't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device), + 'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)), + } + + if not materialize_fake: + sd['st'] = st + sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4)) + + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + with safe_globals([TwoTensor]): + sd_loaded = torch.load(f, weights_only=True) + self.assertEqual(sd_loaded, sd_expected, exact_device=True) + self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False)) + self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False)) + + # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx + if not materialize_fake: + ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) + with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__..remove'"): + with skip_data(), BytesIOContext() as f: + torch.save(ft, f) + + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization_preserves_views(self, materialize_fake): + ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext + with ctx(): + t = torch.randn(2, 3) + t_view = t.view(-1) + t_slice = t[1] + sd = {'t': t, 't_view': t_view, 't_slice': t_slice} + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + sd_loaded = torch.load(f, weights_only=True) + self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + + def test_skip_data_serialization_error_cases(self): + def _save_load(t): + with BytesIOContext() as f: + with skip_data(): + torch.save(t, f) + f.seek(0) + torch.load(f, weights_only=True) + + nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)]) + t = torch.randn(2, 3, device="meta") + with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"): + _save_load(nt) + + with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"): + _save_load(t) + + with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"): + with skip_data(), BytesIOContext() as f: + torch.save(torch.randn(2, 3), f) + f.seek(0) + torch.load(f, weights_only=True) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 98563aebae9aa7..61d5e3891b9e56 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -209,8 +209,16 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) state = torch._utils._get_obj_state(self) - if type(self) is Tensor and not state: + # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has + # some state that cannot be pickled + if ( + type(self) is torch._subclasses.fake_tensor.FakeTensor + and materialize_fake_tensors + ) or (type(self) is Tensor and not state): # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): @@ -251,6 +259,12 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() + + skip_data = torch.serialization._serialization_tls.skip_data + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -268,6 +282,10 @@ def _reduce_ex_internal(self, proto): # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) numpy_tensor = ( self.cpu().numpy() if self.dtype != torch.bfloat16 @@ -280,6 +298,10 @@ def _reduce_ex_internal(self, proto): if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. + if skip_data: + warnings.warn( + "Serializing tensors on the meta device under skip_data context manager is a no-op" + ) arg_meta = ( self.dtype, tuple(self.size()), @@ -288,6 +310,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: + if skip_data: + raise RuntimeError( + "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" + ) # quantizer_params can be different type based on torch attribute quantizer_params: Union[ Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] @@ -369,6 +395,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif self.is_nested: + if skip_data: + raise RuntimeError( + "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" + ) args_nested = ( # NB: values() currently returns the storage as a buffer in an unsafe way. # Ideally, we'd use a private API for this instead. TODO: Switch to this if @@ -383,14 +413,30 @@ def _reduce_ex_internal(self, proto): type(self) is not torch.Tensor and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ and ( - isinstance( - self, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), + isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) + or ( + not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and self.data_ptr() == 0 ) - or self.data_ptr() == 0 + ) + ): + arg_wrapper_subclass = ( + type(self), + self.dtype, + tuple(self.size()), + self.stride(), + self.storage_offset(), + self.layout, + self.device, + self.requires_grad, + ) + return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) + elif ( + type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + and ( + isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and not (skip_data and materialize_fake_tensors) ) ): arg_wrapper_subclass = ( @@ -418,6 +464,10 @@ def _reduce_ex_internal(self, proto): dtype=self.dtype, _internal=True, ) # type: ignore[assignment] + + if isinstance(self, torch._subclasses.fake_tensor.FakeTensor) and skip_data: + storage._fake_device = self.device + args = ( storage, self.storage_offset(), diff --git a/torch/_utils.py b/torch/_utils.py index 938392fa971590..f0d38daa811490 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -3,7 +3,6 @@ import functools import logging import sys -import threading import traceback import warnings from collections import defaultdict @@ -109,16 +108,13 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] -_thread_local_state = threading.local() - - def _get_restore_location(device): """Return the map_location location. Used for rebuild functions where the tensor device is distinct from the storage """ - map_location = getattr(_thread_local_state, "map_location", None) + map_location = torch.serialization._serialization_tls.map_location if map_location is None: return device else: diff --git a/torch/serialization.py b/torch/serialization.py index 2ac36ec371fe57..d936d31d6f5209 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,6 +11,7 @@ import sys import tarfile import tempfile +import threading import warnings from contextlib import closing, contextmanager from enum import Enum @@ -60,6 +61,7 @@ "get_safe_globals", "add_safe_globals", "safe_globals", + "skip_data", ] @@ -87,6 +89,22 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] +# _serialization_tls is used to store thread local state specific to serialization +# that needs to be propagated to other files, in particular we use this for +# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) +# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +class _SerializationLocal(threading.local): + def __init__(self): + super().__init__() + self.map_location: Optional[MAP_LOCATION] = None + self.skip_data: bool = False + self.materialize_fake_tensors: bool = False + + +_serialization_tls = _SerializationLocal() + + class SourceChangeWarning(Warning): pass @@ -268,6 +286,47 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ +class skip_data: + """ + Context-manager that skips writing storage bytes for ``torch.save`` calls. + + Storages will still be saved, but the space that their bytes would usually be written to + will be empty space. The storage bytes can then be populated in a separate pass. + + .. warning:: + The ``skip_data`` context manager is an early prototype and is subject to change. + + Args: + materialize_fake_tensors: Whether to materialize FakeTensors. + + Example: + >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows") + >>> import tempfile + >>> t = torch.randn(2, 3) + >>> with tempfile.NamedTemporaryFile() as f: + ... with torch.serialization.skip_data(): + ... torch.save(t, f.name) + ... torch.load(f.name, weights_only=True) + tensor([[0., 0., 0.], + [0., 0., 0.]]) + """ + + def __init__(self, materialize_fake_tensors: bool = False): + self.materialize_fake_tensors = materialize_fake_tensors + + def __enter__(self): + global _serialization_tls + self._old_skip_data = _serialization_tls.skip_data + self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors + _serialization_tls.skip_data = True + _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors + + def __exit__(self, type, value, tb): + global _serialization_tls + _serialization_tls.skip_data = self._old_skip_data + _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors + + def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -797,6 +856,11 @@ def save( ) return else: + global _serialization_tls + if _serialization_tls.skip_data: + raise RuntimeError( + "Cannot use skip_data=True with _use_new_zipfile_serialization=False" + ) with _open_file_like(f, "wb") as opened_file: _legacy_save(obj, opened_file, pickle_module, pickle_protocol) @@ -955,7 +1019,13 @@ def persistent_id(obj: Any) -> Optional[Tuple]: ) -def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): serialized_storages = {} id_map: Dict[int, str] = {} @@ -990,7 +1060,7 @@ def persistent_id(obj): # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check - if storage.data_ptr() != 0: + if str(storage.device) != "meta" and storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( @@ -1001,7 +1071,10 @@ def persistent_id(obj): storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) - location = location_tag(storage) + if hasattr(obj, "_fake_device") and obj._fake_device is not None: + location = str(obj._fake_device) + else: + location = location_tag(storage) serialized_storages[storage_key] = storage return ("storage", storage_type, storage_key, location, storage_numel) @@ -1027,14 +1100,18 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f"data/{key}" storage = serialized_storages[key] - # given that we copy things around anyway, we might use storage.cpu() - # this means to that to get tensors serialized, you need to implement - # .cpu() on the underlying Storage - if storage.device.type != "cpu": - storage = storage.cpu() - # Now that it is on the CPU we can directly copy it into the zip file num_bytes = storage.nbytes() - zip_file.write_record(name, storage, num_bytes) + global _serialization_tls + if _serialization_tls.skip_data: + zip_file.write_record_metadata(name, num_bytes) + else: + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + zip_file.write_record(name, storage, num_bytes) def load( @@ -1184,6 +1261,14 @@ def _get_wo_message(message: str) -> str: updated_message += message return updated_message + DOCS_MESSAGE + global _serialization_tls + skip_data = _serialization_tls.skip_data + if skip_data: + raise RuntimeError( + "`torch.load` called within a torch.serialization.skip_data context manager " + "is not supported yet. Please call torch.load outside the skip_data context manager." + ) + if weights_only is None: weights_only, warn_weights_only = False, True else: @@ -1758,9 +1843,10 @@ def find_class(self, mod_name, name): unpickler.persistent_load = persistent_load # Needed for tensors where storage device and rebuild tensor device are # not connected (wrapper subclasses and tensors rebuilt using numpy) - torch._utils._thread_local_state.map_location = map_location + global _serialization_tls + _serialization_tls.map_location = map_location result = unpickler.load() - del torch._utils._thread_local_state.map_location + _serialization_tls.map_location = None torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata( diff --git a/torch/storage.py b/torch/storage.py index b6ba608c16e5ca..8848649905f936 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,6 +39,8 @@ class _StorageBase: is_sparse: _bool = False is_sparse_csr: _bool = False device: torch.device + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None def __init__(self, *args, **kwargs): pass @@ -649,6 +651,8 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None dtype: torch.dtype From 5c65ecc11e4e2d74ae7deaff501f2dc5fa94a065 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Thu, 29 Aug 2024 01:18:56 +0000 Subject: [PATCH 0211/1018] Get accumulate dtype for Intel GPU (#134465) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): There are two function variants to get accumulated dtype for a given dtype: - Func1: `c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device)` - Func2: `c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda)` The Func1 is general enough to support different devices, while the Func2 only supports CUDA and CPU. This PR intends to add the Intel GPU path in the Func1. And we expect users to invoke the Func1 to ensure compatibility for different devices. * __->__ #134465 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134465 Approved by: https://github.com/Skylion007, https://github.com/atalman --- aten/src/ATen/AccumulateType.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aten/src/ATen/AccumulateType.cpp b/aten/src/ATen/AccumulateType.cpp index c4623cc08629c7..5de47571ea6fed 100644 --- a/aten/src/ATen/AccumulateType.cpp +++ b/aten/src/ATen/AccumulateType.cpp @@ -9,6 +9,8 @@ c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) { switch (device) { \ case DeviceType::CUDA: \ return CppTypeToScalarType>::value; \ + case DeviceType::XPU: \ + return CppTypeToScalarType>::value; \ case DeviceType::MPS: \ return CppTypeToScalarType>::value; \ default: \ From b72e82cff888d3696ac32a2d26a475d117b965e1 Mon Sep 17 00:00:00 2001 From: wz337 Date: Thu, 29 Aug 2024 06:01:12 +0000 Subject: [PATCH 0212/1018] [DTensor] Add pointwise ops strategy for aten.isinf, aten.isneginf, aten.isposinf (#134699) Fixes #ISSUE_NUMBER Need it for https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/utils/shampoo_preconditioner_list.py#L671 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134699 Approved by: https://github.com/tianyu-l --- test/distributed/_tensor/test_dtensor_ops.py | 3 --- torch/distributed/_tensor/ops/_pointwise_ops.py | 5 +++++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 480207e17d29ca..532bd3facae554 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -191,9 +191,6 @@ def wrapped(fn): xfail("index_reduce", "amin"), xfail("index_select"), xfail("isin"), - xfail("isinf"), - xfail("isneginf"), - xfail("isposinf"), xfail("kthvalue"), xfail("linalg.cholesky"), xfail("linalg.cholesky_ex"), diff --git a/torch/distributed/_tensor/ops/_pointwise_ops.py b/torch/distributed/_tensor/ops/_pointwise_ops.py index f3f32805dca065..73062441025fae 100644 --- a/torch/distributed/_tensor/ops/_pointwise_ops.py +++ b/torch/distributed/_tensor/ops/_pointwise_ops.py @@ -239,7 +239,12 @@ aten.igammac.default, aten.igammac.out, aten.igammac_.default, + aten.isinf.default, aten.isnan.default, + aten.isneginf.default, + aten.isneginf.out, + aten.isposinf.default, + aten.isposinf.out, aten.ldexp.default, aten.ldexp.out, aten.ldexp_.default, From 31f399508af5b0e67074a7357c120101b61bd5c3 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 29 Aug 2024 06:20:50 +0000 Subject: [PATCH 0213/1018] [Inductor][mkldnn] Bug fix: incorrect codegen arg order for qconv (#134579) Fixes #133448 The arg order for mkldnn qconv IR became incorrect after PR #132367 . This PR fixes the bug. **Test plan** `python test/inductor/test_mkldnn_pattern_matcher.py -k qconv` `python test/inductor/test_cpu_cpp_wrapper.py -k qconv` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134579 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5 --- test/inductor/test_mkldnn_pattern_matcher.py | 8 +++++--- torch/_inductor/mkldnn_ir.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 73fa12fe329778..3a98efae39b02c 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -733,7 +733,9 @@ def __init__( super().__init__() self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) self.unary_fn = copy.deepcopy(unary_op) - self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1) + self.conv2 = torch.nn.Conv2d( + 128, 128, kernel_size=3, stride=1, bias=False + ) self.unary_fn2 = copy.deepcopy(unary_op) def forward(self, x): @@ -889,8 +891,8 @@ def __init__( self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) self.add_fn = add_fn self.relu = torch.nn.ReLU() - self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) - self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False) + self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False) self.add_fn2 = add_fn self.relu2 = torch.nn.ReLU() self.use_relu = use_relu diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 0a8bbb9d1e6e17..eeab8fcc76c652 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -623,15 +623,19 @@ def codegen(self, wrapper): "scalars", "algorithm", ] + if not self.has_bias: + const_arg_names.insert(2, "bias") const_args = list(self.codegen_const_args(const_arg_names)) x = args[0] packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] + bias = args[2] if self.has_bias else const_args[2] w_scale, w_zp = args[-2], args[-1] ( x_scale, x_zp, + ) = const_args[:2] + ( stride, padding, dilation, @@ -642,7 +646,7 @@ def codegen(self, wrapper): unary_attr, unary_scalars, unary_algorithm, - ) = const_args[-12:] + ) = const_args[-10:] codegen_args = ( x, @@ -821,17 +825,21 @@ def codegen(self, wrapper): "unary_scalars", "unary_algorithm", ] + if not self.has_bias: + const_arg_names.insert(4, "bias") const_args = list(self.codegen_const_args(const_arg_names)) x = args[0] packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] + bias = args[2] if self.has_bias else const_args[4] accum, w_scale, w_zp = args[-3], args[-2], args[-1] ( x_scale, x_zp, accum_scale, accum_zp, + ) = const_args[:4] + ( stride, padding, dilation, @@ -844,7 +852,7 @@ def codegen(self, wrapper): unary_attr, unary_scalars, unary_algorithm, - ) = const_args[-16:] + ) = const_args[-12:] conv_args = ( x, x_scale, From c22a6766dd2beb2c1bff34d38cd97b17160fc023 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 28 Aug 2024 19:16:43 -0700 Subject: [PATCH 0214/1018] [2nd try][Traceable FSDP2] Allow tracing through FSDP2 impl in trace_rules.py (#134539) The previous PR https://github.com/pytorch/pytorch/pull/133532 caused stuck compilation issue on internal models. In this 2nd attempt PR, we gate the trace_rules.py changes with `if not torch._dynamo.config.skip_fsdp_hooks:`, so that they don't take effect for current graph-break FSDP2 (which relies on the default config value `skip_fsdp_hooks=True`), and will only take effect when we are using Traceable FSDP2 (in which case the user needs to proactively set `skip_fsdp_hooks=False`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/134539 Approved by: https://github.com/ckluk2, https://github.com/yanboliang --- torch/_dynamo/trace_rules.py | 4 ++++ torch/distributed/_composable/fsdp/_fsdp_state.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d924c383186253..f4ad6ff01e848c 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3211,6 +3211,8 @@ def _module_dir(m: types.ModuleType): # the forward_hook won't be ignored. "torch.distributed._composable.replicate", } + if not torch._dynamo.config.skip_fsdp_hooks: + LEGACY_MOD_INLINELIST.add("torch.distributed._composable.fsdp") # Force inline functions under these modules, even they are in *_SKIPLIST. @@ -3260,6 +3262,8 @@ def _module_dir(m: types.ModuleType): if torch.distributed.is_available(): MOD_INLINELIST.add("torch.distributed") + if not torch._dynamo.config.skip_fsdp_hooks: + MOD_INLINELIST.add("torch.distributed._composable.fsdp") @functools.lru_cache(None) diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index 4b949c64d9f7f8..ceb480fd23939f 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -351,12 +351,14 @@ def _register_group_forward_hooks( """ modules_set = set(modules) + @disable_if_config_true @functools.wraps(pre_hook) def wrapped_pre_hook(*args: Any, **kwargs: Any): if len(modules_to_run) == 0: # first to run modules_to_run.update(modules_set) return pre_hook(*args, **kwargs) + @disable_if_config_true def get_wrapped_post_hook(module: nn.Module): @functools.wraps(post_hook) def wrapped_post_hook(*args: Any, **kwargs: Any): From 8fddfb644fc8ab960adeb6b3b36b61e544ebf1e2 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 28 Aug 2024 22:39:36 -0700 Subject: [PATCH 0215/1018] [1/N] Move NaN check onto NCCL stream (#134300) So that the tensor's lifetime management is the same as the management built for the NCCL, pre and post kernels. Also so that on visualizers, they show up in the NCCL stream line. Otherwise if they show up in the compute line, user may get confused (my code does not have these kernels). The check is thus moved after the point where we depend NCCL stream from the last compute kernel. Also moved declaration of `checkForNan` from Utils.hpp to NCCLUtils.hpp, and renamed Utils.cu to NCCLUtils.cu. Differential Revision: [D61957573](https://our.internmc.facebook.com/intern/diff/D61957573) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134300 Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab --- BUILD.bazel | 4 ++-- build_variables.bzl | 2 +- .../distributed/c10d/{Utils.cu => NanCheck.cu} | 10 +++++++--- torch/csrc/distributed/c10d/NanCheck.hpp | 16 ++++++++++++++++ torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 15 +++++++++------ torch/csrc/distributed/c10d/Utils.hpp | 2 -- 6 files changed, 35 insertions(+), 14 deletions(-) rename torch/csrc/distributed/c10d/{Utils.cu => NanCheck.cu} (83%) create mode 100644 torch/csrc/distributed/c10d/NanCheck.hpp diff --git a/BUILD.bazel b/BUILD.bazel index f079c76a7f7204..c71781606b7098 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -574,7 +574,7 @@ cu_library( name = "torch_cuda", srcs = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NanCheck.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], copts = torch_cuda_half_options, @@ -722,7 +722,7 @@ cc_library( "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NanCheck.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], )) + torch_sources, diff --git a/build_variables.bzl b/build_variables.bzl index c3c06c01ada6e8..aa6728eb00b883 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -691,7 +691,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NanCheck.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] diff --git a/torch/csrc/distributed/c10d/Utils.cu b/torch/csrc/distributed/c10d/NanCheck.cu similarity index 83% rename from torch/csrc/distributed/c10d/Utils.cu rename to torch/csrc/distributed/c10d/NanCheck.cu index ae2017efdf8d97..a8d22b9a44e21e 100644 --- a/torch/csrc/distributed/c10d/Utils.cu +++ b/torch/csrc/distributed/c10d/NanCheck.cu @@ -1,9 +1,11 @@ +#ifdef USE_C10D_NCCL + #include #include #include -#include #include #include +#include namespace c10d { @@ -20,7 +22,7 @@ __global__ void checkForNaN(T* data, size_t size) { } // CHECK if a Tensor contains NAN in any of its element -void checkForNan(const at::Tensor& tensor) { +void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { // skip check for non float types if (!torch::is_floating_point(tensor)) { return; @@ -40,10 +42,12 @@ void checkForNan(const at::Tensor& tensor) { tensor.scalar_type(), "checkForNaN", [&] { - checkForNaN<<>>( + checkForNaN<<>>( tensor.data_ptr(), tensor.numel()); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } } // namespace c10d + +#endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/NanCheck.hpp b/torch/csrc/distributed/c10d/NanCheck.hpp new file mode 100644 index 00000000000000..cc9a5867c3dd40 --- /dev/null +++ b/torch/csrc/distributed/c10d/NanCheck.hpp @@ -0,0 +1,16 @@ +#pragma once + +#ifdef USE_C10D_NCCL + +#include +#include + +namespace c10d { + +// Check for NaNs in a tensor on a given stream. If any are found, throw a +// device-side error. +void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream); + +} // namespace c10d + +#endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 71a081b5e495d5..310b351eb219cc 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -2638,9 +2639,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( OpType opType, const char* profilingTitle, bool avoidRecordStreams) { - if (enableNanCheck_) { - checkForNan(input); - } // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; c10::cuda::CaptureStatus capture_status = @@ -2696,6 +2694,10 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; + if (enableNanCheck_) { + checkForNan(input, ncclStream); + } + // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { work->ncclStartEvent_->record(ncclStream); @@ -2997,9 +2999,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( PreProcess pre, PostProcess post, const char* profilingTitle) { - if (enableNanCheck_) { - checkForNan(tensor); - } // avoidRecordStreams_ note: // send, recv, and irecv should be ok with avoidRecordStreams, // However, for isend, I don't think the API requires the user @@ -3128,6 +3127,10 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; + if (enableNanCheck_) { + checkForNan(tensor, ncclStream); + } + if (!coalescing_state_) { // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 5c7365393404fc..ea4a4653bc35fc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -611,8 +611,6 @@ using SizeType = uint64_t; // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) -void checkForNan(const at::Tensor& tensor); - namespace tcputil { // Send and receive From a3118c0a42c8c5595e5d4e4adc84ac712727d57d Mon Sep 17 00:00:00 2001 From: wz337 Date: Thu, 29 Aug 2024 09:01:29 +0000 Subject: [PATCH 0216/1018] [DTensor] Extend implicit replication to replicate DTensor for foreach ops so model doesn't have to be fully tp-ed when using 2D (#134551) Fixes [134212](https://github.com/pytorch/pytorch/issues/134212) Currently, when we use 2D FSDP with TP, `optimizer.step()` would fail if the model were not fully tensor parallelized. If we don't have the entire model tensor parallelized when doing 2D, we would have both 1D and 2D DTensor parameters. As foreach is turned on by default, `optimizer.step()` would fail as cross mesh op is not allowed. Error as follows: ``` NotImplementedError: aten._foreach_mul_.Scalar: DTensor does not support cross-mesh operation yet!Got meshes: DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')) DeviceMesh('cuda', [1, 3], mesh_dim_names=('dp',)) ``` In this PR, we extend implicit_replication to replicate DTensor in missing dimensions for foreach ops. If users don't want to fully tensor parallelize the model when using 2D, they have the option of using the `implicit_replication()` context manager for `optimizer.step()`. In this case, we would swap out the 1D DTensorSpec and replace it with 2D DTensorSpec. However, we don't want to turn this on by default yet, as we want the users to be aware that the tp dimension is replicated if a layer is not tp-ed. With implicit implication turning on, try replicate dtensor spec in missing dimension would work for most cases for foreach case except when the first DTensor in the list is one that also need to be replicated. This is currently a limitation, which I don't have a good solution yet. Currently, with this change, we can handle most of the cases except the case that the first DTensor's ndim is not the largest. ``` [2D_DTensor, 1D_DTensor...] ---> Implicit_replication() can handle this. [1D_DTensor, 2D_DTensor...] ---> Implicit_replication() can't handle this. ``` This change doesn't affect the existing default behavior, as `implicit_replication()` is not turned on by default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134551 Approved by: https://github.com/tianyu-l --- test/distributed/_tensor/test_dtensor.py | 38 ++- .../checkpoint/fsdp/test_fsdp_dsd.py | 228 ++++++++++-------- torch/distributed/_tensor/_dispatch.py | 66 ++++- 3 files changed, 219 insertions(+), 113 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 14f92115e7637b..e4570f7dc41bcb 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -15,6 +15,7 @@ init_device_mesh, ) from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed._tensor.experimental import implicit_replication from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, @@ -778,8 +779,6 @@ def test_implicit_replication(self): local_tensor1 = torch.ones(4, 3) sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)]) - from torch.distributed._tensor.experimental import implicit_replication - with implicit_replication(): # We put the scalar tensor as the left operand so we can test out # when a non-dtensor is a the arg in the args list. @@ -816,6 +815,41 @@ def add_scalar_tensor_with_dtensor(): (numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor ) + @with_comms + def test_implicit_replication_for_foreach_ops(self): + mesh = init_device_mesh( + self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") + ) + global_tensor1 = torch.randn(4, 2) + dtensor_2d = distribute_tensor(global_tensor1, mesh, [Shard(0), Shard(1)]) + self.assertEqual(dtensor_2d.full_tensor(), global_tensor1) + global_tensor2 = torch.randn(4) + dtensor_1d = distribute_tensor(global_tensor2, mesh["dp"], [Shard(0)]) + dtensor_list = [dtensor_2d, dtensor_1d] + + # Check without implicit replication, cross mesh error raises. + with self.assertRaisesRegex( + RuntimeError, "DTensor does not support cross-mesh operation yet!" + ): + torch._foreach_mul(dtensor_list, 2.0) + + # Check dtensor result matches tensor result. + with implicit_replication(): + torch._foreach_mul_(dtensor_list, 2.0) + self.assertEqual(dtensor_list[0].full_tensor(), global_tensor1 * 2.0) + self.assertEqual(dtensor_list[1].full_tensor(), global_tensor2 * 2.0) + + mesh_1d = DeviceMesh.from_group(mesh["tp"].get_group(), self.device_type) + dtensor_1d = distribute_tensor(global_tensor2, mesh_1d, [Shard(0)]) + dtensor_list = [dtensor_2d, dtensor_1d] + + # Check even with implicit replication, cross mesh error raises if different device mesh don't + # belong to the same root mesh. + with self.assertRaisesRegex( + RuntimeError, "DTensor does not support cross-mesh operation yet!" + ): + torch._foreach_mul_(dtensor_list, 2.0) + @with_comms def test_metadata_consistency_check(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py b/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py index 887e46e2510bad..5ea55270936d7d 100644 --- a/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py +++ b/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import contextlib import copy import torch @@ -7,6 +8,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp import fully_shard from torch.distributed._tensor import DTensor, init_device_mesh +from torch.distributed._tensor.experimental import implicit_replication from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, @@ -439,8 +441,17 @@ def _get_base_model(mlp_dim: int = 2): self.assertEqual(base_osd, fsdp2_tp_full_osd) @skip_if_lt_x_gpu(4) - @with_temp_dir def test_save_with_fsdp2_tp_and_load_with_tp(self): + self.run_subtests( + {"allow_implicit_replication": [True, False]}, + self._test_save_with_fsdp2_tp_and_load_with_tp, + ) + + @skip_if_lt_x_gpu(4) + @with_temp_dir + def _test_save_with_fsdp2_tp_and_load_with_tp( + self, allow_implicit_replication: bool + ): """ Test that we can save a model with FSDP2 + TP on 2d mesh and load it with TP. """ @@ -449,6 +460,11 @@ def _get_base_model(mlp_dim: int = 2): base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim)) return base_model + cm = ( + implicit_replication() + if allow_implicit_replication + else contextlib.nullcontext() + ) tp_parallelize_plan = { "0.in_proj": ColwiseParallel(), "0.out_proj": RowwiseParallel(), @@ -457,108 +473,124 @@ def _get_base_model(mlp_dim: int = 2): "2.in_proj": ColwiseParallel(), "2.out_proj": RowwiseParallel(), } + if allow_implicit_replication: + # intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized + tp_parallelize_plan.pop("0.in_proj") + tp_parallelize_plan.pop("0.out_proj") - # init device mesh - dp_size = 2 - global_mesh_1d = init_device_mesh( - "cuda", (self.world_size,), mesh_dim_names=("tp",) - ) - global_mesh_2d = init_device_mesh( - "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") - ) - dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"] - - for save_full_state_dict in [True, False]: - # Save state dict with original model - base_model = _get_base_model().cuda() - base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1) - - # Save state dict with FSDP2 + TP model - fsdp2_tp_model = copy.deepcopy(base_model) - fsdp2_tp_model = parallelize_module( - fsdp2_tp_model, - device_mesh=tp_mesh, - parallelize_plan=tp_parallelize_plan, - ) - for module in fsdp2_tp_model: - fully_shard(module, mesh=dp_mesh) - fully_shard(fsdp2_tp_model, mesh=dp_mesh) - fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1) - - # one-step training to modify state dict - inp = torch.randn((2,), device=self.rank) - base_model(inp).sum().backward() - base_optim.step() - fsdp2_tp_model(inp).sum().backward() - fsdp2_tp_optim.step() - - # obtain the unsharded state dict - base_msd = get_model_state_dict( - base_model, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - base_osd = get_optimizer_state_dict( - base_model, - base_optim, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - - # obtain FSDP2 + TP state dict - fsdp2_tp_msd = get_model_state_dict( - fsdp2_tp_model, - options=StateDictOptions(full_state_dict=save_full_state_dict), - ) - fsdp2_tp_osd = get_optimizer_state_dict( - fsdp2_tp_model, - fsdp2_tp_optim, - options=StateDictOptions(full_state_dict=save_full_state_dict), - ) - - fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd} - dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir) - - fsdp2_tp_full_msd = get_model_state_dict( - fsdp2_tp_model, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - fsdp2_tp_full_osd = get_optimizer_state_dict( - fsdp2_tp_model, - fsdp2_tp_optim, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - - # Load state dict into model with TP applied - tp_model = _get_base_model() - tp_model = parallelize_module( - tp_model, - device_mesh=global_mesh_1d, - parallelize_plan=tp_parallelize_plan, - ) - tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1) - - tp_state_dict = { - "model": get_model_state_dict(tp_model), - "optim": get_optimizer_state_dict(tp_model, tp_optim), + with cm: + tp_parallelize_plan = { + "0.in_proj": ColwiseParallel(), + "0.out_proj": RowwiseParallel(), + "1.in_proj": ColwiseParallel(), + "1.out_proj": RowwiseParallel(), + "2.in_proj": ColwiseParallel(), + "2.out_proj": RowwiseParallel(), } - dcp.load(tp_state_dict, checkpoint_id=self.temp_dir) - tp_model.load_state_dict(tp_state_dict["model"]) - tp_optim.load_state_dict(tp_state_dict["optim"]) - tp_full_msd = get_model_state_dict( - tp_model, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), + # init device mesh + dp_size = 2 + global_mesh_1d = init_device_mesh( + "cuda", (self.world_size,), mesh_dim_names=("tp",) ) - tp_full_osd = get_optimizer_state_dict( - tp_model, - tp_optim, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), + global_mesh_2d = init_device_mesh( + "cuda", + (dp_size, self.world_size // dp_size), + mesh_dim_names=("dp", "tp"), ) - - # Compare full state dict to make sure they are the same. - self.assertEqual(base_msd, tp_full_msd) - self.assertEqual(base_osd, tp_full_osd) - self.assertEqual(fsdp2_tp_full_msd, tp_full_msd) - self.assertEqual(fsdp2_tp_full_osd, tp_full_osd) + dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"] + + for save_full_state_dict in [True, False]: + # Save state dict with original model + base_model = _get_base_model().cuda() + base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1) + + # Save state dict with FSDP2 + TP model + fsdp2_tp_model = copy.deepcopy(base_model) + fsdp2_tp_model = parallelize_module( + fsdp2_tp_model, + device_mesh=tp_mesh, + parallelize_plan=tp_parallelize_plan, + ) + for module in fsdp2_tp_model: + fully_shard(module, mesh=dp_mesh) + fully_shard(fsdp2_tp_model, mesh=dp_mesh) + fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1) + + # one-step training to modify state dict + inp = torch.randn((2,), device=self.rank) + base_model(inp).sum().backward() + base_optim.step() + fsdp2_tp_model(inp).sum().backward() + fsdp2_tp_optim.step() + + # obtain the unsharded state dict + base_msd = get_model_state_dict( + base_model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + base_osd = get_optimizer_state_dict( + base_model, + base_optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + # obtain FSDP2 + TP state dict + fsdp2_tp_msd = get_model_state_dict( + fsdp2_tp_model, + options=StateDictOptions(full_state_dict=save_full_state_dict), + ) + fsdp2_tp_osd = get_optimizer_state_dict( + fsdp2_tp_model, + fsdp2_tp_optim, + options=StateDictOptions(full_state_dict=save_full_state_dict), + ) + + fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd} + dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir) + + fsdp2_tp_full_msd = get_model_state_dict( + fsdp2_tp_model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + fsdp2_tp_full_osd = get_optimizer_state_dict( + fsdp2_tp_model, + fsdp2_tp_optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + # Load state dict into model with TP applied + tp_model = _get_base_model() + tp_model = parallelize_module( + tp_model, + device_mesh=global_mesh_1d, + parallelize_plan=tp_parallelize_plan, + ) + tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1) + + tp_state_dict = { + "model": get_model_state_dict(tp_model), + "optim": get_optimizer_state_dict(tp_model, tp_optim), + } + dcp.load(tp_state_dict, checkpoint_id=self.temp_dir) + tp_model.load_state_dict(tp_state_dict["model"]) + tp_optim.load_state_dict(tp_state_dict["optim"]) + + tp_full_msd = get_model_state_dict( + tp_model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + tp_full_osd = get_optimizer_state_dict( + tp_model, + tp_optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + # Compare full state dict to make sure they are the same. + self.assertEqual(base_msd, tp_full_msd) + self.assertEqual(base_osd, tp_full_osd) + self.assertEqual(fsdp2_tp_full_msd, tp_full_msd) + self.assertEqual(fsdp2_tp_full_osd, tp_full_osd) if __name__ == "__main__": diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py index 129e556e747298..c811fc457d6f76 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/_tensor/_dispatch.py @@ -309,16 +309,20 @@ def unwrap_to_op_info( for arg in args_list: if isinstance(arg, dtensor.DTensor): - args_schema.append(arg._spec) local_args.append(arg._local_tensor) - if mesh is not None: - if mesh != arg.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - f"Got meshes: {mesh} {arg.device_mesh}" - ) + if mesh is not None and mesh != arg.device_mesh: + # TODO: try replicate dtensor spec in missing dimension would work + # for most cases for foreach case except when the first DTensor in + # the list is one that also need to be replicated. We need to revisit + # how we want to handle this corner case. For now, this case would hit + # the cross mesh error even if implicit replication is turned on. + spec = self._try_replicate_dtensor_spec_in_missing_dim( + op_call, arg, mesh + ) + args_schema.append(spec) else: mesh = arg.device_mesh + args_schema.append(arg._spec) elif isinstance(arg, torch.Tensor): mesh = mesh or try_find_mesh_from_args(op_call, args_list) args_schema.append( @@ -331,15 +335,15 @@ def unwrap_to_op_info( for k, v in kwargs.items(): if isinstance(v, dtensor.DTensor): - kwargs_schema[k] = v._spec local_kwargs[k] = v._local_tensor - if mesh is not None: - if mesh != v.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - ) + if mesh is not None and mesh != v.device_mesh: + spec = self._try_replicate_dtensor_spec_in_missing_dim( + op_call, v, mesh + ) + kwargs_schema[k] = spec else: mesh = v.device_mesh + kwargs_schema[k] = v._spec elif isinstance(v, torch.Tensor): mesh = mesh or try_find_mesh_from_args(op_call, args_list) kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( @@ -426,3 +430,39 @@ def _try_replicate_spec_for_scalar_tensor( " torch.Tensor to DTensor before calling distributed operators!" ) return replication_spec + + def _try_replicate_dtensor_spec_in_missing_dim( + self, + op_call: torch._ops.OpOverload, + dtensor_arg: "dtensor.DTensor", + mesh: "DeviceMesh", + ) -> DTensorSpec: + # util function to produce a new spec for a DTensor arg/kwarg + # that puts Replicate() placement in the missing dimension for foreach ops + from torch.distributed.device_mesh import _mesh_resources + + cur_mesh = dtensor_arg.device_mesh + root_mesh = _mesh_resources.get_root_mesh(cur_mesh) + if ( + self._allow_implicit_replication + and "foreach" in op_call.__name__ + and root_mesh == mesh + ): + placements = [Replicate() for _ in range(root_mesh.ndim)] + cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh) + placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload] + replicate_spec = DTensorSpec( + root_mesh, + tuple(placements), + tensor_meta=TensorMeta( + shape=dtensor_arg.shape, + stride=dtensor_arg.stride(), + dtype=dtensor_arg.dtype, + ), + ) + else: + raise NotImplementedError( + f"{op_call}: DTensor does not support cross-mesh operation yet! " + f"Got meshes: {mesh} {cur_mesh}" + ) + return replicate_spec From 434c177df50ffb0183e18423b2f497199f7be63a Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 28 Aug 2024 21:59:02 -0700 Subject: [PATCH 0217/1018] [dynamo][exceptions] Use exception subclass whenever possible (#134610) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134610 Approved by: https://github.com/drisspg, https://github.com/jansel --- test/dynamo/test_exceptions.py | 23 +++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 7 ++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index d9b364443eb132..33c6e3d0f6121b 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -267,6 +267,29 @@ def forward(self, x): x = torch.ones(4) self.assertEqual(mod(x), opt_mod(x)) + def test_attribute_error_from_getattr(self): + class Mock: + def __init__(self): + self.a = 5 + + def __getattr__(self, name): + if name != "a": + raise AttributeError("missing") + return self.__dict__["a"] + + mock = Mock() + + def fn(x): + if hasattr(mock, "b"): + return torch.cos(x) + return torch.sin(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + def test_stop_iteration(self): def zip_longest(*iterables, fillvalue=None): # Get the iterators for each iterable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 0b7d68be08b9e5..7e8040b09a03e4 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1379,9 +1379,10 @@ def RAISE_VARARGS(self, inst): # 2) when user raises exception instance if isinstance(val, variables.ExceptionVariable): - if val.exc_type is StopIteration: - # StopIteration is used to find the end of iteration while tracing __next__ - raise exc.ObservedUserStopIteration(f"raised exception {val}") + if observed_exception_type := exc.observed_exception_map.get( + val.exc_type + ): + raise observed_exception_type(f"raised exception {val}") raise exc.ObservedException(f"raised exception {val}") unimplemented(f"raise {exc}") else: From 9bab9c94937b30c50a45d961d64c85210b66cde7 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 28 Aug 2024 21:59:03 -0700 Subject: [PATCH 0218/1018] [dynamo][dicts] Support hasattr on dicts (#134590) Fixes - https://github.com/pytorch/pytorch/issues/134577 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134590 Approved by: https://github.com/Skylion007 ghstack dependencies: #134610 --- test/dynamo/test_functions.py | 16 ++++++++++++++++ torch/_dynamo/variables/dicts.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 67ab47dfac9093..f5cbfb8760c307 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1683,6 +1683,22 @@ def test_dict_sorted(x): tmp = {1: "D", 10: "B", 3: "E", 0: "F"} return x + 1, sorted(tmp), sorted(tmp, reverse=True) + def test_dict_hasattr(self): + def fn(x): + if hasattr(x, "to"): + return x.to("cpu") + if hasattr(x, "items"): + return torch.cos(x["a"]) + return x + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = dict(a=torch.randn(3)) + self.assertEqual(fn(x), opt_fn(x)) + + x = torch.randn(4) + self.assertEqual(fn(x), opt_fn(x)) + @make_test def test_list_clear(a, b): tmp = [a + 1, a + 2] diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 2d4c015bf3ac1b..f9b74bb2f202f4 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -355,6 +355,15 @@ def call_method( def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] + def call_hasattr(self, tx, name): + # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict. + # OrderedDict though requires side effects tracking because it supports arbitrary setattr. + if self.user_cls is dict: + if name in self.user_cls.__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + unimplemented(f"hasattr on {self.user_cls} is not supported") + class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: From 2de164bd58e37fb2f9eccd4dc468f70610cef765 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 28 Aug 2024 21:59:03 -0700 Subject: [PATCH 0219/1018] [raland][dynamo][exceptions] Support raise from None (#134621) The PR was reverted because this PR traced more code and surfaced a latent bug. Resubmitting w/o any changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134621 Approved by: https://github.com/jansel ghstack dependencies: #134610, #134590 --- test/dynamo/test_exceptions.py | 33 +++++++++++++++ torch/_dynamo/symbolic_convert.py | 69 ++++++++++++++++++------------- 2 files changed, 73 insertions(+), 29 deletions(-) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 33c6e3d0f6121b..5efa344c7ea2ed 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -353,6 +353,39 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + def test_raise_from_None(self): + # Inspired from os.environ + class MyMapping: + def __init__(self, d): + self._d = d + + def __getitem__(self, key): + try: + value = self._d[key] + except KeyError: + raise KeyError(key) from None + return value + + d = MyMapping({"a": 10, "b": 20}) + + def mapping_get(obj, key, value=None): + try: + return obj.__getitem__(key) + except KeyError: + return value + + def fn(x, d, key): + x = torch.sin(x + 1) + return x, mapping_get(d, key) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = torch.rand(2, 3) + ref = fn(x, d, "m") + res = opt_fn(x, d, "m") + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 7e8040b09a03e4..e605e57b6aa49c 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1359,35 +1359,47 @@ def FOR_ITER(self, inst): self.push(ConstantVariable.create(None)) self.jump(inst) + def _raise_exception_variable(self, inst): + val = self.pop() + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + if isinstance(val, variables.BuiltinVariable): + # Create the instance of the exception type + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 + val = val.call_function(self, [], {}) # type: ignore[arg-type] + + # Save the exception in a global data structure + self.exn_vt_stack.append(val) + + # 2) when user raises exception instance + if isinstance(val, variables.ExceptionVariable): + if observed_exception_type := exc.observed_exception_map.get(val.exc_type): + raise observed_exception_type(f"raised exception {val}") + raise exc.ObservedException(f"raised exception {val}") + unimplemented(f"raise {exc}") + def RAISE_VARARGS(self, inst): if inst.arg == 0: unimplemented("re-raise") elif inst.arg == 1: - val = self.pop() - # User can raise exception in 2 ways - # 1) raise exception type - raise NotImplementedError - # 2) raise execption instance - raise NotImplemetedError("foo") - - # 1) when user raises exception type - if isinstance(val, variables.BuiltinVariable): - # Create the instance of the exception type - # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 - val = val.call_function(self, [], {}) # type: ignore[arg-type] - - # Save the exception in a global data structure - self.exn_vt_stack.append(val) - - # 2) when user raises exception instance - if isinstance(val, variables.ExceptionVariable): - if observed_exception_type := exc.observed_exception_map.get( - val.exc_type - ): - raise observed_exception_type(f"raised exception {val}") - raise exc.ObservedException(f"raised exception {val}") - unimplemented(f"raise {exc}") + self._raise_exception_variable(inst) else: + # Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we + # ignore `from None` part. + from_vt = self.pop() + if isinstance(from_vt, ConstantVariable) and from_vt.value is None: + self._raise_exception_variable(inst) unimplemented("raise ... from ...") + def RERAISE(self, inst): + if sys.version_info >= (3, 11): + # RERAISE is currently supported in a narrow case of `raise ... from None` + self._raise_exception_variable(inst) + unimplemented("RERAISE") + def exception_handler(self, raised_exception): if sys.version_info >= (3, 11): exn_tab_entry = self.current_instruction.exn_tab_entry @@ -1401,10 +1413,6 @@ def exception_handler(self, raised_exception): # 2) if 'lasti' is true, then push the offset that the exception was raised at if exn_tab_entry.lasti: - # This is untested. Any test that tests this end-to-end - # requires supporting more bytecodes. Therefore graph - # breaking for now. - unimplemented("lasti=True while exception handling") self.push( variables.ConstantVariable(self.current_instruction.offset) ) @@ -1436,9 +1444,12 @@ def exception_handler(self, raised_exception): # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 self.popn(3) if len(self.block_stack) == 0: - unimplemented( - "exception is raised when block stack " "is empty" - ) + # No handler found in this frame. Bubble the exception to the parent + # instruction translater. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise raised_exception block_stack_entry = self.block_stack.pop() if block_stack_entry.inst.opname != "SETUP_FINALLY": From b78a9e2086fcbb7deda861c133fcd97108dae7b5 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 28 Aug 2024 21:59:04 -0700 Subject: [PATCH 0220/1018] [dynamo] Graph break on FSDP flat_param inconsistent tensor and grad dtype (#134614) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134614 Approved by: https://github.com/awgu, https://github.com/yf225 ghstack dependencies: #134610, #134590, #134621 --- torch/_dynamo/variables/builder.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 9cd48c9c04c85b..b4f35dbbd404b8 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -15,6 +15,7 @@ import re import sys import types +import warnings import weakref from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union @@ -25,7 +26,7 @@ from torch._ops import HigherOrderOperator from torch._streambase import _EventBase, _StreamBase from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode -from torch._subclasses.meta_utils import is_sparse_any +from torch._subclasses.meta_utils import is_sparse_any, safe_grad from torch._utils_internal import justknobs_check from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.symbolic_shapes import ( @@ -220,6 +221,12 @@ DimList = List +def safe_has_grad(t): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + return hasattr(t, "grad") + + class _missing: pass @@ -1512,6 +1519,18 @@ def wrap_tensor(self, value: torch.Tensor): # SPARSE_TENSOR_GUARDS for guards to work propertly. unimplemented("torch.compile does not support sparse Tensors") + if ( + safe_has_grad(value) + and safe_grad(value) is not None + and value.dtype != safe_grad(value).dtype + ): + unimplemented( + "Inconsistent dtype between tensor and its gradient. " + "This can happen in FSDP and crashes meta tensor creation. " + "This is potentially a workaround. Fixing it correctly " + "requires some design around FSDP + torch.compile." + ) + tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=tensor_proxy, From cdebfde97488c26654e0a7adc4b694174f1701b4 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 29 Aug 2024 18:11:27 +0800 Subject: [PATCH 0221/1018] [dynamo] simplify implementation for `os.fspath` (#133801) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133801 Approved by: https://github.com/anijain2305 --- torch/_dynamo/polyfills/__init__.py | 12 ---------- torch/_dynamo/polyfills/builtins.py | 6 +++++ torch/_dynamo/polyfills/functools.py | 6 +++++ torch/_dynamo/polyfills/itertools.py | 6 +++++ torch/_dynamo/polyfills/loader.py | 8 ++++++- torch/_dynamo/polyfills/os.py | 36 ++++++++++++++++++++++++++++ torch/_dynamo/polyfills/sys.py | 6 +++++ torch/_dynamo/variables/misc.py | 14 +---------- 8 files changed, 68 insertions(+), 26 deletions(-) create mode 100644 torch/_dynamo/polyfills/builtins.py create mode 100644 torch/_dynamo/polyfills/functools.py create mode 100644 torch/_dynamo/polyfills/itertools.py create mode 100644 torch/_dynamo/polyfills/os.py create mode 100644 torch/_dynamo/polyfills/sys.py diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 07ded8c4bd562f..acc4f613c8e426 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -168,15 +168,3 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): if isinstance(obj, cls): obj.__init__(*args, **kwargs) return obj - - -def fspath(path): - # Python equivalent of os.fspath - if isinstance(path, (str, bytes)): - return path - elif hasattr(path, "__fspath__"): - return path.__fspath__() - else: - raise TypeError( - f"expected str, bytes or os.PathLike object, not {type(path).__name__}" - ) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py new file mode 100644 index 00000000000000..04757a19afea46 --- /dev/null +++ b/torch/_dynamo/polyfills/builtins.py @@ -0,0 +1,6 @@ +""" +Python polyfills for builtins +""" + + +__all__ = [] # type: ignore[var-annotated] diff --git a/torch/_dynamo/polyfills/functools.py b/torch/_dynamo/polyfills/functools.py new file mode 100644 index 00000000000000..98fd9597773b35 --- /dev/null +++ b/torch/_dynamo/polyfills/functools.py @@ -0,0 +1,6 @@ +""" +Python polyfills for functools +""" + + +__all__ = [] # type: ignore[var-annotated] diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py new file mode 100644 index 00000000000000..dae625497e42f4 --- /dev/null +++ b/torch/_dynamo/polyfills/itertools.py @@ -0,0 +1,6 @@ +""" +Python polyfills for itertools +""" + + +__all__ = [] # type: ignore[var-annotated] diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 04e30572d773cd..255486b6e016dd 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -11,7 +11,13 @@ from types import ModuleType -POLYFILLED_MODULE_NAMES: Tuple[str, ...] = () +POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( + "builtins", + "functools", + "itertools", + "os", + "sys", +) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) for submodule in POLYFILLED_MODULE_NAMES diff --git a/torch/_dynamo/polyfills/os.py b/torch/_dynamo/polyfills/os.py new file mode 100644 index 00000000000000..013cc81ac30e6f --- /dev/null +++ b/torch/_dynamo/polyfills/os.py @@ -0,0 +1,36 @@ +""" +Python polyfills for os +""" + +from __future__ import annotations + +import os +from typing import AnyStr + +from ..decorators import substitute_in_graph + + +__all__ = ["fspath"] + + +# Copied from os.py in the standard library +@substitute_in_graph(os.fspath) +def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: + if isinstance(path, (str, bytes)): + return path + + path_type = type(path) + try: + path_repr = path_type.__fspath__(path) # type: ignore[arg-type] + except AttributeError: + if hasattr(path_type, "__fspath__"): + raise + raise TypeError( + f"expected str, bytes or os.PathLike object, not {path_type.__name__}", + ) from None + if isinstance(path_repr, (str, bytes)): + return path_repr # type: ignore[return-value] + raise TypeError( + f"expected {path_type.__name__}.__fspath__() to return str or bytes, " + f"not {type(path_repr).__name__}", + ) diff --git a/torch/_dynamo/polyfills/sys.py b/torch/_dynamo/polyfills/sys.py new file mode 100644 index 00000000000000..e9479eda8a0862 --- /dev/null +++ b/torch/_dynamo/polyfills/sys.py @@ -0,0 +1,6 @@ +""" +Python polyfills for sys +""" + + +__all__ = [] # type: ignore[var-annotated] diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 7a81bd37094a53..2979bd58e312fe 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -4,7 +4,6 @@ import functools import inspect import itertools -import os import random import re import sys @@ -15,7 +14,7 @@ import torch._numpy as tnp import torch.utils._pytree as pytree -from .. import config, polyfills, variables +from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented @@ -30,7 +29,6 @@ ) from ..utils import ( check_unspec_or_constant_args, - hashable, identity, is_tensor_base_attr_getter, proxy_args_kwargs, @@ -49,10 +47,6 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator -POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS = { - os.fspath: polyfills.fspath, -} - class NO_SUCH_SUBOBJ: pass @@ -1167,12 +1161,6 @@ def var_getattr(self, tx: "InstructionTranslator", name): else: attr_value = self.value.__dict__[name] - if hashable(attr_value): - if polyfill_fn := POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS.get( - attr_value, None - ): - return variables.UserFunctionVariable(polyfill_fn) - if self.source: new_source = AttrSource(self.source, name) return VariableBuilder(tx, new_source)(attr_value) From 77c10b920932b9b84c4c3839393bbf61adb2ef02 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 29 Aug 2024 18:11:28 +0800 Subject: [PATCH 0222/1018] [dynamo][itertools] support `itertools.tee` (#133771) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133771 Approved by: https://github.com/jansel ghstack dependencies: #133801 --- test/dynamo/test_misc.py | 16 ++++++++++++++ torch/_dynamo/polyfills/itertools.py | 32 +++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 74dcc2507f8888..da258c94308837 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10061,6 +10061,22 @@ def fn(l): self.assertEqual(eager, compiled) self.assertEqual(len(counters["graph_break"]), 0) + def test_itertools_tee(self): + counters.clear() + + def fn(l): + a, b = itertools.tee(l) + return list(a), list(b) + + l = [1, 2, 2, 3, 4, 4, 4, 1, 2] + eager = fn(l) + + compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) + compiled = compiled_fn(l) + + self.assertEqual(eager, compiled) + self.assertEqual(len(counters["graph_break"]), 0) + def test_list_iterator_contains(self): def fn(x): it = iter(["my_weight", "not_my_weight"]) diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index dae625497e42f4..090df0f84b5281 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -2,5 +2,35 @@ Python polyfills for itertools """ +from __future__ import annotations -__all__ = [] # type: ignore[var-annotated] +import itertools +from typing import Iterable, Iterator, TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = ["tee"] + + +_T = TypeVar("_T") + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee +@substitute_in_graph(itertools.tee) +def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: + iterator = iter(iterable) + shared_link = [None, None] + + def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def] + try: + while True: + if link[1] is None: + link[0] = next(iterator) + link[1] = [None, None] + value, link = link + yield value + except StopIteration: + return + + return tuple(_tee(shared_link) for _ in range(n)) From 877e99aa4178503b82aaef3c8bf6341c2e2de8e2 Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 28 Aug 2024 11:21:49 -0700 Subject: [PATCH 0223/1018] Never CSE aten.empty in the partitioner (#134703) aten.empty is almost always fusible into its consumer, so we never CSE it. This fixes a bug that looks like the following: ```py @torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}) def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None: out_sin.copy_(x.sin()) out_cos.copy_(x.cos()) @torch.compile def f(x): out0 = torch.empty_like(x) out1 = torch.empty_like(x) sin_cos(x, out0, out1) return x.clone(), out0, out1 x = torch.randn(3, requires_grad=True) f(x) ``` - cse would de-duplicate the empty nodes - reinplacing would add an additional clone (because it can't write to both tensors at the same time) - the clone lowers into a new buffer + a copy_ kernel - the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional buffer, it doesn't matter what the values are. We could attempt to fix this on the reinplacing side but this seemed better as a partitioner heuristic and the reinplacing fix is a bit more tricky (we'd need to identify that the op never reads from the empty node). Test Plan: - new test (the old number was 27, the new number is 21, so this PR helped). Pull Request resolved: https://github.com/pytorch/pytorch/pull/134703 Approved by: https://github.com/yf225 ghstack dependencies: #134466, #134490, #134491 --- test/inductor/test_perf.py | 20 ++++++++++++++++++++ torch/_functorch/compile_utils.py | 4 ++++ 2 files changed, 24 insertions(+) diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 14694e54161d10..302e246de62b19 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -905,6 +905,26 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel_train(f, x), """9""") + @requires_cuda + def test_inplace_custom_op_training_two_mutated_inputs(self): + @torch.library.custom_op( + "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"} + ) + def sin_cos( + x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor + ) -> None: + out_sin.copy_(x.sin()) + out_cos.copy_(x.cos()) + + def f(x): + out0 = torch.empty_like(x) + out1 = torch.empty_like(x) + sin_cos(x, out0, out1) + return x.clone(), out0, out1 + + x = T(3, grad=True) + self.assertExpectedInline(count_numel(f, x), """21""") + @requires_cuda def test_inplace_custom_op_training(self): @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index 77aeeeb2b7f1f1..ac7e0d5fdffd34 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -58,6 +58,10 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph): or n.op == "output" or n.op == "get_attr" or get_aten_target(n) in rand_ops + # aten.empty is non-deterministic, so don't CSE it. + # Also, aten.empty is almost always fusible into its consumer, + # so it's not worth CSEing. + or get_aten_target(n) is aten.empty ): new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node From 44c041e228f571c90861a18ac88939d0e41540de Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Thu, 29 Aug 2024 14:53:30 +0000 Subject: [PATCH 0224/1018] preserve aten::to device in export training (#134622) Summary: With training IR, we cannot rely on trapping `to()` in `FunctionalTensor` because the regular decomposition kicks it first, and that can cause it to be optimized away. So instead we preserve it until we functionalize, and then replace it explicitly with `_to_copy()`. Test Plan: expected test failures go away Differential Revision: D61883878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134622 Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan --- test/export/test_export.py | 10 ------- torch/_subclasses/functional_tensor.py | 17 ++++++++++++ torch/export/exported_program.py | 38 +++++++++++++++++++++----- 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index ec6bbdc93d6c44..a109e324d3ab2b 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3442,8 +3442,6 @@ def _patch_config(kwargs): ): _ = export(mod, inp, strict=True) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_device_to_static(self): class Module(torch.nn.Module): def forward(self, x): @@ -3458,8 +3456,6 @@ def forward(self, x): for op in ops: self.assertIn(op, (torch.ops.aten._to_copy.default,)) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_device_to_dynamic(self): class Module(torch.nn.Module): def forward(self, x): @@ -3478,8 +3474,6 @@ def forward(self, x): for op in ops: self.assertIn(op, (torch.ops.aten._to_copy.default,)) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_device_to_mutation(self): class Module(torch.nn.Module): def forward(self, x): @@ -3492,8 +3486,6 @@ def forward(self, x): ): export(Module(), (torch.tensor(1, device="cpu"),)) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_float_conversion(self): class Module(torch.nn.Module): def forward(self, x): @@ -3508,8 +3500,6 @@ def forward(self, x): for op in ops: self.assertIn(op, (torch.ops.aten._to_copy.default,)) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_device_to_mutation_float(self): class Module(torch.nn.Module): def forward(self, x): diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index dd7cba26b1cb2e..1fb64fbcf7bad9 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -342,6 +342,23 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + if self.export: + # We need to make sure that we don't decompose to() as usual in export mode, + # because it can get optimized away. Instead we always replace it with _to_copy(). + if func == torch.ops.aten.to.dtype_layout: + kwargs.pop("copy", None) + return self.__torch_dispatch__( + torch.ops.aten._to_copy.default, types, args, kwargs + ) + if func == torch.ops.aten.to.dtype: + schema = tuple(arg.name for arg in func._schema.arguments) + for arg, name in zip(args[1:], schema[1:]): + kwargs[name] = arg + kwargs.pop("copy", None) + return self.__torch_dispatch__( + torch.ops.aten._to_copy.default, types, args[:1], kwargs + ) + unrecognized_types = [ t for t in types diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index d90ff9580f2dec..447c12e864d5fb 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -157,7 +157,7 @@ def _register_cia_to_meta(*args, **kwargs): @contextmanager -def _override_composite_implicit_decomp(ops_to_preserve, decomp_table): +def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True): # This function overrides CompositeImplicitAutograd decomp for # functional composite ops that user specified. Ideally we want to not-decompose # ALL composite ops but today's C++ functinalization relies on @@ -165,6 +165,14 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table): # Hence we can only do it for functional ops. One caveat is that # there are some composite ops that lie about their schema (claimed to be # functional but not really aka dropout), for these cases, we just decompose. + + # When safe=False, we will assume that ops_to_preserve can be mutating/aliasing + # and their usual decompositions need to be shadowed rather than overridden. + # Thus we will avoid asserting that they are valid to preserve, and will not + # replace their CompositeImplicitAutograd kernels with NotImplemented. + # The only current users of this mode are variants of aten::to that we will + # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__. + saved_tables = {} patched_ops = set() removed_decomps = {} @@ -205,8 +213,9 @@ def assert_valid_to_preserve(op_overload): return True - # If we didn't error, it means we can go ahead - assert_valid_to_preserve(op_overload) + if safe: + # If we didn't error, it means we can go ahead + assert_valid_to_preserve(op_overload) saved_tables[op_overload] = op_overload.py_kernels.copy() patched_ops.add(op_overload) @@ -220,10 +229,12 @@ def assert_valid_to_preserve(op_overload): if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] - def _(*args, **kwargs): - return NotImplemented + if safe: - op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_) + def _(*args, **kwargs): + return NotImplemented + + op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_) # For fake tensor prop, we do want to register meta kernel directly if torch._C.DispatchKey.Meta not in op_overload.py_kernels: @@ -247,6 +258,19 @@ def _(*args, **kwargs): decomp_table[op] = decomp +@contextmanager +def _override_decomp_aten_to_variants(): + # Preserve variants of aten::to understanding that they are mutating/aliasing + # and their CompositeImplicitAutograd kernels will not become NotImplemented. + # We will later replace them with aten._to_copy when functionalizing. + with _override_composite_implicit_decomp( + (torch.ops.aten.to.dtype_layout, torch.ops.aten.to.dtype), + {}, + safe=False, + ): + yield + + def _rename_without_collisions( name_map: Dict[str, str], orig_name: str, @@ -413,7 +437,7 @@ def _decompose_and_get_gm_with_new_signature_constants( entry ] - with fake_mode: + with fake_mode, _override_decomp_aten_to_variants(): fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec) aten_export_artifact = _export_to_aten_ir( mod, From 977cf0263c7de6d3961d9df7d8fb0aa94f907f20 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 28 Aug 2024 22:04:48 -0700 Subject: [PATCH 0225/1018] [dynamo] Support reading attributes from pybind objects (#134630) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134630 Approved by: https://github.com/jansel --- build_variables.bzl | 1 + test/dynamo/test_functions.py | 16 +++++++++ .../TestScript.test_is_after_use | 0 torch/_dynamo/variables/user_defined.py | 19 +++++++---- torch/csrc/dynamo/init.cpp | 6 ++++ torch/csrc/dynamo/utils.cpp | 33 +++++++++++++++++++ torch/csrc/dynamo/utils.h | 9 +++++ 7 files changed, 78 insertions(+), 6 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestScript.test_is_after_use create mode 100644 torch/csrc/dynamo/utils.cpp diff --git a/build_variables.bzl b/build_variables.bzl index aa6728eb00b883..e19bb7afe84576 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -839,6 +839,7 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/extra_state.cpp", "torch/csrc/dynamo/framelocals_mapping.cpp", "torch/csrc/dynamo/guards.cpp", + "torch/csrc/dynamo/utils.cpp", "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", "torch/csrc/fx/node.cpp", diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index f5cbfb8760c307..84b3a923c1484e 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -3602,6 +3602,22 @@ def foo_default_str(x): dynamo_class_name = dynamo_default_str[1].split(" object at")[0] self.assertEqual(eager_class_name, dynamo_class_name) + def test_pybind_object(self): + def fn(x, pybind_obj): + if pybind_obj.result: + return torch.cos(x) + return torch.sin(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0) + x = torch.randn(4) + self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj)) + + pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1) + x = torch.randn(4) + self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj)) + instantiate_parametrized_tests(FunctionTests) diff --git a/test/dynamo_expected_failures/TestScript.test_is_after_use b/test/dynamo_expected_failures/TestScript.test_is_after_use deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index a254bf97cf1c61..362d81d82aea04 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -899,6 +899,18 @@ def _check_for_getattribute(self): def _check_for_getattr(self): return get_custom_getattr(self.value) + def _is_c_defined_property(self, subobj): + if not isinstance(subobj, property): + return False + + # pybind def_readwrite is implemented via PyCFunction. At the python level, it is visible as a property whose + # fget is an instancemethod wrapper - https://docs.python.org/3/c-api/method.html#c.PyInstanceMethod_Check + + # If we have a PyCFunction, we make an assumption that there is no side effect. + return isinstance( + subobj.fget, types.BuiltinFunctionType + ) or torch._C._dynamo.utils.is_instancemethod(subobj.fget) + def _getattr_static(self, name): subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ) import _collections @@ -913,12 +925,7 @@ def _getattr_static(self, name): or ( inspect.ismemberdescriptor(subobj) and name in self.value.__slots__ ) # handle memberdecriptor and slots - or ( - isinstance(subobj, property) - and isinstance( - subobj.fget, types.BuiltinFunctionType - ) # property with C-defined fget - ) + or self._is_c_defined_property(subobj) ): # Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't # want to call getattr because it can be user-overridden. diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index bf20bfe724f4d0..3d60a75b4561d1 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -44,6 +45,11 @@ void initDynamoBindings(PyObject* torch) { throw python_error(); } + PyObject* utils = torch_c_dynamo_utils_init(); + if (utils == nullptr || PyModule_AddObject(dynamo, "utils", utils) != 0) { + throw python_error(); + } + PyObject* guards = torch_c_dynamo_guards_init(); if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) { throw python_error(); diff --git a/torch/csrc/dynamo/utils.cpp b/torch/csrc/dynamo/utils.cpp new file mode 100644 index 00000000000000..662b16bfb567e4 --- /dev/null +++ b/torch/csrc/dynamo/utils.cpp @@ -0,0 +1,33 @@ +#include + +namespace torch::dynamo { + +static std::array _methods = {{ + {nullptr, + nullptr, + 0, + nullptr} // Sentinel value indicating the end of the array +}}; + +bool is_instancemethod(py::object obj) { + return PyInstanceMethod_Check(obj.ptr()); +} + +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + "torch._C._dynamo.utils", + "Module containing C utils", + -1, + _methods.data()}; + +PyObject* torch_c_dynamo_utils_init() { + auto m = PyModule_Create(&_module); + if (m == nullptr) + return nullptr; + + auto py_m = py::handle(m).cast(); + py_m.def("is_instancemethod", is_instancemethod); + return m; +} + +} // namespace torch::dynamo diff --git a/torch/csrc/dynamo/utils.h b/torch/csrc/dynamo/utils.h index 3fd932f0441fe0..c7bcfb5b5e643f 100644 --- a/torch/csrc/dynamo/utils.h +++ b/torch/csrc/dynamo/utils.h @@ -1,5 +1,10 @@ #pragma once +#include +// C2039 MSVC +#include +#include +#include // The visibility attribute is to avoid a warning about storing a field in the // struct that has a different visibility (from pybind) than the struct. #ifdef _WIN32 @@ -7,3 +12,7 @@ #else #define VISIBILITY_HIDDEN __attribute__((visibility("hidden"))) #endif + +namespace torch::dynamo { +PyObject* torch_c_dynamo_utils_init(); +} // namespace torch::dynamo From 7fa82004a093f7bb8df58dce350fdd6490eaca36 Mon Sep 17 00:00:00 2001 From: Stonepia Date: Thu, 29 Aug 2024 15:12:38 +0000 Subject: [PATCH 0226/1018] Fix test_sgd_weight_decay_xpu accuracy error (#134744) Fixes #134743 This PR adds `test_sgd_weight_decay_xpu` in `KERNEL_COUNT_OVERRIDES` to override. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134744 Approved by: https://github.com/EikanWang, https://github.com/desertfire --- test/inductor/test_compiled_optimizers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 0fcba0d0db56e0..2f54dff94cee8a 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -148,6 +148,7 @@ class KernelCounts(NamedTuple): "test_sgd_weight_decay_maximize_cpu": 4, "test_sgd_weight_decay_cpu": 4, "test_sgd_weight_decay_cuda": 4, + "test_sgd_weight_decay_xpu": 4, "test_sgd_momentum_weight_decay_foreach_cuda": 2, "test_sgd_momentum_weight_decay_foreach_xpu": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, From e839defc4d61e080186c33723de476d469904953 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Thu, 29 Aug 2024 16:04:23 +0000 Subject: [PATCH 0227/1018] [CPU][flash attention] make the stride of output align with input (#134656) Fixes #133671 Currently, the output of CPU flash attention has a fixed layout, no matter what the input is. This PR makes the stride of output align with input q/k/v, which is the same behavior as math backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134656 Approved by: https://github.com/jgong5, https://github.com/drisspg --- .../ATen/native/transformers/attention.cpp | 3 +- test/test_transformers.py | 39 ++++++++++++++++++- torch/_meta_registrations.py | 6 +-- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 08a6eb04cee568..cab9dbd520d61b 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -870,7 +870,6 @@ _scaled_dot_product_flash_attention_cpu( int64_t batchSize = query.size(0); int64_t qSize = query.size(2); int64_t num_head = query.size(1); - int64_t headSize = query.size(3); TORCH_CHECK(c10::isFloatingType(dtype), "scaled_dot_product_attention_flash_attention: Expected data type in FP32, FP64, BF16, FP16, but got ", dtype, " instead."); @@ -888,7 +887,7 @@ _scaled_dot_product_flash_attention_cpu( (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4), "scaled_dot_product_attention_flash_attention: Attention mask dim in {2, 4}"); - at::Tensor output = at::empty({batchSize, qSize, num_head, headSize}, query.options()); + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); const auto accumulate_dtype = toOpMathType(dtype); at::Tensor logsumexp = at::empty({batchSize, qSize, num_head}, query.options().dtype(accumulate_dtype)); diff --git a/test/test_transformers.py b/test/test_transformers.py index f83f4fdce1eba4..5ccce245bf965e 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2077,7 +2077,7 @@ def test_scaled_dot_product_fused_attention_mask_vs_math_cpu( self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol) self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol) - def test_scaled_dot_product_fused_attention_with_inf(self, device): + def test_sdpa_with_inf(self, device): # https://github.com/pytorch/pytorch/issues/127055. full = torch.full((600, 600), float("-inf"), device=device) mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10) @@ -2092,6 +2092,43 @@ def test_scaled_dot_product_fused_attention_with_inf(self, device): actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) self.assertEqual(math_ref, actual) + def test_sdpa_backward_with_gradient(self, device): + # https://github.com/pytorch/pytorch/issues/133671. + def sdpa_helper(): + torch.manual_seed(777) + query = ( + torch.empty(size=[2, 2, 49, 32], dtype=torch.float32, device=device) + .uniform_(-1, 1) + .requires_grad_(True) + ) + key = ( + torch.empty(size=[2, 2, 49, 32], dtype=torch.float32, device=device) + .uniform_(-1, 1) + .requires_grad_(True) + ) + value = ( + torch.empty(size=[2, 2, 49, 32], dtype=torch.float32, device=device) + .uniform_(-1, 1) + .requires_grad_(True) + ) + res = torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, 0.0, False + ) + res_grad = ( + torch.empty_like(res, device=device) + .uniform_(-1, 1) + ) + res.backward(res_grad, retain_graph=True) + return res, query.grad, key.grad, value.grad + with sdpa_kernel(backends=[SDPBackend.MATH]): + res_ref, query_grad_ref, key_grad_ref, value_grad_ref = sdpa_helper() + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + res_actual, query_grad_actual, key_grad_actual, value_grad_actual = sdpa_helper() + self.assertEqual(res_ref, res_actual) + self.assertEqual(query_grad_ref, query_grad_actual) + self.assertEqual(key_grad_ref, key_grad_actual) + self.assertEqual(value_grad_ref, value_grad_actual) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("backend", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]) @parametrize("seq_len", [32, 64, 128]) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 022e8010265ba0..61e01fe3b1d11e 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5188,11 +5188,7 @@ def meta__scaled_dot_product_flash_attention_for_cpu( max_seqlen_batch_q = query.size(2) head_dim = query.size(3) - attention = torch.empty( - (batch_size, max_seqlen_batch_q, num_heads, head_dim), - dtype=query.dtype, - device=query.device, - ).transpose(1, 2) + attention = torch.empty_like(query) logsumexp = torch.empty( ( batch_size, From 5d527184f36f63800a532a74b9f2efc8af10ba4c Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 28 Aug 2024 04:46:50 -0700 Subject: [PATCH 0228/1018] [AOTI] Switch benchmarking to use export non-strict mode (#130977) Summary: Switch the export part used by AOTInductor benchmarking from strict to non-strict, and switch it from producing torch IR to aten IR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130977 Approved by: https://github.com/angelayi ghstack dependencies: #134639 --- .../aot_eager_torchbench_training.csv | 2 +- .../aot_inductor_huggingface_inference.csv | 4 - .../aot_inductor_timm_inference.csv | 2 +- .../aot_inductor_torchbench_inference.csv | 6 +- ...ctor_amp_freezing_torchbench_inference.csv | 62 +------------- ...nductor_freezing_huggingface_inference.csv | 8 -- ...inductor_freezing_torchbench_inference.csv | 62 +------------- ...tor_amp_freezing_huggingface_inference.csv | 4 - ...ctor_amp_freezing_torchbench_inference.csv | 28 ------- ...nductor_freezing_huggingface_inference.csv | 4 - ...inductor_freezing_torchbench_inference.csv | 18 +---- .../cpu_inductor_huggingface_inference.csv | 4 - .../dynamic_aot_eager_torchbench_training.csv | 2 +- ...ctor_amp_freezing_torchbench_inference.csv | 80 +------------------ ...inductor_freezing_torchbench_inference.csv | 80 +------------------ ...mic_cpu_inductor_huggingface_inference.csv | 4 - .../ci_expected_accuracy/update_expected.py | 8 ++ benchmarks/dynamo/common.py | 9 +-- 18 files changed, 33 insertions(+), 354 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index f1b119c0284058..05d52bedb3e046 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -282,4 +282,4 @@ vision_maskrcnn,pass,33 -yolov3,fail_accuracy,8 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv index 784d3788e33554..1cafcbe55675d3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv @@ -10,10 +10,6 @@ AlbertForQuestionAnswering,pass,0 -AllenaiLongformerBase,fail_to_run,0 - - - BartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv index c7e86a6d317eb5..dc36107e0d0292 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv @@ -114,7 +114,7 @@ lcnet_050,pass,0 -levit_128,pass,0 +levit_128,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 6cbf71ffeded09..fe3c67bba120b3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -6,7 +6,7 @@ torchrec_dlrm,eager_fail_to_run,0 -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -282,7 +282,7 @@ sam,fail_to_run,0 -sam_fast,fail_to_run,0 +sam_fast,timeout,0 @@ -346,4 +346,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv index fae39359c6148f..e3772d03264f73 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv @@ -2,11 +2,7 @@ name,accuracy,graph_breaks -torchrec_dlrm,eager_fail_to_run,0 - - - -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -14,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,0 - - - LearningToPaint,pass,0 @@ -138,10 +130,6 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 - - - hf_DistilBert,pass,0 @@ -154,10 +142,6 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_T5,pass,0 - - - hf_T5_base,pass,0 @@ -166,10 +150,6 @@ hf_T5_large,pass_due_to_skip,0 -hf_Whisper,pass,0 - - - hf_distil_whisper,pass,0 @@ -182,14 +162,6 @@ llama,fail_to_run,0 -llama_v2_7b_16h,model_fail_to_load,0 - - - -llava,model_fail_to_load,0 - - - maml,pass_due_to_skip,0 @@ -198,10 +170,6 @@ maml_omniglot,pass,0 -mnasnet1_0,pass,0 - - - mobilenet_v2,pass,0 @@ -214,18 +182,10 @@ mobilenet_v3_large,pass,0 -moco,fail_to_run,0 - - - moondream,pass,0 -nanogpt,pass,0 - - - nvidia_deeprecommender,pass,0 @@ -282,18 +242,6 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 - - - -sam_fast,fail_to_run,0 - - - -shufflenet_v2_x1_0,pass,0 - - - soft_actor_critic,fail_to_run,0 @@ -302,10 +250,6 @@ squeezenet1_1,pass,0 -stable_diffusion_text_encoder,pass,0 - - - stable_diffusion_unet,pass_due_to_skip,0 @@ -346,7 +290,7 @@ torch_multimodal_clip,pass,0 -tts_angular,fail_to_run,0 +tts_angular,pass,0 @@ -358,4 +302,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv index 784d3788e33554..8d3646d6533a40 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv @@ -10,10 +10,6 @@ AlbertForQuestionAnswering,pass,0 -AllenaiLongformerBase,fail_to_run,0 - - - BartForCausalLM,pass,0 @@ -130,10 +126,6 @@ MobileBertForQuestionAnswering,pass,0 -OPTForCausalLM,pass,0 - - - PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv index 5e5e0a0d507ff5..36064e7361d0eb 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv @@ -2,11 +2,7 @@ name,accuracy,graph_breaks -torchrec_dlrm,eager_fail_to_run,0 - - - -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -14,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,0 - - - LearningToPaint,pass,0 @@ -138,10 +130,6 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 - - - hf_DistilBert,pass,0 @@ -154,10 +142,6 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_T5,pass,0 - - - hf_T5_base,pass,0 @@ -166,10 +150,6 @@ hf_T5_large,pass_due_to_skip,0 -hf_Whisper,pass,0 - - - hf_distil_whisper,pass,0 @@ -182,14 +162,6 @@ llama,fail_to_run,0 -llama_v2_7b_16h,model_fail_to_load,0 - - - -llava,model_fail_to_load,0 - - - maml,pass_due_to_skip,0 @@ -198,10 +170,6 @@ maml_omniglot,pass,0 -mnasnet1_0,pass,0 - - - mobilenet_v2,pass,0 @@ -214,18 +182,10 @@ mobilenet_v3_large,pass,0 -moco,fail_to_run,0 - - - moondream,pass,0 -nanogpt,pass,0 - - - nvidia_deeprecommender,pass,0 @@ -282,18 +242,6 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 - - - -sam_fast,fail_to_run,0 - - - -shufflenet_v2_x1_0,pass,0 - - - soft_actor_critic,fail_to_run,0 @@ -302,10 +250,6 @@ squeezenet1_1,pass,0 -stable_diffusion_text_encoder,pass,0 - - - stable_diffusion_unet,pass_due_to_skip,0 @@ -346,7 +290,7 @@ torch_multimodal_clip,pass,0 -tts_angular,fail_to_run,0 +tts_angular,pass,0 @@ -358,4 +302,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv index 349239b058a7b7..cfc3baf0d2ef9d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv @@ -130,10 +130,6 @@ MobileBertForQuestionAnswering,pass,0 -OPTForCausalLM,pass,0 - - - PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index c900d1768921b6..f72a083f1013b4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -10,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,model_fail_to_load,0 - - - LearningToPaint,pass,0 @@ -142,10 +138,6 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,61 - - - hf_DistilBert,pass,0 @@ -158,18 +150,10 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Longformer,pass,4 - - - hf_Reformer,pass,5 -hf_T5,pass,0 - - - hf_T5_base,pass,0 @@ -198,10 +182,6 @@ maml_omniglot,pass,0 -mnasnet1_0,pass,0 - - - mobilenet_v2,pass,0 @@ -214,10 +194,6 @@ mobilenet_v3_large,pass,0 -moco,model_fail_to_load,0 - - - moondream,pass,0 @@ -282,10 +258,6 @@ resnext50_32x4d,pass,0 -shufflenet_v2_x1_0,pass,0 - - - soft_actor_critic,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv index 349239b058a7b7..cfc3baf0d2ef9d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv @@ -130,10 +130,6 @@ MobileBertForQuestionAnswering,pass,0 -OPTForCausalLM,pass,0 - - - PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv index 10f3904d678b72..8197f6cb65c695 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv @@ -10,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,model_fail_to_load,0 - - - LearningToPaint,pass,0 @@ -82,7 +78,7 @@ detectron2_fcos_r_50_fpn,pass,23 -detectron2_maskrcnn_r_101_c4,fail_accuracy,57 +detectron2_maskrcnn_r_101_c4,pass,57 @@ -186,10 +182,6 @@ maml_omniglot,pass,0 -mnasnet1_0,pass,0 - - - mobilenet_v2,pass,0 @@ -202,10 +194,6 @@ mobilenet_v3_large,pass,0 -moco,model_fail_to_load,0 - - - moondream,pass,0 @@ -270,10 +258,6 @@ resnext50_32x4d,pass,0 -shufflenet_v2_x1_0,pass,0 - - - soft_actor_critic,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv index 349239b058a7b7..cfc3baf0d2ef9d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv @@ -130,10 +130,6 @@ MobileBertForQuestionAnswering,pass,0 -OPTForCausalLM,pass,0 - - - PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 38ff07764e08d2..40011a6000d61e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -278,4 +278,4 @@ vision_maskrcnn,pass,33 -yolov3,fail_accuracy,8 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv index 4abe5ae064a93a..21cee0dc41ca15 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv @@ -2,11 +2,7 @@ name,accuracy,graph_breaks -torchrec_dlrm,eager_fail_to_run,0 - - - -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -14,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,0 - - - LearningToPaint,pass,0 @@ -82,22 +74,6 @@ detectron2_fasterrcnn_r_50_fpn,fail_to_run,0 -detectron2_maskrcnn_r_101_c4,fail_to_run,0 - - - -detectron2_maskrcnn_r_101_fpn,fail_to_run,0 - - - -detectron2_maskrcnn_r_50_c4,fail_to_run,0 - - - -detectron2_maskrcnn_r_50_fpn,fail_to_run,0 - - - dlrm,fail_to_run,0 @@ -138,10 +114,6 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 - - - hf_DistilBert,pass,0 @@ -154,10 +126,6 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_T5,pass,0 - - - hf_T5_base,pass,0 @@ -166,10 +134,6 @@ hf_T5_large,pass_due_to_skip,0 -hf_Whisper,pass,0 - - - hf_distil_whisper,pass,0 @@ -182,14 +146,6 @@ llama,fail_to_run,0 -llama_v2_7b_16h,model_fail_to_load,0 - - - -llava,model_fail_to_load,0 - - - maml,pass_due_to_skip,0 @@ -198,10 +154,6 @@ maml_omniglot,pass,0 -mnasnet1_0,pass,0 - - - mobilenet_v2,pass,0 @@ -214,18 +166,10 @@ mobilenet_v3_large,pass,0 -moco,fail_to_run,0 - - - moondream,pass,0 -nanogpt,pass,0 - - - nvidia_deeprecommender,pass,0 @@ -282,18 +226,6 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 - - - -sam_fast,fail_to_run,0 - - - -shufflenet_v2_x1_0,pass,0 - - - soft_actor_critic,fail_to_run,0 @@ -302,10 +234,6 @@ squeezenet1_1,pass,0 -stable_diffusion_text_encoder,pass,0 - - - stable_diffusion_unet,pass_due_to_skip,0 @@ -342,11 +270,11 @@ timm_vovnet,pass,0 -torch_multimodal_clip,fail_to_run,0 +torch_multimodal_clip,pass,0 -tts_angular,fail_to_run,0 +tts_angular,pass,0 @@ -358,4 +286,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv index 4360e2858cce24..fb0c5ec473135a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv @@ -2,11 +2,7 @@ name,accuracy,graph_breaks -torchrec_dlrm,eager_fail_to_run,0 - - - -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -14,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,0 - - - LearningToPaint,pass,0 @@ -82,22 +74,6 @@ detectron2_fasterrcnn_r_50_fpn,fail_to_run,0 -detectron2_maskrcnn_r_101_c4,fail_to_run,0 - - - -detectron2_maskrcnn_r_101_fpn,fail_to_run,0 - - - -detectron2_maskrcnn_r_50_c4,fail_to_run,0 - - - -detectron2_maskrcnn_r_50_fpn,fail_to_run,0 - - - dlrm,fail_to_run,0 @@ -138,10 +114,6 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 - - - hf_DistilBert,pass,0 @@ -154,10 +126,6 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_T5,pass,0 - - - hf_T5_base,pass,0 @@ -166,10 +134,6 @@ hf_T5_large,pass_due_to_skip,0 -hf_Whisper,pass,0 - - - hf_distil_whisper,pass,0 @@ -182,14 +146,6 @@ llama,fail_to_run,0 -llama_v2_7b_16h,model_fail_to_load,0 - - - -llava,model_fail_to_load,0 - - - maml,pass_due_to_skip,0 @@ -198,10 +154,6 @@ maml_omniglot,pass,0 -mnasnet1_0,pass,0 - - - mobilenet_v2,pass,0 @@ -214,18 +166,10 @@ mobilenet_v3_large,pass,0 -moco,fail_to_run,0 - - - moondream,pass,0 -nanogpt,pass,0 - - - nvidia_deeprecommender,pass,0 @@ -282,18 +226,6 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 - - - -sam_fast,fail_to_run,0 - - - -shufflenet_v2_x1_0,pass,0 - - - soft_actor_critic,fail_to_run,0 @@ -302,10 +234,6 @@ squeezenet1_1,pass,0 -stable_diffusion_text_encoder,pass,0 - - - stable_diffusion_unet,pass_due_to_skip,0 @@ -342,11 +270,11 @@ timm_vovnet,pass,0 -torch_multimodal_clip,fail_to_run,0 +torch_multimodal_clip,pass,0 -tts_angular,fail_to_run,0 +tts_angular,pass,0 @@ -358,4 +286,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv index 349239b058a7b7..cfc3baf0d2ef9d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv @@ -130,10 +130,6 @@ MobileBertForQuestionAnswering,pass,0 -OPTForCausalLM,pass,0 - - - PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py index 29d204c02e218d..c6bd92ae8ed015 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py +++ b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py @@ -152,8 +152,16 @@ def apply_lints(filename): [ "aot_eager", "aot_inductor", + "cpu_aot_inductor", + "cpu_aot_inductor_amp_freezing", + "cpu_aot_inductor_freezing", "cpu_inductor", + "cpu_inductor_amp_freezing", + "cpu_inductor_freezing", "dynamic_aot_eager", + "dynamic_cpu_aot_inductor", + "dynamic_cpu_aot_inductor_amp_freezing", + "dynamic_cpu_aot_inductor_freezing", "dynamic_cpu_inductor", "dynamic_inductor", "dynamo_eager", diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 6457e1b0814d1a..bea9fc76e226a7 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1364,14 +1364,13 @@ def load(cls, model, example_inputs, device): else: _register_dataclass_output_as_pytree(example_outputs) - # TODO(angelayi): change this to predispatch - # https://github.com/pytorch/pytorch/issues/127513 needs to be fixed before changing - # to predispatch to avoid performance regressions - gm = torch.export._trace._export_to_torch_ir( + gm = torch.export._trace._export( model, example_args, example_kwargs, - ) + pre_dispatch=True, + strict=False, + ).module() with torch.no_grad(): so_path = torch._inductor.aot_compile( gm, example_args, example_kwargs From 49b06ab0582dc7888a6d81aa9a85acf755be13cf Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 29 Aug 2024 09:09:51 -0700 Subject: [PATCH 0229/1018] [2/N] Add flag to control which rank should perform NaN check (#134345) Fixes https://github.com/pytorch/pytorch/issues/134062. For example, in case of broadcast / scatter, only the root rank should perform the NaN check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134345 Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 22 +++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 39 +++++++++++++------ .../distributed/c10d/ProcessGroupNCCL.hpp | 6 ++- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 64c51f791fddd0..55b785a3552b5a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -366,6 +366,28 @@ def test_nan_assert(self, type): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_nan_p2p(self): + # Putting NaN at recv buffer, program should not fail as NaN checker + # should not check on receive buffer + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + if self.rank == 0: + c10d.send(t, 1) + elif self.rank == 1: + # Putting NaN at recv buffer + t[1, 1] = float("nan") + c10d.recv(t, 0) + c10d.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 310b351eb219cc..65e9fea044932d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2638,9 +2638,12 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PostProcess post, OpType opType, const char* profilingTitle, - bool avoidRecordStreams) { + bool avoidRecordStreams, + bool nanCheck) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; + nanCheck &= enableNanCheck_; + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2694,7 +2697,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { + if (nanCheck) { checkForNan(input, ncclStream); } @@ -3127,7 +3130,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { + // Only check for NaN for send ops, for recv ops `tensor` can be a random + // placeholder + if (enableNanCheck_ && opType == OpType::SEND) { checkForNan(tensor, ncclStream); } @@ -3224,7 +3229,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( Fn fn, OpType opType, const char* profilingTitle, - bool avoidRecordStreams) { + bool avoidRecordStreams, + bool nanCheck) { return collective( input, output, @@ -3235,7 +3241,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( c10::intrusive_ptr& work) {}, opType, profilingTitle, - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } template @@ -3485,6 +3492,9 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( // avoidRecordStreams_ note: collective() will stash tensors. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( tensor, tensor, @@ -3492,7 +3502,6 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; return ncclBcast( input.data_ptr(), input.numel(), @@ -3503,7 +3512,8 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( }, OpType::BROADCAST, "nccl:broadcast", - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } // _broadcast_oop adds an out-of-place broadcast in PGNCCL @@ -3523,6 +3533,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( "Tensor input and output of _broadcast_oop must have the same number of elements "); } + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( inputTensor, outputTensor, @@ -3530,7 +3543,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; return ncclBroadcast( input.data_ptr(), output.data_ptr(), @@ -3541,7 +3553,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, - "nccl:_broadcast_oop"); + "nccl:_broadcast_oop", + /*avoidRecordStreams=*/false, + nanCheck); } c10::intrusive_ptr ProcessGroupNCCL::reduce( @@ -4492,6 +4506,9 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // inputs, which == inputTensors[0] on the root rank where it matters. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank; + bool nanCheck = (rank_ == root); + return collective( outputTensor, inputs[0], // just to fit the collective interface @@ -4499,7 +4516,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank; if (getRank() == root) { if (!avoidRecordStreams) { for (auto input : inputs) { @@ -4513,7 +4529,8 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( }, OpType::SCATTER, "nccl:scatter", - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 2ba68cda9cc3ea..b6635157d75e1b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -762,7 +762,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool avoidRecordStreams = false, + bool nanCheck = true); template c10::intrusive_ptr collective( @@ -773,7 +774,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PostProcess post, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool avoidRecordStreams = false, + bool nanCheck = true); template c10::intrusive_ptr collectiveCoalesced( From c22d1119f010348f17a9df91364796f810d2d7c0 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Thu, 29 Aug 2024 16:18:39 +0000 Subject: [PATCH 0230/1018] [inductor] calibration inductor windows uts (15/N) (#134586) Fix `test_logs_out` UT on Windows. make `test/dynamo/test_logging.py` all UTs pass on Windows. Changes: 1. Close `NamedTemporaryFile` to release file handle to avoid PermissionError issue. 2. `PermissionError` setup as `delete=False`, let file not be auto deleted. 3. Open log file as "utf-8" to align with Linux. 4. Process wrap difference for Windows. 5. Delete tmp file manually. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134586 Approved by: https://github.com/jansel --- test/dynamo/test_logging.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index a43951cc719f31..8e16443abb9153 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -9,7 +9,8 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch.distributed as dist -from torch._dynamo.testing import skipIfNotPy311 +from torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311 +from torch._dynamo.trace_rules import _as_posix_path from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_utils import ( find_free_port, @@ -668,10 +669,18 @@ def fn(a): def test_logs_out(self): import tempfile - with tempfile.NamedTemporaryFile() as tmp: + with tempfile.NamedTemporaryFile(delete=False) as tmp: + file_path = _as_posix_path(tmp.name) + """ + NamedTemporaryFile will include a file open operation. + On Windowsm the file is opened by NamedTemporaryFile, the + following run_process_no_exception can't access a opened file. + And then, raise a PermissionError: [Errno 13] Permission denied: [file_path] + """ + tmp.close() env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" - env["TORCH_LOGS_OUT"] = tmp.name + env["TORCH_LOGS_OUT"] = file_path stdout, stderr = self.run_process_no_exception( """\ import torch @@ -683,9 +692,18 @@ def fn(a): """, env=env, ) - with open(tmp.name) as fd: + with open( + file_path, encoding="utf-8" + ) as fd: # encoding file to UTF-8 for Windows. lines = fd.read() - self.assertEqual(lines, stderr.decode("utf-8")) + fd.close() + os.remove( + file_path + ) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False. + self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix. + empty_line_normalizer(lines), + empty_line_normalizer(stderr.decode("utf-8")), + ) # single record tests From bb608896e9ecad03613eb453846c4992197f78fe Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 29 Aug 2024 09:09:51 -0700 Subject: [PATCH 0231/1018] [3/N] Set correct device to CUDA guards (#134357) In `collective()`, `pointToPoint()` and `collectiveCoalesced()`, CUDA guards were created with an unset (default) CUDA device. This is the reason for the IMA facing the NaN checker in issue https://github.com/pytorch/pytorch/issues/134062. With this fix, `torch.cuda.set_device(device)` is not needed to work around the IMA. Also refactored a couple places where the guard is created -- preferably we create the guard with a known device, rather than setting the device later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134357 Approved by: https://github.com/wconstab, https://github.com/shuqiangzhang ghstack dependencies: #134345 --- test/distributed/test_c10d_nccl.py | 19 +++++++++++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 18 ++++++++++-------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 55b785a3552b5a..3a575729cdff78 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -350,6 +350,7 @@ def test_close_pg(self): @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16]) @skip_if_rocm def test_nan_assert(self, type): + # Expecting a device-side error when NaN is detected os.environ["TORCH_NCCL_NAN_CHECK"] = "1" store = c10d.FileStore(self.file_name, self.world_size) pg = self._create_process_group_nccl(store, self.opts()) @@ -388,6 +389,24 @@ def test_nan_p2p(self): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_nan_check(self): + # Not expecting an error, NaN check should not make legit code fail + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank + t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + c10d.broadcast(x, src=0) + c10d.all_reduce(t) + c10d.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 65e9fea044932d..878bb7c8bec2d2 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1149,6 +1149,10 @@ void ProcessGroupNCCL::abortCommsFromMap( at::cuda::OptionalCUDAGuard gpuGuard; at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); if (deviceIndex >= 0) { + // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in + // the map could be non deviceIndex, but rank to rank numbers. So we + // indeed need to check if deviceIndex >= 0 + // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` gpuGuard.set_index(deviceIndex); } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " @@ -2142,7 +2146,9 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( bool batchP2P = ncclActiveGroupCounter_ > 0; bool singleP2POp = isP2POp(opType, batchP2P); - at::cuda::OptionalCUDAGuard gpuGuard; + // Get the device index + auto deviceIndex = device.index(); + at::cuda::OptionalCUDAGuard gpuGuard(device); // [Group Start/End Note] This is used to ensure that nccl communicator will // be created before communication primitives are called. Let's look at this @@ -2182,10 +2188,6 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( rank = p2pRank; } - // Get the device index - auto deviceIndex = device.index(); - gpuGuard.set_index(deviceIndex); - #ifdef NCCL_HAS_COMM_SPLIT if (options_->split_from) { TORCH_CHECK( @@ -2695,7 +2697,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( work->stashed_for_allocator_safety_->push_back(input); } - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); if (nanCheck) { checkForNan(input, ncclStream); @@ -2860,7 +2862,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) @@ -3128,7 +3130,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); // Only check for NaN for send ops, for recv ops `tensor` can be a random // placeholder From 50170de1b796a04d48592fdbbe9f3ec799664d95 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 29 Aug 2024 16:39:19 +0000 Subject: [PATCH 0232/1018] Revert "[reland][dtensor][MTPG] make sharding prop lru cache not shared among threads (#134509)" This reverts commit ba5aec88c678fe4b9ad101602c29726724f56e21. Reverted https://github.com/pytorch/pytorch/pull/134509 on behalf of https://github.com/ZainRizvi due to Sorry but this fails internally. For details see D61953754 ([comment](https://github.com/pytorch/pytorch/pull/134509#issuecomment-2318323161)) --- test/distributed/_tensor/test_dtensor_ops.py | 2 ++ torch/distributed/_tensor/_sharding_prop.py | 20 ++------------------ 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 532bd3facae554..d3377bd47af97b 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -304,6 +304,7 @@ def wrapped(fn): xfail("nn.functional.elu"), xfail("nn.functional.fractional_max_pool2d"), xfail("nn.functional.fractional_max_pool3d"), + xfail("nn.functional.gaussian_nll_loss"), xfail("nn.functional.glu"), xfail("nn.functional.grid_sample"), xfail("nn.functional.group_norm"), @@ -349,6 +350,7 @@ def wrapped(fn): xfail("nn.functional.pdist"), xfail("nn.functional.pixel_shuffle"), xfail("nn.functional.pixel_unshuffle"), + xfail("nn.functional.poisson_nll_loss"), xfail("nn.functional.prelu"), xfail("nn.functional.relu6"), xfail("nn.functional.rrelu"), diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index c8c8ddb9817c1c..aefb4e41c94479 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import threading from functools import lru_cache from itertools import chain from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union @@ -38,17 +37,6 @@ def _length(obj) -> int: return len(obj) -class LocalLRUCache(threading.local): - def __init__(self, user_function: Callable) -> None: - self.cache = lru_cache(None)(user_function) - - def __call__(self, *args, **kwargs) -> object: - return self.cache(*args, **kwargs) - - def cache_info(self): - return self.cache.cache_info() - - class ShardingPropagator: def __init__(self) -> None: self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} @@ -58,9 +46,7 @@ def __init__(self) -> None: ] = {} # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} - self.propagate_op_sharding = LocalLRUCache( - self.propagate_op_sharding_non_cached - ) + self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign] # op map to save indices of shape (and stride) args which may need to be modified in sharding prop self.op_to_shape_and_stride_idx: Dict[ OpOverload, Union[int, Tuple[int, int]] @@ -197,9 +183,7 @@ def propagate(self, op_info: OpInfo) -> None: if op_info.schema.has_symints: output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: - output_sharding = cast( - OutputSharding, self.propagate_op_sharding(op_info.schema) - ) + output_sharding = self.propagate_op_sharding(op_info.schema) op_info.output_sharding = output_sharding def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: From f220d45ef1b1d1ef586424168a51073120095676 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 29 Aug 2024 09:09:52 -0700 Subject: [PATCH 0233/1018] [4/N] Test NaN checker against broadcast (#134701) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134701 Approved by: https://github.com/wconstab ghstack dependencies: #134345, #134357 --- test/distributed/test_c10d_nccl.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 3a575729cdff78..0fb69028f93fa5 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -369,7 +369,7 @@ def test_nan_assert(self, type): @requires_nccl() @skip_if_lt_x_gpu(2) - def test_nan_p2p(self): + def test_nan_rank_filter(self): # Putting NaN at recv buffer, program should not fail as NaN checker # should not check on receive buffer os.environ["TORCH_NCCL_NAN_CHECK"] = "1" @@ -379,11 +379,15 @@ def test_nan_p2p(self): backend="nccl", store=store, rank=self.rank, world_size=self.world_size ) t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + if self.rank != 0: + # Putting NaN at recv buffer + t[1, 1] = float("nan") + # Against broadcast + c10d.broadcast(t, 0) + # Against P2P if self.rank == 0: c10d.send(t, 1) elif self.rank == 1: - # Putting NaN at recv buffer - t[1, 1] = float("nan") c10d.recv(t, 0) c10d.destroy_process_group() # reset env From dc93755cfa3c90a9cb31f9cb42c304fea0cd53d9 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 29 Aug 2024 17:06:29 +0000 Subject: [PATCH 0234/1018] [export] use single FQN for param_buffer_mapping (#134500) Fixes #133252 In strict mode, we have this routine for mapping traced parameters to their FQNs using tensor ids. Currently we assume there's at least 1 unique FQN for each traced parameter, but this seems to break with parameter reuse when call_module nodes are present. Adding a test case where this breaks. Fixes this by assigning the same FQN to all traced parameters with the same tensor id. This is fine because we return the original state_dict for the EP, and the unflattener has its own routine of handling aliasing: https://github.com/pytorch/pytorch/pull/125758 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134500 Approved by: https://github.com/angelayi --- test/export/test_export.py | 12 ++++++++++++ torch/export/_trace.py | 18 ++++++------------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index a109e324d3ab2b..4eb5cc9fa3a55b 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3557,6 +3557,18 @@ def forward(self, x): ) ) + def test_use_embedding_twice(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed = torch.nn.Embedding(4, 4) + + def forward(self, x): + return self.embed(x) + self.embed.weight[x] + + inputs = (torch.tensor([0, 1, 2, 3]),) + ep = export(Foo(), inputs) + def test_module_with_dict_container_inp_out(self): class MyLinear(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 66c34d1047a89d..1c963691058f47 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -312,18 +312,12 @@ def _get_param_buffer_mapping( of a traced module to what the original module contains. """ - param_lookup: Dict[int, List[str]] = {} - buffer_lookup: Dict[int, List[str]] = {} + param_lookup: Dict[int, str] = {} + buffer_lookup: Dict[int, str] = {} for name, param in original_module.named_parameters(remove_duplicate=False): - param_lookup.setdefault(id(param), []).append(name) + param_lookup[id(param)] = name for name, buffer in original_module.named_buffers(remove_duplicate=False): - buffer_lookup.setdefault(id(buffer), []).append(name) - - # reverse lists so FQN assignment is FIFO wrt model structure - for name, fqns in param_lookup.items(): - param_lookup[name] = fqns[::-1] - for name, fqns in buffer_lookup.items(): - buffer_lookup[name] = fqns[::-1] + buffer_lookup[id(buffer)] = name param_buffer_table: Dict[str, str] = {} for dynamo_name, dynamo_param in traced_module.named_parameters( @@ -331,14 +325,14 @@ def _get_param_buffer_mapping( ): assert dynamo_name not in param_buffer_table if id(dynamo_param) in param_lookup: - param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop() + param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)] for dynamo_name, dynamo_buffer in traced_module.named_buffers( remove_duplicate=False ): assert dynamo_name not in param_buffer_table if id(dynamo_buffer) in buffer_lookup: - param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop() + param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)] return param_buffer_table From 3ccde1c3e5903e66705c9a0506c25210c4f4232d Mon Sep 17 00:00:00 2001 From: chilli Date: Wed, 28 Aug 2024 19:39:44 -0700 Subject: [PATCH 0235/1018] Changed masked out rows logsumexp to be -inf and not zero (#134650) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134650 Approved by: https://github.com/yanboliang, https://github.com/BoyuanFeng, https://github.com/drisspg --- test/inductor/test_flex_attention.py | 82 ++++++++++++++++++++++- test/inductor/test_flex_decoding.py | 2 +- torch/_higher_order_ops/flex_attention.py | 4 +- torch/_inductor/kernel/flex_attention.py | 4 +- torch/_inductor/kernel/flex_decoding.py | 16 ++--- 5 files changed, 95 insertions(+), 13 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 83ec40ac7036a2..4652166ceb62b5 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1507,7 +1507,7 @@ def mask_mod(b, h, q, kv): ) out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True) self.assertEqual(out[:, :, M:, :].sum(), 0) - self.assertTrue((lse[:, :, M:] == 0.0).all()) + self.assertTrue((lse[:, :, M:] == -float("inf")).all()) loss = out.sum() + lse.sum() loss.backward() @@ -1641,6 +1641,86 @@ def test_force_write_lse(self): torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) + @supported_platform + @common_utils.parametrize("backend", ["flex_attention", "flex_decode", "eager"]) + def test_lse_masked_output(self, backend): + if backend == "flex_decode": + kernel_options = {"FORCE_USE_FLEX_ATTENTION": False} + flex_call = torch.compile(flex_attention, fullgraph=True) + elif backend == "flex_attention": + kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} + flex_call = torch.compile(flex_attention, fullgraph=True) + else: + kernel_options = {} + flex_call = flex_attention + + N_CTX = 96 + SLIDING_WINDOW = 64 + make_tensor = functools.partial( + torch.randn, + (2, 2, N_CTX, 64), + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + + def sliding_window_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + window_mask = q_idx - kv_idx <= SLIDING_WINDOW + return causal_mask & window_mask + + def global_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + window_mask = q_idx - kv_idx > SLIDING_WINDOW + return causal_mask & window_mask + + sliding_window_causal = torch.nn.attention.flex_attention.create_block_mask( + sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX + ) + global_causal = torch.nn.attention.flex_attention.create_block_mask( + global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX + ) + + local_attn = functools.partial( + flex_call, + block_mask=sliding_window_causal, + return_lse=True, + kernel_options=kernel_options, + ) + global_attn = functools.partial( + flex_call, + block_mask=global_causal, + return_lse=True, + kernel_options=kernel_options, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + gradOut = make_tensor(requires_grad=False) + + x_local, lse_local = local_attn(q, k, v) + x_global, lse_global = global_attn(q, k, v) + + max_lse = torch.maximum(lse_local, lse_global) + lse_global = lse_global - max_lse + lse_local = lse_local - max_lse + lse_global = torch.exp(lse_global) + lse_local = torch.exp(lse_local) + x = ((x_local * lse_local[..., None]) + (x_global * lse_global[..., None])) / ( + lse_global[..., None] + lse_local[..., None] + ) + x.backward(gradOut) + flex_q_grad, flex_k_grad, flex_v_grad = q.grad, k.grad, v.grad + q.grad = None + k.grad = None + v.grad = None + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) + out.backward(gradOut) + + torch.testing.assert_close(x, out, atol=3e-3, rtol=2e-3) + torch.testing.assert_close(flex_q_grad, q.grad, atol=3e-3, rtol=2e-3) + torch.testing.assert_close(flex_k_grad, k.grad, atol=3e-3, rtol=2e-3) + torch.testing.assert_close(flex_v_grad, v.grad, atol=3e-3, rtol=2e-3) + @supported_platform def test_small_q_kv_len(self): make_tensor = functools.partial( diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index bc41ff2fe8b16a..4d22fd877ecabb 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -810,7 +810,7 @@ def mask_mod(b, h, q, kv): query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True ) self.assertEqual(out[:, :, M:, :].sum(), 0) - self.assertTrue((lse[:, :, M:] == 0.0).all()) + self.assertTrue((lse[:, :, M:] == -float("inf")).all()) loss = out.sum() + lse.sum() loss.backward() diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index d66f8f4ba54d84..4890fa499048c3 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -207,7 +207,7 @@ def math_attention( # Set fully masked rows' sumexp to 0.0 logsumexp = post_mod_scores.logsumexp(dim=-1) masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1) - logsumexp = torch.where(masked_rows, 0.0, logsumexp) + logsumexp = torch.where(masked_rows, -float("inf"), logsumexp) post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1) @@ -705,7 +705,9 @@ def sdpa_dense_backward( score_mod_other_buffers, mask_mod_other_buffers, ) + masked_out_rows = logsumexp == -float("inf") softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1)) + softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores) grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 24ba8ba1ec1191..5c87cc40a8e32a 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -314,8 +314,6 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step l_i = tl.where(l_i == 0.0, 1, l_i) - masked_out_rows = (m_i == float("-inf")) - m_i = tl.where(masked_out_rows, 0, m_i) acc = acc / l_i[:, None] idx_z = tl.program_id(1) // HQ @@ -1022,6 +1020,7 @@ def flex_attention_backward_grid( else: Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) lse = lse[:, None] # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1454,6 +1453,7 @@ def bwd_dkdv_block_mn( else: qT = tl.load(qT_ptrs, mask=offs_m1[None, :] < Q_LEN) lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) qkT = tl.dot(k, qT) if not PRESCALE_QK: qkT *= SM_SCALE diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 870e6194aec1e1..8a5b7c60eafdda 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -210,12 +210,12 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(QK_HEAD_DIM, KV_LEN), # (d, N) - strides=(stride_kk, stride_kn), - offsets=(0, off_n), - block_shape=(QK_HEAD_DIM, BLOCK_N), - order=(0, 1) + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM, BLOCK_N), + order=(0, 1) ) V_block_ptr = tl.make_block_ptr( base=V + v_offset, @@ -277,7 +277,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me idx_hq = off_hkv*G + off_g[:, None, None] idx_m = off_m[None, :, None] idx_d = offs_vd[None, None, :] - # TODO generalize and add proper mask support + mask = (idx_m < Q_LEN) acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} @@ -546,8 +546,8 @@ def create_flex_decoding_kernel(*args, **kwargs): # See [Note] Handle fully masked out rows: # g_M Is the global max among split kv blocks. masked_rows = lowerings[aten.eq](g_M, -float("inf")) - g_M = lowerings[aten.where](masked_rows, 0.0, g_M) adj_M = lowerings[aten.sub](buf_M, g_M) + adj_M = lowerings[aten.where](masked_rows, 0, adj_M) alpha = lowerings[aten.exp2](adj_M) buf_L = lowerings[aten.mul](buf_L, alpha) From 4853785234401499f0f492adc06fa7ec67febbb6 Mon Sep 17 00:00:00 2001 From: Xintong Hu Date: Thu, 29 Aug 2024 18:32:04 +0000 Subject: [PATCH 0236/1018] [PT2] Fix node metadata setting in group_batch_fusion_aten (#134543) Summary: Current impl results in `meta` missing fields like`val`, use `FakeTensorProp` to update the information Differential Revision: D61832932 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134543 Approved by: https://github.com/frank-wei --- torch/_inductor/fx_passes/pre_grad.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index eef058ab0b5aa2..e9cc56b2b0f6b4 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -12,6 +12,7 @@ matches_module_pattern, replace_node_module, ) +from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F @@ -168,6 +169,10 @@ def shape_prop(mod) -> None: example_inputs, "[Pre grad(predispatch IR)] Apply group_batch_fusion", ) + # update node.meta after group batch fusion + FakeTensorProp(module=gm, mode=detect_fake_mode(example_inputs)).propagate( + *example_inputs + ) pass_execution_and_save( normalize_node_kwargs_pass, gm, From 5203340b2ba60ea1f490ced1a95e286b8550f937 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Thu, 29 Aug 2024 18:35:44 +0000 Subject: [PATCH 0237/1018] [BE][Ez]: Update ruff to 0.6.3 (#134769) Mostly bugfix release, updating because it fixes an edgecase with a rule we are using Pull Request resolved: https://github.com/pytorch/pytorch/pull/134769 Approved by: https://github.com/albanD --- .lintrunner.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index dafe6207a7c9a6..55849ff4695a03 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1482,7 +1482,7 @@ init_command = [ 'black==23.12.1', 'usort==1.0.8.post1', 'isort==5.13.2', - 'ruff==0.6.0', # sync with RUFF + 'ruff==0.6.3', # sync with RUFF ] is_formatter = true @@ -1567,7 +1567,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.6.0', # sync with PYFMT + 'ruff==0.6.3', # sync with PYFMT ] is_formatter = true From 650a84a1bbbf000477152b1f4af96425fe9974d5 Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Thu, 29 Aug 2024 18:38:45 +0000 Subject: [PATCH 0238/1018] [AOTI] Fix cosmetic indentation issue in cuda cpp wrapper codegen for DeferredCudaKernelLine/GridLine (#134705) Summary: Follow up fix for D61018114, D61800622 Increase indentation for `loadKernel` `launchKernel` and `Grid` lines. Test Plan: ``` TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_zero_grid_with_unbacked_symbols_abi_compatible_cuda ``` ``` TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_zero_grid_with_backed_symbols_abi_compatible_cuda ``` Differential Revision: D61927248 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134705 Approved by: https://github.com/ColinPeppler --- torch/_inductor/codegen/cpp_wrapper_cuda.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index cc2baaa72c2c21..5aabf6bd4c1d7d 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -145,7 +145,7 @@ def __call__(self): grid_args_str = ", ".join( [cexpr(V.graph.sizevars.simplify(item)) for item in grid] ) - return f"Grid {self.grid_var} = Grid({grid_args_str});" + return f" Grid {self.grid_var} = Grid({grid_args_str});" def _new_line(self, line): return DeferredCudaGridLine( @@ -251,9 +251,13 @@ def generate_load_kernel_once( self.writeline( DeferredCudaKernelLine( kernel_name, - kernel_var_name + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);""" + """ """ + + kernel_var_name + + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);""" if V.graph.aot_mode - else kernel_var_name + """ = loadKernel("%s", "%s", %s);""", + else """ """ + + kernel_var_name + + """ = loadKernel("%s", "%s", %s);""", keys, ) ) @@ -404,7 +408,7 @@ def generate_kernel_call( self.writeline( DeferredCudaKernelLine( kernel_name, - r"launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( + r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( kernel_var_name, f"{grid_var}.grid_x", f"{grid_var}.grid_y", From d42a199bf99fa0204f9289cc30f96fb9a5893e98 Mon Sep 17 00:00:00 2001 From: zdevito Date: Thu, 29 Aug 2024 09:19:05 -0700 Subject: [PATCH 0239/1018] expandable_segments <-> other allocator options (#134338) Previously setting garbage_collection_threshold or max_split_size_mb along with expandable_segments:True could cause the allocator to hit assert failures when running nearly out of memory. This PR ensures garbage_collection and max_split freeing do not accidentally try to release expandable segments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134338 Approved by: https://github.com/ezyang --- c10/cuda/CUDACachingAllocator.cpp | 16 +++++---- test/test_cuda.py | 56 +++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index bfab4507b62b2c..fa26a9f3e0cca7 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -2611,7 +2611,7 @@ class DeviceCachingAllocator { while (it != large_blocks.blocks.end()) { Block* block = *it; ++it; - if (!block->is_split() && + if (!block->is_split() && !block->expandable_segment_ && static_cast(block->gc_count()) >= age_threshold) { block_freed = true; gc_reclaimed += block->size; @@ -2754,7 +2754,8 @@ class DeviceCachingAllocator { ? CUDAAllocatorConfig::max_split_size() : key.size; auto it = pool.blocks.lower_bound(&key); - if (it == pool.blocks.end() || (*it)->stream != p.stream()) { + if (it == pool.blocks.end() || (*it)->stream != p.stream() || + (*it)->expandable_segment_) { // No single block is large enough; free multiple oversize blocks, // starting with the largest if (it == pool.blocks.begin()) @@ -2766,12 +2767,15 @@ class DeviceCachingAllocator { ((*it)->size >= CUDAAllocatorConfig::max_split_size()) && ((*it)->stream == p.stream())) { auto cur = it; - totalReleased += (*it)->size; - if (it != pool.blocks.begin()) { + bool is_first = cur == pool.blocks.begin(); + if (!is_first) { --it; + } + if (!(*cur)->expandable_segment_) { release_block(*cur, context); - } else { - release_block(*cur, context); + totalReleased += (*cur)->size; + } + if (is_first) { break; } } diff --git a/test/test_cuda.py b/test/test_cuda.py index b2f1ef3b53cf25..f53ae5f61da150 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -4098,6 +4098,62 @@ def foo(): finally: torch.cuda.memory._record_memory_history(None) + def test_max_split_expandable(self): + torch.cuda.memory.empty_cache() + mb = 1024 * 1024 + _, all_memory = torch.cuda.memory.mem_get_info() + total_allowed = 120 * mb + fraction_allowed = total_allowed / all_memory + assert int(fraction_allowed * all_memory) == total_allowed + torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) + + def alloc(n): + return torch.ones(n * mb, dtype=torch.int8, device="cuda") + + torch.cuda.memory._set_allocator_settings( + "expandable_segments:False,max_split_size_mb:40" + ) + a = alloc(40) + torch.cuda.memory._set_allocator_settings( + "expandable_segments:True,max_split_size_mb:40" + ) + b = alloc(40) + torch.cuda.memory._set_allocator_settings( + "expandable_segments:False,max_split_size_mb:40" + ) + c = alloc(40) + with self.assertRaises(torch.OutOfMemoryError): + alloc(40) + del a, b, c + # force release_cached_blocks to run with some expandable segments in the free list + alloc(120) + + def test_garbage_collect_expandable(self): + torch.cuda.memory.empty_cache() + mb = 1024 * 1024 + _, all_memory = torch.cuda.memory.mem_get_info() + total_allowed = 120 * mb + fraction_allowed = total_allowed / all_memory + assert int(fraction_allowed * all_memory) == total_allowed + torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) + + def alloc(n): + return torch.ones(n * mb, dtype=torch.int8, device="cuda") + + torch.cuda.memory._set_allocator_settings( + "expandable_segments:False,garbage_collection_threshold:0.5" + ) + a = alloc(40) + torch.cuda.memory._set_allocator_settings( + "expandable_segments:True,garbage_collection_threshold:0.5" + ) + b = alloc(40) + del a, b + # causes GC to run. The expandable segment block will be split + # so GC would not attempt to free it anyway, but this at least makes sure + # expandable_segment blocks can be in the free list when this is called. + alloc(80) + def test_allocator_settings(self): def power2_div(size, div_factor): pow2 = 1 From 656474474c12c9fa22095ee3993cb971d4be2148 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 29 Aug 2024 09:09:52 -0700 Subject: [PATCH 0240/1018] [5/N] Reconcile barrier and NaN checker (#134707) By using a zeros() tensor instead of empty() tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134707 Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab ghstack dependencies: #134345, #134357, #134701 --- test/distributed/test_c10d_nccl.py | 1 + torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 0fb69028f93fa5..36c3e28e16065a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -407,6 +407,7 @@ def test_nan_check(self): t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) c10d.broadcast(x, src=0) c10d.all_reduce(t) + c10d.barrier() c10d.destroy_process_group() # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 878bb7c8bec2d2..3fe8aab1764aa6 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -4064,8 +4064,10 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx); // Create a dummy tensor on the device + // Note: we use zeros() instead of empty() to prevent barrier from triggering + // alarm when NaN checker is enabled. at::Tensor barrierTensor = - at::empty({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); // All reduce to achieve the barrier auto work = allreduce_impl(barrierTensor); From 31c5ec1877c74449ff3d915bdd597fc31669c7fb Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Wed, 28 Aug 2024 16:05:26 -0400 Subject: [PATCH 0241/1018] register all entry_point backends on first attempt (#132546) fixes: https://github.com/pytorch/pytorch/issues/131360 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132546 Approved by: https://github.com/jansel --- test/dynamo/test_backends.py | 36 ++++++++++++++++++ torch/_dynamo/backends/registry.py | 60 +++++++++++++++++------------- 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index d2aacd15f5e3d0..e9d94d4f67be06 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -1,8 +1,11 @@ # Owner(s): ["module: dynamo"] +import sys import unittest +from unittest.mock import MagicMock, patch import torch import torch._dynamo +import torch._dynamo.backends import torch._dynamo.test_case from torch._dynamo.backends.debugging import ExplainWithBackend from torch._dynamo.backends.onnxrt import has_onnxruntime @@ -294,6 +297,39 @@ def f(x): opt_f(torch.randn(3, 3)) self.assertTrue(backend_run) + def test_lookup_custom_backend(self): + from torch._dynamo import list_backends + + backends_group = "torch_dynamo_backends" + name = "mycustombackend" + + mock_3_9 = MagicMock() + mock_3_9.load.return_value = lambda: "mocked 3.9" + mock_3_9.name = name + + mock_3_10 = MagicMock() + mock_3_10.load.return_value = lambda: "mocked 3.10" + + def mock_eps(group=None): + if sys.version_info < (3, 10): + return {backends_group: [mock_3_9]} + else: + assert group == backends_group, group + mock_group = MagicMock() + mock_group.names = [name] + mock_group[name] = mock_3_10 + # mock_group[name].load.return_value = lambda: "mocked 3.10" + return mock_group + + with patch("importlib.metadata.entry_points", mock_eps): + from torch._dynamo.backends import registry + + registry._lazy_import.cache_clear() + registry._discover_entrypoint_backends.cache_clear() + + backends = list_backends() + assert name in backends, (name, backends) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index e3538a4f6730f6..749d11937ea330 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -1,13 +1,18 @@ # mypy: ignore-errors import functools +import logging import sys +from importlib.metadata import EntryPoint from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple import torch from torch import fx +log = logging.getLogger(__name__) + + class CompiledFn(Protocol): def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: ... @@ -15,7 +20,8 @@ def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn] -_BACKENDS: Dict[str, CompilerFn] = {} +_BACKENDS: Dict[str, Optional[EntryPoint]] = {} +_COMPILER_FNS: Dict[str, CompilerFn] = {} def register_backend( @@ -39,8 +45,10 @@ def register_backend( return functools.partial(register_backend, name=name, tags=tags) assert callable(compiler_fn) name = name or compiler_fn.__name__ - assert name not in _BACKENDS, f"duplicate name: {name}" - _BACKENDS[name] = compiler_fn + assert name not in _COMPILER_FNS, f"duplicate name: {name}" + if compiler_fn not in _BACKENDS: + _BACKENDS[name] = None + _COMPILER_FNS[name] = compiler_fn compiler_fn._tags = tuple(tags) return compiler_fn @@ -56,13 +64,15 @@ def lookup_backend(compiler_fn): if isinstance(compiler_fn, str): if compiler_fn not in _BACKENDS: _lazy_import() - if compiler_fn not in _BACKENDS: - _lazy_import_entry_point(compiler_fn) if compiler_fn not in _BACKENDS: from ..exc import InvalidBackend raise InvalidBackend(name=compiler_fn) - compiler_fn = _BACKENDS[compiler_fn] + + if compiler_fn not in _COMPILER_FNS: + entry_point = _BACKENDS[compiler_fn] + register_backend(compiler_fn=entry_point.load(), name=compiler_fn) + compiler_fn = _COMPILER_FNS[compiler_fn] return compiler_fn @@ -74,13 +84,14 @@ def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: """ _lazy_import() exclude_tags = set(exclude_tags or ()) - return sorted( - [ - name - for name, backend in _BACKENDS.items() - if not exclude_tags.intersection(backend._tags) - ] - ) + + backends = [ + name + for name in _BACKENDS.keys() + if name not in _COMPILER_FNS + or not exclude_tags.intersection(_COMPILER_FNS[name]._tags) + ] + return sorted(backends) @functools.lru_cache(None) @@ -94,22 +105,21 @@ def _lazy_import(): assert dynamo_minifier_backend is not None + _discover_entrypoint_backends() + @functools.lru_cache(None) -def _lazy_import_entry_point(backend_name: str): +def _discover_entrypoint_backends(): + # importing here so it will pick up the mocked version in test_backends.py from importlib.metadata import entry_points - compiler_fn = None group_name = "torch_dynamo_backends" if sys.version_info < (3, 10): - backend_eps = entry_points() - eps = [ep for ep in backend_eps.get(group_name, ()) if ep.name == backend_name] - if len(eps) > 0: - compiler_fn = eps[0].load() + eps = entry_points() + eps = eps[group_name] if group_name in eps else [] + eps = {ep.name: ep for ep in eps} else: - backend_eps = entry_points(group=group_name) - if backend_name in backend_eps.names: - compiler_fn = backend_eps[backend_name].load() - - if compiler_fn is not None and backend_name not in list_backends(()): - register_backend(compiler_fn=compiler_fn, name=backend_name) + eps = entry_points(group=group_name) + eps = {name: eps[name] for name in eps.names} + for backend_name in eps: + _BACKENDS[backend_name] = eps[backend_name] From 66acfd4a4f166efb0a9fa2aa07b7609c76271bac Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 28 Aug 2024 21:47:29 -0700 Subject: [PATCH 0242/1018] [inductor] Fix error in debug_str_extra (#134747) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134747 Approved by: https://github.com/Skylion007, https://github.com/shunting314 --- torch/_inductor/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 97953f7866d016..38142f42495273 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1044,10 +1044,10 @@ def debug_str_extra(self) -> str: for i, node in enumerate(self.snodes) ] node = self.snodes[0].node - assert node is not None - device = node.get_device() - if ir.is_triton(device): - lines.extend(debug_triton_code(self)) + if node is not None: + device = node.get_device() + if ir.is_triton(device): + lines.extend(debug_triton_code(self)) return textwrap.indent("\n".join(lines).rstrip(), " ") From fcda8e50d1314bdd6af8bfac90f214e84a16d3a9 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 29 Aug 2024 19:15:59 +0000 Subject: [PATCH 0243/1018] Move py 3.8->3.9 pull, trunk, inductor, prerioric CI tests (#133624) Part of Deprecation of python 3.8 and moving to 3.9. Related to: https://github.com/pytorch/pytorch/issues/120718 Except XPU and ROCM jobs Pull Request resolved: https://github.com/pytorch/pytorch/pull/133624 Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/ZainRizvi --- .ci/docker/build.sh | 24 ++-- .ci/docker/common/install_onnx.sh | 2 +- .ci/docker/requirements-ci.txt | 2 +- .ci/pytorch/test.sh | 2 +- .github/merge_rules.yaml | 4 +- .github/workflows/docker-builds.yml | 8 +- .../inductor-perf-test-nightly-x86.yml | 32 ++--- .github/workflows/inductor.yml | 20 +-- .github/workflows/nightly.yml | 6 +- .github/workflows/periodic.yml | 20 +-- .github/workflows/pull.yml | 114 +++++++++--------- .github/workflows/slow.yml | 20 +-- scripts/compile_tests/download_reports.py | 10 +- scripts/compile_tests/update_failures.py | 4 +- 14 files changed, 134 insertions(+), 134 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index decc0caeac3514..61f7acfe7633fe 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -236,7 +236,7 @@ case "$image" in TRITON=yes ;; pytorch-linux-focal-py3-clang10-onnx) - ANACONDA_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.9 CLANG_VERSION=10 PROTOBUF=yes DB=yes @@ -245,7 +245,7 @@ case "$image" in ONNX=yes ;; pytorch-linux-focal-py3-clang9-android-ndk-r21e) - ANACONDA_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.9 CLANG_VERSION=9 LLVMDEV=yes PROTOBUF=yes @@ -254,8 +254,8 @@ case "$image" in GRADLE_VERSION=6.8.3 NINJA_VERSION=1.9.0 ;; - pytorch-linux-focal-py3.8-clang10) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-focal-py3.9-clang10) + ANACONDA_PYTHON_VERSION=3.9 CLANG_VERSION=10 PROTOBUF=yes DB=yes @@ -276,8 +276,8 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-py3.8-gcc9) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-focal-py3.9-gcc9) + ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=9 PROTOBUF=yes DB=yes @@ -318,8 +318,8 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) + ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=11 PROTOBUF=yes DB=yes @@ -330,8 +330,8 @@ case "$image" in DOCS=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12) + ANACONDA_PYTHON_VERSION=3.9 CUDA_VERSION=11.8 CUDNN_VERSION=9 CLANG_VERSION=12 @@ -355,8 +355,8 @@ case "$image" in CONDA_CMAKE=yes VISION=yes ;; - pytorch-linux-jammy-py3.8-gcc11) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-jammy-py3.9-gcc11) + ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=11 PROTOBUF=yes DB=yes diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 1d384233163d8d..65907a99f2572a 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -15,7 +15,7 @@ pip_install \ flatbuffers==2.0 \ mock==5.0.1 \ ninja==1.10.2 \ - networkx==2.0 \ + networkx==2.5 \ numpy==1.24.2 # ONNXRuntime should be installed before installing diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 6e0445b1a74a28..9a9b568dd7e2b9 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -104,7 +104,7 @@ networkx==2.8.8 #test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py numba==0.49.0 ; python_version < "3.9" -numba==0.54.1 ; python_version == "3.9" +numba==0.55.2 ; python_version == "3.9" numba==0.55.2 ; python_version == "3.10" #Description: Just-In-Time Compiler for Numerical Functions #Pinned versions: 0.54.1, 0.49.0, <=0.49.1 diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 169cb8ce5f8943..53970cf346658c 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1479,7 +1479,7 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then install_torchvision test_inductor_shard "${SHARD_NUMBER}" if [[ "${SHARD_NUMBER}" == 1 ]]; then - if [[ "${BUILD_ENVIRONMENT}" != linux-jammy-py3.8-gcc11-build ]]; then + if [[ "${BUILD_ENVIRONMENT}" != linux-jammy-py3.9-gcc11-build ]]; then test_inductor_distributed fi fi diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 7d0976f2bdddc1..bf80880e604144 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -107,8 +107,8 @@ mandatory_checks_name: - EasyCLA - Lint - - pull / linux-focal-py3_8-clang9-xla / build - - pull / linux-focal-py3_8-clang9-xla / test (xla, 1, 1, linux.12xlarge) + - pull / linux-focal-py3_9-clang9-xla / build + - pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge) - name: Documentation patterns: diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 5967032e876b47..8b46e314c20da1 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -45,15 +45,15 @@ jobs: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks, pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9, - pytorch-linux-focal-py3.8-clang10, + pytorch-linux-focal-py3.9-clang10, pytorch-linux-focal-py3.11-clang10, pytorch-linux-focal-py3.12-clang10, pytorch-linux-focal-rocm-n-1-py3, pytorch-linux-focal-rocm-n-py3, - pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12, + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12, pytorch-linux-focal-py3-clang9-android-ndk-r21e, - pytorch-linux-jammy-py3.8-gcc11, - pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks, + pytorch-linux-jammy-py3.9-gcc11, + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks, pytorch-linux-jammy-py3.12-halide, pytorch-linux-jammy-xpu-2024.0-py3, pytorch-linux-jammy-py3-clang15-asan, diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index 6011a133b257a9..27a1ce5db4554a 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -48,12 +48,12 @@ concurrency: permissions: read-all jobs: - linux-jammy-cpu-py3_8-gcc11-inductor-build: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-jammy-py3.8-gcc11-build - docker-image-name: pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks test-matrix: | { include: [ { config: "inductor_huggingface_perf_cpu_x86", shard: 1, num_shards: 3, runner: "linux.24xl.spr-metal" }, @@ -74,32 +74,32 @@ jobs: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-jammy-cpu-py3_8-gcc11-inductor-test-nightly: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-test-nightly: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cpu-py3_8-gcc11-inductor-build + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build if: github.event.schedule == '0 7 * * *' with: - build-environment: linux-jammy-py3.8-gcc11-build + build-environment: linux-jammy-py3.9-gcc11-build dashboard-tag: training-false-inference-true-default-true-dynamic-true-aotinductor-true - docker-image: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.test-matrix }} + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} use-gha: anything-non-empty-to-use-gha timeout-minutes: 720 secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-jammy-cpu-py3_8-gcc11-inductor-test: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cpu-py3_8-gcc11-inductor-build + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build if: github.event_name == 'workflow_dispatch' with: - build-environment: linux-jammy-py3.8-gcc11-build + build-environment: linux-jammy-py3.9-gcc11-build dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-aotinductor-${{ inputs.aotinductor }} - docker-image: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.test-matrix }} + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} use-gha: anything-non-empty-to-use-gha timeout-minutes: 720 secrets: diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 083821e94ad311..88ffe090fd8897 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -149,13 +149,13 @@ jobs: secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-jammy-cpu-py3_8-gcc11-inductor-build: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-py3.8-gcc11-build - docker-image-name: pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ @@ -204,13 +204,13 @@ jobs: secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-jammy-cpu-py3_8-gcc11-inductor-test: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cpu-py3_8-gcc11-inductor-build + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build with: - build-environment: linux-jammy-py3.8-gcc11-build - docker-image: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.9-gcc11-build + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index cef5841fa9f219..5057e9da2d1dd6 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -32,8 +32,8 @@ jobs: needs: get-label-type with: runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - build-environment: linux-jammy-py3.8-gcc11 - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 docs-push: name: docs push @@ -43,7 +43,7 @@ jobs: - get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.8-gcc11 + build-environment: linux-jammy-py3.9-gcc11 docker-image: ${{ needs.docs-build.outputs.docker-image }} push: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || startsWith(github.event.ref, 'refs/tags/v') }} run-doxygen: true diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index b12e478518d4c5..c5f6d8c853e112 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -104,14 +104,14 @@ jobs: docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} - parallelnative-linux-jammy-py3_8-gcc11-build: - name: parallelnative-linux-jammy-py3.8-gcc11 + parallelnative-linux-jammy-py3_9-gcc11-build: + name: parallelnative-linux-jammy-py3.9-gcc11 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: parallelnative-linux-jammy-py3.8-gcc11 - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + build-environment: parallelnative-linux-jammy-py3.9-gcc11 + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -119,16 +119,16 @@ jobs: { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} - parallelnative-linux-jammy-py3_8-gcc11-test: - name: parallelnative-linux-jammy-py3.8-gcc11 + parallelnative-linux-jammy-py3_9-gcc11-test: + name: parallelnative-linux-jammy-py3.9-gcc11 uses: ./.github/workflows/_linux-test.yml needs: - - parallelnative-linux-jammy-py3_8-gcc11-build + - parallelnative-linux-jammy-py3_9-gcc11-build - target-determination with: - build-environment: parallelnative-linux-jammy-py3.8-gcc11 - docker-image: ${{ needs.parallelnative-linux-jammy-py3_8-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.parallelnative-linux-jammy-py3_8-gcc11-build.outputs.test-matrix }} + build-environment: parallelnative-linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.parallelnative-linux-jammy-py3_9-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.parallelnative-linux-jammy-py3_9-gcc11-build.outputs.test-matrix }} linux-focal-cuda11_8-py3_9-gcc9-build: name: linux-focal-cuda11.8-py3.9-gcc9 diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index ed77a706c2ee89..92c5b043358e6c 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -43,14 +43,14 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} - linux-jammy-py3_8-gcc11-build: - name: linux-jammy-py3.8-gcc11 + linux-jammy-py3_9-gcc11-build: + name: linux-jammy-py3.9-gcc11 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.8-gcc11 - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -65,49 +65,49 @@ jobs: ]} secrets: inherit - linux-jammy-py3_8-gcc11-test: - name: linux-jammy-py3.8-gcc11 + linux-jammy-py3_9-gcc11-test: + name: linux-jammy-py3.9-gcc11 uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3_8-gcc11-build + - linux-jammy-py3_9-gcc11-build - target-determination with: - build-environment: linux-jammy-py3.8-gcc11 - docker-image: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.test-matrix }} secrets: inherit linux-docs: name: linux-docs uses: ./.github/workflows/_docs.yml - needs: linux-jammy-py3_8-gcc11-build + needs: linux-jammy-py3_9-gcc11-build with: - build-environment: linux-jammy-py3.8-gcc11 - docker-image: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.docker-image }} + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.docker-image }} secrets: inherit - linux-jammy-py3_8-gcc11-no-ops: - name: linux-jammy-py3.8-gcc11-no-ops + linux-jammy-py3_9-gcc11-no-ops: + name: linux-jammy-py3.9-gcc11-no-ops uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.8-gcc11-no-ops - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + build-environment: linux-jammy-py3.9-gcc11-no-ops + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} secrets: inherit - linux-jammy-py3_8-gcc11-pch: - name: linux-jammy-py3.8-gcc11-pch + linux-jammy-py3_9-gcc11-pch: + name: linux-jammy-py3.9-gcc11-pch uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.8-gcc11-pch - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + build-environment: linux-jammy-py3.9-gcc11-pch + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -148,13 +148,13 @@ jobs: sync-tag: asan-test secrets: inherit - linux-focal-py3_8-clang10-onnx-build: - name: linux-focal-py3.8-clang10-onnx + linux-focal-py3_9-clang10-onnx-build: + name: linux-focal-py3.9-clang10-onnx uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-py3.8-clang10-onnx + build-environment: linux-focal-py3.9-clang10-onnx docker-image-name: pytorch-linux-focal-py3-clang10-onnx test-matrix: | { include: [ @@ -163,26 +163,26 @@ jobs: ]} secrets: inherit - linux-focal-py3_8-clang10-onnx-test: - name: linux-focal-py3.8-clang10-onnx + linux-focal-py3_9-clang10-onnx-test: + name: linux-focal-py3.9-clang10-onnx uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_8-clang10-onnx-build + - linux-focal-py3_9-clang10-onnx-build - target-determination with: - build-environment: linux-focal-py3.8-clang10-onnx - docker-image: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.test-matrix }} + build-environment: linux-focal-py3.9-clang10-onnx + docker-image: ${{ needs.linux-focal-py3_9-clang10-onnx-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-py3_9-clang10-onnx-build.outputs.test-matrix }} secrets: inherit - linux-focal-py3_8-clang10-build: - name: linux-focal-py3.8-clang10 + linux-focal-py3_9-clang10-build: + name: linux-focal-py3.9-clang10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - build-environment: linux-focal-py3.8-clang10 - docker-image-name: pytorch-linux-focal-py3.8-clang10 + build-environment: linux-focal-py3.9-clang10 + docker-image-name: pytorch-linux-focal-py3.9-clang10 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -195,16 +195,16 @@ jobs: { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} - linux-focal-py3_8-clang10-test: - name: linux-focal-py3.8-clang10 + linux-focal-py3_9-clang10-test: + name: linux-focal-py3.9-clang10 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_8-clang10-build + - linux-focal-py3_9-clang10-build - target-determination with: - build-environment: linux-focal-py3.8-clang10 - docker-image: ${{ needs.linux-focal-py3_8-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-build.outputs.test-matrix }} + build-environment: linux-focal-py3.9-clang10 + docker-image: ${{ needs.linux-focal-py3_9-clang10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} secrets: inherit linux-focal-py3_11-clang10-build: @@ -347,14 +347,14 @@ jobs: ]} secrets: inherit - linux-jammy-cuda-11_8-cudnn9-py3_8-clang12-build: - name: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + linux-jammy-cuda-11_8-cudnn9-py3_9-clang12-build: + name: linux-jammy-cuda11.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 - docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + build-environment: linux-jammy-cuda11.8-cudnn9-py3.9-clang12 + docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -376,13 +376,13 @@ jobs: ]} secrets: inherit - linux-focal-py3_8-clang9-xla-build: - name: linux-focal-py3_8-clang9-xla + linux-focal-py3_9-clang9-xla-build: + name: linux-focal-py3_9-clang9-xla uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-py3.8-clang9-xla + build-environment: linux-focal-py3.9-clang9-xla docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.1-lite test-matrix: | { include: [ @@ -390,14 +390,14 @@ jobs: ]} secrets: inherit - linux-focal-py3_8-clang9-xla-test: - name: linux-focal-py3_8-clang9-xla + linux-focal-py3_9-clang9-xla-test: + name: linux-focal-py3_9-clang9-xla uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-py3_8-clang9-xla-build + needs: linux-focal-py3_9-clang9-xla-build with: - build-environment: linux-focal-py3.8-clang9-xla - docker-image: ${{ needs.linux-focal-py3_8-clang9-xla-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang9-xla-build.outputs.test-matrix }} + build-environment: linux-focal-py3.9-clang9-xla + docker-image: ${{ needs.linux-focal-py3_9-clang9-xla-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-py3_9-clang9-xla-build.outputs.test-matrix }} secrets: inherit win-vs2019-cpu-py3-build: @@ -488,14 +488,14 @@ jobs: ]} secrets: inherit - linux-jammy-py3_8-gcc11-mobile-lightweight-dispatch-build: - name: linux-jammy-py3.8-gcc11-mobile-lightweight-dispatch-build + linux-jammy-py3_9-gcc11-mobile-lightweight-dispatch-build: + name: linux-jammy-py3.9-gcc11-mobile-lightweight-dispatch-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.8-gcc111-mobile-lightweight-dispatch-build - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + build-environment: linux-jammy-py3.9-gcc111-mobile-lightweight-dispatch-build + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 build-generates-artifacts: false test-matrix: | { include: [ diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 932e15fe1230a2..3b253bfbb89f6f 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -102,30 +102,30 @@ jobs: docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.test-matrix }} - linux-focal-py3_8-clang10-build: - name: linux-focal-py3.8-clang10 + linux-focal-py3_9-clang10-build: + name: linux-focal-py3.9-clang10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-py3.8-clang10 - docker-image-name: pytorch-linux-focal-py3.8-clang10 + build-environment: linux-focal-py3.9-clang10 + docker-image-name: pytorch-linux-focal-py3.9-clang10 test-matrix: | { include: [ { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} - linux-focal-py3_8-clang10-test: - name: linux-focal-py3.8-clang10 + linux-focal-py3_9-clang10-test: + name: linux-focal-py3.9-clang10 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_8-clang10-build + - linux-focal-py3_9-clang10-build - target-determination with: - build-environment: linux-focal-py3.8-clang10 - docker-image: ${{ needs.linux-focal-py3_8-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-build.outputs.test-matrix }} + build-environment: linux-focal-py3.9-clang10 + docker-image: ${{ needs.linux-focal-py3_9-clang10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} linux-focal-rocm6_1-py3_8-build: name: linux-focal-rocm6.1-py3.8 diff --git a/scripts/compile_tests/download_reports.py b/scripts/compile_tests/download_reports.py index c428ddca8a04ba..7ad6521032c9ab 100644 --- a/scripts/compile_tests/download_reports.py +++ b/scripts/compile_tests/download_reports.py @@ -8,10 +8,10 @@ CONFIGS = { - "dynamo38": { - "linux-focal-py3.8-clang10 / test (dynamo, 1, 3, linux.2xlarge)", - "linux-focal-py3.8-clang10 / test (dynamo, 2, 3, linux.2xlarge)", - "linux-focal-py3.8-clang10 / test (dynamo, 3, 3, linux.2xlarge)", + "dynamo39": { + "linux-focal-py3.9-clang10 / test (dynamo, 1, 3, linux.2xlarge)", + "linux-focal-py3.9-clang10 / test (dynamo, 2, 3, linux.2xlarge)", + "linux-focal-py3.9-clang10 / test (dynamo, 3, 3, linux.2xlarge)", }, "dynamo311": { "linux-focal-py3.11-clang10 / test (dynamo, 1, 3, linux.2xlarge)", @@ -26,7 +26,7 @@ } -def download_reports(commit_sha, configs=("dynamo38", "dynamo311", "eager311")): +def download_reports(commit_sha, configs=("dynamo39", "dynamo311", "eager311")): log_dir = "tmp_test_reports_" + commit_sha def subdir_path(config): diff --git a/scripts/compile_tests/update_failures.py b/scripts/compile_tests/update_failures.py index b5511038c29edb..a56e30e99870c6 100755 --- a/scripts/compile_tests/update_failures.py +++ b/scripts/compile_tests/update_failures.py @@ -221,5 +221,5 @@ def read_test_results(directory): args = parser.parse_args() assert Path(args.filename).exists(), args.filename assert Path(args.test_dir).exists(), args.test_dir - dynamo38, dynamo311 = download_reports(args.commit, ("dynamo38", "dynamo311")) - update(args.filename, args.test_dir, dynamo38, dynamo311, args.also_remove_skips) + dynamo39, dynamo311 = download_reports(args.commit, ("dynamo39", "dynamo311")) + update(args.filename, args.test_dir, dynamo39, dynamo311, args.also_remove_skips) From 49d12214627e1b6e6551f857b1cdb38fb9fdebc7 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 29 Aug 2024 08:23:06 -0700 Subject: [PATCH 0244/1018] [Inductor] Fix error checking for scaled_mm lowering (#134765) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134765 Approved by: https://github.com/Skylion007 --- torch/_inductor/kernel/mm_scaled.py | 23 ++++++++++++++++++----- torch/_inductor/select_algorithm.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py index 2372420c1d58b6..2f0d020716a199 100644 --- a/torch/_inductor/kernel/mm_scaled.py +++ b/torch/_inductor/kernel/mm_scaled.py @@ -192,6 +192,18 @@ aten__fp8_mm = ExternKernelChoice(torch._scaled_mm, "at::_scaled_mm") +def are_compatible_scales(size_a: List[int], size_b: List[int]) -> bool: + # Same sized scales are compatable + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + def scaled_mm_options( # type: ignore[no-untyped-def] config, # triton.Config sym_m: sympy.core.numbers.Integer, @@ -207,10 +219,11 @@ def scaled_mm_options( # type: ignore[no-untyped-def] sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] ) - assert len(scale_a.get_size()) == len( - scale_b.get_size() - ), "Expect inverse scale_a and scale_b to be both scalars (tensor-wise scaling) or tensors (rowwise scaling)." - + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) return dict( GROUP_M=8, EVEN_K=even_k_symbolic, @@ -220,7 +233,7 @@ def scaled_mm_options( # type: ignore[no-untyped-def] num_stages=config.num_stages, num_warps=config.num_warps, # tensor-wise scaling if scalar scales - SCALING_ROWWISE=len(scale_a.get_size()) != 0, + SCALING_ROWWISE=len(scale_a.get_size()) == 2, **config.kwargs, ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 5b585753bd12e2..67d9651a1e39fa 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -203,7 +203,7 @@ def estimate_kernel_num_bytes(self): num_bytes = [] for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): size = V.graph.sizevars.size_hints(inp.get_size()) - numel = functools.reduce(operator.mul, size) + numel = functools.reduce(operator.mul, size, 1) dtype_size = get_dtype_size(inp.get_dtype()) num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) return sum(num_bytes) From 1fa4167d8accd544479f5bd11c66f28ab4b6fb27 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 30 Aug 2024 00:47:02 +0800 Subject: [PATCH 0245/1018] [dynamo] simplify polyfill registration for `builtins.all` and `builtins.any` (#133769) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133769 Approved by: https://github.com/jansel --- torch/_dynamo/decorators.py | 6 ++++++ torch/_dynamo/polyfills/__init__.py | 14 -------------- torch/_dynamo/polyfills/builtins.py | 26 +++++++++++++++++++++++++- torch/_dynamo/polyfills/os.py | 2 +- torch/_dynamo/trace_rules.py | 2 ++ torch/_dynamo/variables/builtin.py | 18 +----------------- torch/_dynamo/variables/functions.py | 10 ++++++++++ torch/compiler/__init__.py | 6 ++++++ 8 files changed, 51 insertions(+), 33 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index c83d8d718a62a4..d7c4cbdcb507ea 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -168,6 +168,7 @@ def forbid_in_graph(fn): def substitute_in_graph( original_fn: _F, *, + can_constant_fold_through: bool = False, skip_signature_check: bool = False, ) -> Callable[[_F], _F]: """ @@ -187,6 +188,10 @@ def substitute_in_graph( Args: original_fn (callable): The original function, usually a C function, to register a polyfill handler for. + can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant + folded through. That is, if the polyfill handler is a pure function and its arguments + are constant, the result of the polyfill handler can be constant folded during the + compilation. Defaults to ``False``. skip_signature_check (bool, optional): Whether to skip the signature check between the original function and the polyfill handler. Defaults to ``False``. @@ -319,6 +324,7 @@ def dispatch_fn(self, value): wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined] wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined] + wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined] return wrapped # type: ignore[return-value] diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index acc4f613c8e426..9b465726754bbe 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -13,20 +13,6 @@ import torch -def all(iterator): - for elem in iterator: - if not elem: - return False - return True - - -def any(iterator): - for elem in iterator: - if elem: - return True - return False - - def index(iterator, item, start=0, end=None): for i, elem in islice(enumerate(iterator), start, end): if item == elem: diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index 04757a19afea46..d2f563ca6b883f 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -2,5 +2,29 @@ Python polyfills for builtins """ +import builtins +from typing import Iterable -__all__ = [] # type: ignore[var-annotated] +from ..decorators import substitute_in_graph + + +__all__ = [ + "all", + "any", +] + + +@substitute_in_graph(builtins.all, can_constant_fold_through=True) +def all(iterable: Iterable[object], /) -> bool: + for elem in iterable: + if not elem: + return False + return True + + +@substitute_in_graph(builtins.any, can_constant_fold_through=True) +def any(iterable: Iterable[object], /) -> bool: + for elem in iterable: + if elem: + return True + return False diff --git a/torch/_dynamo/polyfills/os.py b/torch/_dynamo/polyfills/os.py index 013cc81ac30e6f..5388816b826742 100644 --- a/torch/_dynamo/polyfills/os.py +++ b/torch/_dynamo/polyfills/os.py @@ -14,7 +14,7 @@ # Copied from os.py in the standard library -@substitute_in_graph(os.fspath) +@substitute_in_graph(os.fspath, can_constant_fold_through=True) def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: if isinstance(path, (str, bytes)): return path diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index f4ad6ff01e848c..966a54e672ee39 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2980,6 +2980,7 @@ def _disallowed_callable_ids() -> Dict[int, str]: @FunctionIdSet def _builtin_function_ids() -> Dict[int, str]: + # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids rv = { id(v): f"builtins.{k}" for k, v in builtins.__dict__.items() @@ -3072,6 +3073,7 @@ def is_forbidden(obj) -> bool: def is_builtin_callable(obj) -> bool: + # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids return id(obj) in _builtin_function_ids diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 84a2f6f05e4b1e..86900dee30bc56 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -16,7 +16,7 @@ from torch import sym_float, sym_int from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from .. import config, polyfills, variables +from .. import config, variables from ..exc import ( AttributeMutationError, unimplemented, @@ -94,19 +94,6 @@ } -def _polyfill_call_impl(name): - """Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}""" - - def call_fn(self, tx: "InstructionTranslator", *args, **kwargs): - return tx.inline_user_function_return( - variables.UserFunctionVariable(fn), args, kwargs - ) - - fn = getattr(polyfills, name) - call_fn.__name__ = f"call_{name}" - return call_fn - - class BuiltinVariable(VariableTracker): _SENTINEL = object() _nonvar_fields = { @@ -2124,9 +2111,6 @@ def call_contains( ): return a.call_method(tx, "__contains__", [b], {}) - call_all = _polyfill_call_impl("all") - call_any = _polyfill_call_impl("any") - @contextlib.contextmanager def dynamo_disable_grad(tx): diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 694d0247a00158..a9ec80cc3de589 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -16,6 +16,7 @@ from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource from ..utils import ( check_constant_args, + check_unspec_or_constant_args, identity, is_wrapper_or_member_descriptor, istype, @@ -965,6 +966,15 @@ def call_function( handler = self._get_polyfill_handlers().get(self.fn) if handler: assert callable(handler) + if getattr( + handler, "__torch_dynamo_can_constant_fold_through__", False + ) and check_unspec_or_constant_args(args, kwargs): + return ConstantVariable.create( + self.fn( # use the original function which is faster than the polyfill + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + ) return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs) for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"): diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 406b20dc779550..7da8e911b83b22 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -123,6 +123,7 @@ def fn(a): def substitute_in_graph( original_fn: _F, *, + can_constant_fold_through: bool = False, skip_signature_check: bool = False, ) -> Callable[[_F], _F]: """ @@ -142,6 +143,10 @@ def substitute_in_graph( Args: original_fn (callable): The original function, usually a C function, to register a polyfill handler for. + can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant + folded through. That is, if the polyfill handler is a pure function and its arguments + are constant, the result of the polyfill handler can be constant folded during the + compilation. Defaults to ``False``. skip_signature_check (bool, optional): Whether to skip the signature check between the original function and the polyfill handler. Defaults to ``False``. @@ -173,6 +178,7 @@ def substitute_in_graph( return torch._dynamo.substitute_in_graph( original_fn, + can_constant_fold_through=can_constant_fold_through, skip_signature_check=skip_signature_check, ) From 5600e8f689c7e68eb1856b73d42365bc8ff67b57 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 30 Aug 2024 00:47:03 +0800 Subject: [PATCH 0246/1018] [dynamo] simplify implementation for `functools.reduce` (#133778) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133778 Approved by: https://github.com/jansel, https://github.com/anijain2305 ghstack dependencies: #133769 --- torch/_dynamo/polyfills/functools.py | 42 +++++++++++++++++++++++++++- torch/_dynamo/trace_rules.py | 1 - torch/_dynamo/variables/builtin.py | 18 +++--------- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/torch/_dynamo/polyfills/functools.py b/torch/_dynamo/polyfills/functools.py index 98fd9597773b35..329725b41e354a 100644 --- a/torch/_dynamo/polyfills/functools.py +++ b/torch/_dynamo/polyfills/functools.py @@ -2,5 +2,45 @@ Python polyfills for functools """ +import functools +from typing import Callable, Iterable, TypeVar -__all__ = [] # type: ignore[var-annotated] +from ..decorators import substitute_in_graph + + +__all__ = ["reduce"] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class _INITIAL_MISSING: + pass + + +# Reference: https://docs.python.org/3/library/functools.html#functools.reduce +@substitute_in_graph(functools.reduce) +def reduce( + function: Callable[[_U, _T], _U], + iterable: Iterable[_T], + initial: _U = _INITIAL_MISSING, # type: ignore[assignment] + /, +) -> _U: + it = iter(iterable) + + value: _U + if initial is _INITIAL_MISSING: + try: + value = next(it) # type: ignore[assignment] + except StopIteration: + raise TypeError( + "reduce() of empty iterable with no initial value", + ) from None + else: + value = initial + + for element in it: + value = function(value, element) + + return value diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 966a54e672ee39..e4f2e3786892b8 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2999,7 +2999,6 @@ def _builtin_function_ids() -> Dict[int, str]: rv.update( { id(cast): "typing.cast", - id(functools.reduce): "functools.reduce", id(copy.deepcopy): "copy.deepcopy", } ) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 86900dee30bc56..b4fbbbc794df1d 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1606,7 +1606,10 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): if start is self._SENTINEL: start = variables.ConstantVariable.create(0) items = seq.unpack_var_sequence(tx) - return BuiltinVariable(functools.reduce).call_function( + + from .builder import SourcelessBuilder + + return SourcelessBuilder.create(tx, functools.reduce).call_function( tx, [ BuiltinVariable(operator.add), @@ -1616,19 +1619,6 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): {}, ) - def call_reduce( - self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL - ): - if iterable.has_unpack_var_sequence(tx): - items = iterable.unpack_var_sequence(tx) - if initial is self._SENTINEL: - value, items = items[0], items[1:] - else: - value = initial - for element in items: - value = function.call_function(tx, [value, element], {}) - return value - def call_getattr( self, tx: "InstructionTranslator", From 9e44e665347aed2f30de37b3ce597076802cb8ed Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 30 Aug 2024 00:47:03 +0800 Subject: [PATCH 0247/1018] [dynamo] simplify implementation for `builtins.sum` (#133779) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133779 Approved by: https://github.com/jansel, https://github.com/anijain2305 ghstack dependencies: #133769, #133778 --- torch/_dynamo/polyfills/builtins.py | 13 ++++++++++- torch/_dynamo/variables/builtin.py | 34 ----------------------------- 2 files changed, 12 insertions(+), 35 deletions(-) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index d2f563ca6b883f..d5ce75c562bfb6 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -3,7 +3,9 @@ """ import builtins -from typing import Iterable +import functools +import operator +from typing import Iterable, TypeVar from ..decorators import substitute_in_graph @@ -11,9 +13,13 @@ __all__ = [ "all", "any", + "sum", ] +_T = TypeVar("_T") + + @substitute_in_graph(builtins.all, can_constant_fold_through=True) def all(iterable: Iterable[object], /) -> bool: for elem in iterable: @@ -28,3 +34,8 @@ def any(iterable: Iterable[object], /) -> bool: if elem: return True return False + + +@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] +def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] + return functools.reduce(operator.add, iterable, start) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index b4fbbbc794df1d..ca972b282e387b 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1585,40 +1585,6 @@ def call_filter(self, tx: "InstructionTranslator", fn, seq): except NotImplementedError: return - def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): - # Special case for sum on tuple of floats and ints - if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( - isinstance(x, variables.ConstantVariable) - and isinstance(x.value, (int, float)) - for x in seq.items - ): - if start is self._SENTINEL: - return variables.ConstantVariable.create( - sum(x.value for x in seq.items), - ) - if isinstance(start, variables.ConstantVariable) and isinstance( - start.value, (int, float) - ): - return variables.ConstantVariable.create( - sum((x.value for x in seq.items), start=start.value), - ) - if seq.has_unpack_var_sequence(tx): - if start is self._SENTINEL: - start = variables.ConstantVariable.create(0) - items = seq.unpack_var_sequence(tx) - - from .builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, functools.reduce).call_function( - tx, - [ - BuiltinVariable(operator.add), - variables.TupleVariable(items), - start, - ], - {}, - ) - def call_getattr( self, tx: "InstructionTranslator", From aaed91b6d42878258e570f4d0631f2acb55c0b81 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 30 Aug 2024 00:47:04 +0800 Subject: [PATCH 0248/1018] [dynamo][itertools] refactor `itertools.chain` and `itertools.chain.from_iterable` to use polyfills (#133864) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133864 Approved by: https://github.com/jansel ghstack dependencies: #133769, #133778, #133779 --- torch/_dynamo/decorators.py | 16 +++++++++++++++- torch/_dynamo/polyfills/itertools.py | 21 ++++++++++++++++++++- torch/_dynamo/polyfills/loader.py | 6 ++++++ torch/_dynamo/trace_rules.py | 8 ++------ torch/_dynamo/variables/builder.py | 28 ++++++++++++++++++++++++++-- torch/_dynamo/variables/builtin.py | 16 ---------------- torch/_dynamo/variables/functions.py | 22 ++++++++++++++++++++++ torch/_dynamo/variables/iter.py | 8 -------- 8 files changed, 91 insertions(+), 34 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index d7c4cbdcb507ea..4338e3047b8902 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -170,6 +170,8 @@ def substitute_in_graph( *, can_constant_fold_through: bool = False, skip_signature_check: bool = False, + # type that is embedded in the Python interpreter + is_embedded_type: bool = False, # internal use only ) -> Callable[[_F], _F]: """ Register a polyfill handler for a function, usually a C function from the C extension, to be @@ -219,10 +221,22 @@ def substitute_in_graph( >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) 2 """ - if not is_function(original_fn): + if not is_function(original_fn) and not ( + is_embedded_type and inspect.isclass(original_fn) + ): raise TypeError( f"substitute_in_graph expects a function but got {type(original_fn)!r}" ) + if is_embedded_type: + if not inspect.isclass(original_fn): + raise TypeError( + f"substitute_in_graph expects a class but got {type(original_fn)!r}" + ) + + from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS + + if id(original_fn) in ITERTOOLS_TYPE_IDS: + ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) def wrapper(traceable_fn: _F) -> _F: if not is_function(traceable_fn): diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 090df0f84b5281..802a62a82c8cdd 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -10,12 +10,31 @@ from ..decorators import substitute_in_graph -__all__ = ["tee"] +__all__ = [ + "chain", + "chain_from_iterable", + "tee", +] _T = TypeVar("_T") +# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain +@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type] +def chain(*iterables: Iterable[_T]) -> Iterator[_T]: + for iterable in iterables: + yield from iterable + + +@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] +def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: + return itertools.chain(*iterable) + + +chain.from_iterable = chain_from_iterable # type: ignore[method-assign] + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 255486b6e016dd..d9367b55bf1400 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -32,3 +32,9 @@ polyfill_handler = getattr(polyfill_module, polyfill_name) original_fn = polyfill_handler.__torch_dynamo_original__ trace_rules._builtin_function_ids.remove(id(original_fn)) + + # Unregister the class object if the original function is its __new__ method + if original_fn.__name__ == "__new__" and isinstance( + getattr(original_fn, "__self__", None), type + ): + trace_rules._builtin_function_ids.remove(id(original_fn.__self__)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index e4f2e3786892b8..464b411a8b9507 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2993,9 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update( - {id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)} - ) + rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) rv.update( { id(cast): "typing.cast", @@ -3471,9 +3469,7 @@ def check_verbose(obj, is_inlined_call=False): # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: Set[str] = set() - rule = torch._dynamo.trace_rules.lookup_inner( - fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons - ) + rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): return SkipResult( False, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index b4f35dbbd404b8..e0d355a7fb2b34 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -17,7 +17,17 @@ import types import warnings import weakref -from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + FrozenSet, + List, + MutableMapping, + NamedTuple, + Optional, + Set, + TYPE_CHECKING, + Union, +) import torch from torch import SymInt @@ -319,6 +329,17 @@ class FrameStateSizeEntry: stride: Optional[List[int]] +# All class-based iterators in itertools +# NOTE: use id() because some objects are not hashable, it will raise error during lookup +ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset( + id(member) + for name, member in vars(itertools).items() + if not name.startswith("_") and inspect.isclass(member) +) +# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py +ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set() + + class VariableBuilder: """Wrap a python value in a VariableTracker() instance""" @@ -874,7 +895,10 @@ def build_key_value(i, k, v): value, source=self.source, ) - elif istype(value, type) and value in itertools.__dict__.values(): + elif ( + id(value) in ITERTOOLS_TYPE_IDS + and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS + ): self.install_guards(GuardBuilder.FUNCTION_MATCH) return ItertoolsVariable(value, source=self.source) elif isinstance(value, torch.SymBool): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index ca972b282e387b..f01b4b6efbe596 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -994,15 +994,6 @@ def call_method( ) if self.fn is dict and name == "fromkeys": return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) - if self.fn is itertools.chain and name == "from_iterable": - assert len(args) == 1 - assert len(kwargs) == 0 - obj = args[0] - items = [] - for item in obj.unpack_var_sequence(tx): - items.extend(item.unpack_var_sequence(tx)) - return variables.TupleVariable(items) - return super().call_method(tx, name, args, kwargs) def _call_int_float(self, tx: "InstructionTranslator", arg): @@ -1898,13 +1889,6 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg ) return variables.ListVariable(items) - def call_chain(self, tx: "InstructionTranslator", *args): - if all(obj.has_unpack_var_sequence(tx) for obj in args): - items = [] - for obj in args: - items.extend(obj.unpack_var_sequence(tx)) - return variables.TupleVariable(items) - def call_islice(self, tx: "InstructionTranslator", iterable, *args): if iterable.has_unpack_var_sequence(tx) and all( x.is_python_constant() for x in args diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index a9ec80cc3de589..239b720a5a7530 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -18,6 +18,7 @@ check_constant_args, check_unspec_or_constant_args, identity, + is_function, is_wrapper_or_member_descriptor, istype, make_cell, @@ -992,6 +993,27 @@ def call_function( handler, ).call_function(tx, args, kwargs) + return super().call_function(tx, args, kwargs) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__call__": + return self.call_function(tx, args, kwargs) + + method = getattr(self.fn, name, None) + assert method is not None, f"Member {name} not found in {self.fn}" + assert is_function(method), f"Member {name} is not callable in {self.fn}" + options = {} + if self.source: + options["source"] = AttrSource(self.source, name) + member_variable = PolyfilledFunctionVariable(method, **options) + return member_variable.call_function(tx, args, kwargs) + def as_python_constant(self): return self.fn diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 6687611cf0aabd..ed3ae786634e92 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -52,14 +52,6 @@ def call_function( for item in itertools.product(*seqs): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) - elif ( - self.value is itertools.chain - and not kwargs - and all(arg.has_unpack_var_sequence(tx) for arg in args) - ): - seqs = [arg.unpack_var_sequence(tx) for arg in args] - items = list(itertools.chain.from_iterable(seqs)) - return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif self.value is itertools.accumulate: from .builtin import BuiltinVariable From fcd99e6c96b22b56e4645875c1f559da2ee57b44 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 30 Aug 2024 00:47:05 +0800 Subject: [PATCH 0249/1018] [dynamo] refactor `builtins.enumerate` to use polyfill (#133894) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133894 Approved by: https://github.com/jansel ghstack dependencies: #133769, #133778, #133779, #133864 --- torch/_dynamo/polyfills/builtins.py | 15 +++++++++++++++ torch/_dynamo/variables/builtin.py | 16 ---------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index d5ce75c562bfb6..62305086b804fe 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -2,6 +2,8 @@ Python polyfills for builtins """ +from __future__ import annotations + import builtins import functools import operator @@ -13,6 +15,7 @@ __all__ = [ "all", "any", + "enumerate", "sum", ] @@ -36,6 +39,18 @@ def any(iterable: Iterable[object], /) -> bool: return False +@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type] +def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]: + if not isinstance(start, int): + raise TypeError( + f"{type(start).__name__!r} object cannot be interpreted as an integer" + ) + + for x in iterable: + yield start, x + start += 1 + + @substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] return functools.reduce(operator.add, iterable, start) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index f01b4b6efbe596..279981465b2713 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1442,22 +1442,6 @@ def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] return variables.TupleVariable(items) - def call_enumerate(self, tx: "InstructionTranslator", *args): - if len(args) == 1: - start = 0 - else: - assert len(args) == 2 - assert isinstance(args[1], variables.ConstantVariable) - start = args[1].as_python_constant() - if args[0].has_unpack_var_sequence(tx): - items = [ - variables.TupleVariable( - [variables.ConstantVariable.create(idx), var], - ) - for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) - ] - return variables.TupleVariable(items) - def call_len(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) From 31df249fba03d303eadf760a6c54da553cbe3feb Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 30 Aug 2024 00:47:06 +0800 Subject: [PATCH 0250/1018] [dynamo][itertools] refactor `itertools.islice` to use polyfill (#133876) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133876 Approved by: https://github.com/jansel ghstack dependencies: #133769, #133778, #133779, #133864, #133894 --- torch/_dynamo/polyfills/__init__.py | 27 +++---------------------- torch/_dynamo/polyfills/itertools.py | 30 ++++++++++++++++++++++++++++ torch/_dynamo/trace_rules.py | 2 -- torch/_dynamo/variables/builtin.py | 9 --------- torch/_dynamo/variables/iter.py | 6 ------ 5 files changed, 33 insertions(+), 41 deletions(-) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 9b465726754bbe..69145ce86d8c0d 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -14,36 +14,15 @@ def index(iterator, item, start=0, end=None): - for i, elem in islice(enumerate(iterator), start, end): + import itertools + + for i, elem in itertools.islice(enumerate(iterator), start, end): if item == elem: return i # This will not run in dynamo raise ValueError(f"{item} is not in {type(iterator)}") -def islice(iterator, start=0, end=None, step=1): - if start < 0 or (end is not None and end < 0) or step < 0: - raise ValueError("Indices must be non-negative") - if step == 0: - raise ValueError("Step cannot be 0") - - it = iter(iterator) - - for _ in range(start): - next(it) - - if end is None: - for i, element in enumerate(it): - if i % step == 0: - yield element - else: - for i, element in enumerate(it): - if i % step == 0 and i + start < end - start: - yield element - elif i + start >= end - start: - break - - def repeat(item, count): for i in range(count): yield item diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 802a62a82c8cdd..a829d2e9b88081 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -13,6 +13,7 @@ __all__ = [ "chain", "chain_from_iterable", + "islice", "tee", ] @@ -35,6 +36,35 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: chain.from_iterable = chain_from_iterable # type: ignore[method-assign] +# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice +@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] +def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: + s = slice(*args) + start = 0 if s.start is None else s.start + stop = s.stop + step = 1 if s.step is None else s.step + if start < 0 or (stop is not None and stop < 0) or step <= 0: + raise ValueError( + "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", + ) + + if stop is None: + # TODO: use indices = itertools.count() and merge implementation with the else branch + # when we support infinite iterators + next_i = start + for i, element in enumerate(iterable): + if i == next_i: + yield element + next_i += step + else: + indices = range(max(start, stop)) + next_i = start + for i, element in zip(indices, iterable): + if i == next_i: + yield element + next_i += step + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 464b411a8b9507..e8692fe47891a3 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -12,7 +12,6 @@ import functools import importlib import inspect -import itertools import linecache import logging import multiprocessing @@ -2993,7 +2992,6 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) rv.update( { id(cast): "typing.cast", diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 279981465b2713..919daf6fbd2b59 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1873,15 +1873,6 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg ) return variables.ListVariable(items) - def call_islice(self, tx: "InstructionTranslator", iterable, *args): - if iterable.has_unpack_var_sequence(tx) and all( - x.is_python_constant() for x in args - ): - const_args = [x.as_python_constant() for x in args] - items = iterable.unpack_var_sequence(tx) - items = list(itertools.islice(items, *const_args)) - return variables.TupleVariable(items) - # neg is a constant fold function, so we only get here if constant fold is not valid def call_neg(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index ed3ae786634e92..34c2354026a7bd 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -166,12 +166,6 @@ def retrieve_const_key(key): from_exc=e, ) return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) - elif self.value is itertools.islice: - from .builder import SourcelessBuilder - - return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.islice), args, kwargs - ) elif self.value is itertools.repeat: if len(args) < 2: return variables.RepeatIteratorVariable( From 188d37e9ceba0e30e286e3f6ed78ec604ba65db7 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 29 Aug 2024 20:58:54 +0000 Subject: [PATCH 0251/1018] [Windows][XPU] Disable Kineto PTI on Windows only (#134620) Disable Kineto + XPU PTI on Windows only. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134620 Approved by: https://github.com/guangyey, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- cmake/Dependencies.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 71bfcc40da9a4d..3e59b813d31381 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1581,7 +1581,7 @@ if(USE_KINETO) message(STATUS "Using Kineto with Roctracer support") endif() - if(NOT USE_XPU) + if((NOT USE_XPU) OR WIN32) set(LIBKINETO_NOXPUPTI ON CACHE STRING "" FORCE) else() set(LIBKINETO_NOXPUPTI OFF CACHE STRING "") From 7a51f491545dbe09b8d6bd7c50b6d2178ca1f37c Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Thu, 29 Aug 2024 11:31:04 -0700 Subject: [PATCH 0252/1018] [inductor] move loop ordering after fusion (#126254) Restart the work from PR https://github.com/pytorch/pytorch/pull/100331 in this new PR since it's hard to rebase. It would be expected that some code is copy/pasted from the previous PR and main idea is the same. Previously we see relatively large compilation time increase due to too many loop orders being considered. This PR tries to continue the work by doing pruning and only considering loop orders that we know for sure are relevant (i.e. do it on demand). Some manually created cases that loop ordering matters are added as unit tests. The PR can make sure inductor does not miss fusion opportunities for them. This PR should solve the not-able to fusion problem in https://github.com/pytorch/pytorch/issues/130015 Right now there is still significant increase of compilation time. I'll disable the feature by default. Later on after the compilation time issue is resolved, I'll enable it by default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126254 Approved by: https://github.com/jansel --- test/dynamo/test_logging.py | 1 + test/inductor/test_benchmark_fusion.py | 9 + test/inductor/test_loop_ordering.py | 343 +++++++++++++++++- torch/_dynamo/utils.py | 20 + .../_inductor/codegen/cpp_template_kernel.py | 8 +- torch/_inductor/config.py | 3 + torch/_inductor/dependencies.py | 151 ++++++-- torch/_inductor/ir.py | 170 ++++++++- torch/_inductor/metrics.py | 4 + torch/_inductor/scheduler.py | 251 ++++++++++++- torch/_logging/_registrations.py | 5 + 11 files changed, 905 insertions(+), 60 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 8e16443abb9153..b71f8b7fce95bf 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -741,6 +741,7 @@ def fn(a): "trace_shape_events", "cudagraph_static_inputs", "benchmarking", + "loop_ordering", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index f6d7daba75d2c9..9cd98735f7c10e 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -5,6 +5,7 @@ import torch from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.test_operators import realize from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN @@ -166,6 +167,14 @@ def foo(m, inp): "empty_strided_cuda", 2, exactly=True ).check("return").run(c) + def test_tield_kernel_fusion(self): + def f(x): + y = realize(x + x.t()) + return y + 1 + + x = torch.randn(1024, 1024, device=self.device) + self.common(f, (x,)) + if HAS_CUDA and not TEST_WITH_ASAN: diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 0e90b4a4c708d7..35f0cc55c8d6d1 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -1,29 +1,203 @@ # Owner(s): ["module: inductor"] +import contextlib +import unittest + +import numpy as np + import torch +from torch import nn from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same -from torch._inductor import config as inductor_config, metrics +from torch._inductor import config as inductor_config, ir, metrics +from torch._inductor.codegen.triton import TritonScheduling +from torch._inductor.graph import GraphLowering +from torch._inductor.scheduler import SchedulerNode from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.test_operators import realize +from torch._inductor.utils import sympy_index_symbol +from torch._inductor.virtualized import ops, V +from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.utils._pytree import tree_map +from torch.utils._sympy.functions import ModularIndexing if HAS_CUDA: torch.set_default_device("cuda") +class MockScheduler: + available_buffer_names = () + + @staticmethod + def get_backend(cls, *args): + return TritonScheduling(cls) + + +@inductor_config.patch(loop_ordering_after_fusion=True) +class ImplDetailTest(TestCase): + _exit_stack = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + + gm = torch.fx.symbolic_trace(lambda: 0) + graph = GraphLowering(gm) + graph.scheduler = MockScheduler + cls._exit_stack = contextlib.ExitStack() + cls._exit_stack.enter_context(V.set_graph_handler(graph)) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._exit_stack.close() + + @staticmethod + def _get_snode_body_sym_prefix(snode): + body = snode._body + prefix = "" + + for var in body.var_ranges: + prefix = str(var)[0] + break + + assert prefix + return prefix + + @staticmethod + def _create_computed_buffer_ax2(sizes=(32, 64), strides=None): + """ + Create a ComputedBuffer for 'a x 2' + """ + if strides is None: + strides = ir.FlexibleLayout.contiguous_strides(sizes) + + box_a = ir.TensorBox.create( + ir.Buffer( + "a", ir.FixedLayout(torch.device("cuda"), torch.float32, sizes, strides) + ) + ) + box_a_loader = box_a.make_loader() + + def inner_fn(index): + return box_a_loader(index) * 2 + + buf = ir.Pointwise.create( + device=box_a.get_device(), + dtype=box_a.get_dtype(), + inner_fn=inner_fn, + ranges=box_a.get_size(), + ) + buf.realize() + computed_buf = buf.data.data + computed_buf.decide_layout() + return computed_buf + + def test_reorder_twice(self): + """ + This may happen in practice if we pick a order when fusing A and B. + Then we pick another order for AB when we fusion C into it. + + E.g. happens for BertForMaskedLM. + """ + + buf = self._create_computed_buffer_ax2() + snode = SchedulerNode(V.graph.scheduler, buf) + snode.apply_new_loop_order([1, 0]) + prefix1 = self._get_snode_body_sym_prefix(snode) + self.assertTrue(prefix1 == "z") + snode.apply_new_loop_order([1, 0]) + prefix2 = self._get_snode_body_sym_prefix(snode) + self.assertTrue(prefix2 == "z") + + def test_reorder_and_merge_loops(self): + sizes = (1024, 2048) + strides = (1, 1024) + buf = self._create_computed_buffer_ax2(sizes, strides) + old_sizes, old_body = buf.simplify_and_reorder() + + # Make sure loop reordering happens here + self.assertTrue(tuple(old_sizes[0]) == tuple(reversed(sizes)), f"{old_sizes=}") + new_body = old_body.merge_loops() + new_sizes = new_body.sizes + self.assertTrue(tuple(new_sizes[0]) == (np.prod(sizes),), f"{new_sizes=}") + + def test_reorder_modular_indexing(self): + """ + There was a bug that we wrongly map i0 to the dimension with size 49 + when reordering the loop and cause ModularIndexing get optimized away + as an no-op. + """ + + def _create_computed_buffer(): + def inner_fn(index): + i0, i1, i2, i3 = index + return ops.load( + "primal", i3 + 49 * i2 + 2401 * ModularIndexing(i0, 1, 64) + ) + + buf = ir.Pointwise.create( + device=torch.device("cuda"), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=[128, 4, 49, 49], + ) + buf.realize() + cbuf = buf.data.data + cbuf.decide_layout() + return cbuf + + buf = _create_computed_buffer() + _, body = buf.simplify_and_reorder() + new_body = body.reorder_iter_loops([1, 2, 3, 0]) + + z0, z1, z2, z3 = (sympy_index_symbol(f"z{i}") for i in range(4)) + self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49}) + self.assertEqual( + body.indexing_exprs["index0"], + z3 + 49 * z2 + 2401 * ModularIndexing(z0, 1, 64), + ) + self.assertEqual(new_body.var_ranges, {z0: 4, z1: 49, z2: 49, z3: 128}) + self.assertEqual( + new_body.indexing_exprs["index0"], + z2 + 49 * z1 + 2401 * ModularIndexing(z3, 1, 64), + ) + + @inductor_config.patch( { "benchmark_kernel": True, + "loop_ordering_after_fusion": True, "triton.unique_kernel_names": True, } ) class LoopOrderingTest(TestCase): - def do_acc_test(self, f, *args): + def do_acc_test(self, f, *args, cast_fp8=True): expect = f(*args) actual = torch.compile(f)(*args) + + if cast_fp8: + + def _cast(x): + if isinstance(x, torch.Tensor) and x.dtype in ( + torch.float8_e5m2, + torch.float8_e4m3fn, + ): + return x.to(torch.float32) + return x + + # Wordaround the issue that call allclose on fp8 tensor triggers error + # RuntimeError: "mul_cuda" not implemented for 'Float8_e4m3fn' + expect = tree_map(_cast, expect) + actual = tree_map(_cast, actual) self.assertTrue(same(expect, actual, tol=1e-3)) + def setUp(self): + super().setUp() + metrics.reset() + def test_for_reordering_reindex(self): """ ComputedBuffer.iter_reoredering_reindex can cause some fusion @@ -54,6 +228,171 @@ def f(x, y): expected_num_bytes *= x.itemsize self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed) + def test_apbt_realize(self): + M = 1024 + N = 2048 + + def f(x, y): + """ + There will be 2 kernels being generated without loop ordering after fusion: + https://gist.github.com/shunting314/44df83f71de2c110232c50ac6638ed69 + """ + x = realize(x * 2) + y = realize(y * 3) + return x + y + + x = torch.randn(M, N) + y = torch.randn(N, M).t() + + self.do_acc_test(f, x, y) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_sum_and_t(self): + N = 1024 + + def f(x): + return x.sum(dim=-1), x.t().contiguous() + + x = torch.randn(N, N * 2) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_pw_outer_red(self): + def f(x): + x = realize(x + 1) + return x.sum(dim=[0, 1]) + + # make the first 2 dimension small so we don't split the reduction + x = torch.randn(2, 4, 512) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_pw_outer_red_2(self): + """ + The pointwise kernel is a fused kernel + """ + + def f(x): + x = realize(x + 1) + x = realize(x - 2) + x = realize(x * 3) + return x.sum(dim=[0, 1]) + + # make the first 2 dimension small so we don't split the reduction + x = torch.randn(2, 4, 512) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + @inductor_config.patch(split_reductions=False) + def test_different_reduction_order(self): + """ + We should not reorder loops in this case. Since reordering loops does + not help! + """ + + def f(x): + return x.sum(dim=0), x.sum(dim=1) + + x = torch.randn(1024, 2048) + self.do_acc_test(f, x) + self.assertEqual(2, metrics.generated_kernel_count) + self.assertEqual(0, metrics.num_loop_reordering) + + def test_keep_fake_dep(self): + """ + In this model, there are fake dependencies (StarDep) between Scatter + and a following mutation kernel that computes the gradients of + the embedding tables. + + When we do loop reordering for the mutation kernel, we re-analyze + the node's dependencies. But the analysis result does not contains + those fake dependencies. Have to add them back manually. + """ + V = 2048 + hidden_size = 64 + max_seqlen = 512 + batch_size = 8 + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.word_embeddings = nn.Embedding(V, hidden_size) + self.position_embeddings = nn.Embedding(max_seqlen, hidden_size) + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, input_ids, labels, position_ids): + emb = self.word_embeddings(input_ids) + self.position_embeddings( + position_ids + ) + return self.layer_norm(emb) + + m = Model() + + @torch.compile + def f(*args): + m(*args).sum().backward() + + input_ids = torch.randint(0, V, (batch_size, max_seqlen)) + labels = torch.randint(0, V, (batch_size, max_seqlen)) + position_ids = torch.arange(max_seqlen)[None, :] + # Make sure this line does not raise exceptions. If we miss + # fake dependencies after loop reordering, we may get exception that + # some buffer is used before being defined. + f(input_ids, labels, position_ids) + + def test_different_broadcast_shapes(self): + def f(x, y, c): + return x + c, y + c + + x = torch.randn(4, 256, 1024) + y = torch.randn(2, 512, 1024) + c = torch.randn(1024) + self.do_acc_test(f, x, y, c) + + # The two kernels are not fused due to c is broadcasted + self.assertEqual(2, metrics.generated_kernel_count) + + def test_view(self): + """ + Passing this test relies that we compare normalized MemoryDep. + Normlaization here means merging contiguous loops. + + To make loop reordering work, we don't merge loops when creating + SchedulerNode. Thus we need explicitly normalize MemoryDep when + we check if two MemeoryDep matches. + """ + + def f(x): + y = x.sin() + x = realize(x.view(10, 10)) + return x, y + + x = torch.randn(100) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + @unittest.skipIf(not SM90OrLater, "FP8 requires H100+") + def test_fp8_cast_and_t(self): + """ + This test repros the not able to fuses issue in + https://github.com/pytorch/pytorch/issues/130015 + for fp8 cast and transpose + """ + + def f(x, scale): + x = x * scale + x = x.clamp(-1 * E4M3_MAX_POS, E4M3_MAX_POS) + x = x.to(torch.float8_e4m3fn) + x_t = x.t().contiguous().t() + return x, x_t + + x = torch.randn(4096, 4096, dtype=torch.bfloat16) + scale = torch.Tensor([10.0]).cuda() + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + + self.do_acc_test(f, x, scale) + self.assertEqual(1, metrics.generated_kernel_count) + if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ae78f22374665a..fa4103d9c3211c 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1728,6 +1728,26 @@ def to_tensor(t): # Check error from fp64 version if fp64_ref.dtype == torch.float64: + # Fix a corner case that res and fp64_ref does not contains NaN and match (with loose tolerance) + # while the ref contains NaN. In this case, RMSE should not match any ways. + # But res is 'BETTER' than ref so we count it pass. + # + # This happens for Super_SloMo when loop ordering after fusion is enabled: + # https://gist.github.com/shunting314/11f235c70f7db0d52718d26f4a701cab + loose_tol = 1e-2 * 4 + if ( + not fp64_ref.isnan().any() + and not res.isnan().any() + and ref.isnan().any() + and torch.allclose( + fp64_ref.to(dtype=res.dtype), + res, + atol=loose_tol, + rtol=loose_tol, + equal_nan=equal_nan, + ) + ): + return True ref_error = rmse(fp64_ref, ref).item() # ref unable to produce this with stable numerics in this precision, ignore if math.isnan(ref_error): diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 4db3e7b699ec63..e9993e22cf63fe 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -239,7 +239,13 @@ def fn(*args): node.make_loader()(new_args).value, ) - body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) + body = ir.LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) bodies.append(body) var_sizes_list.append(var_sizes) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 443bb7cacc0775..ddf156df300a40 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -411,6 +411,9 @@ def use_autoheuristic(name: str) -> bool: debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") +loop_ordering_after_fusion = ( + os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1" +) # For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel benchmark_epilogue_fusion = ( diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 5fab371d4406de..851ff2d648fe35 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -6,7 +6,7 @@ import logging import re import typing -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from unittest.mock import patch import sympy @@ -55,6 +55,9 @@ def has_unbacked_symbols(self) -> bool: def is_contiguous(self) -> bool: pass + def normalize_with_stride_order(self, prefix="t"): + return self + @dataclasses.dataclass(frozen=True) class MemoryDep(Dep): @@ -67,12 +70,87 @@ class MemoryDep(Dep): def __repr__(self) -> str: return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}, {self.mode})" + @property + def num_vars(self): + return len(self.var_names) + + def decide_loop_order_to_match(self, other): + """ + Can return None if not able to decide loop orders. + """ + assert self.num_vars == other.num_vars + + # ignore broadcast for now since broadcast causes extra 0 strides + # which makes it hard to decide the correct loop orders. + if self.num_vars != len(self.index.free_symbols): + return None + if other.num_vars != len(other.index.free_symbols): + return None + + # bail out if any size is 0 or 1 + # For size == 0, it's an empty tensor, any strides for that dimension + # are equivalent. Skip for simplicity and it may not matter that much. + # + # For size == 1, it cause cause tie for strides of different dimensions. + # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder + # we can dependencies.index_vars_squeeze which should already sqeeuze + # the size == 1 dimensions. + if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)): + return None + + # Extract strides for both expression + self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names) + + # Even if the shape contains no 0/1, some complex index expression may + # still have duplicate stride values. Here is an example: + # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129 + # We don't reorder the loop for these cases for now, but in theory + # we could improve the algorithm to detect the correct loop orders. + if len(set(self_strides)) != len(self_strides) or len( + set(other_strides) + ) != len(other_strides): + log.debug( + "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s", + self, + other, + self_strides, + other_strides, + ) + return None + + # May hanppen if self and other are as follows + # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None) + # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None) + if set(self_strides) != set(other_strides): + return None + + stride_to_index = {s: i for i, s in enumerate(self_strides)} + order = [] + for s in other_strides: + order.append(stride_to_index[s]) + + assert set(order) == set(range(0, self.num_vars)) + return order + def get_offset(self): """ Return the offset by setting every variable to be 0. """ return sympy_subs(self.index, dict.fromkeys(self.var_names, 0)) + def normalize(self) -> "MemoryDep": + """ + Normalize by merging loops. The different to normalize_with_stride_order is, + this method does not reorder loops while normalize_with_stride_order reorder + loops based on stride order. + """ + return MemoryDep( + self.name, + *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type] + self.mode, + ) + def normalize_with_stride_order(self, prefix="t"): r""" Used to decide if two MemoryDep does not equal due to different loop orders. @@ -292,10 +370,12 @@ def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": op_counts=self.op_counts, ) - def with_read(self, dep: Dep) -> "ReadWrites": - assert isinstance(dep, (WeakDep, StarDep)) + def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites": + assert isinstance(dep, (WeakDep, StarDep, set)) + if not isinstance(dep, set): + dep = {dep} return ReadWrites( - OrderedSet.union(self.reads, [dep]), + OrderedSet.union(self.reads, dep), self.writes, self.index_exprs, self.range_vars, @@ -358,31 +438,32 @@ def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: self._writes: OrderedSet[MemoryDep] = OrderedSet() self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet() self._var_ranges: VarRanges = var_ranges - self._normalize: bool = normalize + self._should_normalize: bool = normalize - def canonicalize( - self, index: sympy.Expr - ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: - if not self._normalize: - sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] - var_names = tuple( - k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1 - ) - sizes = tuple(v for v in sizes if v != 1) - return index, var_names, sizes # type: ignore[return-value] + @staticmethod + def drop_unused_symbols(index, var_names, sizes): + """ + Reduction has last (reduced) dim in its sizes, but + downstream users won't. Normalize this away. + """ + if not isinstance(index, sympy.Expr): + # index can be an int + return + free_symbols = index.free_symbols + while var_names and var_names[-1] not in free_symbols: + var_names.pop() + sizes.pop() + @classmethod + def _normalize( + cls, index: sympy.Expr, var_ranges: VarRanges + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: # Try to further simplify the indexes even if simplify_loops didn't # convert it to the simplest form because of the interference from # different indexing formulas. free_symbols = index.free_symbols - var_ranges = { - k: V.graph.sizevars.simplify(v) - for k, v in self._var_ranges.items() - # TODO(jansel): explore this further normalization - # if k in free_symbols - } index_vars = [*var_ranges.keys()] - sizes = tuple(var_ranges.values()) + sizes = tuple(var_ranges.values()) # type: ignore[assignment] new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( index_vars, sizes, @@ -397,14 +478,28 @@ def canonicalize( new_vars = [*new_vars.keys()] new_sizes = [*new_sizes] - free_symbols = index.free_symbols - while new_vars and new_vars[-1] not in free_symbols: - # Reduction has last (reduced) dim in its sizes, but - # downstream users won't. Normalize this away. - new_vars.pop() - new_sizes.pop() + cls.drop_unused_symbols(index, new_vars, new_sizes) return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] + def canonicalize( + self, index: sympy.Expr + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: + if not self._should_normalize: + sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] + var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1] + sizes = [v for v in sizes if v != 1] + + self.drop_unused_symbols(index, var_names, sizes) + + return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type] + var_ranges = { + k: V.graph.sizevars.simplify(v) + for k, v in self._var_ranges.items() + # TODO(jansel): explore this further normalization + # if k in free_symbols + } + return self._normalize(index, var_ranges) + def load(self, name: str, index: sympy.Expr) -> str: self._reads.add(MemoryDep(name, *self.canonicalize(index))) return f"load({name}, {sympy_str(index)})" diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0f4cc5d1b61f9d..55d43674807e0d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3719,6 +3719,7 @@ def get_default_sizes_body(self): self.get_store_function(), (args if self.get_reduction_type() else args[:1]), var_ranges, + *args, ) index_vars = [] reduce_vars: List[Any] = [] @@ -3786,36 +3787,55 @@ def simplify_and_reorder( if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): memory_addrs.extend(body.reads_name2expr.values()) - def simplify_and_reorder(x_vars, support_vars, sizes): + def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): sizes, reindex0, reindex1 = self._apply_loop_reordering( x_vars, support_vars, sizes, memory_addrs ) # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] x_vars = reindex0(x_vars) - sizes, reindex2, prune = V.graph.sizevars._simplify_loops( - x_vars, - sizes, - index_prevent_reordering(index_formulas, x_vars, sizes), - ) - reindex = fuse_reindexing(reindex1, reindex2) + + if simplify_loops: + sizes, reindex2, prune = V.graph.sizevars._simplify_loops( + x_vars, + sizes, + index_prevent_reordering(index_formulas, x_vars, sizes), + ) + reindex = fuse_reindexing(reindex1, reindex2) + else: + reindex = reindex1 return sizes, reindex, reindex1 support_vars = index_vars + reduce_vars + should_merge_loops = ( + self.get_device().type != "cuda" or not config.loop_ordering_after_fusion + ) iter_ranges, iter_reindex, _ = simplify_and_reorder( index_vars, support_vars, index_size, + should_merge_loops, ) + + # Like iteration dimensions, we may also want to delay merging reduction dimensions. + # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise + # kernel, merging M and N dimension too early makes it hard to decide what loop order + # we should pick for the piontwise kernel so that it is fusible with the reduction. reduce_ranges, reduce_reindex, _ = simplify_and_reorder( - reduce_vars, support_vars, reduce_size + reduce_vars, support_vars, reduce_size, should_merge_loops ) # retrace the loop body with simplification and reordering applied (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( - iter_ranges, reduce_ranges, prefix="z" + iter_ranges, + reduce_ranges, + prefix="z", ) body = LoopBody( - body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges + body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, ) return (iter_ranges, reduce_ranges), body @@ -3886,9 +3906,9 @@ def __init__(self, layout, inputs, make_kernel_render): V.graph.register_operation(self) def get_read_writes(self): - return self.normalized_read_writes() + return self.extract_read_writes(normalize=True) - def normalized_read_writes(self): + def extract_read_writes(self, normalize): name = self.get_name() indexer = self.layout.make_indexer() @@ -3897,7 +3917,7 @@ def dummy(index, rindex): return ops.store(name, indexer(index), "fake") deps = dependencies.extract_read_writes( - dummy, self.get_size(), (), normalize=True + dummy, self.get_size(), (), normalize=normalize ) deps.reads = OrderedSet(dependencies.StarDep(x.get_name()) for x in self.inputs) return deps @@ -6787,9 +6807,19 @@ class LoopBody: indexing simplifications and makes it easier to analyze loop bodies. """ - def __init__(self, fn, args, var_ranges): + def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): super().__init__() + + _flat_sizes = tuple(var_ranges.values()) + self.sizes = ( + _flat_sizes[: len(iter_vars)], + _flat_sizes[len(iter_vars) :], + ) + + self.iter_vars = iter_vars + self.reduce_vars = reduce_vars self.var_ranges = var_ranges + self.indexing_exprs = {} self.indexing_exprs_name = {} self.reads = [] @@ -6804,6 +6834,112 @@ def __init__(self, fn, args, var_ranges): self.root_block = LoopBodyBlock(self, fn, args) self.indexing = None + def merge_loops(self) -> LoopBody: + """ + Merge both iteration and reduction loops and return a new LoopBody. + """ + old_body = self + old_sizes = self.sizes + old_iter_vars, old_reduce_vars = old_body.vars + old_iter_sizes, old_reduce_sizes = old_sizes + + index_exprs = [*old_body.indexing_exprs.values()] + + iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( + old_iter_vars, + old_iter_sizes, + index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), + ) + + reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( + old_reduce_vars, + old_reduce_sizes, + index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), + ) + + # if iter_sizes == old_iter_sizes: + # # no dimensions get merged. + # return old_sizes, old_body + + # Note: if no dimension get merges, the symbol prefix will + # remain 'y'. But if we merge dimensions, we change prefix to + # 'z'. If this is an issue, we can always retrace the LoopBody + # to change symbol prefix to 'z'. + # + # There is indeed an issue due to symbol name conflicting. + # y0 maybe reused for the y dimension later. + ( + iter_vars, + reduce_vars, + ), var_ranges = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="t" + ) + new_body = LoopBody( + old_body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + + # use the original symbol prefix + # Can try to optimize if this is a bottleneck for compilation time + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="z" + ) + new_body2 = LoopBody( + new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body2 + + def reorder_iter_loops(self, new_order) -> LoopBody: + """ + Reorder iteration loops and return a new LoopBody. + """ + old_body = self + old_sizes = self.sizes + assert len(old_sizes[0]) == len(new_order) + reorder_fn = same_reorder(new_order) + + iter_size, reduce_size = old_sizes + new_iter_size = reorder_fn(iter_size) + + new_sizes = (new_iter_size, reduce_size) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="t" # type: ignore[arg-type] + ) + + inverse_order = {b: a for a, b in enumerate(new_order)} + inverse_order = [inverse_order[i] for i in range(len(new_order))] + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = list(itertools.chain(*indices)) + assert len(index) == len(iter_size) + len(reduce_size) + iter_idx = index[: len(iter_size)] + reduce_idx = index[len(iter_size) :] + iter_idx = [iter_idx[i] for i in inverse_order] + return old_body(iter_idx, reduce_idx) + + loop_body = LoopBody( + new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars + ) + + # use the original symbol prefix so we can do multiple round of reordering + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="z" # type: ignore[arg-type] + ) + new_body = LoopBody( + loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body + + @property + def vars(self): + assert self.iter_vars is not None + assert self.reduce_vars is not None + return self.iter_vars, self.reduce_vars + @cache_on_self def get_nodes(self): all_graphs = itertools.chain( @@ -6832,6 +6968,8 @@ def debug_str(self): ) return "\n".join(lines) + __repr__ = debug_str + def add_index_expr(self, expr: sympy.Expr, category, buf_name): getattr(self, category).append(expr) if buf_name is not None: @@ -6872,7 +7010,9 @@ def get_index(self, name): def indexing_from_args(self, indices): index = [*itertools.chain.from_iterable(indices)] assert len(index) == len(self.var_ranges), (index, self.var_ranges) - assert all(v not in self.var_ranges for v in index) + assert all( + v not in self.var_ranges for v in index + ), f"{self.var_ranges=}, {indices=}" replacements = dict(zip(self.var_ranges.keys(), index)) return { name: sympy_subs(expr, replacements) diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 5c26e322f12633..fe77279800e3da 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -49,6 +49,8 @@ class CppOuterLoopFusedCount: num_comprehensive_padding = 0 num_matches_for_scatter_upon_const_tensor = 0 +num_loop_reordering = 0 + # reset all counters def reset(): @@ -60,6 +62,7 @@ def reset(): global cpp_outer_loop_fused_inner_counts global num_comprehensive_padding global num_matches_for_scatter_upon_const_tensor + global num_loop_reordering generated_kernel_count = 0 generated_cpp_vec_kernel_count = 0 @@ -71,6 +74,7 @@ def reset(): cpp_outer_loop_fused_inner_counts.clear() num_comprehensive_padding = 0 num_matches_for_scatter_upon_const_tensor = 0 + num_loop_reordering = 0 @dataclass diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 38142f42495273..7b2450f37366bc 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -65,6 +65,7 @@ log = logging.getLogger(__name__) fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") @dataclasses.dataclass @@ -240,6 +241,11 @@ def log_details(self) -> None: self.read_writes.writes, ) + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + return + def update_mutated_names(self, renames: Dict[str, str]) -> None: self.set_read_writes(self.read_writes.rename(renames)) @@ -830,12 +836,21 @@ def _compute_attrs( group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn self.group = (self.node.get_device(), group_fn(self._sizes)) + # Don't normalize since normalization will merge loops which + # makes it hard to decide new loop orders. + should_normalize = ( + not config.loop_ordering_after_fusion + or self.node.get_device().type != "cuda" + ) + if isinstance(self.node, ir.TemplateBuffer): - self.set_read_writes(self.node.normalized_read_writes()) + self.set_read_writes( + self.node.extract_read_writes(normalize=should_normalize) + ) else: self.set_read_writes( dependencies.extract_read_writes( - self._body, *self._sizes, normalize=True + self._body, *self._sizes, normalize=should_normalize ) ) @@ -844,6 +859,49 @@ def recompute_size_and_body( ) -> None: self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) + def refresh_dependencies(self, normalize: bool) -> None: + # Fake dependencies are added manually. They can not be analyzed from + # extract_read_writes. Find them out and apply manually. + fake_deps = { + dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep)) + } + + # don't normalize since the loop order may need to be further changed + # later + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=normalize + ).with_read(fake_deps) + ) + + def apply_new_loop_order(self, new_order: Sequence[int]) -> None: + self._body = self._body.reorder_iter_loops( + new_order, + ) + self._sizes = self._body.sizes + + self.refresh_dependencies(normalize=False) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + new_order = None + self_sizes = self._sizes[0] + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if new_order: + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for %s with order %s", self.get_name(), new_order + ) + self.apply_new_loop_order(new_order) + else: + loop_ordering_log.debug( + "Don't reordering %s because we can not decide the suitable loop order", + self.get_name(), + ) + def debug_str_extra(self) -> str: name = self.get_name() lines = [ @@ -963,6 +1021,22 @@ def _get_atomic_add_buffers(self) -> OrderedSet[str]: return buffers_store_as_atomic_add +def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None: + snodes = group_snode.snodes # type: ignore[attr-defined] + group_snode.set_read_writes( + dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) + ) + + group_snode.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) + if dep.name not in group_snode.get_buffer_names() + ) + - group_snode.read_writes.writes + ) + + def init_group_node( group_snode: BaseSchedulerNode, scheduler: Scheduler, @@ -976,18 +1050,7 @@ def init_group_node( *[x.ancestors for x in snodes if x.ancestors is not None] ) - group_snode.set_read_writes( - dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) - ) - - group_snode.unmet_dependencies = ( - OrderedSet( - dep - for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) - if dep.name not in group_snode.get_buffer_names() - ) - - group_snode.read_writes.writes - ) + refresh_group_node_dependencies(group_snode) group_snode.min_order = min(x.min_order for x in group_snode.snodes) group_snode.max_order = max(x.max_order for x in group_snode.snodes) @@ -1015,6 +1078,43 @@ def fuse( nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) return cls(node1.scheduler, nodes) + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + if self.is_template(): + # We can not really reorder loops for a triton template + return + self_sizes = None + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + if self_sizes is not None and self_sizes != snode._sizes[0]: + loop_ordering_log.debug( + "Can not reorder fused node due to different sizes" + ) + return + self_sizes = snode._sizes[0] + new_order = None + + assert self_sizes is not None + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if not new_order: + loop_ordering_log.debug( + "Dont reordering fused node %s because we can not decide the suitable loop order", + self.get_name(), + ) + return + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for fused node %s with order %s", self.get_name(), new_order + ) + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + snode.apply_new_loop_order(new_order) # type: ignore[arg-type] + + refresh_group_node_dependencies(self) + def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: # NB: No need to call super().__init__() because we don't need to re-use any of its logic. init_group_node(self, scheduler, snodes) @@ -1692,6 +1792,7 @@ def _init(self, nodes: List[ir.Operation]) -> None: if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) self.nodes = self.fuse_nodes(self.nodes) + self.merge_loops() self.finalize_multi_template_buffers() if config.reorder_for_compute_comm_overlap: self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) @@ -2122,6 +2223,39 @@ def compute_ancestors(self) -> None: node.min_order = order node.max_order = order + def merge_loops(self) -> None: + for node in self.nodes: + if not config.loop_ordering_after_fusion: + continue + + # Even for CPU, if we are using the halide backend, we still need + # the merge loops steps below + if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or ( + node.get_device().type != "cuda" and config.cpu_backend != "halide" + ): + continue + for snode in node.get_nodes(): + # merge loops for the scheduler node + if not isinstance(snode, SchedulerNode) or snode.is_template(): + continue + + snode._body = snode._body.merge_loops() + snode._sizes = snode._body.sizes + + # merge_loops is called after loop reordering. + # We still need retain fake dependencies since codegen the + # estimated amount of memory access rely on them. + snode.refresh_dependencies(normalize=True) + + # Note that for CPU backend, merging loops will change + # snode.group. It's fine for Triton backend. + # But if we simplify update snode.group like this: + # group_fn = self.get_backend(snode.node.get_device()).group_fn + # snode.group = (snode.node.get_device(), group_fn(snode._sizes)) + # There is still an issue due to different snode in a + # FusedSchedulerNode having different merged loops. + # Skip CPU backend for now. + def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: """ Combine eligible nodes into FusedSchedulerNodes. @@ -2664,6 +2798,78 @@ def decide_fusion_fail_reason( return str(reasons) + def has_shared_data_after_reordering_loop( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Right now just greedily reorder the loop of node1 to be compatible with node2, + but ideally we should have some heuristics to reorder the loop for node2 + to be compatibile with node1 if that's more efficient. + """ + + # TODO Don't do loop reordering for CPU for now. + # Should debug more why it does not work for CPU codegen + if not config.loop_ordering_after_fusion or any( + n.get_device().type == "cpu" for n in [node1, node2] + ): + return False + + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + # Fast path: no common buffers. + common_buffer_names = node1_buffer_names & node2_buffer_names + if not common_buffer_names: + return False + + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + # Find the commons buffers that has different loop orders + candidates = [] + for buffer_name in common_buffer_names: + lhs_dep = node1_name2dep[buffer_name] + rhs_dep = node2_name2dep[buffer_name] + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + candidates.append( + ( + V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0), + lhs_dep, + rhs_dep, + ) + ) + + if len(candidates) == 0: + return False + + # Pick the largest buffer to guide the loop reordering + numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[ + 0 + ] + + if lhs_dep.num_vars != rhs_dep.num_vars: + # this can happen due to we don't merge loops. + # We can not do loop reordering in this case right now + # Simply returning true if the two Deps are the same after + # normalization (merging loops) + return lhs_dep.normalize() == rhs_dep.normalize() + + # Only reorder loops for pointwise for now + if not node1.is_reduction(): + node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) + elif not node2.is_reduction(): + node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) + else: + loop_ordering_log.debug( + "Don't reorder loops since both nodes are reductions: %s v.s. %s", + node1.get_name(), + node2.get_name(), + ) + + return self.score_fusion_memory(node1, node2) > 0 + def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: """ Determine if it is possible to combine node1 and node2 into a @@ -2722,6 +2928,17 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: del device2 no_shared_data = self.score_fusion_memory(node1, node2) == 0 + if no_shared_data: + no_shared_data = not self.has_shared_data_after_reordering_loop( + node1, node2 + ) + + loop_ordering_log.debug( + "%s and %s has%s shared data", + node1.get_name(), + node2.get_name(), + " no" if no_shared_data else "", + ) if no_shared_data and ( not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() ): @@ -2857,6 +3074,12 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: read_name = read.name if read_name in self.mutation_renames: read_name = self.mutation_renames[read_name] + + if read.num_vars != write.num_vars: + # merge loops + read = read.normalize() + write = write.normalize() + return ( read_name == write.name and not free_symbol_is_type(read.index, SymT.TMP) diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 69af10f4a4a5d5..b85ef6d880ae44 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -157,6 +157,11 @@ "Detailed Inductor fusion decisions. More detailed than 'schedule'", off_by_default=True, ) +register_artifact( + "loop_ordering", + "Logs related to loop ordering", + off_by_default=True, +) register_artifact( "overlap", "Detailed Inductor compute/comm overlap decisions", From b379a6f1add8a4e07371f2bad8464ee569e89497 Mon Sep 17 00:00:00 2001 From: Zhuoran Zhao Date: Thu, 29 Aug 2024 21:59:04 +0000 Subject: [PATCH 0253/1018] Fix AOTInductor complication on ROCM (#134522) Summary: Original PR (https://github.com/pytorch/pytorch/pull/124123) is broken by cpp_builder refactoring So resubmit it to fix Test Plan: Test with command here: https://www.internalfb.com/phabricator/paste/view/P1549765548 Differential Revision: D61827208 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134522 Approved by: https://github.com/frank-wei --- torch/_inductor/cpp_builder.py | 28 +++++++++++++++++++------- torch/_inductor/fx_passes/reinplace.py | 11 ++++++++-- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 72348824bd75a1..f78bc354cc90cc 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -143,7 +143,9 @@ def get_cpp_compiler() -> str: check_compiler_exist_windows(compiler) else: if config.is_fbcode(): - return build_paths.cc() + return ( + build_paths.cc() if torch.version.hip is None else build_paths.clang() + ) if isinstance(config.cpp.cxx, (list, tuple)): search = tuple(config.cpp.cxx) else: @@ -633,10 +635,20 @@ def _setup_standard_sys_libs( if config.is_fbcode(): cflags.append("nostdinc") - include_dirs.append(build_paths.sleef()) - include_dirs.append(build_paths.cc_include()) - include_dirs.append(build_paths.libgcc()) - include_dirs.append(build_paths.libgcc_arch()) + # Note that the order of include paths do matter, as a result + # we need to have several branches interleaved here + if torch.version.hip is None: + include_dirs.append(build_paths.sleef()) + include_dirs.append(build_paths.openmp()) + include_dirs.append(build_paths.python()) + if torch.version.hip is not None: + include_dirs.append(build_paths.clang_include()) + include_dirs.append(build_paths.gcc_include()) + include_dirs.append(build_paths.gcc_install_tools_include()) + else: + include_dirs.append(build_paths.cc_include()) + include_dirs.append(build_paths.libgcc()) + include_dirs.append(build_paths.libgcc_arch()) include_dirs.append(build_paths.libgcc_backward()) include_dirs.append(build_paths.glibc()) include_dirs.append(build_paths.linux_kernel()) @@ -1070,7 +1082,9 @@ def get_cpp_torch_cuda_options( and "CUDA_HOME" not in os.environ and "CUDA_PATH" not in os.environ ): - os.environ["CUDA_HOME"] = build_paths.cuda() + os.environ["CUDA_HOME"] = ( + build_paths.rocm() if torch.version.hip else build_paths.cuda() + ) _set_gpu_runtime_env() from torch.utils import cpp_extension @@ -1086,7 +1100,7 @@ def get_cpp_torch_cuda_options( libraries += ["amdhip64"] else: libraries += ["c10_hip", "torch_hip"] - definations.append(" __HIP_PLATFORM_AMD__") + definations.append(" __HIP_PLATFORM_AMD__") else: if config.is_fbcode(): libraries += ["cuda"] diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index a7cd01dc0043e3..2618c7591b605b 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -11,7 +11,7 @@ kernel_side_table, triton_kernel_wrapper_functional, ) -from torch._inductor import inductor_prims +from torch._inductor import config, inductor_prims from torch._inductor.fx_utils import get_node_storage, is_node_realized from torch._inductor.lowering import ( inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, @@ -591,7 +591,14 @@ def tensor_with_same_storage_already_reinplaced(arg): if isinstance(kernel, JITFunction): kernel_name = kernel.fn.__name__ elif isinstance(kernel, Autotuner): - kernel_name = kernel.base_fn.__name__ + if config.is_fbcode(): + # Autotuner has different implementations for AMD and NV + if torch.version.hip is None: + kernel_name = kernel.base_fn.__name__ + else: + kernel_name = kernel.fn.__name__ + else: + kernel_name = kernel.base_fn.__name__ else: raise AssertionError("Unknown triton kernel type") From 09ef3f64f5cbb38dc589876b3b740a5a91ee9749 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Wed, 28 Aug 2024 15:35:16 -0700 Subject: [PATCH 0254/1018] [c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931) We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG. Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options" We need to make changes to the test to make it aligned with the change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132931 Approved by: https://github.com/H-Huang --- test/distributed/test_c10d_common.py | 8 +- test/distributed/test_device_mesh.py | 3 +- torch/_C/_distributed_c10d.pyi | 15 +--- torch/csrc/distributed/c10d/ProcessGroup.cpp | 24 +----- torch/csrc/distributed/c10d/ProcessGroup.hpp | 78 +++++++++++++------- torch/csrc/distributed/c10d/init.cpp | 60 ++++++++------- torch/distributed/device_mesh.py | 5 +- torch/distributed/distributed_c10d.py | 66 ++++++++++------- 8 files changed, 140 insertions(+), 119 deletions(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 3e5538d57e38ae..d96abb1ca82675 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1820,6 +1820,7 @@ def test_init_process_group_optional_backend(self): def test_init_process_group_for_all_backends(self): for backend in dist.Backend.backend_list: + excepted_backend = backend # skip if the backend is not available on the system if backend == dist.Backend.UNDEFINED: continue @@ -1835,6 +1836,11 @@ def test_init_process_group_for_all_backends(self): elif backend == dist.Backend.UCC: if not dist.is_ucc_available(): continue + # Multi-threaded PG is defined as a pure python class. + # Its pg.name() does not going through Pybind, so its backend name + # is still "threaded" instead of "custom". + elif backend != "threaded": + excepted_backend = "custom" with tempfile.NamedTemporaryFile(delete=False) as f: store = dist.FileStore(f.name, self.world_size) @@ -1847,7 +1853,7 @@ def test_init_process_group_for_all_backends(self): pg = c10d._get_default_group() self.assertEqual(pg.rank(), self.rank) self.assertEqual(pg.size(), self.world_size) - self.assertEqual(pg.name(), str(backend)) + self.assertEqual(pg.name(), str(excepted_backend)) dist.destroy_process_group() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 972c0ae89028be..08ee59dbff0a53 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -232,7 +232,8 @@ def test_set_mesh_dim_group_options(self): mesh_tensor = torch.arange(4).reshape(2, 2) mesh = DeviceMesh(device_type, mesh_tensor) - self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake") + # Fake pg only have BackendType as BackendType::CUSTOM. + self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom") class DeviceMeshTestNDim(DTensorTestBase): diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 6033d969925972..f89f0b50c8582e 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -296,15 +296,6 @@ class Backend: def _set_default_timeout(self, timeout: timedelta) -> None: ... class ProcessGroup: - class Options: - def __init__(self, backend: str, timeout: timedelta = ...) -> None: ... - @property - def backend(self) -> str: ... - @property - def _timeout(self) -> timedelta: ... - @_timeout.setter - def _timeout(self, val: timedelta) -> None: ... - class BackendType(Enum): UNDEFINED = ... GLOO = ... @@ -319,7 +310,6 @@ class ProcessGroup: store: Store, rank: int, size: int, - options: Options, ) -> None: ... def rank(self) -> int: ... def size(self) -> int: ... @@ -509,6 +499,7 @@ class ProcessGroup: @property def _device_types(self) -> list[torch.device]: ... def _get_backend(self, device: torch.device) -> Backend: ... + def _set_default_backend(self, backend_type: BackendType) -> None: ... def _register_backend( self, device: torch.device, @@ -533,7 +524,7 @@ class ProcessGroup: class ProcessGroupGloo(Backend): class Device: ... - class Options(ProcessGroup.Options): + class Options(Backend.Options): devices: list[ProcessGroupGloo.Device] threads: int @@ -563,7 +554,7 @@ class ProcessGroupNCCL(Backend): min_ctas: int max_ctas: int - class Options(ProcessGroup.Options): + class Options(Backend.Options): config: ProcessGroupNCCL.NCCLConfig is_high_priority_stream: bool split_from: ProcessGroupNCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 70356b3bf382ce..f565de2013260c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -14,24 +14,6 @@ namespace c10d { -static ProcessGroup::BackendType strToBackendType(std::string_view backend) { - if (backend == "undefined") { - return ProcessGroup::BackendType::UNDEFINED; - } else if (backend == "gloo") { - return ProcessGroup::BackendType::GLOO; - } else if (backend == "nccl") { - return ProcessGroup::BackendType::NCCL; - } else if (backend == "xccl") { - return ProcessGroup::BackendType::XCCL; - } else if (backend == "ucc") { - return ProcessGroup::BackendType::UCC; - } else if (backend == "mpi") { - return ProcessGroup::BackendType::MPI; - } else { - return ProcessGroup::BackendType::CUSTOM; - } -} - std::string opTypeToString(OpType opType) { switch (opType) { case OpType::BROADCAST: @@ -121,13 +103,11 @@ c10::intrusive_ptr ProcessGroup::getBackend( ProcessGroup::ProcessGroup( const c10::intrusive_ptr<::c10d::Store>& store, int rank, - int size, - c10::intrusive_ptr options) + int size) : store_(store), rank_(rank), size_(size), - options_(std::move(options)), - backendType_(strToBackendType(options_->backend)), + backendType_(BackendType::UNDEFINED), dist_debug_level_(debug_level()) { C10_LOG_API_USAGE_ONCE("c10d.process_group"); } diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 73fc2bda701327..83d2729fc43d43 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -45,24 +45,6 @@ namespace c10d { // class TORCH_API ProcessGroup : public torch::CustomClassHolder { public: - // ProcessGroup Options is a base struct that defines the basic options - // when constructing a ProcessGroup. Each ProcessGroup subclass should - // extend this struct and define its options if it wants to provide more - // config options (beyond basic ones defined here) to end user. - struct TORCH_API Options : torch::CustomClassHolder { - explicit Options( - std::string backend, - std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout) - : timeout(timeout), backend(std::move(backend)) {} - ~Options() override = default; - - std::chrono::milliseconds timeout; - - // backend name - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::string backend; - }; - enum BackendType : uint8_t { UNDEFINED = 0, GLOO = 1, @@ -73,6 +55,45 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { XCCL = 6, }; + static std::string backendTypeToString(const BackendType& type) { + switch (type) { + case BackendType::GLOO: + return "gloo"; + case BackendType::NCCL: + return "nccl"; + case BackendType::XCCL: + return "xccl"; + case BackendType::UCC: + return "ucc"; + case BackendType::MPI: + return "mpi"; + case BackendType::UNDEFINED: + return "undefined"; + case BackendType::CUSTOM: + return "custom"; + default: + TORCH_CHECK(false, "THis should never happen!"); + } + }; + + static BackendType strToBackendType(const std::string& backend) { + if (backend == "undefined") { + return BackendType::UNDEFINED; + } else if (backend == "gloo") { + return BackendType::GLOO; + } else if (backend == "nccl") { + return BackendType::NCCL; + } else if (backend == "xccl") { + return BackendType::XCCL; + } else if (backend == "ucc") { + return BackendType::UCC; + } else if (backend == "mpi") { + return BackendType::MPI; + } else { + return BackendType::CUSTOM; + } + }; + // Not used, set for backwards compatibility and only used for TypeDef in // Ops.cpp explicit ProcessGroup(int rank, int size); @@ -80,8 +101,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { explicit ProcessGroup( const c10::intrusive_ptr<::c10d::Store>& store, int rank, - int size, - c10::intrusive_ptr options); + int size); ~ProcessGroup() override; int getRank() const { @@ -104,7 +124,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { } virtual const std::string getBackendName() const { - return options_->backend; + return backendTypeToString(backendType_); }; BackendType getBackendType() const { @@ -612,10 +632,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { opts.timeout.count()); } - c10::intrusive_ptr getOptions() { - return options_; - } - bool hasBackends() { return !deviceTypeToBackendType_.empty(); } @@ -656,6 +672,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return backendTypeToBackend_.at(backendType_); } + void setDefaultBackend(const BackendType& backendType) { + backendType_ = backendType; + } + + void setDefaultBackend(const std::string& backend) { + backendType_ = strToBackendType(backend); + } + c10::intrusive_ptr getBackend(c10::DeviceType deviceType); c10::intrusive_ptr getBackend(BackendType backendType) const { @@ -728,9 +752,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const int size_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const c10::intrusive_ptr options_; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const BackendType backendType_; + BackendType backendType_; std::string pg_desc_; // Debug level setting. It is parsed once when ProcessGroup is constructed and diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index e3ed6d6bd4bcb4..62b2916e5f7af8 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1818,8 +1818,7 @@ communication mechanism. py::init< const c10::intrusive_ptr<::c10d::Store>&, int, - int, - c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(), + int>(), py::call_guard()) .def("rank", &::c10d::ProcessGroup::getRank) .def("size", &::c10d::ProcessGroup::getSize) @@ -1829,7 +1828,6 @@ communication mechanism. "_backend_id", &::c10d::ProcessGroup::getBackendID, py::arg("backend_type")) - .def_property_readonly("options", &::c10d::ProcessGroup::getOptions) .def( "broadcast", &::c10d::ProcessGroup::broadcast, @@ -2139,6 +2137,14 @@ communication mechanism. }, py::arg("device"), py::call_guard()) + .def( + "_set_default_backend", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const ::c10d::ProcessGroup::BackendType& backendType) { + return self->setDefaultBackend(backendType); + }, + py::arg("backend_type"), + py::call_guard()) .def( "_register_on_completion_hook", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -2242,27 +2248,6 @@ The hook must have the following signature: .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) .export_values(); - // base ProcessGroup::Options binding - auto processGroupOptions = - intrusive_ptr_class_<::c10d::ProcessGroup::Options>( - processGroup, - "Options", - R"( -Base class for all processes group options implementations, such as the nccl -options :class:`~torch.distributed.ProcessGroupNCCL.Options`). -)") - .def( - py::init([](const std::string& backend, - const std::chrono::milliseconds& timeout) { - return c10::make_intrusive<::c10d::ProcessGroup::Options>( - backend, timeout); - }), - py::arg("backend"), - py::arg("timeout") = kProcessGroupDefaultTimeout, - py::call_guard()) - .def_readonly("backend", &::c10d::ProcessGroup::Options::backend) - .def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout); - // TODO: The collection definitions handles direct instantiation of // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported // and should be removed once all tests are transitioned @@ -2561,6 +2546,29 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). &::c10d::Backend::endCoalescing, py::call_guard()); + // base Backend::Options binding + // TODO: Maybe we can consider how to merge this with + // `DistributedBackendOptions`. + auto backendOptions = + intrusive_ptr_class_<::c10d::Backend::Options>( + backend, + "Options", + R"( +Base class for all backend options implementations, such as the nccl +options :class:`~torch.distributed.ProcessGroupNCCL.Options`). +)") + .def( + py::init([](const std::string& backend, + const std::chrono::milliseconds& timeout) { + return c10::make_intrusive<::c10d::Backend::Options>( + backend, timeout); + }), + py::arg("backend"), + py::arg("timeout") = kProcessGroupDefaultTimeout, + py::call_guard()) + .def_readonly("backend", &::c10d::Backend::Options::backend) + .def_readwrite("_timeout", &::c10d::Backend::Options::timeout); + #ifdef USE_C10D_GLOO static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; @@ -2572,7 +2580,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device"); intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>( - processGroupGloo, "_Options", processGroupOptions) + processGroupGloo, "_Options", backendOptions) .def(py::init<>()) .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices) .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads); @@ -2799,7 +2807,7 @@ for details. intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( processGroupNCCL, "Options", - processGroupOptions, + backendOptions, R"( ProcessGroup options for the NCCL backend diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 0ad17536ef1832..f6e77fe166c655 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -36,6 +36,7 @@ def _init_device_mesh_stub(): else: + from torch._C._distributed_c10d import Backend as C10dBackend from torch.distributed.distributed_c10d import ( _find_pg_by_ranks_and_tag, _get_default_group, @@ -66,7 +67,7 @@ def __init__(self) -> None: self.mesh_stack: List[DeviceMesh] = [] self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {} self.mesh_dim_group_options: Dict[ - int, Tuple[str, Optional[ProcessGroup.Options]] + int, Tuple[str, Optional[C10dBackend.Options]] ] = {} self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {} # Record flatten mesh name to its mesh dim index in root mesh. @@ -279,7 +280,7 @@ def _set_mesh_dim_group_options( self, dim: int, backend: str, - pg_options: Optional[ProcessGroup.Options] = None, + pg_options: Optional[C10dBackend.Options] = None, ) -> None: self.mesh_dim_group_options[dim] = (backend, pg_options) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 9fa3224873c9fc..2d9357bbd15a44 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -280,6 +280,7 @@ class Backend(str): NCCL: ProcessGroup.BackendType.NCCL, XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, + MPI: ProcessGroup.BackendType.MPI, } def __new__(cls, name: str): @@ -1554,7 +1555,7 @@ def init_process_group( backend, store, group_name, - pg_options=pg_options, + backend_options=pg_options, timeout=timeout, device_id=device_id, group_desc="default_pg", @@ -1651,7 +1652,7 @@ def _new_process_group_helper( backend, store, group_name, - pg_options=None, + backend_options=None, timeout=None, pg_tag=None, device_id=None, @@ -1730,11 +1731,17 @@ def _new_process_group_helper( return GroupMember.NON_GROUP_MEMBER, None prefix_store = PrefixStore(f"{group_name}/", store) - base_pg_options = ProcessGroup.Options(backend=str(backend)) - base_pg_options._timeout = timeout + # The backend for PG will be set later based on what's inside BackendConfig + # and timeout are set in each backend's option. pg: ProcessGroup = ProcessGroup( - prefix_store, group_rank, group_size, base_pg_options + prefix_store, + group_rank, + group_size, ) + # Set the default backend when only single backend is passed in. + if "," not in str(backend) and ":" not in str(backend): + assert backend in Backend.backend_type_map, f"Unknown backend type {backend}" + pg._set_default_backend(Backend.backend_type_map[backend]) if device_id: pg.bound_device_id = device_id backend_config = BackendConfig(backend) @@ -1761,8 +1768,8 @@ def _new_process_group_helper( backend_prefix_store, backend_class.rank(), backend_class.size(), - base_pg_options, ) + pg._set_default_backend(backend_type) elif backend_str == Backend.GLOO: # TODO: remove this check after lazy initialization is supported # if pg_options is not None: @@ -1774,27 +1781,30 @@ def _new_process_group_helper( elif backend_str == Backend.NCCL: if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL built in") - if pg_options is not None: + if backend_options is not None: assert isinstance( - pg_options, ProcessGroupNCCL.Options - ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" - if pg_options._timeout != timeout: + backend_options, ProcessGroupNCCL.Options + ), "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + if backend_options._timeout != timeout: warnings.warn( - "pg_options._timeout was specified, " + "backend_options._timeout was specified, " "but timeout kwarg has a default value that will always override it. " ) else: - # default pg_options for NCCL - pg_options = ProcessGroupNCCL.Options() - pg_options.is_high_priority_stream = False - pg_options._timeout = timeout + # default backend_options for NCCL + backend_options = ProcessGroupNCCL.Options() + backend_options.is_high_priority_stream = False + backend_options._timeout = timeout + if split_from: - pg_options.split_from = split_from - pg_options.split_color = _process_group_color(global_ranks_in_group) - pg_options.global_ranks_in_group = global_ranks_in_group - pg_options.group_name = group_name + backend_options.split_from = split_from + backend_options.split_color = _process_group_color( + global_ranks_in_group + ) + backend_options.global_ranks_in_group = global_ranks_in_group + backend_options.group_name = group_name backend_class = ProcessGroupNCCL( - backend_prefix_store, group_rank, group_size, pg_options + backend_prefix_store, group_rank, group_size, backend_options ) backend_type = ProcessGroup.BackendType.NCCL elif backend_str == Backend.UCC and is_ucc_available(): @@ -1840,7 +1850,7 @@ def _new_process_group_helper( dist_backend_opts.group_id = group_name dist_backend_opts.global_ranks_in_group = global_ranks_in_group - backend_class = creator_fn(dist_backend_opts, pg_options) + backend_class = creator_fn(dist_backend_opts, backend_options) # Set sequence numbers for gloo and nccl backends. if backend_str == Backend.GLOO: @@ -4475,12 +4485,15 @@ def split_group( global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] prefix_store = PrefixStore(f"{group_name}/", default_store) - base_pg_options = ProcessGroup.Options(backend=str(backend)) - base_pg_options._timeout = timeout + # We register the backend after initializing and timeout is set in pg_options. pg: ProcessGroup = ProcessGroup( - prefix_store, group_rank, len(my_group), base_pg_options + prefix_store, + group_rank, + len(my_group), ) + backend_type = ProcessGroup.BackendType.NCCL pg.bound_device_id = device_id + pg._set_default_backend(backend_type) pg_options._timeout = timeout pg_options.split_from = parent_backend @@ -4490,7 +4503,6 @@ def split_group( backend_class = ProcessGroupNCCL( prefix_store, group_rank, len(my_group), pg_options ) - backend_type = ProcessGroup.BackendType.NCCL backend_class._set_sequence_number_for_group() pg._register_backend(torch.device("cuda"), backend_type, backend_class) @@ -4613,7 +4625,7 @@ def _new_group_with_tag( ranks=None, timeout=None, backend=None, - pg_options=None, + backend_options=None, pg_tag=None, use_local_synchronization=False, group_desc=None, @@ -4688,7 +4700,7 @@ def _new_group_with_tag( backend, default_store, group_name, - pg_options=pg_options, + backend_options=backend_options, timeout=timeout, pg_tag=pg_tag, device_id=device_id, From e4dec2209c344b46786313a10ee8e7be6fc5a653 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 27 Aug 2024 22:59:43 -0700 Subject: [PATCH 0255/1018] add add_loop benchmarks (#134652) This benchmark measure the cost of compiling the following function in eager and inductor its basically two benchmarks. ``` @torch.compile(backend=self.backend, fullgraph=True) def f(a, b): result = a.clone() for i in range(1000): if i % 3 == 0: result = result + b elif i % 3 == 1: result = result + 8 * b else: result = result.sin() return result ``` PYTHONPATH=$(pwd) python benchmarks/add_loop.py out ``` collecting compile time instruction count for add_loop_eager compile time instruction count for iteration 0 is 8286649663 compile time instruction count for iteration 1 is 2838971338 compile time instruction count for iteration 2 is 2834263023 compile time instruction count for iteration 3 is 2829447493 compile time instruction count for iteration 4 is 2830904231 compile time instruction count for iteration 5 is 2830281077 compile time instruction count for iteration 6 is 2831466595 compile time instruction count for iteration 7 is 2830732164 compile time instruction count for iteration 8 is 2831088056 compile time instruction count for iteration 9 is 2831204407 collecting compile time instruction count for add_loop_inductor compile time instruction count for iteration 0 is 32585687849 compile time instruction count for iteration 1 is 11747553436 compile time instruction count for iteration 2 is 11746959875 compile time instruction count for iteration 3 is 11749479461 compile time instruction count for iteration 4 is 11750053711 compile time instruction count for iteration 5 is 11750793958 compile time instruction count for iteration 6 is 11751673576 compile time instruction count for iteration 7 is 11754552912 compile time instruction count for iteration 8 is 11753723127 compile time instruction count for iteration 9 is 11759059942 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134652 Approved by: https://github.com/anijain2305 ghstack dependencies: #133834, #134635, #134649 --- .../pr_time_benchmarks/benchmarks/add_loop.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py new file mode 100644 index 00000000000000..9959ab452db62c --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py @@ -0,0 +1,55 @@ +import gc +import sys + +from benchmark_base import BenchmarkBase + +import torch + + +class Benchmark(BenchmarkBase): + def __init__(self, backend): + self.backend = backend + + def name(self): + return f"add_loop_{self.backend}" + + def description(self): + return "a loop over 100 add node" + + def _prepare_once(self, dynamic=True): + self.a = torch.ones(1000) + self.b = torch.torch.ones(1000) + + def _prepare(self): + torch._dynamo.reset() + gc.collect() + gc.disable() + + def _work(self): + @torch.compile(backend=self.backend, fullgraph=True) + def f(a, b): + result = a.clone() + for i in range(1000): + if i % 3 == 0: + result = result + b + elif i % 3 == 1: + result = result + 8 * b + else: + result = result.sin() + return result + + f(self.a, self.b) + + +def main(): + result_path = sys.argv[1] + Benchmark( + "eager" + ).enable_compile_time_instruction_count().collect_all().append_results(result_path) + Benchmark( + "inductor" + ).enable_compile_time_instruction_count().collect_all().append_results(result_path) + + +if __name__ == "__main__": + main() From d194d399da6452d8456dbf393b365d7023995268 Mon Sep 17 00:00:00 2001 From: Lucian Grijincu Date: Thu, 29 Aug 2024 23:15:11 +0000 Subject: [PATCH 0256/1018] pytorch: llvm_codegen: prefix JIT generated functions with 8B of data so jitted code can be called from ASAN+UBSAN on LLVM17 (llvm/llvm-project#65253) (#134572) Summary: Similar workaround was already applied elsewhere in pytorch https://github.com/pytorch/pytorch/pull/133623 {D61348865} LLVM17 UBSAN change discussion https://github.com/llvm/llvm-project/issues/104505 Here we also have to associate the data with the function with `setPrefixData(dummyPrefixData)` to prevent this workaround being disabled by the `optimize(*module_);` call which could change layout/remove the unused variable/etc. Differential Revision: D61845799 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134572 Approved by: https://github.com/atalman --- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 334ed836ebe739..a9f24139e029f3 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -626,6 +626,20 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector& params) { module_.get()); #endif + { + // Work around UBSAN crashes which reads 8 byte in front of every function. + // Otherwise, if the function was placed at the beginning of a page, reading + // 8B before the page could trigger a wild-addr-read ASAN failure if the + // page before this function was not mapped. + // - https://reviews.llvm.org/D148665 + // - https://github.com/llvm/llvm-project/issues/65253 + // Place the variable just before the function, + // the optimizer might otherwise disable this workaround. + // https://llvm.org/docs/LangRef.html#prefix-data + wrapper->setPrefixData(llvm::Constant::getNullValue( + llvm::ArrayType::get(llvm::Type::getInt8Ty(getContext()), 8))); + } + auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper); irb_.SetInsertPoint(wrapBB); llvm::SmallVector wrappedArgs; From ad3a7113ca490cf67150ab21a40a4e30e78c28d1 Mon Sep 17 00:00:00 2001 From: Shuqiang Zhang Date: Thu, 29 Aug 2024 11:34:04 -0700 Subject: [PATCH 0257/1018] [c10d] release gil lock during eager init (#134779) Summary: We found that if we init the pG in a background thread, it would block the main thread till init is complete. This is because in the pybinding we never release the GIL lock Test Plan: existing CI on eager init Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/134779 Approved by: https://github.com/c-p-i-o --- torch/csrc/distributed/c10d/init.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 62b2916e5f7af8..5bc1b15cfaacb4 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2532,7 +2532,8 @@ The hook must have the following signature: py::call_guard()) .def( "eager_connect_single_device", - &::c10d::Backend::eagerConnectSingleDevice) + &::c10d::Backend::eagerConnectSingleDevice, + py::call_guard()) .def( "_get_backend_name", &::c10d::Backend::getBackendName, From 6c357146c339f1adca4a1b18c668e5040b382bf7 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 28 Aug 2024 17:14:02 -0700 Subject: [PATCH 0258/1018] [Inductor][CPP] Fix Index name error (#134645) **Summary** Fix the comment: https://github.com/pytorch/pytorch/pull/122961#issuecomment-2313930242. For all of the cases we see in the 3 test suits (TorchBench, Timms, Huggingface) we expect: * `_node` is a FX Node with target in ["index_expr", "load", "store"] * `_node.args[1 if _node.target == "index_expr" else 2]` is another FX node with target `get_index` * `_node.args[1 if _node.target == "index_expr" else 2].args[0]` is a str for the name of this index expression It seems not true in some FB internal testcase from the failure log posted in above link. So, add the condition check to work around it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134645 Approved by: https://github.com/jgong5, https://github.com/masnesral --- torch/_inductor/codegen/cpp.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 1791d23da690e5..ab480e066e68c3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3372,13 +3372,10 @@ def _is_valid_indices( for _node in sub_block.graph.nodes: if _node.target in ["index_expr", "load", "store"]: # get the index and replace prefix from z to x + arg_idx = 1 if _node.target == "index_expr" else 2 index = sub_block.body.indexing_from_args( (vars, reduction_vars) - )[ - _node.args[ - 1 if _node.target == "index_expr" else 2 - ].args[0] - ] + )[_node.args[arg_idx].args[0]] if _is_valid_indices(itervars, tiling_indices): stride = _try_get_stride( index, itervars, tiling_factor, tiling_indices From af8d07b2f10f20dc76297fc0686cfcee54d0877a Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Thu, 22 Aug 2024 14:46:55 +0000 Subject: [PATCH 0259/1018] fix windows xpu build issue (#133845) # Motivation If build XPU via oneAPI 2024.2, it will fail because `sycl-preview.lib` exists in windows. And linking the unexpected lib results in `error LNK2019: unresolved external symbol`. # Solution Use explicitly `sycl-preview` in linux build only. # Additional Context For `find_library`, please note that the variable will not be updated if it has been stored. ``` If the library is found the result is stored in the variable and the search will not be repeated unless the variable is cleared. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133845 Approved by: https://github.com/min-jean-cho, https://github.com/EikanWang, https://github.com/atalman, https://github.com/malfet --- cmake/Modules/FindSYCLToolkit.cmake | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cmake/Modules/FindSYCLToolkit.cmake b/cmake/Modules/FindSYCLToolkit.cmake index 5b249e5e444fcf..d96b4c8d45cfef 100644 --- a/cmake/Modules/FindSYCLToolkit.cmake +++ b/cmake/Modules/FindSYCLToolkit.cmake @@ -49,12 +49,14 @@ find_file( ) # Find SYCL library fullname. -find_library( - SYCL_LIBRARY - NAMES sycl-preview - HINTS ${SYCL_LIBRARY_DIR} - NO_DEFAULT_PATH -) +if(LINUX) + find_library( + SYCL_LIBRARY + NAMES sycl-preview + HINTS ${SYCL_LIBRARY_DIR} + NO_DEFAULT_PATH + ) +endif() # On Windows, currently there's no sycl.lib. Only sycl7.lib with version suffix, # where the current version of the SYCL runtime is 7. # Until oneAPI adds support to sycl.lib without the version suffix, From ce7371c5f5bf9479490b040d6dd7fc3dfde4c200 Mon Sep 17 00:00:00 2001 From: Ivan Zaitsev Date: Thu, 29 Aug 2024 23:58:12 +0000 Subject: [PATCH 0260/1018] Correctly detect "Rate limit exceeded" error (#134785) Currently all 403 errors are treated as "Rate limit exceeded": https://github.com/pytorch/pytorch/actions/runs/10622019167/job/29445336924 [Github docs](https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api?apiVersion=2022-11-28#exceeding-the-rate-limit) claim: > If you exceed your primary rate limit, you will receive a 403 or 429 response, and the x-ratelimit-remaining header will be 0. You should not retry your request until after the time specified by the x-ratelimit-reset header. After this change: https://github.com/pytorch/pytorch/actions/runs/10622365327/job/29446456395 Note, the 403 error in the jobs above is a separate issue, this PR addresses only the logging. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134785 Approved by: https://github.com/clee2000 --- .github/scripts/github_utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index 5acef33903ba52..79f6c4732ba558 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -46,16 +46,24 @@ def gh_fetch_url_and_headers( with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn: return conn.headers, reader(conn) except HTTPError as err: - if err.code == 403 and all( - key in err.headers for key in ["X-RateLimit-Limit", "X-RateLimit-Used"] + if ( + err.code == 403 + and all( + key in err.headers + for key in ["X-RateLimit-Limit", "X-RateLimit-Remaining"] + ) + and int(err.headers["X-RateLimit-Remaining"]) == 0 ): print( - f"""Rate limit exceeded: + f"""{url} + Rate limit exceeded: Used: {err.headers['X-RateLimit-Used']} Limit: {err.headers['X-RateLimit-Limit']} Remaining: {err.headers['X-RateLimit-Remaining']} Resets at: {err.headers['x-RateLimit-Reset']}""" ) + else: + print(f"Error fetching {url} {err}") raise From 893a4a5b98adbc63a2b7006042cf35bee4781409 Mon Sep 17 00:00:00 2001 From: "Colin L. Rice" Date: Fri, 30 Aug 2024 00:35:23 +0000 Subject: [PATCH 0261/1018] dynamo: Only log if we've disabled eval_frame once. (#134529) This spams logs pretty badly otherwise Pull Request resolved: https://github.com/pytorch/pytorch/pull/134529 Approved by: https://github.com/chuanhaozhuge, https://github.com/oulgen --- torch/_dynamo/eval_frame.py | 2 +- torch/_utils_internal.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index ed0f89097987f9..c04e4ccb00a97a 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -104,7 +104,7 @@ def _maybe_set_eval_frame(callback: DynamoCallback): from torch._C._dynamo.eval_frame import set_eval_frame if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): - log.warning( + torch._dynamo.utils.warn_once( "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame" ) return callback diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 59b579913ae9c3..32b9c7ff9fca74 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -235,7 +235,8 @@ def justknobs_feature( killswitch work by having feature return True to turn off features Requirements: - Don't use this at import time - Simply pass in the existing config + WARNING - Don't use this at import time - Simply pass in the existing config. + If you want to use this at config time, use JustKnobsConfig """ if config_value is not None: return config_value From 34cbf0ae5e21b34710240004620d978db0375e05 Mon Sep 17 00:00:00 2001 From: Chen Haifeng Date: Fri, 30 Aug 2024 01:57:47 +0000 Subject: [PATCH 0262/1018] Support __class__ attr for tuple and list variables (#134099) Fixes #134086 This supports __class__ attribute for TupleVariable and ListVariable. And allows to construct a tuple or list by using __class__ attribute. This patch also fix a bug in NamedTupleVariable which misses a return on calling super var_getattr. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134099 Approved by: https://github.com/anijain2305, https://github.com/jansel --- test/dynamo/test_misc.py | 66 ++++++++++++++++++++++++++++++++ torch/_dynamo/variables/lists.py | 22 ++++++++++- 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index da258c94308837..1bfb15d1c5e81f 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -11536,6 +11536,72 @@ def fn(x): fn(torch.randn(4)) + def test_tuple_class(self): + cnts = torch._dynamo.testing.CompileCounter() + + def fn(x): + updated_x = [] + for v in x: + updated_x.append(v + 1) + return x.__class__(updated_x) + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + + d1 = torch.zeros(2, 2) + d2 = torch.ones(2, 2) + + r = opt_fn((d1, d2)) + self.assertEqual(r.__class__, tuple) + r1, r2 = r + self.assertEqual(r1, torch.ones(2, 2)) + self.assertEqual(r2, torch.ones(2, 2) + 1) + self.assertEqual(cnts.frame_count, 1) + + def test_list_class(self): + cnts = torch._dynamo.testing.CompileCounter() + + def fn(x): + updated_x = [] + for v in x: + updated_x.append(v + 1) + return x.__class__(updated_x) + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + + d1 = torch.zeros(2, 2) + d2 = torch.ones(2, 2) + + r = opt_fn([d1, d2]) + self.assertEqual(r.__class__, list) + self.assertEqual(len(r), 2) + self.assertEqual(r[0], torch.ones(2, 2)) + self.assertEqual(r[1], torch.ones(2, 2) + 1) + self.assertEqual(cnts.frame_count, 1) + + def test_namedtuple_class(self): + import collections + + cnts = torch._dynamo.testing.CompileCounter() + + def fn(x): + updated_x = [] + for v in x: + updated_x.append(v + 1) + return x.__class__(*updated_x) + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + + d1 = torch.zeros(2, 2) + d2 = torch.ones(2, 2) + point = collections.namedtuple("Point", ["x", "y"]) + p = point(d1, d2) + + r = opt_fn(p) + self.assertEqual(r.__class__, point) + self.assertEqual(r.x, torch.ones(2, 2)) + self.assertEqual(r.y, torch.ones(2, 2) + 1) + self.assertEqual(cnts.frame_count, 1) + class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 83be6fa767392a..1fed456a8f7d62 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -433,6 +433,16 @@ def call_method( else: return super().call_method(tx, name, args, kwargs) + def var_getattr(self, tx, name): + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is list: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if self.python_type() is not list: return super().call_hasattr(tx, name) @@ -525,6 +535,16 @@ def call_method( ) -> "VariableTracker": return super().call_method(tx, name, args, kwargs) + def var_getattr(self, tx, name): + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is tuple: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if self.python_type() is not tuple: return super().call_hasattr(tx, name) @@ -730,7 +750,7 @@ def check_and_create_method(): if name not in fields: method = check_and_create_method() if not method: - super().var_getattr(tx, name) + return super().var_getattr(tx, name) return method return self.items[fields.index(name)] From ace2871cfc9d59e76733c472bec8e9665b303d0d Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 26 Aug 2024 21:59:11 +0000 Subject: [PATCH 0263/1018] [reland][dtensor][MTPG] make sharding prop lru cache not shared among threads (#134509) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary** reland of https://github.com/pytorch/pytorch/pull/134294 Fixes #131446 Fixes #126852 Fixes #126868 Fixes #126493 The PR was reverted due to CI red signal in https://github.com/pytorch/pytorch/actions/runs/10537099590/job/29201744658. It seems that the `gaussian_nll_loss` test had been flaky before my original PR #134294 . Therefore this PR also removes the `xfail` mark on this specific test to make CI signal green. See the error message below: ``` 2024-08-24T13:42:01.3228990Z ==================================== RERUNS ==================================== 2024-08-24T13:42:01.3229530Z _ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _ 2024-08-24T13:42:01.3229710Z Unexpected success 2024-08-24T13:42:01.3230235Z _ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _ 2024-08-24T13:42:01.3230407Z Unexpected success 2024-08-24T13:42:01.3230594Z =================================== FAILURES =================================== 2024-08-24T13:42:01.3231128Z _ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _ 2024-08-24T13:42:01.3231296Z Unexpected success ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134509 Approved by: https://github.com/tianyu-l, https://github.com/wz337 --- test/distributed/_tensor/test_dtensor_ops.py | 2 -- torch/distributed/_tensor/_sharding_prop.py | 20 ++++++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index d3377bd47af97b..532bd3facae554 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -304,7 +304,6 @@ def wrapped(fn): xfail("nn.functional.elu"), xfail("nn.functional.fractional_max_pool2d"), xfail("nn.functional.fractional_max_pool3d"), - xfail("nn.functional.gaussian_nll_loss"), xfail("nn.functional.glu"), xfail("nn.functional.grid_sample"), xfail("nn.functional.group_norm"), @@ -350,7 +349,6 @@ def wrapped(fn): xfail("nn.functional.pdist"), xfail("nn.functional.pixel_shuffle"), xfail("nn.functional.pixel_unshuffle"), - xfail("nn.functional.poisson_nll_loss"), xfail("nn.functional.prelu"), xfail("nn.functional.relu6"), xfail("nn.functional.rrelu"), diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index aefb4e41c94479..c8c8ddb9817c1c 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import threading from functools import lru_cache from itertools import chain from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union @@ -37,6 +38,17 @@ def _length(obj) -> int: return len(obj) +class LocalLRUCache(threading.local): + def __init__(self, user_function: Callable) -> None: + self.cache = lru_cache(None)(user_function) + + def __call__(self, *args, **kwargs) -> object: + return self.cache(*args, **kwargs) + + def cache_info(self): + return self.cache.cache_info() + + class ShardingPropagator: def __init__(self) -> None: self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} @@ -46,7 +58,9 @@ def __init__(self) -> None: ] = {} # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} - self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign] + self.propagate_op_sharding = LocalLRUCache( + self.propagate_op_sharding_non_cached + ) # op map to save indices of shape (and stride) args which may need to be modified in sharding prop self.op_to_shape_and_stride_idx: Dict[ OpOverload, Union[int, Tuple[int, int]] @@ -183,7 +197,9 @@ def propagate(self, op_info: OpInfo) -> None: if op_info.schema.has_symints: output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: - output_sharding = self.propagate_op_sharding(op_info.schema) + output_sharding = cast( + OutputSharding, self.propagate_op_sharding(op_info.schema) + ) op_info.output_sharding = output_sharding def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: From 56a667502d11a2cd93e4668e9c517f7e4fc6ad00 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 27 Aug 2024 23:34:24 -0700 Subject: [PATCH 0264/1018] add basic nn modules diff time benchmarks (#134658) benchmarks several shapes of basic nn modules. in both eager and inductor ``` collecting compile time instruction count for basic_modules_ListOfLinears_inductor compile time instruction count for iteration 0 is 48602516013 compile time instruction count for iteration 1 is 20424350269 compile time instruction count for iteration 2 is 20440350455 compile time instruction count for iteration 3 is 20419269999 compile time instruction count for iteration 4 is 20430782200 compile time instruction count for iteration 5 is 20455049622 compile time instruction count for iteration 6 is 20157290712 compile time instruction count for iteration 7 is 20455324001 compile time instruction count for iteration 8 is 20450158317 compile time instruction count for iteration 9 is 20492987748 collecting compile time instruction count for basic_modules_ListOfLinears_eager compile time instruction count for iteration 0 is 961328334 compile time instruction count for iteration 1 is 958887896 compile time instruction count for iteration 2 is 958792214 compile time instruction count for iteration 3 is 958375977 compile time instruction count for iteration 4 is 958568525 compile time instruction count for iteration 5 is 958152305 compile time instruction count for iteration 6 is 959322800 compile time instruction count for iteration 7 is 958332703 compile time instruction count for iteration 8 is 958092100 compile time instruction count for iteration 9 is 958095277 collecting compile time instruction count for basic_modules_ModuleForwardHasGraphBreak_inductor compile time instruction count for iteration 0 is 3572145793 compile time instruction count for iteration 1 is 3503323973 compile time instruction count for iteration 2 is 3501962432 compile time instruction count for iteration 3 is 3501746084 compile time instruction count for iteration 4 is 3500687361 compile time instruction count for iteration 5 is 3822254676 compile time instruction count for iteration 6 is 3498356846 compile time instruction count for iteration 7 is 3499019157 compile time instruction count for iteration 8 is 3500780314 compile time instruction count for iteration 9 is 3500257458 collecting compile time instruction count for basic_modules_ModuleForwardHasGraphBreak_eager compile time instruction count for iteration 0 is 1844838754 compile time instruction count for iteration 1 is 1843476862 compile time instruction count for iteration 2 is 1844761450 compile time instruction count for iteration 3 is 1845371742 compile time instruction count for iteration 4 is 1845159665 compile time instruction count for iteration 5 is 1845035802 compile time instruction count for iteration 6 is 1844895007 compile time instruction count for iteration 7 is 1844697922 compile time instruction count for iteration 8 is 1844780885 compile time instruction count for iteration 9 is 1844493990 collecting compile time instruction count for basic_modules_SequentialWithDuplicatedModule_inductor compile time instruction count for iteration 0 is 1597839479 compile time instruction count for iteration 1 is 1348225351 compile time instruction count for iteration 2 is 1347340818 compile time instruction count for iteration 3 is 1348170800 compile time instruction count for iteration 4 is 1348637747 compile time instruction count for iteration 5 is 1678366444 compile time instruction count for iteration 6 is 1348412420 compile time instruction count for iteration 7 is 1348461578 compile time instruction count for iteration 8 is 1347420149 compile time instruction count for iteration 9 is 1349748195 collecting compile time instruction count for basic_modules_SequentialWithDuplicatedModule_eager compile time instruction count for iteration 0 is 137721777 compile time instruction count for iteration 1 is 139065517 compile time instruction count for iteration 2 is 137130552 compile time instruction count for iteration 3 is 137506030 compile time instruction count for iteration 4 is 137089838 compile time instruction count for iteration 5 is 137477395 compile time instruction count for iteration 6 is 138550452 compile time instruction count for iteration 7 is 137568409 compile time instruction count for iteration 8 is 136968468 compile time instruction count for iteration 9 is 137481664 collecting compile time instruction count for basic_modules_ModuleComparison_inductor compile time instruction count for iteration 0 is 917209684 compile time instruction count for iteration 1 is 899154426 compile time instruction count for iteration 2 is 898145079 compile time instruction count for iteration 3 is 899817018 compile time instruction count for iteration 4 is 899184687 compile time instruction count for iteration 5 is 898172885 compile time instruction count for iteration 6 is 899958951 compile time instruction count for iteration 7 is 899348186 compile time instruction count for iteration 8 is 897745404 compile time instruction count for iteration 9 is 899581123 collecting compile time instruction count for basic_modules_ModuleComparison_eager compile time instruction count for iteration 0 is 113165302 compile time instruction count for iteration 1 is 112724376 compile time instruction count for iteration 2 is 112774611 compile time instruction count for iteration 3 is 114465211 compile time instruction count for iteration 4 is 112689572 compile time instruction count for iteration 5 is 112726465 compile time instruction count for iteration 6 is 112853691 compile time instruction count for iteration 7 is 112295238 compile time instruction count for iteration 8 is 114022136 compile time instruction count for iteration 9 is 112664932 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134658 Approved by: https://github.com/anijain2305 ghstack dependencies: #133834, #134635, #134649, #134652 --- .../benchmarks/basic_modules_benchmarks.py | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py new file mode 100644 index 00000000000000..be545cdedd185e --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py @@ -0,0 +1,164 @@ +import sys + +from benchmark_base import BenchmarkBase + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch._inductor.utils import fresh_inductor_cache + + +class ListOfLinears(nn.Module): + def __init__(self): + super().__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + + +class BasicModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(10, 10) + self.scale = torch.randn(1, 10) + + def forward(self, x): + return F.relu(self.linear1(x)) * self.scale + + +class ModuleForwardHasGraphBreak(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer1 = BasicModule() + self.layer2 = BasicModule() + self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) + self.layer4 = torch.nn.ModuleList( + [ + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + ] + ) + self.layer5 = torch.nn.ModuleDict( + { + "0": torch.nn.Linear(10, 10), + } + ) + self.scale = torch.randn(1, 10) + + def forward(self, x): + """ + This is used to test if the results of functions like `named_parameters` + can be reconstructed correctly after graph break. + + https://github.com/pytorch/torchdynamo/issues/1931 + """ + x = self.layer1(x) + params1 = dict(self.named_parameters()) + params2 = list(self.parameters()) + buffers1 = dict(self.named_buffers()) + buffers2 = list(self.buffers()) + modules1 = dict(self.named_modules()) + modules2 = list(self.modules()) + torch._dynamo.graph_break() + y = modules2 + y = modules1 + y = buffers2 + y = buffers1 + y = params2 + y = params1 + x = ( + self.layer2(x) + + y["layer3.1.linear1.weight"] + + y["layer4.2.weight"] + + y["layer5.0.weight"] + ) + return x * self.scale + + +class SequentialWithDuplicatedModule(torch.nn.Module): + # Sequential module(self.layer) contains three duplicated ReLU module. + def __init__(self) -> None: + super().__init__() + self.relu = torch.nn.ReLU() + self.layer = torch.nn.Sequential( + torch.nn.Linear(10, 20), + self.relu, + torch.nn.Linear(20, 20), + self.relu, + torch.nn.Linear(20, 10), + self.relu, + ) + + def forward(self, x): + return self.layer(x) + + +class ModuleComparison(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer0 = torch.nn.Linear(10, 10) + self.layer1 = torch.nn.Linear(10, 10) + self.layer2 = torch.nn.Linear(10, 10) + + @property + def encoder_layers(self): + return [self.layer0, self.layer1, self.layer2] + + def forward(self, x): + for layer in self.encoder_layers: + output = layer(x) + if layer is None or layer == self.layer0: + output = F.relu6(output) + else: + output = F.relu(output) + return output + + +class Benchmark(BenchmarkBase): + def __init__(self, ModuleClass, backend): + self.ModuleClass = ModuleClass + self.backend = backend + self._name = ModuleClass.__name__ + + def name(self): + return f"basic_modules_{self._name}_{self.backend}" + + def _prepare_once(self): + self.m = self.ModuleClass() + self.input = torch.ones(10) + + def _prepare(self): + torch._dynamo.reset() + + def _work(self): + with fresh_inductor_cache(): + opt_m = torch.compile(backend=self.backend)(self.m) + opt_m(self.input) + + +def main(): + result_path = sys.argv[1] + benchmarks = [ + Benchmark(ListOfLinears, "inductor"), + Benchmark(ListOfLinears, "eager"), + Benchmark(ModuleForwardHasGraphBreak, "inductor"), + Benchmark(ModuleForwardHasGraphBreak, "eager"), + Benchmark(SequentialWithDuplicatedModule, "inductor"), + Benchmark(SequentialWithDuplicatedModule, "eager"), + Benchmark(ModuleComparison, "inductor"), + Benchmark(ModuleComparison, "eager"), + ] + for b in benchmarks: + b.enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) + + +if __name__ == "__main__": + main() From ef2af0cf450055d49840d8268a436ea12736f147 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Fri, 30 Aug 2024 03:16:03 +0000 Subject: [PATCH 0265/1018] [inductor] Install `tlparse` for test\dynamo\test_structured_trace.py UTs. (#134806) Install tlparse for test\dynamo\test_structured_trace.py UTs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134806 Approved by: https://github.com/ezyang --- .ci/pytorch/win-test.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index 39fcb65132f676..69f0c88aa30744 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -40,6 +40,9 @@ python -m pip install pytest-rerunfailures==10.3 pytest-cpp==2.3.0 tensorboard== # Install Z3 optional dependency for Windows builds. python -m pip install z3-solver==4.12.2.0 +# Install tlparse for test\dynamo\test_structured_trace.py UTs. +python -m pip install tlparse==0.3.25 + run_tests() { # Run nvidia-smi if available for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do From 573e63c2f83e704178610f9d5ffd65550fc62f13 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 30 Aug 2024 03:38:35 +0000 Subject: [PATCH 0266/1018] [ROCm] remove triton-rocm commit pin and merge pins with triton.txt (#133438) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133438 Approved by: https://github.com/jithunnair-amd, https://github.com/malfet Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> --- .ci/docker/centos-rocm/Dockerfile | 4 ++-- .ci/docker/ci_commit_pins/triton-rocm.txt | 1 - .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/common/install_triton.sh | 5 +---- .ci/docker/ubuntu-rocm/Dockerfile | 4 ++-- .circleci/scripts/binary_populate_env.sh | 2 +- .github/scripts/build_triton_wheel.py | 26 +---------------------- .github/workflows/build-triton-wheel.yml | 2 -- CODEOWNERS | 1 - 9 files changed, 8 insertions(+), 39 deletions(-) delete mode 100644 .ci/docker/ci_commit_pins/triton-rocm.txt diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index bfac9ddd859084..30ce1406e3f81c 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -108,10 +108,10 @@ ENV CMAKE_C_COMPILER cc ENV CMAKE_CXX_COMPILER c++ COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt +COPY ci_commit_pins/triton.txt triton.txt COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt # Install AOTriton (Early fail) COPY ./aotriton_version.txt aotriton_version.txt diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt deleted file mode 100644 index 0cb336acccb5b1..00000000000000 --- a/.ci/docker/ci_commit_pins/triton-rocm.txt +++ /dev/null @@ -1 +0,0 @@ -21eae954efa5bf584da70324b640288c3ee7aede diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 41c8d1602b6bef..378d625784fbc5 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -dedb7bdf339a3546896d4820366ca562c586bfa0 +757b6a61e7df814ba806f498f8bb3160f84b120c diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index d4a0dc80ee177e..6e5fb8839c686e 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -12,10 +12,7 @@ conda_reinstall() { as_jenkins conda install -q -n py_$ANACONDA_PYTHON_VERSION -y --force-reinstall $* } -if [ -n "${ROCM_VERSION}" ]; then - TRITON_REPO="https://github.com/openai/triton" - TRITON_TEXT_FILE="triton-rocm" -elif [ -n "${XPU_VERSION}" ]; then +if [ -n "${XPU_VERSION}" ]; then TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton" TRITON_TEXT_FILE="triton-xpu" else diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index ee9ede8ba611b6..07e25f533a71f9 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -100,10 +100,10 @@ ARG TRITON # try to reach out to S3, which docker build runners don't have access COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt +COPY ci_commit_pins/triton.txt triton.txt COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt # Install AOTriton COPY ./aotriton_version.txt aotriton_version.txt diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index e918635922a45e..106d0917ca68c5 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -90,7 +90,7 @@ fi if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then - TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt) + TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt) TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}" fi if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 7ee2cb4b8e61b5..7e70c49ff8b1bb 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -15,9 +15,7 @@ def read_triton_pin(device: str = "cuda") -> str: triton_file = "triton.txt" - if device == "rocm": - triton_file = "triton-rocm.txt" - elif device == "xpu": + if device == "xpu": triton_file = "triton-xpu.txt" with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / triton_file) as f: return f.read().strip() @@ -50,25 +48,6 @@ def patch_init_py( f.write(orig) -# TODO: remove patch_setup_py() once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 -def patch_setup_py(path: Path) -> None: - with open(path) as f: - orig = f.read() - try: - orig = check_and_replace( - orig, - "https://tritonlang.blob.core.windows.net/llvm-builds/", - "https://oaitriton.blob.core.windows.net/public/llvm-builds/", - ) - with open(path, "w") as f: - f.write(orig) - except RuntimeError as e: - print( - f"Applying patch_setup_py() for llvm-build package failed: {e}.", - "If you are trying to build a newer version of Triton, you can ignore this.", - ) - - def build_triton( *, version: str, @@ -110,9 +89,6 @@ def build_triton( else: check_call(["git", "checkout", commit_hash], cwd=triton_basedir) - # TODO: remove this and patch_setup_py() once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 - patch_setup_py(triton_pythondir / "setup.py") - if build_conda: with open(triton_basedir / "meta.yaml", "w") as meta: print( diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 01f06cdd286cf4..1aa777c1167c4f 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -13,7 +13,6 @@ on: - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton.txt - - .ci/docker/ci_commit_pins/triton-rocm.txt - .ci/docker/ci_commit_pins/triton-xpu.txt pull_request: paths: @@ -21,7 +20,6 @@ on: - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton.txt - - .ci/docker/ci_commit_pins/triton-rocm.txt - .ci/docker/ci_commit_pins/triton-xpu.txt concurrency: diff --git a/CODEOWNERS b/CODEOWNERS index 7b9db26104a9d5..bafce8f6f53521 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -57,7 +57,6 @@ nn/qat/ @jerryzh168 # Docker /.ci/docker/ @jeffdaily /.ci/docker/ci_commit_pins/triton.txt @desertfire @Chillee @eellison @shunting314 @bertmaher @jeffdaily @jataylo @jithunnair-amd @pruthvistony -/.ci/docker/ci_commit_pins/triton-rocm.txt @jeffdaily @jataylo @jithunnair-amd @pruthvistony /.ci/docker/ci_commit_pins/triton-xpu.txt @EikanWang @gujinghui # Github Actions From 278974d242a55b7d5fafb59381a0153905610f1a Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Fri, 30 Aug 2024 04:11:49 +0000 Subject: [PATCH 0267/1018] [executorch hash update] update the pinned executorch hash (#134736) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134736 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index 3c7d077a795db0..ddcee95fd3ea17 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -69472e5c43481324ad923ceb29392ab72830acee +7608ab8fa7c37f334fdaa9a946730fdbb1c97d34 From 190111cfd288fb84f7acea116acf14c55552671b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 29 Aug 2024 15:23:11 -0700 Subject: [PATCH 0268/1018] [justknobs] Override __bool__ method (#134799) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134799 Approved by: https://github.com/ezyang --- test/test_utils_internal.py | 5 +++++ torch/_utils_internal.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/test/test_utils_internal.py b/test/test_utils_internal.py index 155afe5f9f6cd2..9d0b4d4d57d342 100644 --- a/test/test_utils_internal.py +++ b/test/test_utils_internal.py @@ -39,6 +39,11 @@ def test_justknob_config(self): self.assertTrue(a.get()) a.set(False) self.assertTrue(a.get()) + with self.subTest("Checks __bool__"): + a = JustKnobsConfig(name="fake_name", default=False) + if a: + raise RuntimeError("Should not be true") + self.assertFalse(a) def test_justknob_feature(self): with self.subTest("OSS is True"): diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 32b9c7ff9fca74..877856243545a0 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -202,6 +202,9 @@ def __str__(self): v = bool(self) return f"JustknobsConfig(name={self.name}, env_name={self.env_name}, default={self.default} - evals_to={v})" + def __bool__(self): + return self.get() + def justknobs_feature( name: Optional[str], config_value=None, env_name=None, default: bool = True From 39276464fc442bd80b52c258a7d01ac92bb1cb62 Mon Sep 17 00:00:00 2001 From: Ankur Neog Date: Fri, 30 Aug 2024 05:02:45 +0000 Subject: [PATCH 0269/1018] Generalize devices specific UTs for dynamo (#130714) ## Motivation This is follow up to PR:https://github.com/pytorch/pytorch/pull/126970, adding facility to run content for Intel Gaudi devices. We intend to extend similar generalization for the rest of the content in test/dynamo which is currently being written to work specifically for cuda devices. Other devices can add onto it if support is available. ## Changes carve out bert related content to another class use instantiate_device_type utility to instantiate this class for devices which support the functionality Pull Request resolved: https://github.com/pytorch/pytorch/pull/130714 Approved by: https://github.com/anijain2305 --- test/dynamo/test_model_output.py | 58 ++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index b6d6fd050670c7..4dfa93bdc14911 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -6,6 +6,8 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import TEST_HPU, TestCase try: @@ -238,9 +240,34 @@ def fn(inp): self.assertEqual(fn(inp).attentions, opt_fn(inp).attentions) @maybe_skip - def test_HF_bert_model_output(self): - device = "cuda" if torch.cuda.is_available() else "cpu" + def test_none(self): + class Model(torch.nn.Module): + def forward(self, x): + x = x + 1 + return CausalLMOutputWithPast(loss=None, logits=x)[0] + + model = Model() + opt_model = torch.compile(model, backend="eager", fullgraph=True) + x = torch.randn(1, 1, 1, 1) + + self.assertTrue(same(model(x), opt_model(x))) + + @maybe_skip + def test_reconstruction(self): + class Model(torch.nn.Module): + def forward(self, x): + x = x + 1 + return CausalLMOutputWithPast(loss=x, logits=None) + + model = Model() + x = torch.randn(1, 1, 1, 1) + eo = torch._dynamo.export(Model(), aten_graph=True)(x) + self.assertTrue(same(model(x), eo.graph_module(x))) + +class TestModelOutputBert(TestCase): + @maybe_skip + def test_HF_bert_model_output(self, device): class BertPooler(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -316,31 +343,12 @@ def forward( torch.allclose(orig_result.pooler_output, compiled_result.pooler_output) ) - @maybe_skip - def test_none(self): - class Model(torch.nn.Module): - def forward(self, x): - x = x + 1 - return CausalLMOutputWithPast(loss=None, logits=x)[0] - - model = Model() - opt_model = torch.compile(model, backend="eager", fullgraph=True) - x = torch.randn(1, 1, 1, 1) - - self.assertTrue(same(model(x), opt_model(x))) - - @maybe_skip - def test_reconstruction(self): - class Model(torch.nn.Module): - def forward(self, x): - x = x + 1 - return CausalLMOutputWithPast(loss=x, logits=None) - model = Model() - x = torch.randn(1, 1, 1, 1) - eo = torch._dynamo.export(Model(), aten_graph=True)(x) - self.assertTrue(same(model(x), eo.graph_module(x))) +devices = ["cpu", "cuda"] +if TEST_HPU: + devices.append("hpu") +instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices) if __name__ == "__main__": from torch._dynamo.test_case import run_tests From c5c629f1506c67f69730cbe2b84f960c9ce40902 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 29 Aug 2024 20:21:57 -0700 Subject: [PATCH 0270/1018] [dynamo] Improve minifier error message when fp64 not supported (#134737) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134737 Approved by: https://github.com/anijain2305 --- torch/_dynamo/debug_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index cfec563729a480..30fe8346079d47 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -361,7 +361,9 @@ def same_two_models( fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) except Exception: if require_fp64: - raise RuntimeError("Could not generate fp64 outputs") # noqa: B904 + raise RuntimeError( # noqa: B904 + "Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False" + ) log.warning("Could not generate fp64 outputs") try: From 5cf09ba48b11c34f228bb8a3e87e4d6e5164baad Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 29 Aug 2024 20:37:32 -0700 Subject: [PATCH 0271/1018] [dynamo] mark_static_nn_module (#134713) Fixes issue seen in https://github.com/pytorch/pytorch/issues/132872#issuecomment-2314574656 With this API, we can mark the offending module as static in detectron2. Today's world - Consider user defined nn module int attributes automatic dynamic. Use the API in this PR to make them static if you want. Alternative work - Consider all int attributes of any user defined nn module class static. And then introduce an API - `torch._dynamo.mark_nn_module_attribute_dynamic`. The default being static is worrying if users have `counter` in their model which is updated in each forward invocation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134713 Approved by: https://github.com/jansel ghstack dependencies: #134653 --- test/dynamo/test_decorators.py | 27 ++++++++++++++++++++++++++ torch/_dynamo/decorators.py | 31 +++++++++++++++++++++++++++--- torch/_dynamo/variables/builder.py | 13 ++++++++++++- 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index b915e0633f7c11..0472702fadca61 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -597,6 +597,33 @@ def fn(x, y): self.assertEqual(fn(x, y), torch.compile(fn)(x, y)) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_mark_static_nn_module(self): + @torch._dynamo.mark_static + class Mock(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + def forward(self, x): + return x * self.c + + cnts = torch._dynamo.testing.CompileCounter() + mod1 = Mock(10) + mod2 = Mock(20) + mod3 = Mock(30) + opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True) + opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True) + opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True) + + x = torch.randn(4, 4) + opt_mod1(x) + opt_mod2(x) + opt_mod3(x) + + # Must be 3 compilations. If not marked static there would be 2, because self.c would be converted to symints. + self.assertEqual(cnts.frame_count, 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 4338e3047b8902..4e97422fa605e9 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -467,8 +467,10 @@ def maybe_mark_dynamic(t, index): def mark_static(t, index=None): """ - Mark a tensor as having a static dim. + Mark a tensor as having a static dim or mark a nn module class as static. + For tensors + =========== This will prevent us from attempting to compile it dynamically when dynamic=True; this can improve trace-time performance. @@ -476,6 +478,20 @@ def mark_static(t, index=None): Unlike mark_dynamic, this can be done inside a graph, in which case it induces specialization on the tensor. + + For nn.Module classes + ===================== + For static nn.Module classes, TorchDynamo assumes that the module instance + attributes will not be modified after compilation. This will ensure that + TorchDynamo keeps integer attributes CONSTANT and not symints. + + From TorchDynamo implementation side, the instances of static-marked + nn.Module class will be converted to UnspecializedBuiltinNNModuleVariable, + which have the same properties. + + Note that we still have to guard on the attributes, because different + instances of the nn.Module can have different values of the attributes. The + key point here is that the attributes are static. """ if is_compiling(): if index is None: @@ -490,11 +506,20 @@ def mark_static(t, index=None): # TODO: Make this configurable via a supported public API _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index) + if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module): + t._dynamo_marked_static = True + return t + + if not isinstance(t, torch.Tensor): + raise TypeError( + f"mark_static expects a tensor/nn.Module class but recieved {type(t)}" + ) + if isinstance(index, int): if not hasattr(t, "_dynamo_static_indices"): - t._dynamo_static_indices = set() + t._dynamo_static_indices = set() # type: ignore[attr-defined] # TODO(voz): Should we bounds check? - t._dynamo_static_indices.add(index) + t._dynamo_static_indices.add(index) # type: ignore[attr-defined] elif index is None: for i in range(t.dim()): mark_static(t, i) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index e0d355a7fb2b34..bdd84a3dcaebfa 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1355,7 +1355,9 @@ def wrap_module(self, value: torch.nn.Module): # this will get cleaned up once compile ends self.tx.output.nn_modules[self.name] = value - if value.__module__.startswith(("torch.nn.", "torch.ao.")): + if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr( + value.__class__, "_dynamo_marked_static", False + ): result = UnspecializedBuiltinNNModuleVariable(value, source=self.source) else: result = UnspecializedNNModuleVariable(value, source=self.source) @@ -1719,6 +1721,15 @@ def update_frame_state(value): value, frame_state_entry.scalar, ) + if self.source.guard_source().is_unspecialized_nn_module(): + log.info( + "%s", + ( + f"{name} is converted to a symbolic integer. It is an attribute of a " + "user defined nn module class. If you wish to keep it static, you can " + "mark the nn module class as `torch._dynamo.mark_static`." + ), + ) frame_state_entry.scalar = None self.tx.output.frame_state[name] = frame_state_entry From e7b18160fb32767bb7203e7e3dcf321095107d60 Mon Sep 17 00:00:00 2001 From: wz337 Date: Tue, 27 Aug 2024 13:35:02 -0700 Subject: [PATCH 0272/1018] [DeviceMesh][Test] Add a unit test for get_local_rank for flattened mesh (#134603) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134603 Approved by: https://github.com/fduwjj ghstack dependencies: #133838, #133839, #134048 --- test/distributed/test_device_mesh.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 08ee59dbff0a53..7107b104416166 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -136,6 +136,10 @@ def test_get_local_rank(self): self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp")) self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp")) + # Verify flattened mesh local rank correctness. + flattened_mesh = mesh_2d["dp", "tp"]._flatten() + self.assertEqual(flattened_mesh.get_local_rank(), self.rank) + @with_comms def test_device_mesh_2d(self): mesh_tensor = torch.arange(4).reshape(2, 2) From 98728e2d048183c0244a5291b174d423a09a2310 Mon Sep 17 00:00:00 2001 From: Joona Havukainen Date: Fri, 30 Aug 2024 14:08:43 +0000 Subject: [PATCH 0273/1018] Enable batch matmul for result sizes > 2**32 the tensor can be split along batch axis (#133430) Fixes #131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert. Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it: ``` import torch device='mps' a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device) b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device) res = torch.bmm(a, b) ``` Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit. Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133430 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .../native/mps/operations/LinearAlgebra.mm | 120 ++++++++++++++++++ test/test_mps.py | 14 ++ 2 files changed, 134 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index b1158865c4e688..e40454307ac97f 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -5,6 +5,7 @@ #include #include // For MTLLanguageVersion_3_1 +#include #include #include @@ -509,11 +510,123 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L return output; } +static Tensor& tiled_bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) { + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { + using namespace mps; + + id aBuffer = getMTLBufferStorage(batch1); + id bBuffer = getMTLBufferStorage(batch2); + id resBuffer = getMTLBufferStorage(result); + + MPSStream* mpsStream = getCurrentMPSStream(); + id device = MPSDevice::getInstance()->device(); + id computeEncoder = mpsStream->commandEncoder(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + mpsStream->endKernelCoalescing(); + + uint64_t originalBatchSize = batch1.sizes().size() > 2 ? batch1.size(0) : 1; + uint64_t aRows = batch1.size(-2); + uint64_t bRows = batch2.size(-2); + uint64_t resRows = result.size(-2); + uint64_t aCols = batch1.size(-1); + uint64_t bCols = batch2.size(-1); + uint64_t resCols = result.size(-1); + uint64_t aElemSize = batch1.element_size(); + uint64_t bElemSize = batch2.element_size(); + uint64_t resElemSize = result.element_size(); + MPSDataType dtype = getMPSDataType(batch1); + + uint64_t elemInMatrix = resRows * resCols; + uint64_t largestSupportedBatchSize = floor(pow(2, 32) / elemInMatrix); + uint64_t batchSize = std::min(largestSupportedBatchSize, originalBatchSize); + uint64_t lastBatchSize = originalBatchSize % batchSize; + + id commandBuffer = mpsStream->commandBuffer(); + + auto matmul = [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2]; + + MPSShape* aShape = @[ @(batchSize), @(aRows), @(aCols) ]; + MPSShape* bShape = @[ @(batchSize), @(bRows), @(bCols) ]; + MPSShape* resShape = @[ @(batchSize), @(resRows), @(resCols) ]; + auto aDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:aShape]; + aDesc_.preferPackedRows = true; + auto bDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:bShape]; + bDesc_.preferPackedRows = true; + + auto resDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:resShape]; + resDesc_.preferPackedRows = true; + + getMPSProfiler().beginProfileKernel(matmul, " tiled_bmm_mps", {batch1, batch2}); + + // Descriptors to use for last batch if it exists + //.matrices is a readonly property so we need a separate descriptor. + MPSNDArrayDescriptor *aDescLastBatch_, *bDescLastBatch_, *resDescLastBatch_; + if (lastBatchSize != 0) { + aDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype + shape:@[ @(lastBatchSize), @(aRows), @(aCols) ]]; + aDescLastBatch_.preferPackedRows = true; + bDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype + shape:@[ @(lastBatchSize), @(bRows), @(bCols) ]]; + bDescLastBatch_.preferPackedRows = true; + resDescLastBatch_ = + [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(lastBatchSize), @(resRows), @(resCols) ]]; + resDescLastBatch_.preferPackedRows = true; + } + + uint64_t requiredIterations = ceil(float(originalBatchSize) / batchSize); + auto aDesc = aDesc_; + auto bDesc = bDesc_; + auto resDesc = resDesc_; + for (const auto i : c10::irange(requiredIterations)) { + if (i == requiredIterations - 1 && lastBatchSize != 0) { + aDesc = aDescLastBatch_; + bDesc = bDescLastBatch_; + resDesc = resDescLastBatch_; + } + const uint64_t aArrayOffset = i * batchSize * aRows * aCols; + const uint64_t bArrayOffset = i * batchSize * bRows * bCols; + const uint64_t resArrayOffset = i * batchSize * resRows * resCols; + + auto aMatrix = [[[MPSNDArray alloc] initWithBuffer:aBuffer + offset:(batch1.storage_offset() + aArrayOffset) * aElemSize + descriptor:aDesc] autorelease]; + auto bMatrix = [[[MPSNDArray alloc] initWithBuffer:bBuffer + offset:(batch2.storage_offset() + bArrayOffset) * bElemSize + descriptor:bDesc] autorelease]; + auto resMatrix = [[[MPSNDArray alloc] initWithBuffer:resBuffer + offset:(result.storage_offset() + resArrayOffset) * resElemSize + descriptor:resDesc] autorelease]; + + [matmul encodeToCommandEncoder:computeEncoder + commandBuffer:commandBuffer + sourceArrays:@[ aMatrix, bMatrix ] + destinationArray:resMatrix]; + } + } + }); + return result; + } else { + TORCH_CHECK(false, "Tiling of batch matmul for larger than 2**32 entries only available from MacOS15 onwards"); + } +} + static Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) { using namespace mps; TORCH_CHECK(supportedFloatingOrComplexType(batch1), "MPS device does not support bmm for non-float inputs"); + // Currently unsupported if the matmul output goes over the 32-bit indexing limit + TORCH_CHECK( + batch1.size(1) * batch2.size(2) <= pow(2, 32), + "Output size of the matrix multiplication is larger than currently supported by the MPS backend: ", + batch1.size(1), + ",", + batch2.size(2), + ", needs to be less than 2**32 elements.", + "File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues"); + if (batch1.numel() == 0 || batch2.numel() == 0) { result.zero_(); return result; @@ -543,6 +656,13 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L } } + // Check if we need to split the batch to do the computation + uint64_t resultSize = batch1.size(0) * batch1.size(1) * batch2.size(2); + if (resultSize > pow(2, 32)) { + result = tiled_bmm_out_mps_impl(batch1, batch2, result); + return result; + } + MPSStream* stream = getCurrentMPSStream(); struct CachedGraph : public mps::MPSCachedGraph { diff --git a/test/test_mps.py b/test/test_mps.py index 281c68b42cae14..f7f36e57c82adb 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1914,6 +1914,20 @@ def test_bmm(self): self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + @xfailIf(product_version < 15.0) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_large_bmm(self, dtype): + batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps') + batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps') + output_cpu = torch.bmm(batch1.cpu(), batch2.cpu()) + output_mps = torch.bmm(batch1, batch2) + + # Using the low precision comparison for FP16 + tol = 1e-2 if dtype == torch.float16 else None + self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol) + self.assertEqual(output_cpu.size(), output_mps.size()) + + def test_addr(self): A = torch.ones(5, 10).to("mps") B = torch.ones(5).to("mps") From c613f5475ee80e58e609c8ca2885863df5ce22ba Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 30 Aug 2024 14:08:53 +0000 Subject: [PATCH 0274/1018] [CUDA][P2P] Check device capability in `requires_cuda_p2p_access` (#134523) Tests seem to fail on e.g., Volta without this given the compile time meacros used e.g., in https://github.com/pytorch/pytorch/blob/79b7fff18820e4f62a02534a961d5930040e3475/torch/csrc/distributed/c10d/intra_node_comm.cu#L487 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134523 Approved by: https://github.com/yifuwang, https://github.com/Skylion007 --- test/distributed/test_symmetric_memory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index c27ad2f10f9ed0..ec6aa7f903bf7b 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -28,7 +28,9 @@ def requires_cuda_p2p_access(): cuda_p2p_access_available = ( - torch.cuda.is_available() and torch.cuda.device_count() >= 2 + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) + and torch.cuda.device_count() >= 2 ) num_devices = torch.cuda.device_count() for i in range(num_devices - 1): From a0d55b87bbcc44c37576fc419565d1aa959cc5aa Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Thu, 29 Aug 2024 11:25:35 -0700 Subject: [PATCH 0275/1018] make make_fx collective test single threaded (#134775) make_fx is not thread-safe due to mutating and patching global states. It's difficult and low roi to make it thread-safe so just turn the tracing test into a single-thread test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134775 Approved by: https://github.com/yifuwang --- test/distributed/test_functional_api.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index 2876bdd275a54b..c4972a46405862 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -384,20 +384,23 @@ def test_all_reduce(self): self.assertIsNone(x.grad) -class TestMakeFx(MultiThreadedTestCase): - @property - def world_size(self): - return 2 - +class TestMakeFx(TestCase): def setUp(self): - super().setUp() - self._spawn_threads() + # make_fx is not thread-safe due to patching nd mutating global states + # so create a fake_pg. + self.rank = 0 + self.world_size = 2 + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) def tearDown(self): super().tearDown() - # race condition with threads causes is_fx_tracing flag to be set incorrectly. - torch.fx._symbolic_trace._is_fx_tracing_flag = False self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing()) def test_all_reduce_tracing(self): From 965df527320144dc25976c6dd1877137d8346b23 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Thu, 29 Aug 2024 11:25:36 -0700 Subject: [PATCH 0276/1018] [HOP] fix export x inline_inbuilt_nn_modules (#133731) TLDR; this PR supports exporting cond x inine_inbuilt nn modules flag by inling into tracing code in proxy_tensor.py _symbolic_trace.py (internally, the pattern is make_fx(record_module_stack)(torch.compile(f))). We have two special treatments for following cases: 1. _ModuleStackTracer will wrap all the nn modules into _AttrProxy. This _AttrProxy has several subtiles which make it hard to inline in dynamo like overriding _modules with a property method and overrides the `__getattr__`, which mutates captured states when calling `__getattr__`. Solution to this is that we unwrap the _AttrProxy and get its corresponding nn_module (a 1-1 correspondence). So that dynamo symbolically traces the original nn module instead of tracing _AttrProxy. 2. The tracer applies a bunch of patches the `__getattr__` and `__call__` of nn.Module for tracking reasons. This doesn't work well with dynamo. The immediate error we see is `torch._dynamo.exc.Unsupported: 'inline in skipfiles: WeakKeyDictionary.__contains__ | __contains__ /home/yidi/.conda/envs/pytorch/lib/python3.10/weakref.py` caused by a weakdict in PythonKeyTracer. Solution to this is that we remove the patches during dynamo symbolic convert temporally. So that dynamo has a clean environment. make_fx will be trace the transformed bytecode of dynamo and patches nn modules there instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133731 Approved by: https://github.com/anijain2305 ghstack dependencies: #134775 --- test/export/test_export.py | 1 - torch/_dynamo/convert_frame.py | 7 +++ torch/_dynamo/guards.py | 9 ++++ torch/_dynamo/source.py | 11 ++++ torch/_dynamo/trace_rules.py | 2 + torch/_dynamo/variables/builder.py | 8 +++ torch/_higher_order_ops/utils.py | 8 --- torch/fx/_symbolic_trace.py | 73 ++++++++++++++++++++++++--- torch/fx/experimental/proxy_tensor.py | 3 ++ 9 files changed, 106 insertions(+), 16 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 4eb5cc9fa3a55b..a7a9b2c5a31c8e 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -5566,7 +5566,6 @@ def forward(self, x): ) ) - # Guard validation upsets the guard def test_cond_with_module_stack_export_with(self): class Bar(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index cddc89695d4fee..36c889a17d7cfc 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -2,6 +2,7 @@ from __future__ import annotations import collections +import contextlib import cProfile import dis import functools @@ -206,10 +207,16 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() + + exit_stack = contextlib.ExitStack() + exit_stack.enter_context( + torch.fx._symbolic_trace._maybe_revert_all_patches() + ) try: return fn(*args, **kwargs) finally: cleanup.close() + exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) torch.use_deterministic_algorithms( diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index d9cc7eb6d44d16..fccea9cbd14c27 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -76,6 +76,7 @@ from . import config, convert_frame, exc, mutation_guard from .eval_frame import set_guard_error_hook from .source import ( + AttrProxySource, AttrSource, ChainedSource, ConstDictKeySource, @@ -1064,6 +1065,14 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, AttrProxySource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.get_base(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, TupleIteratorGetItemSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.tuple_iterator_getitem_manager( diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 7a3ad7a9b05ca8..2d3a4424167da9 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -382,6 +382,17 @@ def name(self): return f"{self.base.name()}._type().qualified_name()" +class AttrProxySource(ChainedSource): + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}.get_base()" + + @dataclasses.dataclass(frozen=True) class DefaultsSource(ChainedSource): idx_key: Union[int, str] diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index e8692fe47891a3..5535273b7dc2ea 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3240,6 +3240,8 @@ def _module_dir(m: types.ModuleType): "torch.distributions", "torch.export._tree_utils", "torch.fx._pytree", + "torch.fx._symbolic_trace", + "torch.fx.experimental.proxy_tensor", "torch.fx.passes.shape_prop", "torch.nn", "torch.overrides", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index bdd84a3dcaebfa..b21fa54369cd32 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -58,6 +58,7 @@ from ..guards import GuardBuilder, install_guard, make_dupe_guard from ..side_effects import SideEffects from ..source import ( + AttrProxySource, AttrSource, CallMethodItemSource, ConstantSource, @@ -1340,6 +1341,13 @@ def wrap_module(self, value: torch.nn.Module): return self.tx.output.side_effects.track_object_existing(value, result) elif mutation_guard.is_dynamic_nn_module(value, self.tx.export): # created dynamically, don't specialize on it + + # Note [Tracing a torch.compiled function] + # when make_fx tracing a compiled function, we need + if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy): + value = value.get_base() + self.source = AttrProxySource(self.source) + self.install_guards(GuardBuilder.TYPE_MATCH) if torch._dynamo.config.inline_inbuilt_nn_modules: freezing = is_parameter_freezing() diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index b0729b25ab64e5..3b88d5f7dbac8a 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -105,21 +105,13 @@ def _maybe_reenter_make_fx(fn): @contextmanager def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag - _old_is_inlining = torch._dynamo.config.inline_inbuilt_nn_modules try: # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo # once we are confident fx tracing works with dynamo. torch.fx._symbolic_trace._is_fx_tracing_flag = False - - # TODO(anijain2305, export-team) For non-strict export with module - # stack info, the codepatch forces the nn module __getattr__ to - # ProxyAttr __getattr__ downstream. To circumvent the issue for now, - # skip inlining inbuilt nn modules for cond. - torch._dynamo.config.inline_inbuilt_nn_modules = False yield finally: torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing - torch._dynamo.config.inline_inbuilt_nn_modules = _old_is_inlining def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index bdbd4b04429d31..6693863386513c 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import builtins import copy +import contextlib import functools import inspect import math @@ -793,13 +794,13 @@ def forward(*args, **kwargs): return _orig_module_call(mod, *args, **kwargs) _autowrap_check( - patcher, + patcher, # type: ignore[has-type] getattr(getattr(mod, "forward", mod), "__globals__", {}), self._autowrap_function_ids, ) return self.call_module(mod, forward, args, kwargs) - with _Patcher() as patcher: + with _new_patcher() as patcher: # allow duplicate patches to support the case of nested calls patcher.patch_method( torch.nn.Module, @@ -990,25 +991,36 @@ class _PatchedFn(NamedTuple): frame_dict: Any fn_name: str orig_fn: Any + new_fn: Any def revert(self): raise NotImplementedError + def patch(self): + raise NotImplementedError + class _PatchedFnSetItem(_PatchedFn): def revert(self): self.frame_dict[self.fn_name] = self.orig_fn + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn class _PatchedFnDel(_PatchedFn): def revert(self): del self.frame_dict[self.fn_name] + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + class _PatchedFnSetAttr(_PatchedFn): def revert(self): setattr(self.frame_dict, self.fn_name, self.orig_fn) + def patch(self): + setattr(self.frame_dict, self.fn_name, self.new_fn) class _Patcher: def __init__(self) -> None: @@ -1028,14 +1040,15 @@ def patch( """ new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] if name not in frame_dict and hasattr(builtins, name): - self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) + self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn)) + self.patches_made[-1].patch() elif getattr(frame_dict[name], "__fx_already_patched", False): return # already patched, no need to do it again else: self.patches_made.append( - _PatchedFnSetItem(frame_dict, name, frame_dict[name]) + _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn) ) - frame_dict[name] = new_fn + self.patches_made[-1].patch() def patch_method( self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True @@ -1047,8 +1060,8 @@ def patch_method( orig_fn = getattr(cls, name) if getattr(orig_fn, "__fx_already_patched", False): return # already patched, no need to do it again - self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) - setattr(cls, name, new_fn) + self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn)) + self.patches_made[-1].patch() def visit_once(self, thing: Any): """Return True on the first call to with thing, otherwise false""" @@ -1058,6 +1071,22 @@ def visit_once(self, thing: Any): self.visited.add(idx) return True + def revert_all_patches(self): + """ + Remove all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.revert() + return self.patches_made + + def reapply_all_patches(self): + """ + Patch all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.patch() + return self.patches_made + def __enter__(self): return self @@ -1071,6 +1100,36 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.visited.clear() +CURRENT_PATCHER: Optional[_Patcher] = None + +@contextlib.contextmanager +def _new_patcher(): + global CURRENT_PATCHER + prior_patcher = CURRENT_PATCHER + try: + CURRENT_PATCHER = _Patcher() + yield CURRENT_PATCHER + finally: + # Clear all the patches made by when using current patcher. + assert CURRENT_PATCHER is not None + CURRENT_PATCHER.revert_all_patches() + CURRENT_PATCHER = prior_patcher + + +@contextlib.contextmanager +def _maybe_revert_all_patches(): + current_patcher = CURRENT_PATCHER + patches_made = None + patches_removed = None + try: + if current_patcher is not None: + patches_removed = current_patcher.revert_all_patches() + yield + finally: + if current_patcher is not None: + patches_made = current_patcher.reapply_all_patches() + assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches" + def _patch_wrapped_functions(patcher: _Patcher): """ Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index ab9ef8c0c7a73c..7a3ece886506c3 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1566,6 +1566,9 @@ def __getattr__(self, name: str) -> AttrProxy: ) return tracer.attr_proxy_map[attr_val] + def get_base(self) -> Module: + return tracer.proxy_modules[self] + @property def _modules(self) -> Dict[str, AttrProxy]: assert "_modules" in self.__dict__ From 18e7b23fd640347ed53ce9e6e3e3dee7d88b4fa0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 30 Aug 2024 16:06:09 +0000 Subject: [PATCH 0277/1018] Revert "[dynamo][itertools] refactor `itertools.islice` to use polyfill (#133876)" This reverts commit 7d12e6dceb94a221288f21c0e79ce8ca667d657a. Reverted https://github.com/pytorch/pytorch/pull/133876 on behalf of https://github.com/ZainRizvi due to This is still failing internally with the same error about 'Graph break due to unsupported builtin _functools.reduce' ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2321787968)) --- torch/_dynamo/polyfills/__init__.py | 27 ++++++++++++++++++++++--- torch/_dynamo/polyfills/itertools.py | 30 ---------------------------- torch/_dynamo/trace_rules.py | 2 ++ torch/_dynamo/variables/builtin.py | 9 +++++++++ torch/_dynamo/variables/iter.py | 6 ++++++ 5 files changed, 41 insertions(+), 33 deletions(-) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 69145ce86d8c0d..9b465726754bbe 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -14,15 +14,36 @@ def index(iterator, item, start=0, end=None): - import itertools - - for i, elem in itertools.islice(enumerate(iterator), start, end): + for i, elem in islice(enumerate(iterator), start, end): if item == elem: return i # This will not run in dynamo raise ValueError(f"{item} is not in {type(iterator)}") +def islice(iterator, start=0, end=None, step=1): + if start < 0 or (end is not None and end < 0) or step < 0: + raise ValueError("Indices must be non-negative") + if step == 0: + raise ValueError("Step cannot be 0") + + it = iter(iterator) + + for _ in range(start): + next(it) + + if end is None: + for i, element in enumerate(it): + if i % step == 0: + yield element + else: + for i, element in enumerate(it): + if i % step == 0 and i + start < end - start: + yield element + elif i + start >= end - start: + break + + def repeat(item, count): for i in range(count): yield item diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index a829d2e9b88081..802a62a82c8cdd 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -13,7 +13,6 @@ __all__ = [ "chain", "chain_from_iterable", - "islice", "tee", ] @@ -36,35 +35,6 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: chain.from_iterable = chain_from_iterable # type: ignore[method-assign] -# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice -@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] -def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: - s = slice(*args) - start = 0 if s.start is None else s.start - stop = s.stop - step = 1 if s.step is None else s.step - if start < 0 or (stop is not None and stop < 0) or step <= 0: - raise ValueError( - "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", - ) - - if stop is None: - # TODO: use indices = itertools.count() and merge implementation with the else branch - # when we support infinite iterators - next_i = start - for i, element in enumerate(iterable): - if i == next_i: - yield element - next_i += step - else: - indices = range(max(start, stop)) - next_i = start - for i, element in zip(indices, iterable): - if i == next_i: - yield element - next_i += step - - # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 5535273b7dc2ea..29ad4847f27d28 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -12,6 +12,7 @@ import functools import importlib import inspect +import itertools import linecache import logging import multiprocessing @@ -2992,6 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) + rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) rv.update( { id(cast): "typing.cast", diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 919daf6fbd2b59..279981465b2713 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1873,6 +1873,15 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg ) return variables.ListVariable(items) + def call_islice(self, tx: "InstructionTranslator", iterable, *args): + if iterable.has_unpack_var_sequence(tx) and all( + x.is_python_constant() for x in args + ): + const_args = [x.as_python_constant() for x in args] + items = iterable.unpack_var_sequence(tx) + items = list(itertools.islice(items, *const_args)) + return variables.TupleVariable(items) + # neg is a constant fold function, so we only get here if constant fold is not valid def call_neg(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 34c2354026a7bd..ed3ae786634e92 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -166,6 +166,12 @@ def retrieve_const_key(key): from_exc=e, ) return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) + elif self.value is itertools.islice: + from .builder import SourcelessBuilder + + return tx.inline_user_function_return( + SourcelessBuilder.create(tx, polyfills.islice), args, kwargs + ) elif self.value is itertools.repeat: if len(args) < 2: return variables.RepeatIteratorVariable( From bf0763235ace710a93ea32ef9caf783b058d4056 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 30 Aug 2024 16:06:09 +0000 Subject: [PATCH 0278/1018] Revert "[dynamo] refactor `builtins.enumerate` to use polyfill (#133894)" This reverts commit a2566adfb6064235db6d950568435fb6ef58a598. Reverted https://github.com/pytorch/pytorch/pull/133894 on behalf of https://github.com/ZainRizvi due to This is still failing internally with the same error about 'Graph break due to unsupported builtin _functools.reduce' ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2321787968)) --- torch/_dynamo/polyfills/builtins.py | 15 --------------- torch/_dynamo/variables/builtin.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index 62305086b804fe..d5ce75c562bfb6 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -2,8 +2,6 @@ Python polyfills for builtins """ -from __future__ import annotations - import builtins import functools import operator @@ -15,7 +13,6 @@ __all__ = [ "all", "any", - "enumerate", "sum", ] @@ -39,18 +36,6 @@ def any(iterable: Iterable[object], /) -> bool: return False -@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type] -def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]: - if not isinstance(start, int): - raise TypeError( - f"{type(start).__name__!r} object cannot be interpreted as an integer" - ) - - for x in iterable: - yield start, x - start += 1 - - @substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] return functools.reduce(operator.add, iterable, start) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 279981465b2713..f01b4b6efbe596 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1442,6 +1442,22 @@ def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] return variables.TupleVariable(items) + def call_enumerate(self, tx: "InstructionTranslator", *args): + if len(args) == 1: + start = 0 + else: + assert len(args) == 2 + assert isinstance(args[1], variables.ConstantVariable) + start = args[1].as_python_constant() + if args[0].has_unpack_var_sequence(tx): + items = [ + variables.TupleVariable( + [variables.ConstantVariable.create(idx), var], + ) + for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) + ] + return variables.TupleVariable(items) + def call_len(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) From 10f990933d4da37388b455ced11b6f0c77496ed5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 30 Aug 2024 16:06:10 +0000 Subject: [PATCH 0279/1018] Revert "[dynamo][itertools] refactor `itertools.chain` and `itertools.chain.from_iterable` to use polyfills (#133864)" This reverts commit 1b703669576223024eb84a76c53b7ec5ed8bb270. Reverted https://github.com/pytorch/pytorch/pull/133864 on behalf of https://github.com/ZainRizvi due to This is still failing internally with the same error about 'Graph break due to unsupported builtin _functools.reduce' ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2321787968)) --- torch/_dynamo/decorators.py | 16 +--------------- torch/_dynamo/polyfills/itertools.py | 21 +-------------------- torch/_dynamo/polyfills/loader.py | 6 ------ torch/_dynamo/trace_rules.py | 8 ++++++-- torch/_dynamo/variables/builder.py | 28 ++-------------------------- torch/_dynamo/variables/builtin.py | 16 ++++++++++++++++ torch/_dynamo/variables/functions.py | 22 ---------------------- torch/_dynamo/variables/iter.py | 8 ++++++++ 8 files changed, 34 insertions(+), 91 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 4e97422fa605e9..e73aecad34ce3a 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -170,8 +170,6 @@ def substitute_in_graph( *, can_constant_fold_through: bool = False, skip_signature_check: bool = False, - # type that is embedded in the Python interpreter - is_embedded_type: bool = False, # internal use only ) -> Callable[[_F], _F]: """ Register a polyfill handler for a function, usually a C function from the C extension, to be @@ -221,22 +219,10 @@ def substitute_in_graph( >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) 2 """ - if not is_function(original_fn) and not ( - is_embedded_type and inspect.isclass(original_fn) - ): + if not is_function(original_fn): raise TypeError( f"substitute_in_graph expects a function but got {type(original_fn)!r}" ) - if is_embedded_type: - if not inspect.isclass(original_fn): - raise TypeError( - f"substitute_in_graph expects a class but got {type(original_fn)!r}" - ) - - from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS - - if id(original_fn) in ITERTOOLS_TYPE_IDS: - ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) def wrapper(traceable_fn: _F) -> _F: if not is_function(traceable_fn): diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 802a62a82c8cdd..090df0f84b5281 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -10,31 +10,12 @@ from ..decorators import substitute_in_graph -__all__ = [ - "chain", - "chain_from_iterable", - "tee", -] +__all__ = ["tee"] _T = TypeVar("_T") -# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain -@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type] -def chain(*iterables: Iterable[_T]) -> Iterator[_T]: - for iterable in iterables: - yield from iterable - - -@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] -def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: - return itertools.chain(*iterable) - - -chain.from_iterable = chain_from_iterable # type: ignore[method-assign] - - # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index d9367b55bf1400..255486b6e016dd 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -32,9 +32,3 @@ polyfill_handler = getattr(polyfill_module, polyfill_name) original_fn = polyfill_handler.__torch_dynamo_original__ trace_rules._builtin_function_ids.remove(id(original_fn)) - - # Unregister the class object if the original function is its __new__ method - if original_fn.__name__ == "__new__" and isinstance( - getattr(original_fn, "__self__", None), type - ): - trace_rules._builtin_function_ids.remove(id(original_fn.__self__)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 29ad4847f27d28..f42c4653093cd7 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2993,7 +2993,9 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) + rv.update( + {id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)} + ) rv.update( { id(cast): "typing.cast", @@ -3471,7 +3473,9 @@ def check_verbose(obj, is_inlined_call=False): # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: Set[str] = set() - rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) + rule = torch._dynamo.trace_rules.lookup_inner( + fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons + ) if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): return SkipResult( False, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index b21fa54369cd32..ea71db1e5c0992 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -17,17 +17,7 @@ import types import warnings import weakref -from typing import ( - Any, - FrozenSet, - List, - MutableMapping, - NamedTuple, - Optional, - Set, - TYPE_CHECKING, - Union, -) +from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union import torch from torch import SymInt @@ -330,17 +320,6 @@ class FrameStateSizeEntry: stride: Optional[List[int]] -# All class-based iterators in itertools -# NOTE: use id() because some objects are not hashable, it will raise error during lookup -ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset( - id(member) - for name, member in vars(itertools).items() - if not name.startswith("_") and inspect.isclass(member) -) -# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py -ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set() - - class VariableBuilder: """Wrap a python value in a VariableTracker() instance""" @@ -896,10 +875,7 @@ def build_key_value(i, k, v): value, source=self.source, ) - elif ( - id(value) in ITERTOOLS_TYPE_IDS - and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS - ): + elif istype(value, type) and value in itertools.__dict__.values(): self.install_guards(GuardBuilder.FUNCTION_MATCH) return ItertoolsVariable(value, source=self.source) elif isinstance(value, torch.SymBool): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index f01b4b6efbe596..ca972b282e387b 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -994,6 +994,15 @@ def call_method( ) if self.fn is dict and name == "fromkeys": return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) + if self.fn is itertools.chain and name == "from_iterable": + assert len(args) == 1 + assert len(kwargs) == 0 + obj = args[0] + items = [] + for item in obj.unpack_var_sequence(tx): + items.extend(item.unpack_var_sequence(tx)) + return variables.TupleVariable(items) + return super().call_method(tx, name, args, kwargs) def _call_int_float(self, tx: "InstructionTranslator", arg): @@ -1889,6 +1898,13 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg ) return variables.ListVariable(items) + def call_chain(self, tx: "InstructionTranslator", *args): + if all(obj.has_unpack_var_sequence(tx) for obj in args): + items = [] + for obj in args: + items.extend(obj.unpack_var_sequence(tx)) + return variables.TupleVariable(items) + def call_islice(self, tx: "InstructionTranslator", iterable, *args): if iterable.has_unpack_var_sequence(tx) and all( x.is_python_constant() for x in args diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 239b720a5a7530..a9ec80cc3de589 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -18,7 +18,6 @@ check_constant_args, check_unspec_or_constant_args, identity, - is_function, is_wrapper_or_member_descriptor, istype, make_cell, @@ -993,27 +992,6 @@ def call_function( handler, ).call_function(tx, args, kwargs) - return super().call_function(tx, args, kwargs) - - def call_method( - self, - tx, - name, - args: "List[VariableTracker]", - kwargs: "Dict[str, VariableTracker]", - ) -> "VariableTracker": - if name == "__call__": - return self.call_function(tx, args, kwargs) - - method = getattr(self.fn, name, None) - assert method is not None, f"Member {name} not found in {self.fn}" - assert is_function(method), f"Member {name} is not callable in {self.fn}" - options = {} - if self.source: - options["source"] = AttrSource(self.source, name) - member_variable = PolyfilledFunctionVariable(method, **options) - return member_variable.call_function(tx, args, kwargs) - def as_python_constant(self): return self.fn diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index ed3ae786634e92..6687611cf0aabd 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -52,6 +52,14 @@ def call_function( for item in itertools.product(*seqs): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) + elif ( + self.value is itertools.chain + and not kwargs + and all(arg.has_unpack_var_sequence(tx) for arg in args) + ): + seqs = [arg.unpack_var_sequence(tx) for arg in args] + items = list(itertools.chain.from_iterable(seqs)) + return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif self.value is itertools.accumulate: from .builtin import BuiltinVariable From ce9595bdf28a716410ccd2110956892700b3a1c6 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 30 Aug 2024 16:06:10 +0000 Subject: [PATCH 0280/1018] Revert "[dynamo] simplify implementation for `builtins.sum` (#133779)" This reverts commit eaa449fbf0fe528a0827ee9b5bcfcd307a7c658d. Reverted https://github.com/pytorch/pytorch/pull/133779 on behalf of https://github.com/ZainRizvi due to This is still failing internally with the same error about 'Graph break due to unsupported builtin _functools.reduce' ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2321787968)) --- torch/_dynamo/polyfills/builtins.py | 13 +---------- torch/_dynamo/variables/builtin.py | 34 +++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index d5ce75c562bfb6..d2f563ca6b883f 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -3,9 +3,7 @@ """ import builtins -import functools -import operator -from typing import Iterable, TypeVar +from typing import Iterable from ..decorators import substitute_in_graph @@ -13,13 +11,9 @@ __all__ = [ "all", "any", - "sum", ] -_T = TypeVar("_T") - - @substitute_in_graph(builtins.all, can_constant_fold_through=True) def all(iterable: Iterable[object], /) -> bool: for elem in iterable: @@ -34,8 +28,3 @@ def any(iterable: Iterable[object], /) -> bool: if elem: return True return False - - -@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] -def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] - return functools.reduce(operator.add, iterable, start) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index ca972b282e387b..b4fbbbc794df1d 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1585,6 +1585,40 @@ def call_filter(self, tx: "InstructionTranslator", fn, seq): except NotImplementedError: return + def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): + # Special case for sum on tuple of floats and ints + if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( + isinstance(x, variables.ConstantVariable) + and isinstance(x.value, (int, float)) + for x in seq.items + ): + if start is self._SENTINEL: + return variables.ConstantVariable.create( + sum(x.value for x in seq.items), + ) + if isinstance(start, variables.ConstantVariable) and isinstance( + start.value, (int, float) + ): + return variables.ConstantVariable.create( + sum((x.value for x in seq.items), start=start.value), + ) + if seq.has_unpack_var_sequence(tx): + if start is self._SENTINEL: + start = variables.ConstantVariable.create(0) + items = seq.unpack_var_sequence(tx) + + from .builder import SourcelessBuilder + + return SourcelessBuilder.create(tx, functools.reduce).call_function( + tx, + [ + BuiltinVariable(operator.add), + variables.TupleVariable(items), + start, + ], + {}, + ) + def call_getattr( self, tx: "InstructionTranslator", From 2b8871b8c661d9ad39d7dd8ec79a9e6739a45c49 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 30 Aug 2024 16:06:10 +0000 Subject: [PATCH 0281/1018] Revert "[dynamo] simplify implementation for `functools.reduce` (#133778)" This reverts commit b5f1ffa7ab0988184497788f2738e1769888ab7d. Reverted https://github.com/pytorch/pytorch/pull/133778 on behalf of https://github.com/ZainRizvi due to This is still failing internally with the same error about 'Graph break due to unsupported builtin _functools.reduce' ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2321787968)) --- torch/_dynamo/polyfills/functools.py | 42 +--------------------------- torch/_dynamo/trace_rules.py | 1 + torch/_dynamo/variables/builtin.py | 18 +++++++++--- 3 files changed, 16 insertions(+), 45 deletions(-) diff --git a/torch/_dynamo/polyfills/functools.py b/torch/_dynamo/polyfills/functools.py index 329725b41e354a..98fd9597773b35 100644 --- a/torch/_dynamo/polyfills/functools.py +++ b/torch/_dynamo/polyfills/functools.py @@ -2,45 +2,5 @@ Python polyfills for functools """ -import functools -from typing import Callable, Iterable, TypeVar -from ..decorators import substitute_in_graph - - -__all__ = ["reduce"] - - -_T = TypeVar("_T") -_U = TypeVar("_U") - - -class _INITIAL_MISSING: - pass - - -# Reference: https://docs.python.org/3/library/functools.html#functools.reduce -@substitute_in_graph(functools.reduce) -def reduce( - function: Callable[[_U, _T], _U], - iterable: Iterable[_T], - initial: _U = _INITIAL_MISSING, # type: ignore[assignment] - /, -) -> _U: - it = iter(iterable) - - value: _U - if initial is _INITIAL_MISSING: - try: - value = next(it) # type: ignore[assignment] - except StopIteration: - raise TypeError( - "reduce() of empty iterable with no initial value", - ) from None - else: - value = initial - - for element in it: - value = function(value, element) - - return value +__all__ = [] # type: ignore[var-annotated] diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index f42c4653093cd7..12c035caf3cadb 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2999,6 +2999,7 @@ def _builtin_function_ids() -> Dict[int, str]: rv.update( { id(cast): "typing.cast", + id(functools.reduce): "functools.reduce", id(copy.deepcopy): "copy.deepcopy", } ) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index b4fbbbc794df1d..86900dee30bc56 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1606,10 +1606,7 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): if start is self._SENTINEL: start = variables.ConstantVariable.create(0) items = seq.unpack_var_sequence(tx) - - from .builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, functools.reduce).call_function( + return BuiltinVariable(functools.reduce).call_function( tx, [ BuiltinVariable(operator.add), @@ -1619,6 +1616,19 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): {}, ) + def call_reduce( + self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL + ): + if iterable.has_unpack_var_sequence(tx): + items = iterable.unpack_var_sequence(tx) + if initial is self._SENTINEL: + value, items = items[0], items[1:] + else: + value = initial + for element in items: + value = function.call_function(tx, [value, element], {}) + return value + def call_getattr( self, tx: "InstructionTranslator", From cf6ba952b933467089ba265bdca08592d477a70a Mon Sep 17 00:00:00 2001 From: Thomas Bohnstingl Date: Fri, 30 Aug 2024 16:09:52 +0000 Subject: [PATCH 0282/1018] Improvements for associative_scan - combine_mode (#133012) This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `combine_mode`, which can be either `pointwise` (default) or `generic`. In case of `generic`, the `associative_scan` is more flexible and allows also to perform non-pointwise functions. This PR has been derived from https://github.com/pytorch/pytorch/pull/129307. @ydwu4 @Chillee @zou3519 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133012 Approved by: https://github.com/ydwu4 --- test/dynamo/test_repros.py | 102 +++++++++ test/functorch/test_control_flow.py | 225 ++++++++++++++++++-- test/inductor/test_control_flow.py | 65 ++++-- torch/_dynamo/variables/higher_order_ops.py | 25 ++- torch/_higher_order_ops/associative_scan.py | 209 ++++++++++++++---- 5 files changed, 555 insertions(+), 71 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 21f80c4a4fb4be..f2a96131952f6e 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5752,6 +5752,108 @@ def fn(x): fn(torch.randn(4)) + @requires_cuda + # This test will fail as flip in combination with particular input lenghts + # produces weird results. + # This is under investigations in + # https://github.com/pytorch/pytorch/issues/131805 + @unittest.skip("Skip this flip test for the moment. It is under investigation") + def test_flip_bad_accuracy(self): + import torch + import torch._dynamo.config + import torch._functorch.config + import torch._inductor.config + import torch._inductor.inductor_prims + import torch.fx.experimental._config + + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, arg0_1): + rev = torch.ops.prims.rev.default(arg0_1, [0]) + arg0_1 = None + slice_1 = torch.ops.aten.slice.Tensor(rev, 0, 0, -1, 2) + slice_2 = torch.ops.aten.slice.Tensor(rev, 0, 1, 9223372036854775807, 2) + add_1 = torch.ops.aten.add.Tensor(slice_1, slice_2) + slice_1 = slice_2 = None + slice_3 = torch.ops.aten.slice.Tensor(add_1, 0, 0, -1, 2) + slice_4 = torch.ops.aten.slice.Tensor( + add_1, 0, 1, 9223372036854775807, 2 + ) + add_2 = torch.ops.aten.add.Tensor(slice_3, slice_4) + slice_3 = slice_4 = None + slice_5 = torch.ops.aten.slice.Tensor(add_2, 0, 0, -1, 2) + slice_6 = torch.ops.aten.slice.Tensor( + add_2, 0, 1, 9223372036854775807, 2 + ) + add_3 = torch.ops.aten.add.Tensor(slice_5, slice_6) + slice_5 = slice_6 = None + slice_9 = torch.ops.aten.slice.Tensor(add_2, 0, 0, 1) + add_2 = None + unsqueeze = torch.ops.aten.unsqueeze.default(slice_9, 1) + slice_9 = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(add_3, 1) + add_3 = None + cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1) + unsqueeze = unsqueeze_1 = None + view = torch.ops.aten.view.default(cat, [2]) + cat = None + slice_10 = torch.ops.aten.slice.Tensor(view, 0, 0, -1) + slice_11 = torch.ops.aten.slice.Tensor( + add_1, 0, 2, 9223372036854775807, 2 + ) + add_5 = torch.ops.aten.add.Tensor(slice_10, slice_11) + slice_10 = slice_11 = None + slice_12 = torch.ops.aten.slice.Tensor(add_1, 0, 0, 1) + add_1 = None + cat_1 = torch.ops.aten.cat.default([slice_12, add_5]) + slice_12 = add_5 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(cat_1, 1) + cat_1 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view, 1) + view = None + cat_2 = torch.ops.aten.cat.default([unsqueeze_2, unsqueeze_3], 1) + unsqueeze_2 = unsqueeze_3 = None + view_1 = torch.ops.aten.view.default(cat_2, [4]) + cat_2 = None + slice_13 = torch.ops.aten.slice.Tensor( + rev, 0, 2, 9223372036854775807, 2 + ) + add_6 = torch.ops.aten.add.Tensor(view_1, slice_13) + slice_13 = None + slice_14 = torch.ops.aten.slice.Tensor(rev, 0, 0, 1) + rev = None + cat_3 = torch.ops.aten.cat.default([slice_14, add_6]) + slice_14 = add_6 = None + constant_pad_nd = torch.ops.aten.constant_pad_nd.default( + view_1, [0, 1], 0.0 + ) + view_1 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(cat_3, 1) + cat_3 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(constant_pad_nd, 1) + constant_pad_nd = None + cat_4 = torch.ops.aten.cat.default([unsqueeze_4, unsqueeze_5], 1) + unsqueeze_4 = unsqueeze_5 = None + view_2 = torch.ops.aten.view.default(cat_4, [10]) + cat_4 = None + slice_15 = torch.ops.aten.slice.Tensor(view_2, 0, 0, 9) + view_2 = None + rev_1 = torch.ops.prims.rev.default(slice_15, [0]) + slice_15 = None + return (rev_1,) + + mod = Repro() + x = torch.arange(9, device=torch.device("cuda")) + + @torch.compile + def f(x): + return mod(x) + + out = f(x) + self.assertEqual(torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]), out[0]) + instantiate_parametrized_tests(ReproTests) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index e26c8a6f8220e2..ba57a91a9e34a4 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -19,6 +19,7 @@ from torch.testing._internal.common_cuda import SM70OrLater from torch.testing._internal.common_quantization import skipIfNoDynamoSupport from torch.testing._internal.common_utils import ( + decorateIf, instantiate_parametrized_tests, IS_WINDOWS, parametrize, @@ -1207,8 +1208,18 @@ def fwbw(map_op, f, x, y): @unittest.skipIf(not SM70OrLater, "triton") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") @parametrize("reverse", [False, True]) - @parametrize("device", [torch.device("cuda")]) - def test_pointwise_associative_scan_reverse_simple(self, reverse, device): + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_pointwise_associative_scan_simple(self, reverse, combine_mode, device): def add(x: torch.Tensor, y: torch.Tensor): return x + y @@ -1216,17 +1227,19 @@ def mul(x: torch.Tensor, y: torch.Tensor): return x * y x = torch.randn(3, 10, 2, device=device) + for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: - result = associative_scan(op, x, 0, reverse=reverse) + result = associative_scan( + op, x, 0, reverse=reverse, combine_mode=combine_mode + ) result_exp = _fake_associative_scan(op, x, 0, reverse=reverse) self.assertEqual(result, result_exp) - if not reverse: - result_exp_PT = op_pt(x, 0) - self.assertEqual(result, result_exp_PT) # Jax Examples x = torch.arange(0, 4, device=device) - cumsum1 = associative_scan(add, x, 0, reverse=reverse) + cumsum1 = associative_scan( + add, x, 0, reverse=reverse, combine_mode=combine_mode + ) cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse) if not reverse: self.assertEqual( @@ -1241,8 +1254,18 @@ def mul(x: torch.Tensor, y: torch.Tensor): @unittest.skipIf(not SM70OrLater, "triton") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") @parametrize("reverse", [False, True]) - @parametrize("device", [torch.device("cuda")]) - def test_pointwise_associative_scan_reverse_dim(self, reverse, device): + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_pointwise_associative_scan_dim(self, reverse, combine_mode, device): import random def add(x: torch.Tensor, y: torch.Tensor): @@ -1258,7 +1281,9 @@ def mul(x: torch.Tensor, y: torch.Tensor): x = torch.randn(*shapes, device=device) for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: - result = associative_scan(op, x, rnd_scan_dim, reverse=reverse) + result = associative_scan( + op, x, rnd_scan_dim, reverse=reverse, combine_mode=combine_mode + ) result_exp = _fake_associative_scan( op, x, rnd_scan_dim, reverse=reverse ) @@ -1270,10 +1295,20 @@ def mul(x: torch.Tensor, y: torch.Tensor): @unittest.skipIf(not SM70OrLater, "triton") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") @parametrize("reverse", [False, True]) + @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("compile_mode", ["compile", "compile_dynamic_shape"]) - @parametrize("device", [torch.device("cuda")]) - def test_pointwise_associative_scan_reverse_compile( - self, reverse, compile_mode, device + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_pointwise_associative_scan_compile( + self, reverse, combine_mode, compile_mode, device ): def add(x: torch.Tensor, y: torch.Tensor): return x + y @@ -1293,7 +1328,9 @@ def mul(x: torch.Tensor, y: torch.Tensor): ) for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: - result = associative_scan_fct(op, x, 0, reverse=reverse) + result = associative_scan_fct( + op, x, 0, reverse=reverse, combine_mode=combine_mode + ) result_exp = _fake_associative_scan(op, x, 0, reverse=reverse) self.assertEqual(result, result_exp) if not reverse: @@ -1302,7 +1339,9 @@ def mul(x: torch.Tensor, y: torch.Tensor): # Jax Examples x = torch.arange(0, 4, device=device) - cumsum1 = associative_scan_fct(add, x, 0, reverse=reverse) + cumsum1 = associative_scan( + add, x, 0, reverse=reverse, combine_mode=combine_mode + ) cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse) if not reverse: self.assertEqual( @@ -1314,6 +1353,162 @@ def mul(x: torch.Tensor, y: torch.Tensor): ) self.assertEqual(cumsum1, cumsum_exp) + @unittest.skipIf(not SM70OrLater, "triton") + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @parametrize("reverse", [False, True]) + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_pointwise_associative_scan_binary_operator( + self, reverse, combine_mode, device + ): + def fct(x, y): + A_i, Bu_i = x + A_j, Bu_j = y + return A_j * A_i, A_j * Bu_i + Bu_j + + torch.compiler.reset() + associative_scan1 = torch.compile(associative_scan, fullgraph=True) + associative_scan2 = associative_scan + + state_dim = 20 + timesteps = 10 + projected_inputs = torch.randn( + timesteps, state_dim, requires_grad=True, device=device + ) + A = torch.randn(state_dim, requires_grad=True, device=device) + elements = (A.repeat((timesteps, 1)), projected_inputs) + + result1 = associative_scan1( + fct, elements, 0, combine_mode=combine_mode, reverse=reverse + ) + result2 = associative_scan2( + fct, elements, 0, combine_mode=combine_mode, reverse=reverse + ) + expected_result = _fake_associative_scan(fct, elements, 0, reverse=reverse) + self.assertEqual( + result1, + expected_result, + ) + self.assertEqual([r.device.type for r in result1], [device.type] * len(result1)) + self.assertEqual( + result2, + expected_result, + ) + self.assertEqual([r.device.type for r in result2], [device.type] * len(result2)) + + @unittest.skipIf(not SM70OrLater, "triton") + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @parametrize("reverse", [False, True]) + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_pointwise_associative_scan_tuple(self, reverse, combine_mode, device): + def fct(x, y): + return (x[0] + y[0], x[1] * y[1]) + + x = torch.randn(3, 2, 2, device=device, requires_grad=True) + y = torch.randn(3, 2, 2, device=device, requires_grad=True) + inp = (x, y) + + result1 = associative_scan( + fct, inp, 0, reverse=reverse, combine_mode=combine_mode + ) + expected_result = _fake_associative_scan(fct, inp, 0, reverse=reverse) + self.assertEqual(result1, expected_result) + + @unittest.skipIf(not SM70OrLater, "triton") + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @parametrize("reverse", [False, True]) + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_pointwise_associative_scan_complex_pytree( + self, reverse, combine_mode, device + ): + def fct_wrong_pytree(x, y): + return { + "i": x["i"] * y["j"][0][0], + "k": 0.0, + "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), + } + + def fct_pointwise(x, y): + return { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + } + + x = torch.randn(3, 2, 2, device=device, requires_grad=True) + y = torch.randn(3, 2, 2, device=device, requires_grad=True) + z = torch.randn(3, 2, 2, device=device, requires_grad=True) + inp = {"i": x, "j": ([y], [{"o": z}])} + + with self.assertRaisesRegex(Exception, r"."): + result = associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic") + + torch.compiler.reset() + associative_scan1 = torch.compile(associative_scan, fullgraph=True) + associative_scan2 = associative_scan + + result1 = associative_scan1( + fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse + ) + result2 = associative_scan2( + fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse + ) + expected_result = _fake_associative_scan(fct_pointwise, inp, 0, reverse=reverse) + self.assertEqual(result1, expected_result) + self.assertEqual(result2, expected_result) + + @unittest.skipIf(not SM70OrLater, "triton") + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_generic_associative_scan_generic_simple(self, reverse, device): + def non_pointwise(x: torch.Tensor, y: torch.Tensor): + W = torch.diag(torch.ones(2, device=device)) + return x @ W + y @ W + + x = torch.randn(3, 10, 2, device=device) + with self.assertRaisesRegex(Exception, ".*"): + out = associative_scan( + non_pointwise, x, 0, reverse=reverse, combine_mode="pointwise" + ) + + result1 = associative_scan( + non_pointwise, x, 0, reverse=reverse, combine_mode="generic" + ) + result_expected = _fake_associative_scan(non_pointwise, x, 0, reverse=reverse) + self.assertEqual(result1, result_expected) + @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @skipIfNoDynamoSupport diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index ced9637bbb5310..d4580a9956e8e7 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -1,11 +1,13 @@ # Owner(s): ["module: inductor"] import itertools +import unittest import torch import torch._dynamo.testing from torch._higher_order_ops.associative_scan import associative_scan from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import ( + decorateIf, instantiate_parametrized_tests, parametrize, ) @@ -698,9 +700,15 @@ def test_while_loop_with_outer_buffers(self, device, dynamic): class AssociativeScanTests(TestCase): @requires_gpu - @parametrize("device", [GPU_TYPE]) + @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("backend", ["inductor"]) - def test_pointwise_associative_scan_CUDA_flip(self, device, backend): + @parametrize("device", [torch.device("cpu"), GPU_TYPE]) + # This test will fail as flip in combination with particular input lenghts + # produces weird results. + # This is under investigations in + # https://github.com/pytorch/pytorch/issues/131805 + @decorateIf(unittest.skip, lambda params: params["device"] == GPU_TYPE) + def test_associative_scan_CUDA_flip(self, combine_mode, backend, device): def fct(x: torch.Tensor, y: torch.Tensor): return x + y @@ -712,17 +720,35 @@ def fct(x: torch.Tensor, y: torch.Tensor): ) associative_scan2 = associative_scan - result1 = associative_scan1(fct, x, 0, reverse=False) - result2 = associative_scan2(fct, x, 0, reverse=False) + if combine_mode == "pointwise" and device == torch.device("cpu"): + with self.assertRaisesRegex(Exception, r"."): + associative_scan1( + fct, x, 0, reverse=False, combine_mode=combine_mode + ) + + # Skipping test because combine_mode currently only suppors CUDA tensors + return + + result1 = associative_scan1( + fct, x, 0, reverse=False, combine_mode=combine_mode + ) + result2 = associative_scan2( + fct, x, 0, reverse=False, combine_mode=combine_mode + ) result3 = torch.cumsum(x, 0) self.assertEqual(result1, result2) self.assertEqual(result1, result3) # Flip only non-compiled and compare with compiled reverse=True - result1 = associative_scan1(fct, x, 0, reverse=True) + result1 = associative_scan1( + fct, x, 0, reverse=True, combine_mode=combine_mode + ) result2 = torch.flip( - associative_scan2(fct, torch.flip(x, [0]), 0, reverse=False), [0] + associative_scan2( + fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode + ), + [0], ) result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) @@ -731,9 +757,14 @@ def fct(x: torch.Tensor, y: torch.Tensor): # Flip only compiled and compare with non-compiled reverse=True result1 = torch.flip( - associative_scan1(fct, torch.flip(x, [0]), 0, reverse=False), [0] + associative_scan1( + fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode + ), + [0], + ) + result2 = associative_scan2( + fct, x, 0, reverse=True, combine_mode=combine_mode ) - result2 = associative_scan2(fct, x, 0, reverse=True) result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) self.assertEqual(result1, result2) @@ -741,10 +772,16 @@ def fct(x: torch.Tensor, y: torch.Tensor): # Use reverse=False, but flip both results before and after result1 = torch.flip( - associative_scan1(fct, torch.flip(x, [0]), 0, reverse=False), [0] + associative_scan1( + fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode + ), + [0], ) result2 = torch.flip( - associative_scan2(fct, torch.flip(x, [0]), 0, reverse=False), [0] + associative_scan2( + fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode + ), + [0], ) result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) @@ -752,8 +789,12 @@ def fct(x: torch.Tensor, y: torch.Tensor): self.assertEqual(result1, result3) # Reverse=True - result1 = associative_scan1(fct, x, 0, reverse=True) - result2 = associative_scan2(fct, x, 0, reverse=True) + result1 = associative_scan1( + fct, x, 0, reverse=True, combine_mode=combine_mode + ) + result2 = associative_scan2( + fct, x, 0, reverse=True, combine_mode=combine_mode + ) result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) self.assertEqual(result1, result2) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 37de70a2658f2a..92579cf4fa16cd 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1033,9 +1033,25 @@ def arg_extractor(combine_fn, input, dim): # Trace the subgraph # TODO: Fix these pointless new_empty calls appearing in the dynamo output graph. - null_shape = SourcelessBuilder.create(tx, ()) sub_args = [ - leaf.call_method(tx, "new_empty", args=(null_shape,), kwargs={}) + leaf.call_method( + tx, + "new_empty", + args=( + SourcelessBuilder.create( + tx, + leaf.size + if leaf.size is not None + else BuiltinVariable(getattr) + .call_function(tx, [leaf, ConstantVariable.create("shape")], {}) + .items, + ), + ), + kwargs={ + "dtype": SourcelessBuilder.create(tx, leaf.dtype), + "requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad), + }, + ) for leaf in itertools.chain(input.items, input.items) ] ( @@ -1078,11 +1094,6 @@ def arg_extractor(combine_fn, input, dim): + f"got {combine_result_meta.dtype}" ) - if combine_result_meta.shape != (): - unimplemented( - f"Expected combine_fn to return a tensor with shape () but got {combine_result_meta.shape}" - ) - combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm) diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index ef57af9c0d63b4..3616574efe77d6 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -8,13 +8,13 @@ import torch._subclasses.functional_tensor import torch.utils._pytree as pytree from torch._C import DispatchKey -from torch._C._functorch import _add_batch_dim, get_unwrapped, maybe_get_bdim from torch._higher_order_ops.utils import ( _set_compilation_env, autograd_not_implemented, reenter_make_fx, unique_graph_id, ) +from torch._inductor.utils import is_pointwise_use from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( @@ -37,6 +37,37 @@ def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves): return combined_leaves +def _interleave(a, b, dim): + # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors + if b_trunc := (a.shape[dim] == b.shape[dim] + 1): + pad = ( + [0] * ((b.ndim - dim - 1) * 2 + 1) + + [1] + + [0] * (b.ndim * 2 - ((b.ndim - dim - 1) * 2 + 2)) + ) + b = torch.nn.functional.pad(b, pad) + + stacked = torch.stack([a, b], dim=dim + 1) + interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1) + if b_trunc: + # TODO: find torch alternative for slice_along dim for torch.jit.script to work + interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1) + return interleaved + + +def safe_map(f, *args): + args = list(map(list, args)) + n = len(args[0]) + for arg in args[1:]: + if len(arg) != n: + raise ValueError("length mismatch: {list(map(len, args))}") + + def nf(a): + return f(*a) + + return list(map(nf, zip(*args))) + + class AssociativeScanOp(HigherOrderOperator): def __init__(self): super().__init__("associative_scan") @@ -53,6 +84,7 @@ def associative_scan( input: pytree.PyTree, dim: int, reverse: bool = False, + combine_mode: str = "pointwise", ) -> torch.Tensor: r""" Performs an inclusive scan with an associative pointwise combine function. @@ -74,6 +106,11 @@ def associative_scan( All inputs are expected to have the same shape. dim (int): the dimension to scan over reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension. + combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``. + If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations + and ``input`` must be CUDA tensors. + In all other cases ``combine_mode=generic`` should be used. + Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``. Example:: @@ -86,22 +123,29 @@ def add(x: torch.Tensor, y: torch.Tensor): """ assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" + assert combine_mode in ["pointwise", "generic"] if not torch._dynamo.is_compiling(): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): return torch.compile(associative_scan, fullgraph=True)( - combine_fn, input, dim, reverse=reverse + combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode ) leaves, spec = pytree.tree_flatten(input) - if reverse: - leaves = [torch.flip(elem, [dim]) for elem in leaves] + if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves): + raise ValueError( + "For combine_mode='pointwise', all input tensors need to be on CUDA" + ) assert len(leaves) >= 1, "expected at least 1 input leaf" assert all( isinstance(x, torch.Tensor) for x in leaves ), "input leaves must be a Tensor" + + if reverse: + leaves = [torch.flip(elem, [dim]) for elem in leaves] + shape = leaves[0].shape ndim = len(shape) dim = utils.canonicalize_dim(ndim, dim) @@ -109,11 +153,27 @@ def add(x: torch.Tensor, y: torch.Tensor): for x in leaves[1:]: assert x.shape == shape, "All input tensors must have the same shape" + out = combine_fn( + pytree.tree_unflatten(leaves, spec), + pytree.tree_unflatten(leaves, spec), + ) + out_leaves, tree_out = pytree.tree_flatten(out) + assert len(leaves) == len( + out_leaves + ), "The pytree of the output of the operator needs to match the input pytree" + for x in out_leaves: + assert ( + x.shape == shape + ), "The pytree of the output of the operator needs to match the input pytree" + combine_fn = functools.partial( wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves) ) - result_flat = associative_scan_op(combine_fn, leaves, dim) + if combine_mode == "generic": + result_flat = generic_associative_scan(combine_fn, leaves, dim) + else: + result_flat = associative_scan_op(combine_fn, leaves, dim) if reverse: result_flat = [torch.flip(elem, [dim]) for elem in result_flat] @@ -121,12 +181,113 @@ def add(x: torch.Tensor, y: torch.Tensor): return pytree.tree_unflatten(result_flat, spec) +def generic_associative_scan(operator, elems_flat, dim=0): + r""" + This function performs the associative_scan operation. + The algorithm works by recursively collecting neighbours of ``elems_flat`` and subsequently + applying the ``operator`` on all pairs in parallel along ``dim``. + The results of the recursive calls are later combined. + + Args: + operator (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, + or if input is a pytree ``(pytree, pytree) -> pytree``. + This function must be pure, pointwise, and satisfy the associative property. + elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of + ``input`` provided to ``associative_scan``. + All inputs are expected to have the same shape. + dim (int): the dimension to scan over + + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + return x + y + + elems_flat = torch.tensor([0.0, 1.0, 2.0, 3.0]) + + First iteration of _scan -> + # odd_elems -> apply operator on all neighbours + # odd_elems = operator([torch.tensor([0.0, 2.0])], + # [torch.tensor([1.0, 3.0])]) + odd_elems = torch.tensor([1.0, 5.0]) + Second iteration of _scan -> + # odd_elems = operator([torch.tensor([1.0])], + # [torch.tensor([5.0])]) + odd_elems = torch.tensor([6.0]) + # even_elems -> apply operator on all odd_elems and + # every second element of ``elems``, starting from the second element. + # even_elems is expanded with the first element of ``elems`` + even_elems = [1.0] + # Merges odd_elems and even_elems + res = torch.tensor([1.0, 6.0]) + # even_elems -> apply operator on all odd_elems and + # every second element of ``elems``, starting from the second element. + # even_elems is expanded with the first element of ``elems`` + even_elems = [0.0, 3.0] + # Merges odd_elems and even_elems + res = torch.tensor([0.0, 1.0, 3.0, 6.0]) + + """ + + def _scan(elems): + """Perform the actual recursive scan on ``elems``.""" + num_elems = elems[0].shape[dim] + + if num_elems < 2: + return elems + + reduced_elems = operator( + *[aten.slice(elem, dim, 0, -1, 2) for elem in elems], + *[aten.slice(elem, dim, 1, None, 2) for elem in elems], + ) + + # Recursively compute scan for partially reduced tensors. + odd_elems = _scan(reduced_elems) + + if num_elems % 2 == 0: + even_elems = operator( + *[aten.slice(e, dim, 0, -1) for e in odd_elems], + *[aten.slice(e, dim, 2, None, 2) for e in elems], + ) + else: + even_elems = operator( + *odd_elems, + *[aten.slice(e, dim, 2, None, 2) for e in elems], + ) + + # The first element of a scan is the same as the first element + # of the original `elems`. + even_elems = [ + torch.cat([aten.slice(elem, dim, 0, 1), result], dim=dim) + if result.shape.numel() > 0 and elem.shape[dim] > 0 + else result + if result.shape.numel() > 0 + else aten.slice( + elem, dim, 0, 1 + ) # Jax allows/ignores concat with 0-dim, Pytorch does not + for (elem, result) in zip(elems, even_elems) + ] + + return list( + safe_map(functools.partial(_interleave, dim=dim), even_elems, odd_elems) + ) + + scans = _scan(elems_flat) + + return scans + + def trace_associative_scan( proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int ): with disable_proxy_modes_tracing(): sample_inputs = [ - torch.full((), False, dtype=x.dtype, device=x.device) + torch.empty_like( + x, + dtype=x.dtype, + device=x.device, + requires_grad=x.requires_grad, + ) for x in itertools.chain(input, input) ] combine_graph = reenter_make_fx(combine_fn)(*sample_inputs) @@ -138,6 +299,11 @@ def trace_associative_scan( assert len(node.args) == 1 outputs = node.args[0] + if not all(is_pointwise_use(use) or use.op == "output" for use in node.users): + raise ValueError( + "For combine_mode='pointwise', the combine_fn needs to be pointwise" + ) + assert outputs is not None assert len(outputs) == len( input @@ -149,12 +315,6 @@ def trace_associative_scan( f"combine_fn output type mismatch, expected {i.dtype} " + f"but got {o_meta.dtype}" ) - assert ( - o_meta.shape == () - ), f"combine_fn must return a scalar tensor but got shape {o_meta.shape}" - assert ( - o_meta.shape == () - ), f"combine_fn must return a scalar tensor but got shape {o_meta.shape}" _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") @@ -200,28 +360,3 @@ def associative_scan_functionalize(ctx, combine_fn, input, dim): functional_combine_fn = ctx.functionalize(combine_fn) ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim) return ctx.wrap_tensors(ret) - - -@associative_scan_op.py_impl(torch._C._functorch.TransformType.Vmap) -def associative_scan_batch_rule(interpreter, input, dim, combine_fn): - input_ = [get_unwrapped(x) for x in input] - input_bdims = [maybe_get_bdim(x) for x in input] - - batch_size = None - for inp, bdim in zip(input, input_bdims): - if bdim is not None: - batch_size = get_unwrapped(inp).shape[bdim] - - assert batch_size - input_unwrapped = [] - for x, bdim in zip(input, input_bdims): - unwrap = get_unwrapped(x) - if dim is None: - unwrap = unwrap.unsqueeze(0).expand(batch_size, *x.shape) - else: - unwrap = unwrap.movedim(bdim, 0) - input_unwrapped.append(unwrap) - - res = associative_scan_op(combine_fn, input_unwrapped, dim + 1) - lvl = interpreter.level() - return [_add_batch_dim(x, 0, lvl) for x in res] From f59dbb5345d0068439d81edb8dfb1f8da92dd2e2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 30 Aug 2024 16:27:40 +0000 Subject: [PATCH 0283/1018] Revert "[c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931)" This reverts commit 65864d01341d006955579b145f78547314ceb14b. Reverted https://github.com/pytorch/pytorch/pull/132931 on behalf of https://github.com/ZainRizvi due to This PR is breaking builds internally due to the removal of ProcessGroup::Options ([comment](https://github.com/pytorch/pytorch/pull/132931#issuecomment-2321862402)) --- test/distributed/test_c10d_common.py | 8 +- test/distributed/test_device_mesh.py | 3 +- torch/_C/_distributed_c10d.pyi | 15 +- torch/csrc/distributed/c10d/ProcessGroup.cpp | 22 +- torch/csrc/distributed/c10d/ProcessGroup.hpp | 5324 +++++++++++++++--- torch/csrc/distributed/c10d/init.cpp | 60 +- torch/distributed/device_mesh.py | 5 +- torch/distributed/distributed_c10d.py | 65 +- 8 files changed, 4715 insertions(+), 787 deletions(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index d96abb1ca82675..3e5538d57e38ae 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1820,7 +1820,6 @@ def test_init_process_group_optional_backend(self): def test_init_process_group_for_all_backends(self): for backend in dist.Backend.backend_list: - excepted_backend = backend # skip if the backend is not available on the system if backend == dist.Backend.UNDEFINED: continue @@ -1836,11 +1835,6 @@ def test_init_process_group_for_all_backends(self): elif backend == dist.Backend.UCC: if not dist.is_ucc_available(): continue - # Multi-threaded PG is defined as a pure python class. - # Its pg.name() does not going through Pybind, so its backend name - # is still "threaded" instead of "custom". - elif backend != "threaded": - excepted_backend = "custom" with tempfile.NamedTemporaryFile(delete=False) as f: store = dist.FileStore(f.name, self.world_size) @@ -1853,7 +1847,7 @@ def test_init_process_group_for_all_backends(self): pg = c10d._get_default_group() self.assertEqual(pg.rank(), self.rank) self.assertEqual(pg.size(), self.world_size) - self.assertEqual(pg.name(), str(excepted_backend)) + self.assertEqual(pg.name(), str(backend)) dist.destroy_process_group() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 7107b104416166..b1c5b4d5883de6 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -236,8 +236,7 @@ def test_set_mesh_dim_group_options(self): mesh_tensor = torch.arange(4).reshape(2, 2) mesh = DeviceMesh(device_type, mesh_tensor) - # Fake pg only have BackendType as BackendType::CUSTOM. - self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom") + self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake") class DeviceMeshTestNDim(DTensorTestBase): diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f89f0b50c8582e..6033d969925972 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -296,6 +296,15 @@ class Backend: def _set_default_timeout(self, timeout: timedelta) -> None: ... class ProcessGroup: + class Options: + def __init__(self, backend: str, timeout: timedelta = ...) -> None: ... + @property + def backend(self) -> str: ... + @property + def _timeout(self) -> timedelta: ... + @_timeout.setter + def _timeout(self, val: timedelta) -> None: ... + class BackendType(Enum): UNDEFINED = ... GLOO = ... @@ -310,6 +319,7 @@ class ProcessGroup: store: Store, rank: int, size: int, + options: Options, ) -> None: ... def rank(self) -> int: ... def size(self) -> int: ... @@ -499,7 +509,6 @@ class ProcessGroup: @property def _device_types(self) -> list[torch.device]: ... def _get_backend(self, device: torch.device) -> Backend: ... - def _set_default_backend(self, backend_type: BackendType) -> None: ... def _register_backend( self, device: torch.device, @@ -524,7 +533,7 @@ class ProcessGroup: class ProcessGroupGloo(Backend): class Device: ... - class Options(Backend.Options): + class Options(ProcessGroup.Options): devices: list[ProcessGroupGloo.Device] threads: int @@ -554,7 +563,7 @@ class ProcessGroupNCCL(Backend): min_ctas: int max_ctas: int - class Options(Backend.Options): + class Options(ProcessGroup.Options): config: ProcessGroupNCCL.NCCLConfig is_high_priority_stream: bool split_from: ProcessGroupNCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index f565de2013260c..75635bc68aed4f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -14,6 +14,22 @@ namespace c10d { +static ProcessGroup::BackendType strToBackendType(std::string_view backend) { + if (backend == "undefined") { + return ProcessGroup::BackendType::UNDEFINED; + } else if (backend == "gloo") { + return ProcessGroup::BackendType::GLOO; + } else if (backend == "nccl") { + return ProcessGroup::BackendType::NCCL; + } else if (backend == "ucc") { + return ProcessGroup::BackendType::UCC; + } else if (backend == "mpi") { + return ProcessGroup::BackendType::MPI; + } else { + return ProcessGroup::BackendType::CUSTOM; + } +} + std::string opTypeToString(OpType opType) { switch (opType) { case OpType::BROADCAST: @@ -103,11 +119,13 @@ c10::intrusive_ptr ProcessGroup::getBackend( ProcessGroup::ProcessGroup( const c10::intrusive_ptr<::c10d::Store>& store, int rank, - int size) + int size, + c10::intrusive_ptr options) : store_(store), rank_(rank), size_(size), - backendType_(BackendType::UNDEFINED), + options_(std::move(options)), + backendType_(strToBackendType(options_->backend)), dist_debug_level_(debug_level()) { C10_LOG_API_USAGE_ONCE("c10d.process_group"); } diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 83d2729fc43d43..d8c4409c28d4b6 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -1,773 +1,4701 @@ -#pragma once +#ifdef USE_C10D_NCCL -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -#include -// ************************************************************************* -// PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN -// versions 1.7 and 1.8. -// PLEASE DO NOT ADD ANY DEPENDENCIES. -// SEE RFC: https://github.com/pytorch/pytorch/issues/39662 -// ************************************************************************* +namespace c10d { + +constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM"; + +namespace { + +#if defined(NCCL_MAJOR) && \ + ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) +#define NCCL_HAS_AVG 1 +#endif + +// NCCL op mapping +const std::map ncclOp = { + {ReduceOp::MIN, ncclMin}, + {ReduceOp::MAX, ncclMax}, + {ReduceOp::SUM, ncclSum}, + {ReduceOp::PRODUCT, ncclProd}, +#ifdef NCCL_HAS_AVG + {ReduceOp::AVG, ncclAvg}, +#endif +}; + +// NCCL type typing +std::map ncclDataType = { + {at::kChar, ncclInt8}, + {at::kByte, ncclUint8}, + {at::kFloat, ncclFloat}, + {at::kDouble, ncclDouble}, + {at::kInt, ncclInt32}, + {at::kLong, ncclInt64}, + {at::kHalf, ncclHalf}, + {at::kBool, ncclUint8}, + {at::kFloat8_e5m2, ncclUint8}, + {at::kFloat8_e4m3fn, ncclUint8}, + {at::kFloat8_e4m3fnuz, ncclUint8}, + {at::kFloat8_e5m2fnuz, ncclUint8}, +#if HAS_NCCL_BF16_DATATYPE + {at::kBFloat16, ncclBfloat16}, +#endif +}; + +// Helper function that gets the data type and issues error if not supported +ncclDataType_t getNcclDataType(at::ScalarType type) { + auto it = ncclDataType.find(type); + TORCH_CHECK_WITH( + TypeError, + it != ncclDataType.end(), + "Input tensor data type is not supported for NCCL process group: ", + type); + return it->second; +} + +bool complexViewAsRealAllowed(const ReduceOp reduceOp) { + switch (reduceOp) { + case ReduceOp::SUM: + return true; + case ReduceOp::AVG: + return true; + case ReduceOp::PREMUL_SUM: + return true; + case ReduceOp::UNUSED: + return true; + default: + return false; + } + return false; +} + +#ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT +template +ncclRedOpRAII unpackPreMulSum( + const ReduceOp& reduceOp, + const ncclComm_t& comm) { + const auto* preMulSupplement = + reinterpret_cast(reduceOp.supplement_.get()); + ncclRedOp_t preMulSum; + bool has_tensor = preMulSupplement->tensor_factor.defined(); + auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; + const T* ptr_factor = has_tensor + ? preMulSupplement->tensor_factor.const_data_ptr() + : nullptr; + T scalar_factor = T(preMulSupplement->double_factor); + ncclRedOpCreatePreMulSum( + &preMulSum, + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/ops.html#ncclredopcreatepremulsum + // tells us that the scalar input is strictly a multiplier. + /*scalar=*/has_tensor ? const_cast(ptr_factor) : &scalar_factor, + dataType, + residence, + comm); + return ncclRedOpRAII(preMulSum, comm); +} +#endif + +ncclRedOpRAII getNcclReduceOp( + const ReduceOp& reduceOp, + at::Tensor& input, + const ncclDataType_t& dataType, + const ncclComm_t& comm) { + try { + if (input.scalar_type() == at::kBool) { + if (reduceOp == ReduceOp::SUM) { + // For bool tensors, map sum to max, which both represent a bitwise or. + // This is to prevent overflow issues with sum, since we use uint8 to + // represent a bool (see ncclDataType mapping). + return ncclMax; + } +#ifdef NCCL_HAS_AVG + if (reduceOp == ReduceOp::AVG) { + C10_THROW_ERROR( + TypeError, "Cannot use ReduceOp.AVG with boolean inputs"); + } +#endif + } + if (reduceOp == ReduceOp::PREMUL_SUM) { +#ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT + switch (dataType) { + case ncclHalf: + return unpackPreMulSum(reduceOp, comm); + case ncclFloat: + return unpackPreMulSum(reduceOp, comm); + case ncclDouble: + return unpackPreMulSum(reduceOp, comm); + default: + C10_THROW_ERROR( + TypeError, "PreMulSum Data type must be half, float, or double"); + ncclRedOp_t unused; + return unused; + } +#else + C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1"); +#endif + } + return ncclOp.at(reduceOp); + } catch (const std::out_of_range&) { + switch (reduceOp) { + case ReduceOp::AVG: + C10_THROW_ERROR( + ValueError, + c10::str( + "AVG requires NCCL 2.10+. The current version is ", + NCCL_MAJOR, + ".", + NCCL_MINOR)); + break; + case ReduceOp::BAND: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL"); + break; + case ReduceOp::BOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL"); + break; + case ReduceOp::BXOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); + break; + default: + C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); + break; + } + } +} + +// Get a key string from device +inline std::string getKeyFromDevice(at::Device& device) { + return std::to_string(device.index()); +} + +inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) { + // initialize the device index to -1, which is an invalid value. + int index = -1; + try { + index = std::stoi(deviceKey); + } catch (const std::invalid_argument& e) { + LOG(ERROR) << c10::str( + "Invalid deviceKey: ", deviceKey, ",", e.what(), "."); + } catch (const std::out_of_range& e) { + LOG(ERROR) << "Out of range: " << e.what(); + } + return static_cast(index); +} + +std::string getKeySendRecv(int myRank, int peer) { + int lowRank = myRank < peer ? myRank : peer; + int highRank = myRank < peer ? peer : myRank; + std::string sendRecvPair = + std::to_string(lowRank) + ":" + std::to_string(highRank); + return sendRecvPair; +} + +// Get device from tensor +inline at::Device getDevice(at::Tensor& tensor) { + return tensor.device(); +} + +// [Sync Streams] Helper that lets the input ncclStreams to wait for the current +// stream. NCCL communications run on ncclStreams, but input tensors are +// allocated on different streams (i.e., current streams). Communications on +// ncclStreams cannot start before pending input tensor ops on current streams +// finish. Otherwise, ops on two streams might read/write same tensors +// concurrently. +// +// The synchronization above alone is not enough. We also need to make sure +// input tensors are not freed before their usages on ncclStreams finish. This +// can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream, +// which remembers the usage stream (ncclStream), creates an event on the usage +// stream when GC attempts to free the input tensor, and delays GC until that +// event is done. +void syncStream( + at::Device& device, + at::cuda::CUDAEvent& ncclEvent, + at::cuda::CUDAStream& ncclStream) { + ncclEvent.record(at::cuda::getCurrentCUDAStream(device.index())); + ncclEvent.block(ncclStream); +} + +// Given a ncclUniqueId, convert it to a string representation that can be put +// in the store. +std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { + const uint8_t* bytes = reinterpret_cast(&ncclID); + std::ostringstream oss; + for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { + oss << std::hex << static_cast(bytes[i]); + } + return oss.str(); +} + +std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) { + return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; +} + +// Returns exception's what() given an exception_ptr instance. +std::string getExceptionMsgFromExceptionPtr( + const std::exception_ptr& exceptionPtr) { + TORCH_CHECK(exceptionPtr != nullptr); + try { + std::rethrow_exception(exceptionPtr); + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "Unknown exception type"; + } +} + +inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { + // parentheses avoid some compiler warnings + static const uint64_t min_version = + (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); + static const uint64_t cur_version = torch::cuda::nccl::version(); + if (cur_version < min_version) { + TORCH_CHECK_WITH( + NotImplementedError, + status == c10::cuda::CaptureStatus::None, + "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); + } +} + +} // namespace + +// Map from each communicator to its device index. +// This map is used when register/deregister cache segments from cache +// allocator. See design notes below: +// - Each segment should be registered only to the communicator on the +// same device. +// - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be +// ranks rather than device in point-to-point case. +// - This map has also to be maintained as global variable since the register +// hooks are called outside the scope of any PG, thus we need traverse +// communicators in all PGs. +static std::unordered_map, int> ncclCommDevIdxMap; +static std::mutex ncclCommDevIdxMapMutex; +static bool allocatorHooksAttached = false; + +std::atomic ProcessGroupNCCL::shouldDump_(false); + +void cacheAllocatorRegisterHook( + const c10::cuda::CUDACachingAllocator::TraceEntry& te) { + // Register after SEGMENT_ALLOC + if (te.action_ != + c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) { + return; + } + + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& it : ncclCommDevIdxMap) { + auto& ncclComm = it.first; + auto& devIdx = it.second; + if (te.device_ == devIdx) { + ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); + } + } +} + +void cacheAllocatorDeregisterHook( + const c10::cuda::CUDACachingAllocator::TraceEntry& te) { + // deregister before SEGMENT_FREE + if (te.action_ != + c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) { + return; + } + + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& it : ncclCommDevIdxMap) { + auto& ncclComm = it.first; + auto& devIdx = it.second; + if (te.device_ == devIdx) { + ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); + } + } +} + +std::unordered_map> +getNCCLCommDumpMap() { +#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) + std::unordered_map< + std::string /* ncclUniqueID */, + std::unordered_map /* dump from this comm */> + ncclDumpMap; + // dump_nccl_trace is only called from the default PG (local_id_=0), but we + // want to dump from all comms so we need to iterate over ncclCommDevIdxMap, + // which is static + std::vector> allNCCLComms; + // within the critical section, we don't want to dump while holding the lock + // as dump might hang + ncclCommDevIdxMapMutex.lock(); + for (auto& [ncclComm, _] : ncclCommDevIdxMap) { + allNCCLComms.push_back(ncclComm); + } + ncclCommDevIdxMapMutex.unlock(); + for (auto& ncclComm : allNCCLComms) { + std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); + ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); + } + return ncclDumpMap; +#else + return std::unordered_map< + std::string, + std::unordered_map>(); +#endif +} + +std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto ncclDumpMap = getNCCLCommDumpMap(); + return NCCLTraceBuffer::get()->dump( + ncclDumpMap, includeCollectives, includeStackTraces, onlyActive); +} + +std::string dump_nccl_trace_json(bool includeCollectives, bool onlyActive) { + auto ncclDumpMap = getNCCLCommDumpMap(); + return NCCLTraceBuffer::get()->dump_json( + ncclDumpMap, includeCollectives, onlyActive); +} + +std::optional)>>& +get_cpp_trace_dumper() { + static std::optional< + std::function)>> + dumper(std::nullopt); + return dumper; +} + +gil_checker_t& get_gil_checker() { + static gil_checker_t gil_checker = nullptr; + return gil_checker; +} + +std::future launchAsyncGilCheck() { + std::promise resultPromise; + std::future resultFuture = resultPromise.get_future(); + TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker"); + std::thread workerThread([promise = std::move(resultPromise)]() mutable { + c10::setThreadName("pt_nccl_gil_chk"); + + try { + auto& gil_checker = get_gil_checker(); + promise.set_value((*gil_checker)()); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + + // Detach the thread to allow it to run independently + workerThread.detach(); + + return resultFuture; +} + +const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100; +constexpr int64_t kSynchronizeBusyWaitMillis = 10; +thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; + +std::ostream& operator<<( + std::ostream& output, + const ProcessGroupNCCL::WorkNCCL& workNCCL) { + std::string workInfo; + workInfo = c10::str( + "WorkNCCL(", + "SeqNum=", + workNCCL.seq_, + ", OpType=", + opTypeToString(workNCCL.opType_), + ", NumelIn=", + workNCCL.numelIn_, + ", NumelOut=", + workNCCL.numelOut_, + ", Timeout(ms)=", + workNCCL.opTimeout_.count(), + ")"); + return output << workInfo; +} + +ProcessGroupNCCL::WorkNCCL::WorkNCCL( + const std::string& pgUID, + const std::string& pgDesc, + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle, + const std::optional>& inputs, + bool desyncDebug, + bool enableTiming, + bool cudaEventCacheEnabled, + DebugLevel distDebugLevel) + : Work(rank, opType, profilingTitle, inputs), + pgUID_(pgUID), + pgDesc_(pgDesc), + device_(device), + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq), + timingEnabled_(enableTiming), + distDebugLevel_(distDebugLevel) { + // Creates the CUDA event wrappers + // Note: The actual events are lazily created when first recorded to with + // DEFAULT_FLAGS = cudaEventDisableTiming. + if (cudaEventCacheEnabled) { + ncclStartEvent_ = enableTiming + ? ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming) + : nullptr; + ncclEndEvent_ = + ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming); + } else { + ncclStartEvent_ = enableTiming + ? std::make_shared(cudaEventDefault) + : nullptr; + ncclEndEvent_ = std::make_shared( + enableTiming ? cudaEventDefault : cudaEventDisableTiming); + } +} + +ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) + : Work(w.rank_, w.opType_), + std::enable_shared_from_this(w), + pgUID_(w.pgUID_), + pgDesc_(w.pgDesc_), + device_(w.device_), + ncclStartEvent_(w.ncclStartEvent_), + ncclEndEvent_(w.ncclEndEvent_), + ncclComm_(w.ncclComm_), + blockingWait_(w.blockingWait_), + opTimeout_(w.opTimeout_), + ownedEphermeralTimeout_(w.ownedEphermeralTimeout_), + workStartTime_(w.workStartTime_), + seq_(w.seq_), + startTraceUpdated_(w.startTraceUpdated_), + numelIn_(w.numelIn_), + numelOut_(w.numelOut_), + store_(w.store_), + timingEnabled_(w.timingEnabled_), + trace_id_(w.trace_id_), + distDebugLevel_(w.distDebugLevel_) { + exception_ = w.exception_; +} + +ProcessGroupNCCL::WorkNCCL::~WorkNCCL() = default; + +bool ProcessGroupNCCL::WorkNCCL::isCompleted() { + if (!ncclComm_->isAborted()) { + checkAndSetException(); + } + return exception() || finishedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::isStarted() { + if (!ncclComm_->isAborted()) { + checkAndSetException(); + } + return exception() || startedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::isSuccess() const { + C10_THROW_ERROR(NotImplementedError, "WorkNCCL::isSuccess() is deprecated"); +} + +void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { + if (exception()) { + // We already have an exception. + return; + } + + auto exception_ptr = checkForNCCLErrors(); + std::unique_lock lock(mutex_); + exception_ = exception_ptr; + if (exception_) { + LOG(ERROR) << logPrefix() << "Collective " << *this + << " raised the following async exception: " + << getExceptionMsgFromExceptionPtr(exception_); + } +} + +const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { + static std::string prefix = c10::str("[Rank ", rank_, "] "); + return prefix; +} + +void ProcessGroupNCCL::WorkNCCL::setException( + std::exception_ptr exception_ptr) { + std::unique_lock lock(mutex_); + exception_ = exception_ptr; +} + +// Helper that checks if the NCCL kernels are completed on the GPUs +bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() { + checkAndSetException(); + return finishedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const { + // if timing is disabled we won't have allocated start events + if (!timingEnabled_) { + return false; + } + // Checking the work's corresponding CUDA event's status + if (!ncclStartEvent_->query()) { + return false; + } + return true; +} + +bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { + // Checking the work's corresponding CUDA event's status + // It calls `cudaEventQuery` eventually. Although this seems to be a + // non-blocking call, but we did notice hangs in the past. It can + // hang if another thread is holding the CUDA global context lock. For + // example, when doing a `cudaDeviceSynchronize` or even + // `cudaStreamSynchronize`. + if (!ncclEndEvent_->query()) { + return false; + } + return true; +} + +bool ProcessGroupNCCL::WorkNCCL::checkTimeout( + std::optional timeout) { + STATIC_SCOPED_WAIT_COUNTER( + pytorch.wait_counter.ProcessGroupNCCL__checkTimeout); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + auto workTimeout = timeout ? *timeout : opTimeout_; + + if (timeElapsed < workTimeout) + return false; + + // Timed out + + // There is already an error, we don't override it + if (exception()) + return true; + + std::string exceptionMsg = c10::str( + logPrefix(), + "Watchdog caught collective operation timeout: ", + *this, + " ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + + LOG(ERROR) << exceptionMsg; + std::exception_ptr exception_ptr = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exceptionMsg)); + setException(exception_ptr); + return true; +} + +void ProcessGroupNCCL::WorkNCCL::handleException( + ErrorHandlingMode errorHandling) { + if (exception_) { + auto exceptionMsg = c10::str( + "Some NCCL operations have failed or timed out. Due to the ", + "asynchronous nature of CUDA kernels, subsequent GPU operations ", + "might run on corrupted/incomplete data."); + LOG(ERROR) << logPrefix() << exceptionMsg; + C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException"); + + if (SHOULD_TEAR_DOWN(errorHandling)) { + auto tearDownMsg = c10::str( + "To avoid data inconsistency, we are taking the entire process down."); + LOG(ERROR) << logPrefix() << tearDownMsg; + std::rethrow_exception(exception_); + } + } +} + +void ProcessGroupNCCL::WorkNCCL::synchronize() { + // Call Synchronize without a timeout. We use this method to avoid adding a + // timeout argument to the public synchronize API. + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { + auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); + // Block the current stream on the NCCL stream + ncclEndEvent_->block(currentStream); + + if (avoidRecordStreams_) { + stashed_for_allocator_safety_->clear(); + } +} + +// Waiting on the work's corresponding CUDA events +void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { + synchronizeStream(); + + // In case of blocking, wait for the operation to complete. + if (blockingWait_) { + while (!isCompleted()) { + bool timedOut = checkTimeout( + timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); + // Explicitly abort ncclComms here before throwing this timed out + // exception to users. + // If throwing timed out excepiton without aborting nccl communicators + // here, it was observed that CUDA GPU will have 100% utilization and + // can not run new events successfully. + if (timedOut) { + std::string exceptionMsg = c10::str( + logPrefix(), + "Work ", + (*this), + " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1)."); + LOG(ERROR) << exceptionMsg; + break; + } + // Yield + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + // exception() includes timeout and error during blocking wait + if (exception()) { + // Abort NCCL communicators + abort(); + // Throw exception (from main thread here) + handleException(TearDown); + } + } + + // Device synchronize only after we've completed timeout checks. + if (barrierTensor_.defined()) { + // If we use the work to do barrier, we should block here + // `dist.barrier()` only requires all CPU processes to enter this + // function, hence we only need to make sure the dummy all-reduce has + // completed. So we would only need to sync the **current stream** back to + // host, and do not need to synchronize the entire device (which may have + // kernels running on other streams). + // Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can: + // - lower chance of hang; + // - CurrentCUDAStream is usually the context of the next operation in + // Python, thus blocking current stream would already block the next + // compute kernel; + // - achieve better barrier performance. + auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); + // CUDAStream wrapper will correctly use a DeviceGuard here + currentStream.synchronize(); + } +} + +// Same as calling synchronize(). +bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { + RECORD_PARAM_COMMS( + static_cast(this->seq_), // seq + std::make_tuple(pgUID_, pgDesc_), // PG name tuple + rank_, // rank + "wait", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, + -1, + static_cast(1)); // number of device? + synchronizeInternal(timeout); + // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL + // upgrade. Once a NCCL version is qualified, this code should not be needed + // at runtime. +#ifdef PGNCCL_ENABLE_HASH + if (distDebugLevel_ >= DebugLevel::Detail) { + auto numel = getTensorsNumel(*outputs_); + auto hashValue = hashTensors(*outputs_); + PRINT_COLLECTIVE_HASH_SIGNATURE( + "output", opTypeToString(opType_), numel, hashValue); + } +#endif + // Always return true, because abort API is not implemented. + return true; +} + +void ProcessGroupNCCL::WorkNCCL::abort() { + // Abort all communicators of this work + ncclComm_->ncclCommAbort(); + + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.erase(ncclComm_); + ncclCommDevIdxMapMutex.unlock(); +} + +ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} + +// CUDA event is used to record the start/end of one Work. +// Instead of let the CUDA event gets destroyed, we now reuse it after the Work +// has been erased from workMetaList_. +// This is to avoid the potential deadlock caused by CudaEventDestroy. +std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( + bool timing) { + auto deleter = [this, timing](at::cuda::CUDAEvent* event) { + std::lock_guard lock(this->cacheMutex_); + this->eventsArray_[timing ? 1 : 0].push_back(event); + }; + at::cuda::CUDAEvent* event = nullptr; + { + std::lock_guard lock(cacheMutex_); + auto events = eventsArray_[timing ? 1 : 0]; + if (!events.empty()) { + event = events.back(); + events.pop_back(); + } + } + if (!event) { + event = new at::cuda::CUDAEvent( + timing ? cudaEventDefault : cudaEventDisableTiming); + } + return std::shared_ptr(event, std::move(deleter)); +} + +ProcessGroupNCCL::CUDAEventCache& ProcessGroupNCCL::CUDAEventCache::get() { + static ProcessGroupNCCL::CUDAEventCache cache; + return cache; +} + +static std::atomic process_group_id = 0; + +constexpr const char* MULTI_DEVICE_ERROR_MSG = + "Expecting one tensor only but got multiple. You are probably using multiple " + "devices under one thread. The support for such usage has been deprecated. " + "For details, please refer to " + "https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. " + "ProcessGroupNCCL continues supporting multi-process and multi-thread modes."; + +ProcessGroupNCCL::ProcessGroupNCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + c10::intrusive_ptr options) + : Backend(rank, size), + store_(store), + options_(options), + ncclCommCounter_(0), + traceKeyStart_(getTraceStartKey("NCCL", rank)), + traceKeyEnd_(getTraceEndKey("NCCL", rank)), + terminateProcessGroup_(false), + terminateHeartbeatMonitorThread_(false), + collectiveDebugInfoMode_(false), + local_id_(process_group_id++), + intraNodeComm_(initIntraNodeComm()) { + TORCH_CHECK_WITH( + ValueError, + at::cuda::getNumGPUs() != 0, + "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); + + // getNcclVersion needs to get called before launching threads which can + // potentially call getenv. getNcclVersion internally calls setenv to set some + // environment variables from config file, which can race with getenv from + // other threads and cause segfaults. + const auto ncclVersion = getNcclVersion(); + this->setGroupUid(options_->group_name); + this->localDeviceCount_ = at::cuda::getNumGPUs(); + logPrefix_ = createLogPrefix(); + blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false); + asyncErrorHandling_ = static_cast( + getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); + desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || + (dist_debug_level_ >= DebugLevel::Detail); + rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true); + // TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT + // or change its name to reflect that dump happens on exception including + // both timeout and other errors. + dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) || + (dist_debug_level_ >= DebugLevel::Detail); + sleepAfterException_ = getCvarBool(TORCH_NCCL_SLEEP_AFTER_EXCEPTION, false); + // logging C++ stack isn't safe. Introduce a variable to control it. + logCppStackOnUncleanShutdown_ = + getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true); + enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false); + heartbeat_ = 1ULL; + monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true)); + cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, false)); + heartbeatTimeoutInSec_ = + getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/); + waitTimeoutDumpInMilSec_ = + getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/); + coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000); + ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0); + enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); + // store_ usually is wrapped with PrefixStore and the prefix is different + // across different ProcessGroupNCCL(PG) instances. We need to get the + // underlying non-PrefixStore for sharing global information shared across + // different PGs. + PrefixStore* prefixStore = dynamic_cast(store_.get()); + globalStore_ = + prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; +#ifdef ENABLE_NCCL_ERROR_CHECKING + enableTiming_.store( + getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); +#endif + avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false); +#ifdef NCCL_HAS_COMM_REGISTER + useTensorRegisterAllocatorHook_ = + getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); + if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + expandable_segments()) { + useTensorRegisterAllocatorHook_ = false; + LOG(INFO) + << logPrefix() + << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; + } +#endif + + if (blockingWait_) { + if (asyncErrorHandling_ != NoHandling || desyncDebug_) { + LOG(INFO) + << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and " + << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG" + << "should not both be enabled. " + << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process."; + asyncErrorHandling_ = NoHandling; + desyncDebug_ = false; + } + } else { + if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { + LOG(INFO) + << logPrefix() + << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " + << "must both be enabled. " + << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING."; + asyncErrorHandling_ = SkipCleanUp; + } + } + +#ifdef ENABLE_NCCL_ERROR_CHECKING + ncclCommWatchdogThread_ = + std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); +#endif + + init(); + const std::string OFF = "OFF"; + std::string torch_distributed_debug = + getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: " + << "size: " << size << ", global rank: " << globalRank() + << ", TIMEOUT(ms): " << options_->timeout.count() + << ", USE_HIGH_PRIORITY_STREAM: " + << options_->is_high_priority_stream + << ", SPLIT_FROM: " << options_->split_from + << ", SPLIT_COLOR: " << options_->split_color + << ", PG Name: " << options_->group_name; + + LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: " + << "NCCL version: " << ncclVersion + << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ + << ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeoutOrEx_ + << ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: " + << waitTimeoutDumpInMilSec_ + << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_ + << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load() + << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_ + << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug +#ifdef NCCL_HAS_COMM_REGISTER + << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " + << useTensorRegisterAllocatorHook_ +#endif + << ", TORCH_NCCL_ENABLE_MONITORING: " + << monitorThreadEnabled_.load() + << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_ + << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_ + << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_ + << ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_ + << ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_ + << ", TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: " + << logCppStackOnUncleanShutdown_; + + if (options_->global_ranks_in_group.empty()) { + this->globalRankStart = 0; + } else { + this->globalRankStart = options_->global_ranks_in_group[0]; + } + + if (options_->global_ranks_in_group.empty()) { + this->globalRankStride = 1; + } else if (options_->global_ranks_in_group.size() == 1) { + this->globalRankStride = 0; + } else { + bool ranksAreStrided = true; + int startRank = options_->global_ranks_in_group[0]; + int stride = + options_->global_ranks_in_group[1] - options_->global_ranks_in_group[0]; + for (std::vector::size_type i = 0; + i < options_->global_ranks_in_group.size(); + i++) { + if (options_->global_ranks_in_group[i] != startRank + i * stride) { + ranksAreStrided = false; + break; + } + } + + if (ranksAreStrided) { + this->globalRankStride = options_->global_ranks_in_group[1] - + options_->global_ranks_in_group[0]; + } else { + this->globalRankStride = -1; + } + } + + // Attach hooks to cache allocator to trigger the hooks whenever a traced + // action is called. In the following hooks, we register a newly allocated + // segment when SEGMENT_ALLOC action occurs, and deregister a segment when + // SEGMENT_FREE action occurs. + // We attach hooks only once at the first PG creation. + // Attaching hooks fails if CUDACachingAllocator is not initialized, so + // lazyInitCUDA is called (and is a no-op if CUDA is already initialized). + if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { + at::globalContext().lazyInitCUDA(); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorRegisterHook); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorDeregisterHook); + allocatorHooksAttached = true; + } +} + +void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { + const auto key = getKeyFromDevice(device); + LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device " + << device; + getNCCLComm(key, device, OpType::ALLREDUCE); +} + +void ProcessGroupNCCL::performNocolorSplit(at::Device device) { + // If our backend doesn't support splitting, this is a no-op for + // ranks not in the new subgroup (and ranks that would be in it will + // just use a new communicator rather than split). +#ifdef NCCL_HAS_COMM_SPLIT + const auto key = getKeyFromDevice(device); + LOG(INFO) << logPrefix() << "Performing nocolor split on backend device " + << device << ", key " << key << ", i am " << this; + auto comm = getNCCLComm(key, device, OpType::ALLREDUCE); + NCCLComm::split( + comm.get(), + NCCL_SPLIT_NOCOLOR, + rank_, + options_->config, + options_->global_ranks_in_group); +#endif +} + +c10::intrusive_ptr ProcessGroupNCCL:: + initIntraNodeComm() { + using IntraNodeComm = intra_node_comm::IntraNodeComm; + if (!IntraNodeComm::isEnabled()) { + return nullptr; + } + auto prefixStore = c10::make_intrusive("IntraNodeComm", store_); + auto comm = c10::make_intrusive(prefixStore, rank_, size_); + if (comm->rendezvous()) { + return comm; + } else { + return nullptr; + } +} + +void ProcessGroupNCCL::setSequenceNumberForGroup() { +} // NCCL just starts sequence numbers at 0. + +uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { + return seqCollective_; +} + +void ProcessGroupNCCL::registerOnCompletionHook( + std::function)>&& hook) { + TORCH_CHECK_WITH( + DistBackendError, + onCompletionHook_ == nullptr, + "ProcessGroupNCCL OnCompletion hook already registered"); + + TORCH_CHECK_WITH( + ValueError, + enableTiming_.load(), + "ProcessGroupNCCL OnCompletion hook requires recording start and end " + "events which require setting TORCH_NCCL_ENABLE_TIMING environment variable. " + "This is only available for NCCL version >= 2.4."); + onCompletionHook_ = std::move(hook); + onCompletionHookThread_ = std::thread(&ProcessGroupNCCL::runHookLoop, this); +} + +// must release GIL when calling this method +void ProcessGroupNCCL::waitForPendingWorks() { + // Reasoning about hook completion: + // 1. waitForPendingWorks should be called after user code has finished + // calling + // all collectives. This means, when we got here, all of the collectives + // are either in workMetaList_ or has been erased from workMetaList_. + // 2. The watchdog thread grabs both locks to move Work object from the + // workMetaList_ to the completedWorkList_, and the hook thread only erases + // a Work object after the hook is returned. Therefore, after user code + // calls a collective, its Work object is either in workMetaList_ or in + // completedWorkList_ before it finishes. + // 3. We have three threads and two locks. + // a. main thread (this function) grabs two locks atomically + // b. watchdog thread (watchdogHandler function) always grabs + // workMetaListMutex_ + // first and then grabs completedWorkListMutex_. + // c. hook thread (runHookLoop function) only grabs + // completedWorkListMutex_. Therefore, locks are always acquired in the + // same order and hence no deadlocks. + while (true) { + { + std::lock(workMetaListMutex_, completedWorkListMutex_); + std::lock_guard lockWork(workMetaListMutex_, std::adopt_lock); + std::lock_guard lockHook( + completedWorkListMutex_, std::adopt_lock); + + if (workMetaList_.empty() && completedWorkList_.empty()) { + return; + } + } + + std::this_thread::sleep_for( + std::chrono::milliseconds(kWatchdogThreadSleepMillis)); + } +} + +void ProcessGroupNCCL::enableCollectivesTiming() { + enableTiming_.store(true); +} + +void ProcessGroupNCCL::waitForFutureOrTimeout( + std::future& fut, + const std::chrono::milliseconds& timeOutMilSec, + const std::string& futDescription, + bool throwException, + bool log) { + std::string errorMsg; + + ::c10d::C10dLoggingData data; + if (log) { + data.integers["pg_id"] = local_id_; + data.integers["rank"] = rank_; + data.integers["global_rank"] = globalRank(); + data.strings["flight_recorder_version"] = c10d::version_val_str; + } + + TORCH_CHECK(fut.valid(), "Expected a valid future"); + std::future_status status = fut.wait_for(timeOutMilSec); + if (status == std::future_status::ready) { + // Calling .get() will re-raise any exception from the future, and we don't + // care about the retval + try { + bool result = fut.get(); + if (result) { + LOG(INFO) << logPrefix() + << "future is successfully executed for: " << futDescription; + if (log) { + data.strings["status"] = "SUCCESS"; + } + } + } catch (const std::exception& e) { + errorMsg = c10::str( + logPrefix(), + "Exception thrown when waiting for future ", + futDescription, + ": ", + e.what()); + if (log) { + data.strings["status"] = "EXCEPTION"; + data.strings["exception"] = e.what(); + } + LOG(ERROR) << errorMsg; + } catch (...) { + errorMsg = c10::str( + logPrefix(), + "Unknown exception thrown when waiting for future ", + futDescription); + if (log) { + data.strings["status"] = "EXCEPTION"; + data.strings["exception"] = "Unknown exception"; + } + LOG(ERROR) << errorMsg; + } + } else { + errorMsg = c10::str( + logPrefix(), + "Future for ", + futDescription, + " timed out after ", + timeOutMilSec.count(), + " ms"); + data.strings["status"] = "TIMEOUT"; + LOG(ERROR) << errorMsg; + } + if (log) { + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(data); + } + } + if (throwException && !errorMsg.empty()) { + C10_THROW_ERROR(DistBackendError, errorMsg); + } +} + +void ProcessGroupNCCL::abortCommsFromMap( + std::unordered_map>& ncclCommsMap, + std::optional abortReason) { + // The process may control multiple devices, loop through the communicators on + // each device + for (auto& it : ncclCommsMap) { + auto& devName = it.first; + auto& ncclComm = it.second; + at::cuda::OptionalCUDAGuard gpuGuard; + at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); + if (deviceIndex >= 0) { + // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in + // the map could be non deviceIndex, but rank to rank numbers. So we + // indeed need to check if deviceIndex >= 0 + // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` + gpuGuard.set_index(deviceIndex); + } + LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " + << ncclComm->ncclComm_ << " on CUDA device: " << devName; + ncclComm->ncclCommAbort(abortReason); + // Note that we don't remove the aborted communicators from the + // cache. The reason is that if we do remove the communicator + // from the cache, it is possible that a new collective operation + // calls `ncclCommInitRank` to create a new communicator whereas + // other ranks might have failed/timed out and didn't enter + // `ncclCommInitRank`. As a result, when there is a failure on + // a communicator the application receives an exception and its + // their responsibility to destroy the process group and recreate + // it to recover from errors. + + LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed " + << " communicator on CUDA device: " << devName; + } +} + +// Abort all communicators on this rank +bool ProcessGroupNCCL::abort(std::optional abortReason) { + // This will log counter for how long the abort actually takes. + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); + // Remove record from global ncclCommDevIdxMapMutex before aboarting, + // so that a new cache segment would not register to already aborded + // communicators. Note that ncclCommDevIdxMap is a global container which may + // contain other PG's communicators, thus we need to only erase communicators + // for the current PG. + ncclCommDevIdxMapMutex.lock(); + for (auto& it : devNCCLCommMap_) { + auto& ncclComm = it.second; + ncclCommDevIdxMap.erase(ncclComm); + } + ncclCommDevIdxMapMutex.unlock(); + + std::lock_guard lock(mutex_); + abortCommsFromMap(devNCCLCommMap_, abortReason); + abortCommsFromMap(inInitializationCommMap_, abortReason); + return true; +} + +void ProcessGroupNCCL::shutdown(std::optional reason) { + // Don't join threads here since the purpose of this method is to abort all + // communicators and signal the threads to exit. Joining on the threads could + // potentially block and hence avoid it in this method. + terminateProcessGroup_.store(true); + workMetaListCV_.notify_one(); + + // lauch abort asynchrounously and wait for it to complete or timeout + LOG(INFO) << logPrefix() + << "Launching ProcessGroupNCCL abort asynchrounously."; + std::future fut = std::async( + std::launch::async, [this, &reason]() { return this->abort(reason); }); + + waitForFutureOrTimeout( + fut, options_->timeout, "ProcessGroup abort", true, false); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully."; + + // We need to wait for abort to finish before we can safely shut down + // heartbeat monitoring thread. + terminateHeartbeatMonitorThread_.store(true); + monitorWakeUpCV_.notify_one(); +} + +ProcessGroupNCCL::~ProcessGroupNCCL() { + LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered."; + + if (!terminateProcessGroup_.load()) { + if (rank_ % localDeviceCount_ == 0) { + TORCH_WARN_ONCE( + "WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. ", + "On normal program exit, the application should call destroy_process_group to ", + "ensure that any pending NCCL operations have finished in this process. " + "In rare cases this process can exit before this point and block the progress of " + "another member of the process group. This constraint has always been present, " + " but this warning has only been added since PyTorch 2.4"); + } + // If user haven't explicitly destroy/shutdown process group, destructor + // needs to do so + shutdown(); + } + + // Wait for all threads to finish before returning +#ifdef ENABLE_NCCL_ERROR_CHECKING + if (ncclCommWatchdogThread_.joinable()) { + ncclCommWatchdogThread_.join(); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; + } + if (ncclHeartbeatMonitorThread_.joinable()) { + ncclHeartbeatMonitorThread_.join(); + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL heart beat monitor thread joined."; + } +#endif + if (onCompletionHookThread_.joinable()) { + onCompletionHookThread_.join(); + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL onCompletionHookThread thread joined."; + } +} + +bool ProcessGroupNCCL::dumpDebuggingInfo() { + // Serialize all calls to this function to avoid corrupting data, but allow + // multiple calls in one runtime. User is responsible for preserving the + // output file from an earlier call before a later call overwrites it. + static std::mutex writeDebugInfoMutex; + std::lock_guard lock(writeDebugInfoMutex); + LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; + if (ncclTraceBufferSize_ > 0) { + // We dump nccl trace into local disk by default and users can register + // their customized writer by inheriting `DebugInfoWriter` via + // `registerDebugInfoWriter`. + auto ncclTrace = dump_nccl_trace(true, true, false); + DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to " + << writer.getWriterTarget(); + writer.write(ncclTrace); + return true; + } + return false; +} + +void ProcessGroupNCCL::terminateProcess(std::string errMsg) { + // Logging with `FATAL`, after errMsg printed, it calls `std::abort()` + // to terminate the program execution. + LOG(FATAL) << logPrefix() << errMsg; +} + +int computeDeltaMS( + std::chrono::time_point start, + std::chrono::time_point end) { + return std::chrono::duration_cast(end - start) + .count(); +} + +std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg( + const std::string& extraMsg) { + return c10::str( + logPrefix(), + "Received a dump signal due to a collective timeout from ", + extraMsg, + " and we will try our best to dump the debug info. ", + "Last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + ".", + "This is most likely caused by incorrect usages of collectives, e.g., wrong ", + "sizes used across ranks, the order of collectives is not same for all ranks ", + "or the scheduled collective, for some reason, didn't run. Additionally, ", + "this can be caused by GIL deadlock or other reasons such as network errors or ", + "bugs in the communications library (e.g. NCCL), etc. "); +} + +std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg( + const std::string& exitReason) { + return c10::str( + logPrefix(), + "Terminating the process after attempting to dump debug info, due to ", + exitReason, + "."); +} + +void ProcessGroupNCCL::heartbeatMonitor() { + c10::setThreadName("pt_nccl_heartbt"); + + uint64_t heartBeatCounter = 0ULL; + std::string errorMsg; + std::string exitReason; + bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0); + int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_ + : heartbeatTimeoutInSec_ * 1000; + auto lastTimePollStore = std::chrono::steady_clock::now(); + auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now(); + std::optional dumpPipe = std::nullopt; + if (local_id_ == 0) { + // DumpPipe is one per-trainer process, and its convenient to name them + // after 'global' ranks in the system, So we assume processgroup (uid)==0 is + // the global PG and has globally unique rank ids across trainers. + dumpPipe.emplace(rank_); + } + while (true) { + // This won't have any lock since this lock is only used here. + // Please be aware that mutex `monitorMutex_` should not be used + // somewhere else to avoid the deadlock. + std::unique_lock lock(monitorMutex_); + if (monitorWakeUpCV_.wait_for( + lock, std::chrono::milliseconds(monitorPollInterval), [&] { + return terminateHeartbeatMonitorThread_.load(); + })) { + // For the normal complete or user interception, monitorWakeUpCV_ + // will get notified, we early return and exit heartbeatMonitor. + return; + } + auto currentTime = std::chrono::steady_clock::now(); + + // We put extra functionality in the thread for the default PG (aka, + // local_id_=0) because the signal is same across different PGs. We only + // need to run once per process to avoid duplicate things performed in too + // many separate threads. For example, we check a global flag on the + // TCPStore periodically to see if any PG on any rank observed a timeout and + // signaled peers to dump debugging info, and we avoid hammering the + // TCPStore from all PGs on the same rank. + if (checkDumpSignal) { + // There are two scenarios where monitor thread will dump on timeout: + // 1. The current rank is the first to observe a timeout in watchdog. + // (shouldDump_ was set to true by the watchdog thread). + // 2. Other ranks detected the timeout and signal the current rank to + // dump. In addtion, monitor threads will dump if watchdog threads has no + // heartbeat or dumpPipe is not empty. + if (shouldDump_.load()) { + errorMsg = getNCCLWatchdogTimeoutErrorMsg("this local rank"); + exitReason = "collective timeout or exception"; + break; + } + // We poll store to see if some ranks have flagged a timeout when + // we haven't polled for `heartbeat_timeout` seconds and there haven't + // any work added or removed for `watchdog_timeout` seconds. + if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >= + kWatchdogThreadSleepMillis && + computeDeltaMS(lastTimePollStore, currentTime) >= + coordCheckIntervalMilSec_) { + lastTimePollStore = currentTime; + // Wrap globalStore_->check() in a try-catch block to avoid crashing if + // the store is not available. + bool checkExceptionDump = false; + try { + checkExceptionDump = + globalStore_->check({std::string(EXCEPTION_DUMP)}); + } catch (const std::exception& e) { + LOG(WARNING) + << logPrefix() + << "Failed to check the \"should dump\" flag on TCPStore, " + << "(maybe TCPStore server has shut down too early), with error: " + << e.what(); + // We give up for now assuming TCPStore has been torn down. + return; + } + + if (checkExceptionDump) { + int timeOutRank = -1; + if (!shouldDump_.load()) { + LOG(ERROR) + << logPrefix() + << "Observed flight recorder dump signal from another rank via TCPStore."; + } + shouldDump_.store(true); + try { + auto vec = globalStore_->get(std::string(EXCEPTION_DUMP)); + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == sizeof(int), + "Invalid size for the timeout rank ID"); + std::memcpy(&timeOutRank, vec.data(), vec.size()); + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() + << "Failed to get timeout rank ID from TCPStore." + << e.what(); + } + errorMsg = + getNCCLWatchdogTimeoutErrorMsg(c10::str(" rank ", timeOutRank)); + exitReason = "collective timeout or exception"; + break; + } + } + } + + if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >= + heartbeatTimeoutInSec_ * 1000) { + // Check the heart beat of watchdog thread. + lastTimeHeartBeatCheck = currentTime; + auto heartbeat = heartbeat_.load(); + if (heartbeat != heartBeatCounter) { + heartBeatCounter = heartbeat; + } else { + shouldDump_.store(true); + // Watchdog heartbeat timeout. + errorMsg = c10::str( + logPrefix(), + "ProcessGroupNCCL's watchdog got stuck for ", + heartbeatTimeoutInSec_, + " seconds without making progress in monitoring enqueued collectives. ", + "This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, ", + "and could be triggered by another thread holding the GIL inside a ", + "CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.", + "If you suspect the watchdog is not actually stuck and a longer timeout would help, ", + "you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value " + "or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)." + "If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout " + "or false positive abort; otherwise, please attempt to debug the hang. "); + exitReason = "ProcessGroupNCCL watchdog hang"; + break; + } + } + // process a request to dump the trace. only PG uid 0 will respond to dump + // requests, but this is fine since all PG's feed into the same flight + // recorder and dump. After dump, the training should continue. + if (dumpPipe.has_value() && dumpPipe->shouldDump()) { + // best effort dump, not waiting for the dump here + std::future fut = std::async( + std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); + } + } + LOG(ERROR) << errorMsg; + + auto& cpp_dumper = get_cpp_trace_dumper(); + if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) { + LOG(INFO) << "Dumping c++ stacktraces:"; + cpp_dumper.value()([](const std::string& line) { LOG(INFO) << line; }); + } + + if (checkDumpSignal && shouldDump_.load()) { + // Store debug info to storage if no other thread does it. (By default to + // local disk) + std::future asyncDebugDump = std::async( + std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); + + // wait for the dump until timeout - log data + waitForFutureOrTimeout( + asyncDebugDump, + std::chrono::milliseconds(waitTimeoutDumpInMilSec_), + "Flight recorder dump in heartbeatMonitor", + false, + true); + } + + if (get_gil_checker() != nullptr) { + auto fut = launchAsyncGilCheck(); + auto kGilCheckTimeout = std::chrono::milliseconds(300); + auto futStatus = fut.wait_for(kGilCheckTimeout); + if (futStatus != std::future_status::ready) { + TORCH_CHECK( + futStatus != std::future_status::deferred, + "Expected the future to have been launched eagerly."); + LOG(ERROR) + << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang"; + } + } else { + LOG(INFO) + << "GIL checker was not registered, perhaps this is a no-python build?"; + } + + // There are two possible cases for the watchdog thread exit: + // Case one: desync report runs quickly, and it follows the step: + // collective timeout -> desync -> exception handling -> destructors + // -> set terminateHeartbeatMonitorThread_ -> notify monitorWakeUpCV_. + // So the code either early returns above or will skip the sleep below. + // Case two: desync might be slow or get stuck. Or we get stuck in + // destructors, we will sleep for some time before calling std::abort() to + // kill the whole process. + if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() || + shouldDump_.load()) && + !terminateHeartbeatMonitorThread_.load()) { + // Leave another two mins for desync report generation or process group + // destroy. + std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_)); + LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ + << " waiting for desync report or process group destroy."; + } + + // At this point, we either already sleep for another `heartbeatTimeoutInSec_` + // or the thread has finished. Because we don't want to block the monitor + // thread, so We mark the thread detach and the dump of debug info becomes + // "best effort". If the process exit normally, marking it detach also makes + // sense because we don't really care about dumping the debug info. + + // We already log completion inside the thread, so it may not be necessary to + // check the return value here. We mainly use a future so we can exit early + // if done. + + if (!terminateHeartbeatMonitorThread_.load()) { + // Create a error message reported from MonitorThread, so + // we throw exception and make the whole process to be killed. + // TODO(fduwjj): After having a hang debug wiki, we need to update the wiki + // url here. + if (monitorThreadEnabled_.load()) { + terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason)); + } else { + // Ideally we want to merge this one with the above one, but we are going + // to remove the kill switch for monitor thread soon, so we keep this one + // for now. + LOG(ERROR) + << logPrefix() + << "ProcessGroupNCCL monitor thread is disabled, but would have terminated the process" + << "after attempting to dump debug info, due to " << exitReason + << "."; + } + } +} + +void ProcessGroupNCCL::ncclCommWatchdog() { + c10::setThreadName("pt_nccl_watchdg"); + + try { + VLOG(2) << logPrefix() << "Process group watchdog thread started!"; + ncclHeartbeatMonitorThread_ = + std::thread(&ProcessGroupNCCL::heartbeatMonitor, this); + watchdogHandler(); + VLOG(2) << logPrefix() + << "Process group watchdog thread terminated normally"; + } catch (std::exception& e) { + if (std::string(e.what()).find("driver shutting down") != + std::string::npos) { + LOG(INFO) + << logPrefix() + << "main process destroyed cuda before watchdog loop exited, terminating watchdog." + << " (Watchdog caught exception: " << e.what(); + + } else { + // Append error message reported from watchdogHandler + const auto exitMsg = c10::str( + logPrefix(), + "Process group watchdog thread terminated with exception: ", + e.what()); + LOG(ERROR) << exitMsg; + if (C10_LIKELY(rethrowCUDAErrors_) || + !(std::string(e.what()).find("CUDA Error"))) { + // TODO(whc) clean up the rethrow - why is it stored in a class var and + // rethrown? + watchDogException_ = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); + std::rethrow_exception(watchDogException_); + } + } + } catch (...) { + const auto exitMsg = c10::str( + logPrefix(), + "Process group watchdog thread terminated with exception: unknown"); + LOG(ERROR) << exitMsg; + watchDogException_ = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); + std::rethrow_exception(watchDogException_); + } +} + +void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) { + if (work.startTraceUpdated_) + return; + + if (terminateProcessGroup_.load() || storeError_) + return; + + work.startTraceUpdated_ = true; + storeError_ = !c10d::traceUpdate( + store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_)); +} + +void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) { + if (terminateProcessGroup_.load() || storeError_) + return; + + // In case the start of the work hasn't been logged + if (!work.startTraceUpdated_) { + logWorkStart(work); + } + + storeError_ = !c10d::traceUpdate( + store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_)); +} + +std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() { + return retrieveDesyncReport(store_, "NCCL", rank_, size_); +} + +// We want to have both PG ID and global unique ID (guid) for the logging +// prefix. PG ID records how many ProcessGroupNCCL objects were created on a +// specific rank and is a stable index across ranks, which lets users reason +// about, for example, the second PG we initialized on this rank is for FSDP, +// and corresponds with PG ID = 1 on other ranks as well. Unlike PG ID, guid (or +// group name) is a global unique ID across ranks. The guid is either a hash of +// all the ranks in the group or a counter of how many times +// `_process_group_name` is called, essentially it means how many times we +// have PGs users have created. Before using split_group, even if +// we are creating a new sub-PG, all ranks have to call the API at the same +// time, and this makes `group_name` a unique identifier for a group (PG). +std::string ProcessGroupNCCL::createLogPrefix() const { + if (!pg_desc_.empty() && pg_desc_ != "undefined") { + return c10::str( + "[PG ID ", + local_id_, + " PG GUID ", + pg_uid_, + "(", + pg_desc_, + ") Rank ", + rank_, + "] "); + } + return c10::str( + "[PG ID ", local_id_, " PG GUID ", pg_uid_, " Rank ", rank_, "] "); +} + +const std::string& ProcessGroupNCCL::logPrefix() const { + return logPrefix_; +} + +const int& ProcessGroupNCCL::globalRank() const { + static int globalRank = rank_; + return globalRank; +} + +const std::vector& ProcessGroupNCCL::groupRanks() const { + if (options_->global_ranks_in_group.empty() && local_id_ == 0) { + static std::vector globalRanks(size_); + std::iota(globalRanks.begin(), globalRanks.end(), 0); + return globalRanks; + } + return options_->global_ranks_in_group; +} + +void ProcessGroupNCCL::addEphemeralTimeout( + const std::chrono::milliseconds& timeout) { + std::lock_guard timeoutLock(mtxTimeoutExtension_); + ephemeralTimeoutActive_ += timeout; +} + +bool ProcessGroupNCCL::verifyWorkTimeoutForTest( + const c10::intrusive_ptr work, + const std::chrono::milliseconds& timeout) { + // Since collective returns a c10d::Work, we need to cast it to WorkNCCL. + if (auto workNCCL = c10::dynamic_intrusive_pointer_cast(work)) { + // workNCCL is now a c10::intrusive_ptr + return workNCCL->opTimeout_ == timeout; + } + C10_THROW_ERROR( + DistBackendError, "Non c10d::WorkNCCL object returned from collective"); +} + +void ProcessGroupNCCL::watchdogHandler() { + bool done = false; + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); + auto lastStatusUpdateTime = std::chrono::steady_clock::now(); + std::list completedWorkList; + + while (!done || !terminateProcessGroup_.load()) { + std::unique_lock lock(workMetaListMutex_); + // We busy-poll the work vector every kWatchdogThreadSleepMillis + // milliseconds as long as the atomic is True. + workMetaListCV_.wait_for( + lock, + std::chrono::milliseconds(kWatchdogThreadSleepMillis), + [&]() -> bool { return terminateProcessGroup_.load(); }); + // Bump up heart beat by one. + heartbeat_++; + +// Some versions of GLOG support less-spammy version of LOG_EVERY_MS +// in which case we don't want to spam the logs. +#ifdef LOG_EVERY_MS + // Log the progress of this PG periodically + C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str( + logPrefix(), + "NCCL Work update periodically: ", + "last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + "."); +#endif + auto logger = ::c10d::C10dLogger::getLogger(); + if (logger && + computeDeltaMS( + lastStatusUpdateTime, std::chrono::steady_clock::now()) >= + kWorkStatusUpdatePeriodMs) { + ::c10d::C10dLoggingData data; + // logging integers + data.integers["pg_id"] = local_id_; + data.integers["rank"] = rank_; + data.integers["global_rank"] = globalRank(); + data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq; + data.integers["last_started_work"] = pgStatus_->lastStartedSeq; + data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq; + data.integers["last_enqueued_numel_in"] = pgStatus_->lastEnqueuedNumelIn; + data.integers["last_enqueued_numel_out"] = + pgStatus_->lastEnqueuedNumelOut; + data.integers["last_completed_numel_in"] = + pgStatus_->lastCompletedNumelIn; + data.integers["last_completed_numel_out"] = + pgStatus_->lastCompletedNumelOut; + // logging strings + data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName; + data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName; + data.strings["last_completed_work_name"] = + pgStatus_->lastCompletedWorkName; + data.strings["pg_name"] = pg_uid_; + data.strings["pg_desc"] = pg_desc_; + logger->log(data); + lastStatusUpdateTime = std::chrono::steady_clock::now(); + } + + for (auto it = workMetaList_.begin(); it != workMetaList_.end(); + /* no increment */) { + auto& work = *it; + // When terminateProcessGroup_ is true, communicators have already been + // aborted, So cannot check exception based on them. But watchdog needs to + // finish the check for the works that have already been enqueued to + // workMetaList_ + if (!terminateProcessGroup_.load()) { + work.checkAndSetException(); + } + bool timedOut = work.checkTimeout(); + + // If work hits an exception (either an error or timeout) + if (work.exception()) { + // log as soon as exception is detected + LOG(ERROR) << c10::str( + logPrefix(), + "Exception (either an error or timeout) detected by watchdog at work: ", + work.seq_, + ", last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + "."); + // try to notify other ranks via global TCPStore to dump the flight + // recorder when a collective timeout or exception happens. Flight + // recorder behavior is independent of desync Debug. + if (dumpOnTimeoutOrEx_) { + try { + auto rank = globalRank(); + auto vec = std::vector( + reinterpret_cast(&rank), + reinterpret_cast(&rank) + sizeof(rank)); + globalStore_->set(std::string(EXCEPTION_DUMP), vec); + if (!shouldDump_.load()) { + LOG(ERROR) + << logPrefix() + << "Broadcasting flight-recorder dump signal to other processes via TCPStore."; + } + // signal the monitor thread on PG0 to start dumping + shouldDump_.store(true); + if (sleepAfterException_) { + // This sleep is used to give time for dumping before throwing + // exception + std::this_thread::sleep_for( + std::chrono::seconds(heartbeatTimeoutInSec_)); + LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ + << " giving time for flight recorder dumps to finish."; + } + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() + << "Failed to set dump signal in tcpstore. " + << "Error: " << e.what(); + } + } + + if (SHOULD_CLEAN_UP(asyncErrorHandling_)) { + // Abort work and corresponding communicators + work.abort(); + // PG level abort, which would abort all other communicators on this + // rank + abort(); + } + + // Report desync state in case of timeout + if (timedOut) { + LOG(ERROR) << c10::str( + logPrefix(), + "Timeout at NCCL work: ", + work.seq_, + ", last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + "."); + if (desyncDebug_) { + try { + collectiveDebugInfoMode_.store(true); + auto desyncMsg = getNCCLWatchdogDebugInfo(); + LOG(ERROR) << logPrefix() << desyncMsg; + } catch (const std::exception& e) { + LOG(ERROR) + << logPrefix() + << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " + << " Please file an issue. Error: " << e.what(); + } catch (...) { + LOG(ERROR) + << logPrefix() + << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." + << " Please file an issue."; + } + } + } + // Throw exception + work.handleException(asyncErrorHandling_); + } + + // Work status logging for desync debug + if (desyncDebug_) { + if (work.isStarted()) { + logWorkStart(work); + } + if (work.isCompleted()) { + logWorkEnd(work); + } + } + + // a work could be started but not completed, so we should not update + // lastStartedSeq and lastStartedOpName if the work state is checked + // multiple times after the start + if (pgStatus_->lastStartedSeq < static_cast(work.seq_) && + work.isStarted()) { + pgStatus_->lastStartedSeq = work.seq_; + pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); + } + + // Clean up completed work + if (work.isCompleted()) { + { + // Reset the timeout and first work if the work is completed. + std::lock_guard timeoutLock(mtxTimeoutExtension_); + if (work.ownedEphermeralTimeout_.count() > 0) { + ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; + ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; + } + } + pgStatus_->lastCompletedSeq = work.seq_; + pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); + pgStatus_->lastCompletedNumelIn = work.numelIn_; + pgStatus_->lastCompletedNumelOut = work.numelOut_; + NCCLTraceBuffer::get()->retire_id(work.trace_id_, true); + if (onCompletionHook_) { + // Move Work object to completedWorkList_ to be consumed by the hook + // thread + { + const std::lock_guard lock(completedWorkListMutex_); + completedWorkList_.splice( + completedWorkList_.end(), workMetaList_, it++); + } + completedWorkListCV_.notify_one(); + } else { + it = workMetaList_.erase(it); + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); + } + at::cuda::CUDAGraph::dec_pending_event_queries(); + } else { + // Increment the iterator if the current WorkNCCL object is not + // completed. + ++it; + } + // Increment heartbeat after each work processed, + // in case processing is slowed down (but not hung) by cuda api contention + heartbeat_++; + } + done = workMetaList_.empty(); + } +} + +void ProcessGroupNCCL::runHookLoop() { + c10::setThreadName("pt_nccl_runhook"); + + bool done = false; + while (!done || !terminateProcessGroup_.load()) { + std::unique_lock lock(completedWorkListMutex_); + // We busy-poll the work vector every kWatchdogThreadSleepMillis + // milliseconds as long as the atomic is True. + completedWorkListCV_.wait_for( + lock, + std::chrono::milliseconds(kWatchdogThreadSleepMillis), + [&]() -> bool { + return !completedWorkList_.empty() || terminateProcessGroup_.load(); + }); + + try { + for (auto it = completedWorkList_.begin(); it != completedWorkList_.end(); + /* no increment */) { + const WorkNCCL& work = *it; + // Hook might grab GIL, unlock first to prevent deadlock + lock.unlock(); + + auto timeStarted = + std::chrono::system_clock::now() + + std::chrono::duration_cast( + work.workStartTime_ - std::chrono::steady_clock::now()); + onCompletionHook_(std::make_shared( + work.retrieveOpType(), // OpType + work.getSequencenumber(), // seq + timeStarted, // timeStarted + std::chrono::system_clock::now(), // timeFinished + std::chrono::duration( + work.getDuration()) // activeDuration + )); + + lock.lock(); + it = completedWorkList_.erase(it); + } + } catch (std::exception& e) { + if (std::string(e.what()).find("driver shutting down") != + std::string::npos) { + LOG(INFO) + << logPrefix() + << "main process destroyed cuda before runHookLoop exited, terminating runHookLoop." + << " (runHookLoop caught exception: " << e.what(); + + } else { + // PythonOnCompletionHook has already extracted Python exception message + // and wrapped it with a cpp one. So we no longer need to acquire GIL + // here. + const auto errorStr = c10::str( + "Caught exception on rank ", + rank_, + " while running onCompletion hook for ProcessGroupNCCL: ", + e.what(), + ". Aborting all communicators."); + + // No need to call abort() on WorkNCCL here as that collective has + // already finished successfully at this point. We just need to abort + // the process Abort all NCCL Communicators on this ProcessGroupNCCL + // instance. + abort(errorStr); + } + } + + // Lock is still acquired at this point + done = completedWorkList_.empty(); + } +} + +std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors() { + return checkForNCCLErrorsInternal(ncclComm_); +} + +std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors( + std::shared_ptr& ncclComm) { + return checkForNCCLErrorsInternal(ncclComm); +} + +std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( + std::shared_ptr& ncclComm) { + // Prioritize commFailureReason over checkForNcclError() result if + // commFailureReason is set. + auto commFailureReason = ncclComm->getNcclCommFailureReason(); + if (commFailureReason != std::nullopt) { + return std::make_exception_ptr(C10_BUILD_ERROR( + DistBackendError, + c10::str( + "NCCL communicator encountered error set by ProcessGroupNCCL: ", + *commFailureReason))); + } + ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError(); + // When nonblocking mode is enabled by TORCH_NCCL_USE_COMM_NONBLOCKING, + // ncclInProgress could be returned when there are pending NCCL calls. + // In this case, no exception should be thrown +#ifdef NCCL_HAS_COMM_NONBLOCKING + // ncclInProgress is defined only if NCCL_HAS_COMM_NONBLOCKING is defined + if (ncclAsyncErr != ncclSuccess && ncclAsyncErr != ncclInProgress) { +#else + if (ncclAsyncErr != ncclSuccess) { +#endif + return std::make_exception_ptr(C10_BUILD_ERROR( + DistBackendError, + "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" + + getNcclErrorDetailStr(ncclAsyncErr))); + } + + return nullptr; +} + +void ProcessGroupNCCL::broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + bool isSingleP2POp, + const std::string& p2pKey, + int p2pRank) { + // For collective operations: + // For every NCCL communicator that we create we need to broadcast + // a unique ID from rank 0 to all other ranks. This broadcast is + // done by rank 0 setting a key in the store and all other ranks + // retrieving the contents of that key. A single process group + // may create multiple NCCL communicators, so we use a sequence + // number to differentiate between them. + // For single point-to-point operations: + // The sequence number will only be increased on 2 out of all the + // processes in a Process Group. So all following collective + // operations will see different sequence numbers which will cause + // runtime errors. To avoid that, use the src:target pair instead + // of sequence number for p2p communications. + + std::string storeKey; + if (!isSingleP2POp) { + storeKey = std::to_string(ncclCommCounter_++); + } else { + storeKey = p2pKey; + } + if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) { + auto vec = std::vector( + reinterpret_cast(ncclID), + reinterpret_cast(ncclID) + NCCL_UNIQUE_ID_BYTES); + store_->set(storeKey, vec); + } else { + try { + auto vec = store_->get(storeKey); + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == NCCL_UNIQUE_ID_BYTES, + "Invalid size for ncclUniqueId"); + std::memcpy(ncclID, vec.data(), vec.size()); + } catch (const std::exception& e) { + std::string exceptionMsg = c10::str( + "[", + rank_, + "] is setting up NCCL communicator and " + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", + storeKey, + "', but store->get('", + storeKey, + "') got error: "); + C10_THROW_ERROR( + DistBackendError, + exceptionMsg + e.what() + + ". This may indicate a possible application crash on rank 0 or a network set up issue."); + } catch (...) { + C10_THROW_ERROR( + DistBackendError, + c10::str( + "Unknown exception while [", + rank_, + "] is setting up NCCL communicator and " + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", + storeKey, + "'", + ". This may indicate a possible application crash on rank 0 or a network set up issue.")); + } + } +} + +void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { + std::lock_guard lock(mutex_); + if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { + TORCH_INTERNAL_ASSERT( + false, + "Expected to find key ", + devNCCLCommMapKey, + " in NCCL communicator map."); + } + std::shared_ptr& ncclComm = devNCCLCommMap_[devNCCLCommMapKey]; + // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being + // destroyed, so using ncclCommAbort here. + ncclComm->ncclCommAbort(); + // Remove communicators from the cache. + devNCCLCommMap_.erase(devNCCLCommMapKey); + // Clear used device indices. + usedDeviceIdxs_.clear(); + + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.erase(ncclComm); + ncclCommDevIdxMapMutex.unlock(); +} + +std::shared_ptr ProcessGroupNCCL::getNCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank, + bool isSendRecvSelf) { + // Sanity check + if (deviceKey.empty()) { + C10_THROW_ERROR( + DistBackendError, + "Not able to create/get the NCCL Communicator since " + "the GPU devices are not known"); + } + if (bound_device_id_) { + if (*bound_device_id_ != device) { + LOG(ERROR) << logPrefix() << "Tensor found on device " << device + << " but backend constrained to " << *bound_device_id_; + C10_THROW_ERROR( + DistBackendError, + "Attempt to perform collective on tensor not on device passed to init_process_group"); + } + } + + usedDeviceIdxs_.insert(device.index()); + + { + std::lock_guard lock(mutex_); + if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + return devNCCLCommMap_[deviceKey]; + } + } + + // NCCL communicator not cached, create a new entry + std::shared_ptr ncclComm; + + // Create the unique NCCL ID and broadcast it + ncclUniqueId ncclID; + + // reset log prefix to include group_desc + logPrefix_ = createLogPrefix(); + +#ifdef NCCL_COMM_DESCRIPTION + // Pass process group name and description to NCCL communicator + std::string commDesc = pg_desc_ + ':' + pg_uid_; + options_->config.commDesc = strdup(commDesc.c_str()); +#endif + + // For batch_isend_irecv, ncclGroupStart() would be called upfront + bool batchP2P = ncclActiveGroupCounter_ > 0; + bool singleP2POp = isP2POp(opType, batchP2P); + + // Get the device index + auto deviceIndex = device.index(); + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // [Group Start/End Note] This is used to ensure that nccl communicator will + // be created before communication primitives are called. Let's look at this + // example: Using the batch_isend_irecv to send a tensor to a target process. + // On the sender side, the corresponding underlying NCCL calls will look like + // ncclGroupStart() // This is in batch_isend_irecv + // ncclCommInitRank() // Inside NCCLComm::create + // ncclSend() + // ncclGroupEnd() // This is in batch_isend_irecv + // With this pattern, the nccl communicator will be created in the last + // ncclGroupEnd which means when ncclSend is processed, the passed + // communicator argument is NULL which will lead to runtime error. So we need + // to "close" all active nccl groups to ensure nccl communicator is actually + // created before encountering any communication calls. This is why we need + // the following for loop. + for (const auto i : c10::irange(ncclActiveGroupCounter_)) { + (void)i; + // comms have not been initiated yet, so can only check in blocking-way + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + } + + // GPU world size and GPU rank + int numRanks, rank; + + if (!singleP2POp) { + // Collective, all-to-all, or batch P2P + numRanks = getSize(); + rank = getRank(); + } else if (isSendRecvSelf) { + // Same process send and recv. + numRanks = 1; + rank = 0; + } else { + // For single point-to-point operation, there are only 2 processes + // involved so the GPU rank is either 0 or 1. + numRanks = 2; + rank = p2pRank; + } + +#ifdef NCCL_HAS_COMM_SPLIT + if (options_->split_from) { + TORCH_CHECK( + options_->split_color != 0, + "Must specify a non-zero color when splitting"); + // Find a valid, healthy communicator to split from if possible. + std::lock_guard lock(options_->split_from->mutex_); + auto& other_comms = options_->split_from->devNCCLCommMap_; + auto dit = other_comms.find(getKeyFromDevice(device)); + if (dit != other_comms.end()) { + auto& parentComm = dit->second; + if (parentComm != nullptr && !parentComm->isAborted()) { + ncclComm = NCCLComm::split( + parentComm.get(), + options_->split_color, + rank, + options_->config, + options_->global_ranks_in_group); + } + } + } +#endif + + // To simplify conditional nesting, just create the ncclComms[i] + // entry if it hasn't been yet rather than untangling the + // conditions that might have resulted in a split above. + if (!ncclComm) { + if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) { + // For point-to-point communication, lower rank of the two will get unique + // id. + if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { + C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt); + } + + // Broadcast so that each process can have a unique NCCL ID + auto timeStarted = std::chrono::steady_clock::now(); + broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank); + auto timerDeltaMs = + std::chrono::duration_cast>( + std::chrono::steady_clock::now() - timeStarted) + .count() * + 1000; + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL broadcast unique ID through store took " + << timerDeltaMs << " ms"; + } + +#ifdef NCCL_HAS_COMM_NONBLOCKING + ncclComm = NCCLComm::create(numRanks, rank, ncclID, options_->config); +#else + ncclComm = NCCLComm::create(numRanks, rank, ncclID); +#endif + } + + // Creates the NCCL streams + bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); + auto streamVal = at::cuda::getStreamFromPool( + options_->is_high_priority_stream || force_high); + + { + std::lock_guard lock(mutex_); + inInitializationCommMap_.emplace(deviceKey, ncclComm); + } + + NCCLTraceBuffer::get()->record_pg_ranks( + std::make_tuple(pg_uid_, pg_desc_), groupRanks()); + + RECORD_PARAM_COMMS( + 0, // seq + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank, // rank + "init", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + size_); // worldSize + + LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ " + << ncclComm->ncclComm_ << " on CUDA device: " << deviceIndex; + + // At this point NCCL should have been initialized, hence we can accurately + // get the env value even if NCCL sets it by reading from nccl.conf file + LOG(INFO) << logPrefix() + << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A"); + + // See [Group Start/End Note] + for (const auto i : c10::irange(ncclActiveGroupCounter_)) { + (void)i; + C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); + } + + ncclStreams_.emplace(deviceKey, std::move(streamVal)); + + // Note: these events are created with the (default) cudaEventDisableTiming + // flag This flag provides the best performance when used with + // cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't measure the + // performance using cudaEvent, this should be set. + // TODO(kwen2501): is ncclEvents_ used anywhere else? + ncclEvents_.emplace(deviceKey, at::cuda::CUDAEvent(cudaEventDisableTiming)); + + // Move the NCCL resource to cache + auto it = inInitializationCommMap_.find(deviceKey); + // A previous thread could've already removed devicesKey from + // inInitializationCommMap_ and added it to devNCCLCommMap_ + if (it != inInitializationCommMap_.end()) { + devNCCLCommMap_.emplace(deviceKey, std::move(it->second)); + inInitializationCommMap_.erase(deviceKey); + + // Now ncclComms are fully initialized. + // Register all active CUDA memory segments in cache allocator to + // the new NCCL communicators + if (useTensorRegisterAllocatorHook_) { + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + // Register the segment to a new NCCL communicator if on the same device + for (const auto& segmentInfo : snapshot.segments) { + TORCH_INTERNAL_ASSERT( + segmentInfo.device == device.index(), + "Mismatch between CUDA memory segment device and current device"); + ncclComm->registerSegment( + reinterpret_cast(segmentInfo.address), + segmentInfo.total_size); + } + } + // Record the mapping between ncclComm and device index so that later + // register hook can register a newly allocated segment to communicators + // on the same device. + // NOTE: we need remove the communicator from this map when it is + // destroyed, otherwise may register onto an invalid communicator. + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.emplace(ncclComm, device.index()); + ncclCommDevIdxMapMutex.unlock(); + } -constexpr auto kProcessGroupDefaultTimeout = - std::chrono::milliseconds(30 * 60 * 1000); + it = devNCCLCommMap_.find(deviceKey); + TORCH_INTERNAL_ASSERT( + it != devNCCLCommMap_.end(), "Communicators not populated in cache!"); -namespace c10d { + return it->second; +} -// ProcessGroup is a base class that captures collective and point to -// point communication in a fixed set of processes. -// -// The functions specified in the class below describe the API alone; -// implementations are provided in subclasses. -// -// Every function that performs I/O is executed asynchronously by a -// thread pool owned by the ProcessGroup (by default). They return an -// object that can be used to wait for completion or error. -// -// The ProcessGroup can instantiate subgroups with fewer or an equal -// number of members. Implementations must take care that multiple -// process groups can be used in parallel and synchronize accordingly. -// -// The ProcessGroup assumes a fixed set of processes. If the set -// changes, existing instances must be destructed and instantiation -// and initialization must start from scratch. For members of the -// process group to find each other (referred to as rendezvous from -// hereon) -// -class TORCH_API ProcessGroup : public torch::CustomClassHolder { - public: - enum BackendType : uint8_t { - UNDEFINED = 0, - GLOO = 1, - NCCL = 2, - UCC = 3, - MPI = 4, - CUSTOM = 5, - XCCL = 6, - }; +uint64_t ProcessGroupNCCL::getCommSplitCounter() const { + uint64_t ret = 0; + for (const auto& i : devNCCLCommMap_) { + auto& ncclComm = i.second; + ret += ncclComm->getCommSplitCounter(); + } + return ret; +} - static std::string backendTypeToString(const BackendType& type) { - switch (type) { - case BackendType::GLOO: - return "gloo"; - case BackendType::NCCL: - return "nccl"; - case BackendType::XCCL: - return "xccl"; - case BackendType::UCC: - return "ucc"; - case BackendType::MPI: - return "mpi"; - case BackendType::UNDEFINED: - return "undefined"; - case BackendType::CUSTOM: - return "custom"; - default: - TORCH_CHECK(false, "THis should never happen!"); - } - }; +namespace { - static BackendType strToBackendType(const std::string& backend) { - if (backend == "undefined") { - return BackendType::UNDEFINED; - } else if (backend == "gloo") { - return BackendType::GLOO; - } else if (backend == "nccl") { - return BackendType::NCCL; - } else if (backend == "xccl") { - return BackendType::XCCL; - } else if (backend == "ucc") { - return BackendType::UCC; - } else if (backend == "mpi") { - return BackendType::MPI; +// Check validity of tensor +void check_gpu_single_tensor( + const at::Tensor& tensor, + const bool p2p = false // whether operation is a P2P operation +) { + if (!tensor.is_cuda() || tensor.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + } + // Skip the following requirements for P2P operations + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + if (p2p) { + TORCH_WARN_ONCE( + "Detected non-contiguous tensor in P2P operations. It is user " + "responsibility to guarantee that source and destination tensors have " + "the same contiguity format."); } else { - return BackendType::CUSTOM; + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); } - }; + } +} + +// Checks that all `tensors' have the same type and shape and reside on the same +// GPU. +// TODO: test_c10d_nccl.py should consider adding tests for the error conditions +// here, ie, that deliberately pass invalid tensors and check the right +// exception is thrown. The "Expected list of tensors on the same device" +// condition may be a challenge because the test would need to pass tensors on +// different devices in the same process. +int64_t check_gpu_tensors_same_device(const std::vector& tensors) { + if (tensors.size() == 0) { + C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); + } + + const auto& first = tensors.front(); - // Not used, set for backwards compatibility and only used for TypeDef in - // Ops.cpp - explicit ProcessGroup(int rank, int size); + int64_t total_numel = 0; + for (const auto& t : tensors) { + if (!t.is_cuda() || t.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + } + if (t.scalar_type() != first.scalar_type()) { + C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + } + if (!t.is_non_overlapping_and_dense()) { + C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); + } + // If we're in this function, the user called a _coalesced collective + // on a set of tensors with potentially different sizes and strides. + // Therefore, we don't check for matching sizes and strides, + // but we do double-check tensors are on the same device. + TORCH_CHECK_WITH( + ValueError, + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); + total_numel += t.numel(); + } - explicit ProcessGroup( - const c10::intrusive_ptr<::c10d::Store>& store, - int rank, - int size); - ~ProcessGroup() override; + return total_numel; +} - int getRank() const { - return rank_; +bool check_same_size(const std::vector& input_tensors) { + for (const auto& input_tensor : input_tensors) { + if (!input_tensors[0].is_same_size(input_tensor)) { + return false; + } } + return true; +} + +} // namespace - int getSize() const { - return size_; +c10::intrusive_ptr ProcessGroupNCCL::initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle, + const std::vector& inputs, + const std::vector& outputs, // TODO(kwen2501): necessary? + bool record) { + auto r = c10::make_intrusive( + pg_uid_, + pg_desc_, + device, + rank, + opType, + seqCollective_, + profilingTitle, + profilingTitle != nullptr ? std::optional>(inputs) + : std::nullopt, + desyncDebug_, + enableTiming_.load(), + cudaEventCacheEnabled_.load(), + dist_debug_level_); + if (record) { + bool isP2P = isP2POp(opType); + // Ideally record every work that we enqueue, rather than every work we + // create. + // - at the time of this PR we do not currently enqueue every created work + // - but it is unsafe to steal refs to start/end cuda events from Works that + // may go out of scope before flight recorder has retired them, + // so we must ensure that any work that is initialized via initWork will + // be enqueued + // - initially, moved record() into workEnqueue(), but found that makes it + // hard to get access to profilingTitle, + // inputs, and outputs for metadata recording, and we don't want to attach + // these objects to the Work becuase it has implications for keeping those + // tensors alive longer and adds overhead when copying Work objects + // between threads + r->trace_id_ = NCCLTraceBuffer::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), + seqCollective_, + seqP2P_, + op_id_, + profilingTitle ? profilingTitle : "", + inputs, + outputs, + r->ncclStartEvent_.get(), + r->ncclEndEvent_.get(), + options_->timeout, + pgStatus_, + isP2P); } + return r; +} + +// TODO(kwen2501): deprecate +std::vector ProcessGroupNCCL::WorkNCCL::result() { + return *outputs_; +} + +c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: + getFuture() { + return future_; +} - // Returns an unique opaque ID of this process group object. - int64_t getID() const { - return reinterpret_cast(this); +float ProcessGroupNCCL::WorkNCCL::getDuration() const { + TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled"); + TORCH_CHECK( + ncclStartEvent_, + "getDuration only works if ncclStartEvents_ is populated, true if timing enabled"); + TORCH_CHECK( + ncclEndEvent_, + "getDuration only works if ncclEndEvents_ is populated, which should always be true"); + return ncclStartEvent_->elapsed_time(*ncclEndEvent_); +} + +uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const { + return seq_; +} + +void ProcessGroupNCCL::assignTimeoutToWork( + const c10::intrusive_ptr& work, + const c10::intrusive_ptr& option) { + std::chrono::milliseconds timeout = option->timeout; + std::lock_guard timeoutLock(mtxTimeoutExtension_); + if (ephemeralTimeoutActive_.count() > 0) { + timeout += ephemeralTimeoutActive_; } + work->opTimeout_ = timeout; + work->ownedEphermeralTimeout_ = + ephemeralTimeoutActive_ - ephemeralTimeoutInflight_; + ephemeralTimeoutInflight_ = ephemeralTimeoutActive_; +} - // Returns an unique opaque ID of a backend for the specific backend type - // that can correlate with this process group's collectives. - int64_t getBackendID(BackendType backend_type) const { - return reinterpret_cast(getBackend(backend_type).get()); +void ProcessGroupNCCL::workEnqueue( + c10::intrusive_ptr work) { + if (!terminateProcessGroup_.load()) { + std::lock_guard lock(workMetaListMutex_); + // Avoid view tensors to be processed in cleanup thread. + // View tensors' destruction invokes autograd_meta, which + // needs to be destructed in user thread. Otherwise will + // get deadlock. Here we enqueue work without outputs_. + workMetaList_.emplace_back(*work); + // update the PG status related to the last enqueued work + pgStatus_->lastEnqueuedSeq = work->seq_; + pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_); + pgStatus_->lastEnqueuedNumelIn = work->numelIn_; + pgStatus_->lastEnqueuedNumelOut = work->numelOut_; + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); } +} - virtual const std::string getBackendName() const { - return backendTypeToString(backendType_); - }; +ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) + : Backend::Options(NCCL_BACKEND_NAME, kProcessGroupNCCLDefaultTimeout), + is_high_priority_stream(is_high_priority_stream) {} - BackendType getBackendType() const { - return backendType_; - }; +static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; - virtual void startCoalescing(c10::DeviceType deviceType) { - // only nccl has implemented startCoalescing so only execute for nccl - // backends - auto backend = getBackend(deviceType); - backend->startCoalescing(); +void ProcessGroupNCCL::startCoalescing() { + // Other collective ops bump seq_ before creating a work. Thus, if coalesced + // ops bump seq_ only after initing a work they will collide with (reuse) the + // seq_ of the last non-coalesced collective. Previously, seq_ was bumped + // inside endCoalescing, but before initWork. Since we now record individual + // ops from a coalesce group into the flight recorder, we want to have the + // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during + // start, which has one minor downside- we burn a seq_ if someone ever does a + // 'start' and 'end' coalescing region without doing an operation inbetween. + + // Don't bump op_id_ here, because startCoalescing isn't a logical operation. + // Bump it for each logical op inside the coalescing group. + if (coalescing_state_ & CoalP2P) { + seqP2P_++; + } else { + seqCollective_++; } - virtual c10::intrusive_ptr endCoalescing(c10::DeviceType deviceType) { - // only nccl has implemented endCoalescing so only execute for nccl - // backends - auto backend = getBackend(deviceType); - auto work = backend->endCoalescing(); - return work; + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); +} + +// `optype` is for specifying a composite optype, such as ALLGATHER and +// REDUCE_SCATTER +c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { + if (coalescedComm_ == nullptr) { + // There is no actual work being coalesced, return here + groupEnd(); + coalescing_state_ = 0; + return nullptr; + } + TORCH_CHECK( + coalescedDevice_.index() >= 0, + "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + + // `coalescedComm_` should have same set of comms across collectives + auto comm = coalescedComm_; + // `coalescedDevice_` should have same set of devices across collectives + auto device = coalescedDevice_; + + // `getKeyFromDevice` is how we get keys for both collectives and batch P2P + const auto key = getKeyFromDevice(device); + auto ncclStream = ncclStreams_.at(key); + + // Create Work object + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + bool enqueue = + (coalescing_state_) && capture_status == c10::cuda::CaptureStatus::None; + auto work = + initWork(device, rank_, optype, "nccl:coalesced", {}, {}, enqueue); + work->ncclComm_ = comm; + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams_; + work->store_ = store_; + assignTimeoutToWork(work, options_); + + // Record start before ncclGroupEnd + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + if (nccl_use_nonblocking()) { + groupEndNonblocking(comm); + } else { + groupEnd(); + } + + // Record end after ncclGroupEnd + // TODO(eqy): is this still necessary if avoidRecordStreams_ is set? + work->ncclEndEvent_->record(ncclStream); + + if (avoidRecordStreams_) { + // other functions expect an initialized ptr if avoidRecordStreams_ is set + work->stashed_for_allocator_safety_ = + std::make_shared>(); + } + + // Notify graphs before we check the capture status preemptively + at::cuda::CUDAGraph::inc_pending_event_queries(); + + if (enqueue) { + workEnqueue(work); + } else { + at::cuda::CUDAGraph::dec_pending_event_queries(); } - virtual c10::intrusive_ptr broadcast( - std::vector& tensors, - const BroadcastOptions& opts = BroadcastOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::broadcast_", "") - .typed< - std::tuple, c10::intrusive_ptr>( - at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - int64_t, - int64_t, - bool, - int64_t)>(); - // It's awakward to unbox the opts here and box them again in the custom C++ - // op. But it's also complicated to make opts as a CustomClassHolder. Leave - // it as it is now. - return std::get<1>(op.call( - tensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.rootRank, - opts.rootTensor, - opts.asyncOp, - opts.timeout.count())); - } - - virtual c10::intrusive_ptr allreduce( - std::vector& tensors, - const AllreduceOptions& opts = AllreduceOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::allreduce_", "") - .typed< - std::tuple, c10::intrusive_ptr>( - at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - const c10::intrusive_ptr<::c10d::ReduceOp>&, - const std::optional& sparse_indices, - int64_t)>(); - - return std::get<1>(op.call( - tensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - c10::make_intrusive(opts.reduceOp), - opts.sparseIndices, - opts.timeout.count())); - } - - virtual c10::intrusive_ptr allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::allreduce_coalesced_", "") - .typed( - at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - const c10::intrusive_ptr<::c10d::ReduceOp>&, - int64_t)>(); - - return op.call( - tensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - c10::make_intrusive(opts.reduceOp), - opts.timeout.count()); - } - - virtual c10::intrusive_ptr reduce( - std::vector& tensors, - const ReduceOptions& opts = ReduceOptions()) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::reduce_", "") - .typed( - at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - const c10::intrusive_ptr<::c10d::ReduceOp>&, - int64_t, - int64_t, - int64_t)>(); - return op.call( - tensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - c10::make_intrusive(opts.reduceOp), - opts.rootRank, - opts.rootTensor, - opts.timeout.count()); - } - - virtual c10::intrusive_ptr allgather( - std::vector>& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::allgather_", "") - .typed>, - c10::intrusive_ptr>( - const std::vector>&, - at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - int64_t)>(); - - return std::get<1>(op.call( - outputTensors, - inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.timeout.count())); - } - - // Gathers a single tensor inputBuffer into a single buffer outputBuffer that - // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE. - // For implementers of ProcessGroup API and advanced users only. - // Note: this function will be deprecated in near future. - virtual c10::intrusive_ptr _allgather_base( - at::Tensor& outputBuffer, - at::Tensor& inputBuffer, - const AllgatherOptions& opts = AllgatherOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::_allgather_base_", "") - .typed>( - at::Tensor&, - at::Tensor&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - bool, - int64_t)>(); - - return std::get<1>(op.call( - outputBuffer, - inputBuffer, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.asyncOp, - opts.timeout.count())); - } - - // This function is deprecated and will be moved out of ProcessGroup to comms: - // * do not add dependencies on this function, - // * do not implement it in your ProcessGroup, implement _allgather_base - // instead. - virtual c10::intrusive_ptr allgather_coalesced( - std::vector>& outputTensorLists, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::allgather_coalesced_", "") - .typed( - const std::vector>&, - const at::TensorList&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); - - return op.call( - outputTensorLists, - inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); - } - - // This function is a coalesced version of `allgather_into_tensor` (currently - // still named as `_allgather_base`). Each tensor in the vector corresponds to - // an input/output of one `allgather_into_tensor` operation. - virtual c10::intrusive_ptr allgather_into_tensor_coalesced( - std::vector& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::allgather_into_tensor_coalesced_", "") - .typed( - const at::TensorList, - const at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); - - return op.call( - outputTensors, - inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); - } - - virtual c10::intrusive_ptr gather( - std::vector>& outputTensors, - std::vector& inputTensors, - const GatherOptions& opts = GatherOptions()) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::gather_", "") - .typed( - const std::vector>&, - const at::TensorList&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - int64_t, - int64_t)>(); - return op.call( - outputTensors, - inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.rootRank, - opts.timeout.count()); - } - - virtual c10::intrusive_ptr scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ScatterOptions& opts = ScatterOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::scatter_", "") - .typed< - std::tuple, c10::intrusive_ptr>( - const at::TensorList&, - const std::vector>&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - int64_t, - bool, - int64_t)>(); - return std::get<1>(op.call( - outputTensors, - inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.rootRank, - opts.asyncOp, - opts.timeout.count())); - } - - virtual c10::intrusive_ptr reduce_scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ReduceScatterOptions& opts = ReduceScatterOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::reduce_scatter_", "") - .typed< - std::tuple, c10::intrusive_ptr>( - const at::TensorList&, - const std::vector>&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - const c10::intrusive_ptr<::c10d::ReduceOp>&, - int64_t)>(); - return std::get<1>(op.call( - outputTensors, - inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), - opts.timeout.count())); - } - - virtual c10::intrusive_ptr _reduce_scatter_base( - at::Tensor& outputBuffer, - at::Tensor& inputBuffer, - const ReduceScatterOptions& opts = ReduceScatterOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::_reduce_scatter_base_", "") - .typed>( - at::Tensor&, - at::Tensor&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - const c10::intrusive_ptr<::c10d::ReduceOp>&, - bool, - int64_t)>(); - return std::get<1>(op.call( - outputBuffer, - inputBuffer, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), - opts.asyncOp, - opts.timeout.count())); - } - - // This function is a coalesced version of `reduce_scatter_tensor` (currently - // still named as `_reduce_scatter_base`). Each tensor in the vector - // corresponds to an input/output of one `reduce_scatter_tensor` operation. - virtual c10::intrusive_ptr reduce_scatter_tensor_coalesced( - std::vector& outputTensors, - std::vector& inputTensors, - const ReduceScatterOptions& opts = ReduceScatterOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::reduce_scatter_tensor_coalesced_", "") - .typed( - const at::TensorList, - const at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - const c10::intrusive_ptr<::c10d::ReduceOp>&, - int64_t)>(); - - return op.call( - outputTensors, - inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), - opts.timeout.count()); - } - - virtual c10::intrusive_ptr alltoall_base( - at::Tensor& outputBuffer, - at::Tensor& inputBuffer, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& opts = AllToAllOptions()) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::alltoall_base_", "") - .typed( - at::Tensor&, - at::Tensor&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - std::vector, - std::vector, - int64_t)>(); - return op.call( - outputBuffer, - inputBuffer, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - outputSplitSizes, - inputSplitSizes, - opts.timeout.count()); - } - - virtual c10::intrusive_ptr alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& opts = AllToAllOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::alltoall_", "") - .typed< - std::tuple, c10::intrusive_ptr>( - const at::TensorList&, - const at::TensorList&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - int64_t)>(); - return std::get<1>(op.call( - outputTensors, - inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.timeout.count())); - } - - virtual void monitoredBarrier( - const BarrierOptions& opts, - bool wait_all_ranks = false) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::monitored_barrier_", "") - .typed&, - const std::vector&, - int64_t, - bool)>(); - // Default to using cpu implementation, monitored barrier is only for GLOO - at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU)); - op.call( - tensor, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.device_ids, - opts.timeout.count(), - wait_all_ranks); - } - - // Agrees on an initial sequence number for the whole group by having rank 0 - // create it and broadcast it to other ranks using the store. Only implemented - // for GLOO and NCCL backends currently. - virtual void setSequenceNumberForGroup() { - auto backendType = getBackendType(); - // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::XCCL || - backendType == ProcessGroup::BackendType::UCC) { - getDefaultBackend()->setSequenceNumberForGroup(); + coalescing_state_ = 0; + coalescedComm_ = nullptr; + return work; +} + +c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { + // Default OpType to COALESCED if not specified + return endCoalescing(OpType::COALESCED); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; + nanCheck &= enableNanCheck_; + + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + errorIfCapturingNonCapturableNCCL(capture_status); + + // Bump collective counter + seqCollective_++; + op_id_++; + + auto device = getDevice(inputs[0]); + const auto key = getKeyFromDevice(device); + auto ncclComm = getNCCLComm(key, device, opType); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; } else { TORCH_CHECK( - false, - c10::str( - "ProcessGroup ", - getBackendName(), - " does not yet support sequence numbers.")); + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; + } else { + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + bool enqueue = + !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; + auto work = + initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue); + + // Store references to outputs to be used by WorkNCCL::result and operator<<. + work->outputs_ = std::make_shared>(outputs); + + if (avoidRecordStreams) { + work->stashed_for_allocator_safety_ = + std::make_shared>(inputs); + } + + at::cuda::OptionalCUDAGuard gpuGuard(device); + + if (nanCheck) { + for (const auto& input : inputs) { + checkForNan(input, ncclStream); + } + } + + // Start event should only be recorded before the ncclGroupStart() + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + pre(ncclStream, work); + + ncclComm_t comm = ncclComm->getNcclComm(); + + // Both `inputs' and `outputs' are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // We only record `inputs' here, and leave recording `outputs' to `fn' for + // operations where `inputs' and `outputs' are not the same. + // + // See [Sync Streams]. + if (!avoidRecordStreams) { + for (const auto& input : inputs) { + if (!input.is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + input.values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + input.indices().storage().data_ptr(), ncclStream); + } + } + } + +// Not all collectives have the same signature, e.g, all-reduce take in a Tensor +// as the input and output while all-to-all take in a vector of Tensors as input +// and output. Because we define the signature of the fn to take only single +// tensor as input and output, we need to do a hack to get the first element in +// the vector and pass it to fn. +// TODO: we should clean up this in future (by either entirely removing lambda's +// or removing input and output from lambda's signature). +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(inputs[0], outputs[0], comm, ncclStream), + ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(inputs[0], outputs[0], comm, ncclStream), + comm, + ncclComm->getNcclCommFailureReason()); +#endif + + post(ncclStream, work); + + // End event should only be recorded after the ncclGroupEnd() + if (!coalescing_state_) { + work->ncclEndEvent_->record(ncclStream); + } + work->ncclComm_ = ncclComm; + + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); } + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device as + // multi-device per process is deprecated + work->numelIn_ = 0; + work->numelOut_ = 0; + for (const auto& input : inputs) { + work->numelIn_ += input.numel(); + } + for (const auto& output : outputs) { + work->numelOut_ += output.numel(); + } + + // Notify graphs before we check the capture status preemptively + at::cuda::CUDAGraph::inc_pending_event_queries(); + if (enqueue) { + workEnqueue(work); + } else { + at::cuda::CUDAGraph::dec_pending_event_queries(); } - // Retrieves the current sequence number for the whole group, which should be - // in sync. If the returned number is not consistent across the group, it - // may indicate that there is some sort of collective desynchronization. - virtual uint64_t getSequenceNumberForGroup() { - auto backendType = getBackendType(); + return work; +} - // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::XCCL || - backendType == ProcessGroup::BackendType::UCC) { - return getDefaultBackend()->getSequenceNumberForGroup(); +template +c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams) { + // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + errorIfCapturingNonCapturableNCCL(capture_status); + + // Bump collective counter + seqCollective_++; + + // For coalescingManager collectives, there is no individual c++ call per + // collective so there is no flight record and we increment seq*_ and op_id_ + // together. Compare this to startCoalesing/endCoalescing flow where we + // increment seq_ once per group and increment op_id_ once per indvidual + // operation within the group + op_id_++; + + // Currently, the API permits one scenario where inputs.size() and + // outputs.size() are > 0. + // 1. If the call was a _coalesced call, all inputs must be on the same + // device. + // The group of nccl calls applies the collective separately to each input, + // but the group as a whole should be efficient, and might even execute as + // a single fused kernel. + auto device = getDevice(inputs[0]); + const auto key = getKeyFromDevice(device); + auto ncclComm = getNCCLComm(key, device, opType); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; } else { TORCH_CHECK( - false, - c10::str( - "ProcessGroup ", - getBackendName(), - " does not yet support sequence numbers.")); - } - } - - virtual c10::intrusive_ptr send( - std::vector& tensors, - int dstRank, - int tag) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::send", "") - .typed( - at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - int64_t, - int64_t)>(); - return op.call( - tensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - dstRank, - tag); - } - - virtual c10::intrusive_ptr recv( - std::vector& tensors, - int srcRank, - int tag) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::recv_", "") - .typed( - at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - int64_t, - int64_t)>(); - return op.call( - tensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - srcRank, - tag); - } - - virtual c10::intrusive_ptr recvAnysource( - std::vector& tensors, - int tag) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::recv_any_source_", "") - .typed( - at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - int64_t)>(); - return op.call( - tensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - tag); - } - - virtual c10::intrusive_ptr barrier( - const BarrierOptions& opts = BarrierOptions()) { - static at::Tensor tensor; - // TODO: if nccl was specified then use it - auto device = opts.device; - if (device.has_value()) { - // set device tensor from argument - tensor = at::empty( - {1}, at::TensorOptions().device(device.value()).dtype(at::kByte)); - } else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) { - // set cuda tensor - tensor = at::empty( - {1}, - at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte)); + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; + } else { + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + auto work = initWork( + device, rank_, opType, profilingTitle, inputs, outputs, /*record=*/true); + + // Store references to outputs to be used by WorkNCCL::result and operator<<. + work->outputs_ = std::make_shared>(outputs); + + if (avoidRecordStreams) { + work->stashed_for_allocator_safety_ = + std::make_shared>(inputs); + } + + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Start event should only be recorded before the ncclGroupStart() (which + // happens inside AutoNcclGroup guard below) + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + ncclComm_t comm = ncclComm->getNcclComm(); + +// TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL +// upgrade. Once a NCCL version is qualified, this code should not be needed at +// runtime. +#ifdef PGNCCL_ENABLE_HASH + if (enableCollecticeHashDebug_.load()) { + auto numel = getTensorsNumel(inputs); + auto hashValue = hashTensors(inputs); + PRINT_COLLECTIVE_HASH_SIGNATURE( + "input", opTypeToString(opType), numel, hashValue); + } +#endif + + { + torch::cuda::nccl::AutoNcclGroup nccl_group_guard( + comm, nccl_use_nonblocking()); + for (const auto i : c10::irange(inputs.size())) { + // Both `inputs' and `outputs' are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // We only record `inputs' here, and leave recording `outputs' to `fn' for + // operations where `inputs' and `outputs' are not the same. + // + // See [Sync Streams]. + if (!avoidRecordStreams) { + if (!inputs[i].is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].indices().storage().data_ptr(), ncclStream); + } + } +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(inputs[i], outputs[i], comm, ncclStream), + ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(inputs[i], outputs[i], comm, ncclStream), + comm, + ncclComm->getNcclCommFailureReason()); +#endif + } + } + + work->ncclEndEvent_->record(ncclStream); + work->ncclComm_ = ncclComm; + + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); + } + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device as + // multi-device per process is deprecated + work->numelIn_ = inputs[0].numel(); + work->numelOut_ = outputs[0].numel(); + + /* Note [cuda graph capture and workEnqueue] + + Normal behavior of the C10D watchdog is to query cuda events on work objects + periodically, but when cuda graph recording is active these event queries + would crash or mess up the recording. + + To ensure we do not enqueue a work object to the watchdog when cuda graph + capture is active, we use a one-way sync. We increment a flag pre-emptively, + indicating our intent to enqueue a work object. Then we check capture_status + to see if (a) capturing is already in progress (we cannot enqueue in this + case), (b) capturing hasn't started yet, so we can trust that no capture will + start (since a pre-condition of starting a capture is to check the event query + count is 0). + + If we are not able to enqueue the work due to capture-in-progress, we finally + decrement the counter. + + For this reason we cannot easily move the increment inside workEnqueue unless + we also change the semantic of workEnqueue to 'maybeWorkEnqueue'. + + TODO: + - Is our design for flight recorder safe in this context? are we recording + any FR events during cudagraph capture? if so, they won't be safe to poll for + completion status. + */ + at::cuda::CUDAGraph::inc_pending_event_queries(); + if (capture_status == c10::cuda::CaptureStatus::None) { + workEnqueue(work); + } else { + at::cuda::CUDAGraph::dec_pending_event_queries(); + } + // TODO(whc) if the work isn't enqueued, I don't feel great about returning + // it, since interactions with it by usercode won't behave normally - they + // won't observe work completion, for instance. Will this lead to silent + // problems during capture? + return work; +} + +template +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + PreProcess pre, + PostProcess post, + const char* profilingTitle) { + // avoidRecordStreams_ note: + // send, recv, and irecv should be ok with avoidRecordStreams, + // However, for isend, I don't think the API requires the user + // to wait() on the returned handle, so ProcessGroupNCCL can't know + // when it's safe to release the input back to the allocator, + // and the present call has no way to know it's not an isend. + // Therefore, we warn and fall back to the typical recordStream logic: + if (avoidRecordStreams_) { + TORCH_WARN_ONCE( + "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " + "collectives."); + } + + auto device = getDevice(tensor); + std::string key; + int p2pRank = 0, p2pTargetRank = 0; + bool isSendRecvSelf = false; + // For batch_isend_irecv, ncclGroupStart() would be called upfront + bool batchP2P = ncclActiveGroupCounter_ > 0; + if (batchP2P) { + // For batch P2P, we need to treat it like a collective when selecting + // communicator, because other ranks can call into this batch other than my + // rank and my peer + key = getKeyFromDevice(device); + p2pRank = rank_; + p2pTargetRank = peer; + } else { + // For single P2P, preserve the old two-rank behavior (to avoid perf diff) + key = getKeySendRecv(rank_, peer); + p2pRank = rank_ <= peer ? 0 : 1; + isSendRecvSelf = rank_ == peer; + p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; + + if (!coalescing_state_) { + // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be + // bumped in `startCoalescing`. + seqP2P_++; + } + } + + // Bump the logical operation counter regardless of whether this op is + // coalesced or individual + op_id_++; + + auto ncclComm = getNCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalP2P; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; } else { - // Default to using cpu implementation - tensor = at::empty( - {1}, - at::TensorOptions().device(at::DeviceType::CPU).dtype(at::kByte)); - } - - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::barrier", "") - .typed( - at::Tensor, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - const std::vector&, - int64_t)>(); - - return op.call( - tensor, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.device_ids, - opts.timeout.count()); - } - - bool hasBackends() { - return !deviceTypeToBackendType_.empty(); - } - - void setBackend( - c10::DeviceType deviceType, - BackendType backendType, - const std::optional>& backend) { - // TODO: should we add these entries after the backend setting succeeds? - deviceTypeToBackendType_[deviceType] = backendType; - deviceTypes_.insert(deviceType); - // if the backendType is already set then reuse it for this device - if (backendTypeToBackend_.find(backendType) != - backendTypeToBackend_.end()) { - auto existingBackend = backendTypeToBackend_.at(backendType); - deviceTypeToBackend_[deviceType] = existingBackend; TORCH_CHECK( - existingBackend->getBoundDeviceId() == - (*backend)->getBoundDeviceId()); + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; } else { - // check if backend has value - if (backend.has_value()) { - deviceTypeToBackend_[deviceType] = backend.value(); - backendTypeToBackend_[backendType] = backend.value(); - (*backend)->setBoundDeviceId(bound_device_id_); - } + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + // Work itself will create the CUDA events on all GPUs of tensors + c10::intrusive_ptr work; + if (coalescing_state_) { + // When coalescing, we record events per op that lack timing/state + // information becuase there is no 'work' associated with them, and then + // later in endCoalescing we record a 'coalesced' Work which has + // timing/state updates via watchdog thread, but lacks op metadata such as + // input/output sizes and profilingTitle per-op in the group. + auto trace_id = NCCLTraceBuffer::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), + seqCollective_, + seqP2P_, + op_id_, + profilingTitle, + {tensor}, + {tensor}, + nullptr, + nullptr, + options_->timeout, + pgStatus_, + /*isP2P=*/true); + // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get + // their timings/states updated by proxy when the Work obj representing the + // coalesce group gets its update, we could accumulate these trace_ids + // together and ask FlightRecorder to take the update from one Work and + // apply it to multiple entries + (void)trace_id; + } else { + // Store references to outputs to be used by WorkNCCL::result and + // operator<<. Note that these outputs are only valid for recv(), as send() + // does not modify the inputs but we still create these outputs for use + // cases such as profiling. + + work = initWork( + device, rank_, opType, profilingTitle, {tensor}, {}, /*record=*/false); + // This bypasses something in Work() that crashes if {tensor} is given as + // output, not sure what + work->outputs_ = std::make_shared>(); + work->outputs_->push_back(tensor); + // TODO(whc) because we don't pass output {tensor} to initWork, we tell + // initWork to not record, and then we manually call record passing all the + // information it wants. + work->trace_id_ = NCCLTraceBuffer::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), + seqCollective_, + seqP2P_, + op_id_, + profilingTitle, + {tensor}, + {tensor}, + work->ncclStartEvent_.get(), + work->ncclEndEvent_.get(), + options_->timeout, + pgStatus_, + /*isP2P=*/true); + } + + // is gpuGuard needed for the if block below, or can i swap them + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Only check for NaN for send ops, for recv ops `tensor` can be a random + // placeholder + if (enableNanCheck_ && opType == OpType::SEND) { + checkForNan(tensor, ncclStream); + } + + if (!coalescing_state_) { + // Start event should only be recorded before the ncclGroupStart() + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + pre(ncclStream, work); + } + + // Both send tensor and recv tensor are created on a worker stream and used + // in different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + tensor.storage().data_ptr(), ncclStream); + + // This part seems common to both p2p and coalesced-p2p usage? + ncclComm_t comm_ = ncclComm->getNcclComm(); + +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(tensor, comm_, ncclStream, p2pTargetRank), + ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(tensor, comm_, ncclStream, p2pTargetRank), + ncclComm->getNcclComm(), + ncclComm->getNcclCommFailureReason()); +#endif + + if (!coalescing_state_) { + post(ncclStream); + + // End event should only be recorded after the ncclGroupEnd() + work->ncclEndEvent_->record(ncclStream); + work->ncclComm_ = ncclComm; + work->blockingWait_ = blockingWait_; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device + // as multi-device per process is deprecated + work->numelIn_ = work->numelOut_ = tensor.numel(); + + // Future only needs to be created and marked completed with outputs for + // recv(), but still create future for use cases such as profiling even for + // send(). + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); } + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); + } + } + + // Enqueue P2P op so that it can be cancelled by NCCL watchdog + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + + // Notify graphs before we check the capture status preemptively + at::cuda::CUDAGraph::inc_pending_event_queries(); + + if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) { + workEnqueue(work); + return work; + } else { + at::cuda::CUDAGraph::dec_pending_event_queries(); + return nullptr; } +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + pre, + post, + opType, + profilingTitle, + avoidRecordStreams, + nanCheck); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + opType, + profilingTitle, + avoidRecordStreams, + nanCheck); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle) { + return pointToPoint( + tensor, + fn, + peer, + opType, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&) {}, + profilingTitle); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); +#ifdef IS_NCCLX + tensor = tensor.coalesce(); + at::Tensor outputTensor = + torch::zeros(tensor.sizes(), tensor.options().layout(torch::kStrided)); + auto work = collective( + tensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + + size_t num_elements = output.numel(); + auto indices = input.indices(); + auto sizes = input.sizes(); + int colSize = sizes[1]; + auto rows = indices[0]; + size_t blockCount = rows.sizes()[0]; + auto recvIndices = indices[0] * colSize; - c10::intrusive_ptr getDefaultBackend() const { + // prevent output and recvIndices from being freed + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + c10::cuda::CUDACachingAllocator::recordStream( + recvIndices.storage().data_ptr(), stream); + auto result = ncclAllReduceSparseBlock( + input._values().data_ptr(), // sendbuff + recvIndices.data_ptr(), // recv_indices + blockCount, // block_count + colSize, // block_length + output.data_ptr(), // recvbuff + output.numel(), // recv_count + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + return result; + }, + [](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) {}, + [&](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + // Convert output tensors to sparse and back into tensors. + at::cuda::CUDAStreamGuard guard(ncclStream); + if (opts.sparseIndices.has_value()) { + tensor = at::sparse_coo_tensor( + opts.sparseIndices.value(), outputTensor, tensor.sizes()); + } else { + tensor = outputTensor.to_sparse(); + } + }, + OpType::_ALLREDUCE_SPARSE, + "nccl:all_reduce_sparse"); + return work; +#else + // If the nccl branch is not "exp" then we just error + C10_THROW_ERROR( + Error, + "NCCL does not support all_reduce with sparse tensors. Please use dense tensors instead."); +#endif +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts) { + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclAllReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::ALLREDUCE, + "nccl:all_reduce"); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { TORCH_CHECK( - backendTypeToBackend_.find(backendType_) != backendTypeToBackend_.end(), - "Could not find the default backend type ", - backendType_, - " for Process Group with name ", - getBackendName(), - "."); - return backendTypeToBackend_.at(backendType_); + complexViewAsRealAllowed(opts.reduceOp), + "all_reduce does not support", + opts.reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); } + check_gpu_single_tensor(tensor); - void setDefaultBackend(const BackendType& backendType) { - backendType_ = backendType; + if (intraNodeComm_ != nullptr && opts.reduceOp == ReduceOp::SUM) { + using namespace intra_node_comm; + auto algo = intraNodeComm_->selectAllReduceAlgo(tensor); + if (algo != intra_node_comm::AllReduceAlgo::NONE) { + intraNodeComm_->allReduce(tensor, algo); + return c10::make_intrusive(); + } } + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return allreduce_impl(tensor, opts); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + auto total_numel = check_gpu_tensors_same_device(tensors); + TORCH_CHECK( + !isFloat8Type(tensors.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce_coalesced", // collective name + total_numel, // inNelems + total_numel, // outNelems + tensors[0].scalar_type(), // dType + // I'm not sure what in,outSplitSizes mean here. + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclAllReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:allreduce_coalesced"); +} - void setDefaultBackend(const std::string& backend) { - backendType_ = strToBackendType(backend); +c10::intrusive_ptr ProcessGroupNCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + tensor = at::view_as_real(tensor); } + check_gpu_single_tensor(tensor); - c10::intrusive_ptr getBackend(c10::DeviceType deviceType); + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "broadcast", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize - c10::intrusive_ptr getBackend(BackendType backendType) const { + // avoidRecordStreams_ note: collective() will stash tensors. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclBcast( + input.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + root, + comm, + stream.stream()); + }, + OpType::BROADCAST, + "nccl:broadcast", + avoidRecordStreams, + nanCheck); +} + +// _broadcast_oop adds an out-of-place broadcast in PGNCCL +// Custom collectives may be implemented by coalescing broadcast operations +// One use-case is implementing a vector all_gather (all_gather_v) +// where unevenly sized inputs are gathered among participating ranks +// Since all_gather provides an out-of-place API, an all_gather_v +// semantic implemented inside pg_nccl.all_gather also needs to support +// out-of-place, for which an out-of-place broadcast is required to be added +c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _broadcast_oop must have the same number of elements "); + } + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclBroadcast( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + root, + comm, + stream.stream()); + }, + OpType::BROADCAST, + "nccl:_broadcast_oop", + /*avoidRecordStreams=*/false, + nanCheck); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce( + std::vector& tensors, + const ReduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + if (tensor.is_complex()) { TORCH_CHECK( - backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end(), - "Could not find backend type ", - backendType, - "."); - return backendTypeToBackend_.at(backendType); + complexViewAsRealAllowed(opts.reduceOp), + "reduce does not support", + opts.reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } + check_gpu_single_tensor(tensor); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "reduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + root, + comm, + stream.stream()); + }, + OpType::REDUCE, + "nccl:reduce"); +} + +// _reduce_oop exposes an out-of-place reduce from PGNCCL +// Custom collectives may be implemented by coalescing reduce operations +// One use-case is implementing a vector reduce_scatter (reduce_scatter_v) +// where inputs are reduced and scattered unevenly among participating ranks +// Since reduce_scatter provides an out-of-place API, a reduce_scatter_v +// semantic implemented inside pg_nccl.reduce_scatter also needs to support +// out-of-place, for which an out-of-place reduce is required to be added +c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _reduce_oop must have the same number of elements "); + } + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; + const auto ncclDataType = getNcclDataType(input.scalar_type()); + const auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + (int)root, + comm, + stream.stream()); + }, + OpType::REDUCE, + "nccl:_reduce_oop"); +} + +c10::intrusive_ptr ProcessGroupNCCL::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + check_gpu_single_tensor(inputTensor); + // @lint-ignore CLANGTIDY + auto outputTensors_ = outputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * // outNelems + this->getSize(), + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(outputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor outputFlattened = newLikeFlat(outputTensors_); + + return collective( + inputTensor, + outputFlattened, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + [](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + // avoidRecordStreams_ note: We actually don't need to stash anything + // here. + // - inputTensors is stashed onto work->stashed_for_allocator_safety_ + // in collective(). + // - outputFlattened is stashed onto work->outputs_ in collective(). + // - User-facing outputTensors should be held by the user until after + // waiting on work_, or the call makes no sense. + // So all participating tensors are accounted for, and won't be + // released back to their allocation streams until after work_ is + // waited on. + }, + [&](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + // Copy the flattened output tensors to the outputs. + at::cuda::CUDAStreamGuard guard(ncclStream); + for (const auto j : c10::irange(outputTensors_.size())) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), ncclStream); + } + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + OpType::ALLGATHER, + "nccl:all_gather"); + } else { + const auto num_reduces = outputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& output = outputTensors_[i]; + auto& input = (i == rank_) ? inputTensor : output; + auto broadcastOpts = BroadcastOptions{ + static_cast(i), static_cast(0), opts.timeout}; + _broadcast_oop(output, input, broadcastOpts); + } + auto work = endCoalescing(OpType::ALLGATHER); + return work; } +} - // Return device types supported by this ProcessGroup. - // Note: the return type is `Device` rather than `DeviceType` for the purpose - // of easy comparison at Python level. The `Device` will have default index - // (-1). - std::vector getDeviceTypes() const { - std::vector devices; - devices.reserve(deviceTypes_.size()); - for (auto& dt : deviceTypes_) { - devices.emplace_back(dt); +c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( + std::vector>& /* unused */, + std::vector& /* unused */, + const AllgatherOptions& /* unused */) { + C10_THROW_ERROR( + NotImplementedError, + "ProcessGroupNCCL does not support allgather_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:all_gather_into_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto outputTensor = outputTensors.back(); + check_gpu_single_tensor(outputTensor); + // @lint-ignore CLANGTIDY + auto inputTensors_ = inputTensors.back(); + TORCH_CHECK( + !isFloat8Type(outputTensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "reduce_scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + const auto ncclDataType = getNcclDataType(input.scalar_type()); + const auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + [&](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + if (avoidRecordStreams_) { + // We only need to stash inputTensors. + // - inputFlattened is stashed onto + // work->stashed_for_allocator_safety_ + // in collective(). + // - User-facing outputTensors is stashed onto work->outputs_ in + // collective(), + // and should also be held by the user until after waiting on + // work_. + auto& v = work->stashed_for_allocator_safety_; + v->insert(v->end(), inputTensors_.begin(), inputTensors_.end()); + } + + // Copy the input tensors to the flattened inputs. + at::cuda::CUDAStreamGuard guard(ncclStream); + for (const auto j : c10::irange(inputTensors_.size())) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), ncclStream); + } + inputFlattened[j].copy_(inputTensors_[j], true); + } + }, + [&](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::REDUCE_SCATTER, + "nccl:reduce_scatter"); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); } - return devices; + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + if (inputTensor.dtype() != outputTensor.dtype()) { + C10_THROW_ERROR( + TypeError, "input tensor must be the same type as the output tensor."); + } + + if (inputTensor.numel() != outputTensor.numel() * size_) { + C10_THROW_ERROR( + ValueError, + "input tensor must be the same size as output size times world size"); + } + + // @lint-ignore CLANGTIDY + const auto& tensor = outputTensor; + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "_reduce_scatter_base", // collective name + inputTensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dtype + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputs and outputs. + // Note 2: for asyncOp = false, we don't want to record streams because we + // know that the NCCL stream will join back to the "current" stream right + // after this op. So we might just as well keep the stream ownership of the + // input/output tensors unchanged. The benefit would be that the + // allocation/free of the tensors would look deterministic to the "current" + // stream so that the caching allocator can reuse memory pool for this stream + // in a clever way. This setting is added for libraries like FSDP which uses + // `reduce_scatter_tensor`. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::_REDUCE_SCATTER_BASE, + "nccl:_reduce_scatter_base", + avoidRecordStreams); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts) { + TORCH_CHECK( + !isFloat8Type(inputs.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:reduce_scatter_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { + RECORD_PARAM_COMMS( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank_, // rank + "barrier", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // Device to use for barrier + int barDevIdx = -1; + + // Select device to use for barrier + // 1st choice: Use user defined GPU device ids if provided + if (!opts.device_ids.empty()) { + // Use the first device id because PG NCCL is single-device now + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + // 2nd choice: Use the bound GPU device id if available. + // Bounded device id can be passed to `init_process_group`. + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + // 3rd choice: infer the device id from the used device ids. + barDevIdx = *usedDeviceIdxs_.begin(); + } else { + // This means there is not yet a NCCL collective being called + // Here we have to use the best guesses and will use a single GPU to call + // allreduce to achieve barrier. + // In case the multiple processes fall into the same node, we use rank to + // ensure that each process is on a different GPU + // Note: it is better to use global rank because the group-local rank can be + // offset wrt the device id if intra-node GPUs are sharded into multiple + // dimensions. + barDevIdx = static_cast(globalRank() % localDeviceCount_); + LOG(WARNING) + << logPrefix() + << c10::str( + " using GPU ", + barDevIdx, + " to perform barrier as devices used by this process are currently unknown. ", + "This can potentially cause a hang if this rank to GPU mapping is incorrect.", + "Specify device_ids in barrier() to force use of a particular device,", + "or call init_process_group() with a device_id."); } - void registerOnCompletionHook( - std::function)>&& hook) { - getDefaultBackend()->registerOnCompletionHook(std::move(hook)); + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx); + + // Create a dummy tensor on the device + // Note: we use zeros() instead of empty() to prevent barrier from triggering + // alarm when NaN checker is enabled. + at::Tensor barrierTensor = + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + + // All reduce to achieve the barrier + auto work = allreduce_impl(barrierTensor); + + // Work will take over barrierTensors + auto ncclWork = dynamic_cast(work.get()); + TORCH_CHECK(ncclWork); + ncclWork->barrierTensor_ = std::move(barrierTensor); + return work; +} + +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + check_gpu_single_tensor(outputTensor, true); + check_gpu_single_tensor(inputTensor, true); + if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_all", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputTensors. + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + torch::cuda::nccl::all2all_single_equal_split( + input, output, this->getSize(), comm, stream); + return ncclSuccess; + }, + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); + } else { + c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); + c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_allv", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + inputSplitSizes, // inSplitSizes + outputSplitSizes, // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputTensors. + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + std::vector send_lengths(size_); + std::vector recv_lengths(size_); + std::vector send_offsets(size_); + std::vector recv_offsets(size_); + c10d::computeLengthsAndOffsets( + inputSplitSizes, input, &send_lengths, &send_offsets); + c10d::computeLengthsAndOffsets( + outputSplitSizes, output, &recv_lengths, &recv_offsets); + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + torch::cuda::nccl::all2all_single_unequal_split( + input.data_ptr(), + send_lengths.data(), + send_offsets.data(), + output.data_ptr(), + recv_lengths.data(), + recv_offsets.data(), + input.element_size(), + input.scalar_type(), + comm, + stream); + return ncclSuccess; + }, + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); } +} - void waitForPendingWorks() { - getDefaultBackend()->waitForPendingWorks(); +c10::intrusive_ptr ProcessGroupNCCL::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& /* unused */) { + std::vector inSplitSizes; + std::vector outSplitSizes; + int64_t total_numel = 0; + + auto device = outputTensors[0].device(); + for (const auto r : c10::irange(outputTensors.size())) { + check_gpu_single_tensor(outputTensors[r], true); + check_gpu_single_tensor(inputTensors[r], true); + TORCH_CHECK( + device == outputTensors[r].device() && + device == inputTensors[r].device(), + "Tensors must be on the same device") + inSplitSizes.push_back(inputTensors[r].numel()); + outSplitSizes.push_back(outputTensors[r].numel()); + total_numel += inputTensors[r].numel(); } - bool hasHooks() const { - return getDefaultBackend()->hasHooks(); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_to_all", // collective name + total_numel, // inNelems + total_numel, // outNelems + inputTensors.front().scalar_type(), // dType + inSplitSizes, // inSplitSizes + outSplitSizes, // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensors, + outputTensors, + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream); + return ncclSuccess; + }, + [&](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) { + if (avoidRecordStreams_) { + // inputTensor0 and outputTensor0 are stashed redundantly by + // collective(), but that's ok. + auto& v = work->stashed_for_allocator_safety_; + v->insert(v->end(), inputTensors.begin(), inputTensors.end()); + v->insert(v->end(), outputTensors.begin(), outputTensors.end()); + } + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::ALLTOALL, + "nccl:all_to_all"); +} + +c10::intrusive_ptr ProcessGroupNCCL::send( + std::vector& tensors, + int dstRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + check_gpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + dstRank, // dst rank + "send", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream& stream, + int dst) { + torch::cuda::nccl::send(input, comm, stream, dst); + return ncclSuccess; + }, + dstRank, + OpType::SEND, + c10::str("nccl:send ", rank_, "->", dstRank).c_str()); + return ret; +} + +c10::intrusive_ptr ProcessGroupNCCL::recv( + std::vector& tensors, + int srcRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + check_gpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + srcRank, // src rank + "recv", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream, + int src) { + torch::cuda::nccl::recv(output, comm, stream, src); + return ncclSuccess; + }, + srcRank, + OpType::RECV, + c10::str("nccl:recv ", rank_, "<-", srcRank).c_str()); + return ret; +} + +void ProcessGroupNCCL::groupStart() { + C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); + ++ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEnd() { + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + --ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEndNonblocking(std::shared_ptr comm) { +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); +#else + if (!nccl_use_nonblocking()) { + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + } else { + C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt); } +#endif + --ncclActiveGroupCounter_; +} + +c10::intrusive_ptr ProcessGroupNCCL::gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::gather: " + msg); + }; - const std::string& getGroupName() const; - void setGroupName(const std::string& name); - const std::string& getGroupDesc() const; - void setGroupDesc(const std::string& name); - void enableCollectivesTiming(); + assertRootRank(invalidArgument, opts.rootRank, size_); - void release_resources() override; + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); - // ProcessGroups optionally can be "bound" to a specific device. - // Currently this is only for nccl and allows for some opt-in - // optimizations such as automatic use of ncclCommSplit. The device - // is specified in `init_process_group` and eventually makes it - // here and then down into the actual backend instances. - std::optional getBoundDeviceId() const { - return bound_device_id_; + std::vector outputs; + + if (getRank() == opts.rootRank) { + if (outputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element output list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (outputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect output list size " << outputTensors[0].size() + << ". Output list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = inputTensor.options(); + const auto& sizes = inputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); + outputs = outputTensors[0]; + } else { + // if not in the root rank, initialize outputs as empty list + if (outputTensors.size() != 0) { + invalidArgument("requires empty output on non-root"); + } + outputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + outputs.emplace_back(); } - void setBoundDeviceId(std::optional device) { - if (device) { - TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index"); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * this->getSize(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputs, which == outputTensors[0] on the root rank where it matters. + + auto inputs = std::vector{inputTensor}; + return collective( + inputs, + outputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank; + if (getRank() == root) { + if (!avoidRecordStreams_) { + for (auto output : outputs) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + } + } + torch::cuda::nccl::gather(inputTensor, outputs, comm, stream, root); + return ncclSuccess; + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::GATHER, + "nccl:gather"); +} + +c10::intrusive_ptr ProcessGroupNCCL::scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::scatter: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto outputTensor = outputTensors.back(); + + std::vector inputs; + + if (getRank() == opts.rootRank) { + if (inputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element input list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (inputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect input list size " << inputTensors[0].size() + << ". Input list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); } - bound_device_id_ = device; + + const auto& options = outputTensor.options(); + const auto& sizes = outputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); + inputs = inputTensors[0]; + } else { + // if not in the root rank, initialize inputTensors as empty place holder + // with an empty list + if (inputTensors.size() != 0) { + invalidArgument("requires empty input on non-root"); + } + inputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + inputs.emplace_back(); } - protected: - // Implementations of this interface need to call this to setup - // appropriate logging etc. - void init(); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize - c10::intrusive_ptr store_; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const int rank_; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const int size_; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - BackendType backendType_; - std::string pg_desc_; + // avoidRecordStreams_ note: collective() will stash outputTensors and + // inputs, which == inputTensors[0] on the root rank where it matters. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - // Debug level setting. It is parsed once when ProcessGroup is constructed and - // remains the same across use of this process group. - DebugLevel dist_debug_level_{DebugLevel::Off}; + const auto root = opts.rootRank; + bool nanCheck = (rank_ == root); - // Backend classes for this ProcessGroup - std::unordered_set deviceTypes_; - std::unordered_map deviceTypeToBackendType_; - std::unordered_map> - deviceTypeToBackend_; - std::unordered_map> - backendTypeToBackend_; + auto outputs = std::vector{outputTensor}; + return collective( + outputs, + inputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (getRank() == root) { + if (!avoidRecordStreams) { + for (auto input : inputs) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), stream); + } + } + } + torch::cuda::nccl::scatter(inputs, outputTensor, comm, stream, root); + return ncclSuccess; + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::SCATTER, + "nccl:scatter", + avoidRecordStreams, + nanCheck); +} - std::optional bound_device_id_; -}; +c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( + std::vector& /* unused */, + int /* unused */) { + C10_THROW_ERROR( + NotImplementedError, "ProcessGroupNCCL does not support recvAnysource"); +} + +c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + check_gpu_single_tensor(input_tensor); + check_gpu_single_tensor(output_tensor); + + if (input_tensor.dtype() != output_tensor.dtype()) { + C10_THROW_ERROR( + TypeError, "output tensor must have the same type as input tensor"); + } + + if (input_tensor.numel() * size_ != output_tensor.numel()) { + C10_THROW_ERROR( + ValueError, + "output tensor size must be equal to world_size times input tensor size"); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + input_tensor, // inputTensors + output_tensor, // outputTensors + rank_, // rank + "_allgather_base", // collective name + input_tensor.numel(), // inNelems + output_tensor.numel(), // outNelems + output_tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputs and outputs. + // Note 2: for asyncOp = false, we don't want to record streams because we + // know that the NCCL stream will join back to the "current" stream right + // after this op. So we might just as well keep the stream ownership of the + // input/output tensors unchanged. The benefit would be that the + // allocation/free of the tensors would look deterministic to the "current" + // stream so that the caching allocator can reuse memory pool for this stream + // in a clever way. This setting is added for libraries like FSDP which uses + // `all_gather_into_tensor`. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + return collective( + input_tensor, + output_tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + OpType::_ALLGATHER_BASE, + "nccl:_all_gather_base", + avoidRecordStreams); +} } // namespace c10d + +#endif // USE_C10D_NCCL \ No newline at end of file diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 5bc1b15cfaacb4..eced2e685fcf85 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1818,7 +1818,8 @@ communication mechanism. py::init< const c10::intrusive_ptr<::c10d::Store>&, int, - int>(), + int, + c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(), py::call_guard()) .def("rank", &::c10d::ProcessGroup::getRank) .def("size", &::c10d::ProcessGroup::getSize) @@ -1828,6 +1829,7 @@ communication mechanism. "_backend_id", &::c10d::ProcessGroup::getBackendID, py::arg("backend_type")) + .def_property_readonly("options", &::c10d::ProcessGroup::getOptions) .def( "broadcast", &::c10d::ProcessGroup::broadcast, @@ -2137,14 +2139,6 @@ communication mechanism. }, py::arg("device"), py::call_guard()) - .def( - "_set_default_backend", - [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - const ::c10d::ProcessGroup::BackendType& backendType) { - return self->setDefaultBackend(backendType); - }, - py::arg("backend_type"), - py::call_guard()) .def( "_register_on_completion_hook", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -2248,6 +2242,27 @@ The hook must have the following signature: .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) .export_values(); + // base ProcessGroup::Options binding + auto processGroupOptions = + intrusive_ptr_class_<::c10d::ProcessGroup::Options>( + processGroup, + "Options", + R"( +Base class for all processes group options implementations, such as the nccl +options :class:`~torch.distributed.ProcessGroupNCCL.Options`). +)") + .def( + py::init([](const std::string& backend, + const std::chrono::milliseconds& timeout) { + return c10::make_intrusive<::c10d::ProcessGroup::Options>( + backend, timeout); + }), + py::arg("backend"), + py::arg("timeout") = kProcessGroupDefaultTimeout, + py::call_guard()) + .def_readonly("backend", &::c10d::ProcessGroup::Options::backend) + .def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout); + // TODO: The collection definitions handles direct instantiation of // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported // and should be removed once all tests are transitioned @@ -2547,29 +2562,6 @@ The hook must have the following signature: &::c10d::Backend::endCoalescing, py::call_guard()); - // base Backend::Options binding - // TODO: Maybe we can consider how to merge this with - // `DistributedBackendOptions`. - auto backendOptions = - intrusive_ptr_class_<::c10d::Backend::Options>( - backend, - "Options", - R"( -Base class for all backend options implementations, such as the nccl -options :class:`~torch.distributed.ProcessGroupNCCL.Options`). -)") - .def( - py::init([](const std::string& backend, - const std::chrono::milliseconds& timeout) { - return c10::make_intrusive<::c10d::Backend::Options>( - backend, timeout); - }), - py::arg("backend"), - py::arg("timeout") = kProcessGroupDefaultTimeout, - py::call_guard()) - .def_readonly("backend", &::c10d::Backend::Options::backend) - .def_readwrite("_timeout", &::c10d::Backend::Options::timeout); - #ifdef USE_C10D_GLOO static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; @@ -2581,7 +2573,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device"); intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>( - processGroupGloo, "_Options", backendOptions) + processGroupGloo, "_Options", processGroupOptions) .def(py::init<>()) .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices) .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads); @@ -2808,7 +2800,7 @@ for details. intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( processGroupNCCL, "Options", - backendOptions, + processGroupOptions, R"( ProcessGroup options for the NCCL backend diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index f6e77fe166c655..0ad17536ef1832 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -36,7 +36,6 @@ def _init_device_mesh_stub(): else: - from torch._C._distributed_c10d import Backend as C10dBackend from torch.distributed.distributed_c10d import ( _find_pg_by_ranks_and_tag, _get_default_group, @@ -67,7 +66,7 @@ def __init__(self) -> None: self.mesh_stack: List[DeviceMesh] = [] self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {} self.mesh_dim_group_options: Dict[ - int, Tuple[str, Optional[C10dBackend.Options]] + int, Tuple[str, Optional[ProcessGroup.Options]] ] = {} self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {} # Record flatten mesh name to its mesh dim index in root mesh. @@ -280,7 +279,7 @@ def _set_mesh_dim_group_options( self, dim: int, backend: str, - pg_options: Optional[C10dBackend.Options] = None, + pg_options: Optional[ProcessGroup.Options] = None, ) -> None: self.mesh_dim_group_options[dim] = (backend, pg_options) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 2d9357bbd15a44..bf3566338fd3ea 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -280,7 +280,6 @@ class Backend(str): NCCL: ProcessGroup.BackendType.NCCL, XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, - MPI: ProcessGroup.BackendType.MPI, } def __new__(cls, name: str): @@ -1555,7 +1554,7 @@ def init_process_group( backend, store, group_name, - backend_options=pg_options, + pg_options=pg_options, timeout=timeout, device_id=device_id, group_desc="default_pg", @@ -1652,7 +1651,7 @@ def _new_process_group_helper( backend, store, group_name, - backend_options=None, + pg_options=None, timeout=None, pg_tag=None, device_id=None, @@ -1731,17 +1730,11 @@ def _new_process_group_helper( return GroupMember.NON_GROUP_MEMBER, None prefix_store = PrefixStore(f"{group_name}/", store) - # The backend for PG will be set later based on what's inside BackendConfig - # and timeout are set in each backend's option. + base_pg_options = ProcessGroup.Options(backend=str(backend)) + base_pg_options._timeout = timeout pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, + prefix_store, group_rank, group_size, base_pg_options ) - # Set the default backend when only single backend is passed in. - if "," not in str(backend) and ":" not in str(backend): - assert backend in Backend.backend_type_map, f"Unknown backend type {backend}" - pg._set_default_backend(Backend.backend_type_map[backend]) if device_id: pg.bound_device_id = device_id backend_config = BackendConfig(backend) @@ -1768,8 +1761,8 @@ def _new_process_group_helper( backend_prefix_store, backend_class.rank(), backend_class.size(), + base_pg_options, ) - pg._set_default_backend(backend_type) elif backend_str == Backend.GLOO: # TODO: remove this check after lazy initialization is supported # if pg_options is not None: @@ -1781,30 +1774,28 @@ def _new_process_group_helper( elif backend_str == Backend.NCCL: if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL built in") - if backend_options is not None: + if pg_options is not None: assert isinstance( - backend_options, ProcessGroupNCCL.Options - ), "Expected backend_options argument to be of type ProcessGroupNCCL.Options" - if backend_options._timeout != timeout: + pg_options, ProcessGroupNCCL.Options + ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" + if pg_options._timeout != timeout: warnings.warn( - "backend_options._timeout was specified, " + "pg_options._timeout was specified, " "but timeout kwarg has a default value that will always override it. " ) else: - # default backend_options for NCCL - backend_options = ProcessGroupNCCL.Options() - backend_options.is_high_priority_stream = False - backend_options._timeout = timeout + # default pg_options for NCCL + pg_options = ProcessGroupNCCL.Options() + pg_options.is_high_priority_stream = False + pg_options._timeout = timeout if split_from: - backend_options.split_from = split_from - backend_options.split_color = _process_group_color( - global_ranks_in_group - ) - backend_options.global_ranks_in_group = global_ranks_in_group - backend_options.group_name = group_name + pg_options.split_from = split_from + pg_options.split_color = _process_group_color(global_ranks_in_group) + pg_options.global_ranks_in_group = global_ranks_in_group + pg_options.group_name = group_name backend_class = ProcessGroupNCCL( - backend_prefix_store, group_rank, group_size, backend_options + backend_prefix_store, group_rank, group_size, pg_options ) backend_type = ProcessGroup.BackendType.NCCL elif backend_str == Backend.UCC and is_ucc_available(): @@ -1850,7 +1841,7 @@ def _new_process_group_helper( dist_backend_opts.group_id = group_name dist_backend_opts.global_ranks_in_group = global_ranks_in_group - backend_class = creator_fn(dist_backend_opts, backend_options) + backend_class = creator_fn(dist_backend_opts, pg_options) # Set sequence numbers for gloo and nccl backends. if backend_str == Backend.GLOO: @@ -4485,15 +4476,12 @@ def split_group( global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] prefix_store = PrefixStore(f"{group_name}/", default_store) - # We register the backend after initializing and timeout is set in pg_options. + base_pg_options = ProcessGroup.Options(backend=str(backend)) + base_pg_options._timeout = timeout pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - len(my_group), + prefix_store, group_rank, len(my_group), base_pg_options ) - backend_type = ProcessGroup.BackendType.NCCL pg.bound_device_id = device_id - pg._set_default_backend(backend_type) pg_options._timeout = timeout pg_options.split_from = parent_backend @@ -4503,6 +4491,7 @@ def split_group( backend_class = ProcessGroupNCCL( prefix_store, group_rank, len(my_group), pg_options ) + backend_type = ProcessGroup.BackendType.NCCL backend_class._set_sequence_number_for_group() pg._register_backend(torch.device("cuda"), backend_type, backend_class) @@ -4625,7 +4614,7 @@ def _new_group_with_tag( ranks=None, timeout=None, backend=None, - backend_options=None, + pg_options=None, pg_tag=None, use_local_synchronization=False, group_desc=None, @@ -4700,7 +4689,7 @@ def _new_group_with_tag( backend, default_store, group_name, - backend_options=backend_options, + pg_options=pg_options, timeout=timeout, pg_tag=pg_tag, device_id=device_id, From 8850339e1dec98eecdcda00a884faacfaee32331 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 29 Aug 2024 22:33:26 -0700 Subject: [PATCH 0284/1018] Rewrite `unsafe_remove_auto_functionalized_pass` using `decompose_auto_functionalized` (#134831) `unsafe_remove_auto_functionalized_pass` can be written as using `decompose_auto_functionalized`, this way we do not have to update it each time we do a change to `auto_functionalize` (Ex https://github.com/pytorch/pytorch/pull/134409) , and we avoid duplicate logics implemented in two different ways. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134831 Approved by: https://github.com/zou3519 --- test/export/test_passes.py | 27 ++++--- torch/_inductor/pattern_matcher.py | 8 +- .../_remove_auto_functionalized_pass.py | 76 +++++-------------- 3 files changed, 38 insertions(+), 73 deletions(-) diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 7bcd50a7b40cf1..624b528290c9a8 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -1166,20 +1166,19 @@ def forward(self, x): x = torch.randn([3, 3]) ep = export(mod, (x,)) inplace_ep = unsafe_remove_auto_functionalized_pass(ep) - - nodes = inplace_ep.graph.nodes - getitems = 0 - for node in nodes: - if node.op == "call_function": - self.assertFalse(node.target is auto_functionalized) - if node.target is operator.getitem: - getitems += 1 - self.assertEqual(getitems, 2) # tuple return of len 2 - - out_specs = inplace_ep.graph_signature.output_specs - self.assertEqual(out_specs[0].arg.name, "b_state") # state - self.assertEqual(out_specs[1].arg.name, "getitem") # tuple return 1 - self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2 + graph_text = str(inplace_ep.graph) + self.assertExpectedInline( + graph_text, + """\ +graph(): + %b_state : [num_users=2] = placeholder[target=b_state] + %x : [num_users=1] = placeholder[target=x] + %custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\ +default](args = (%x, %b_state), kwargs = {}) + %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {}) + %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {}) + return (b_state, getitem_3, getitem_4)""", + ) @unittest.skipIf(not TEST_CUDA, "requires cuda") def test_move_to_device_pass(self): diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 9c78d09ae9ab0b..c31d7aa9b668ff 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -236,9 +236,13 @@ def replace_by_example( replacement graph. """ - from torch._inductor.virtualized import V + from torch._inductor.virtualized import NullHandler, V - context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext + context = ( + V.fake_mode + if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None)) + else contextlib.nullcontext() + ) with context: if trace_fn is None: diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py index 930915f96f9bd1..76e254119458b6 100644 --- a/torch/export/_remove_auto_functionalized_pass.py +++ b/torch/export/_remove_auto_functionalized_pass.py @@ -5,64 +5,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import operator -from typing import List import torch from torch._higher_order_ops.auto_functionalize import ( auto_functionalized, get_mutable_arg_names, ) +from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized from torch.export import ExportedProgram -def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes): - # Update every use of the HOP - for node in reversed(auto_functionalize_nodes): - func = node.args[0] - original_kwargs = node.kwargs - assert isinstance(func, torch._ops.OpOverload) - - with ep.graph.inserting_before(node): - # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine? - new_node = ep.graph.call_function(func, kwargs=node.kwargs) - for k, v in node.meta.items(): - new_node.meta[k] = v - - # Replace auto_functionalize(func, args) with just func(args) - node.replace_all_uses_with(new_node) - - mutable_args_names = get_mutable_arg_names(new_node.target) - - # update the users of the auto_func node (the getitem nodes) - for user in list(new_node.users.keys()): - assert user.target == operator.getitem - # getitem corresponding to a mutated input, just replace all uses with the original input - if user.args[1] >= len(func._schema.returns): - assert user.args[1] <= len(func._schema.returns) + len( - mutable_args_names - ) - - # If the result of getitem was used in an output node, update the output spec with the correct name - adjusted_index = user.args[1] - len(func._schema.returns) - original_arg = original_kwargs[mutable_args_names[adjusted_index]] - - # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order - # of the getitem calls following the HOP. - user.replace_all_uses_with(original_arg) - - if len(func._schema.returns) == 1: - # If the function has 1 return then it will just directly return the - # result -- we don't need a getitem. So we can replace all the - # getitem(auto_functionalized, 0) with just the note itself. - for user in list(new_node.users.keys()): - if user.args[1] == 0: - user.replace_all_uses_with(new_node) - - new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)] - ep.graph.erase_node(node) - - ep.graph.eliminate_dead_code() +def remove_self_clone(graph: torch.fx.Graph): + for node in graph.nodes: + if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]: + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) def unsafe_remove_auto_functionalized_pass( @@ -73,15 +30,20 @@ def unsafe_remove_auto_functionalized_pass( and modifies the calling EP inplace to have the original mutator op. This pass doesn't perform safety checks to make sure that this inplace mutation is safe. """ - auto_functionalize_nodes: List[torch.fx.Node] = [] - for module in ep.graph_module.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - for node in ep.graph.nodes: - if node.op == "call_function" and node.target is auto_functionalized: - auto_functionalize_nodes.append(node) with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): - _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes) + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in ep.graph.nodes: + if node.op == "call_function" and node.target is auto_functionalized: + func = node.args[0] + assert isinstance(func, torch._ops.OpOverload) + mutable_args_names = get_mutable_arg_names(func) + # re-inplace everything + node.meta["only_clone_these_tensors"] = [] + decompose_auto_functionalized(ep.graph) + remove_self_clone(ep.graph) + ep.graph.eliminate_dead_code() return ep From 73dca41304d6bcb8da9edc5c338f7fc9daac5a65 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Fri, 30 Aug 2024 16:47:05 +0000 Subject: [PATCH 0285/1018] [export] Make sure getitem replacement are synced with module call graph. (#134830) Summary: When we are placing nodes in the graph, we should also replace the references in module_call_graph. Test Plan: buck2 run 'fbcode//mode/opt' torchrec/fb/ir/tests:test_serializer -- --filter-regex test_serialize_deserialize_vlea buck2 test 'fbcode//mode/opt' fbcode//torchrec/fb/ir/tests:test_serializer -- --exact 'torchrec/fb/ir/tests:test_serializer - torchrec.fb.ir.tests.test_serializer.TestSerializer: test_serialize_empty_value_vlea' --run-disabled buck2 test 'fbcode//mode/opt' fbcode//torchrec/fb/ir/tests:test_serializer -- --exact 'torchrec/fb/ir/tests:test_serializer - torchrec.fb.ir.tests.test_serializer.TestSerializer: test_deserialized_device_vle' --run-disabled Differential Revision: D62014035 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134830 Approved by: https://github.com/angelayi --- torch/_export/verifier.py | 20 ++++++++++++++++++++ torch/export/exported_program.py | 23 ++++++++++++++++++++--- torch/export/unflatten.py | 17 +++++++---------- 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index b272621f8c94b8..68c5bcaae39af6 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -153,6 +153,7 @@ def check_additional(self, gm: GraphModule) -> None: @final def check(self, ep: "ExportedProgram") -> None: self._check_graph_module(ep.graph_module) + _verify_exported_program_module_call_graph(ep) _verify_exported_program_signature(ep) @final @@ -271,6 +272,25 @@ class TrainingIRVerifier(Verifier): dialect = "TRAINING" +def _verify_exported_program_module_call_graph(exported_program) -> None: + module_call_graph = exported_program.module_call_graph + nodes = { + node.name for node in exported_program.graph.nodes + } + for entry in module_call_graph: + if entry.signature is not None: + for arg in entry.signature.inputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Input {arg.name} does not exist in the graph." + ) + for arg in entry.signature.outputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Output {arg.name} does not exist in the graph." + ) + + def _verify_exported_program_signature(exported_program) -> None: # Check ExportedProgram signature matches gs = exported_program.graph_signature diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 447c12e864d5fb..cf547eb3f7e6e0 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -83,6 +83,14 @@ class ModuleCallSignature: in_spec: pytree.TreeSpec out_spec: pytree.TreeSpec + def replace_all_uses_with(self, original_node, new_node): + for i in self.inputs: + if i.name == original_node.name: + i.name = new_node.name + for o in self.outputs: + if o.name == original_node.name: + o.name = new_node.name + @dataclasses.dataclass class ModuleCallEntry: @@ -655,7 +663,9 @@ def update_arg(old_arg, new_ph): return gm, new_graph_signature -def _common_getitem_elimination_pass(gm: torch.fx.GraphModule, graph_signature): +def _common_getitem_elimination_pass( + gm: torch.fx.GraphModule, graph_signature, module_call_graph +): with gm._set_replace_hook(graph_signature.get_replace_hook()): for module in gm.modules(): if not isinstance(module, torch.fx.GraphModule): @@ -663,12 +673,17 @@ def _common_getitem_elimination_pass(gm: torch.fx.GraphModule, graph_signature): node_id: Dict[torch.fx.Node, str] = {} getitems: Dict[str, torch.fx.Node] = {} - for node in module.graph.nodes: + for node in list(module.graph.nodes): if node.op == "call_function" and node.target == operator.getitem: source, idx = node.args new_id = f"{node_id[source]}.{idx}" if new_id in getitems: node.replace_all_uses_with(getitems[new_id]) + for entry in module_call_graph: + if entry.signature is not None: + entry.signature.replace_all_uses_with( + node, getitems[new_id] + ) module.graph.erase_node(node) else: getitems[new_id] = node @@ -751,7 +766,9 @@ def __init__( if isinstance(root, torch.fx.GraphModule): self._graph_module.meta.update(root.meta) - _common_getitem_elimination_pass(self._graph_module, graph_signature) + _common_getitem_elimination_pass( + self._graph_module, graph_signature, module_call_graph + ) self._graph_signature: ExportGraphSignature = graph_signature self._state_dict: Dict[str, Any] = state_dict self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 45f7b992827e79..6c0e75e4e0076a 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -407,6 +407,7 @@ def check_module_inputs(module, scope): assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list( fqn_order.keys() ) + self.graph.lint() def _print_graph(self): for fqn, mod in self.named_modules(): @@ -838,14 +839,8 @@ def copy_sym_call_function(self, x): # To avoid this we copy these call_function nodes with sym_type results. # This should however only be done for sym_type nodes - call_function nodes on tensors # should not be deduplicated in the first place. - args = tuple( - self.remap_input(_x) if isinstance(_x, torch.fx.Node) else _x - for _x in x.args - ) - kwargs = { - k: self.remap_input(_x) if isinstance(_x, torch.fx.Node) else _x - for k, _x in x.kwargs.items() - } + args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args) + kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs) node = self.graph.call_function(x.target, args, kwargs) node.meta = copy.copy(x.meta) self.node_map[x] = node @@ -870,7 +865,10 @@ def remap_input(self, x): with self.parent.graph.inserting_before(self.parent_call_module): self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) return self.node_to_placeholder[x] - elif x.op == "call_function": + elif x.op == "call_function" and ( + x.target == torch.ops.aten.sym_size.int + or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator") + ): # export deduplicates sym_size nodes, and may need to re-copy them # if module call signature needs to be preserved self.copy_sym_call_function(x) @@ -879,7 +877,6 @@ def remap_input(self, x): raise RuntimeError( f"Could not run remap_input() on op type: {x.op} for node {x}" ) - return self.node_to_placeholder[x] def finalize_outputs(self): orig_outputs = [] From 3935d4828572c53533e1a157873fbe6db762ba11 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Fri, 30 Aug 2024 14:52:40 +0000 Subject: [PATCH 0286/1018] Add support for 32KB multi_tensor_apply kernel arguments (#134373) ## Benchmark On H100 SXM (HBM2e, 500W TDP), CUDA Toolkit=12.2, Driver Version=535.154.05, with [this script](https://gist.github.com/yifuwang/178c1f4bf951c5794ea79c04d90e44fa) (`torch._foreach_copy_`): **Baseline** ``` https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmp0g_x4sys device ms: 0.891, cpu ms: 7.200 memory bandwidth: 1457.727 GB/s ``` Single iteration trace: image **This PR** ``` https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmp3jqiugli device ms: 0.683, cpu ms: 6.745 memory bandwidth: 1902.010 GB/s ``` Single iteration trace: image ## Binary Size and Kernel Specialization The binary size for `libtorch_cuda.so` increased 6MB (243MB -> 249MB). ``` // NOTE: [32KB kernel argument size support] // 32KB kernel argument size support has three requirements: // - CUDART_VERSION >= 12010 // - Driver version >= 530 // - GPU arch >= VOLTA // // Due to minor version compatibility, it possible for binaries built with // CUDART_VERSION >= 12010 to run with driver version < 530. Since driver // version can only be checked at runtime, if CUDART_VERSION >= 12010, we have // to build both 4KB and 32KB kernels and determine the appropriate kernel to // dispatch at runtime. // // - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated. // // - If CUDART_VERSION >= 12010: // - Host code: // - We always instantiate the launching stub for both 4KB and 32KB kernels. // - Device code: // - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB // kernels. // - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty // 32KB kernel (formal parameter space overflowed). Thus, we only // instantiate a declaration for 32KB kernels. This is valid as long as the // declaration-only kernel is not launched. // // - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and // GPU arch >= VOLTA. // // - TODO(yifu): once there's a CUDART version that is not compatible with any // driver version below 530, we can determine at compile time to not compile // the kernels for 4KB kernel argument size. // // https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134373 Approved by: https://github.com/eqy, https://github.com/crcrpar, https://github.com/janeyx99 --- aten/src/ATen/native/cuda/AmpKernels.cu | 33 +-- .../ATen/native/cuda/ForeachBinaryOpList.cu | 88 ++++--- .../ATen/native/cuda/ForeachBinaryOpScalar.cu | 42 ++-- .../native/cuda/ForeachBinaryOpScalarList.cu | 44 ++-- .../cuda/ForeachBinaryOpScalarTensor.cu | 46 ++-- aten/src/ATen/native/cuda/ForeachFunctors.cuh | 117 ++++++--- .../ATen/native/cuda/ForeachPointwiseOp.cu | 84 ++++--- aten/src/ATen/native/cuda/ForeachReduceOp.cu | 81 ++++-- aten/src/ATen/native/cuda/ForeachTernaryOp.cu | 80 +++--- aten/src/ATen/native/cuda/ForeachUnaryOp.cu | 55 ++-- aten/src/ATen/native/cuda/FusedSgdKernel.cu | 117 +++++---- .../src/ATen/native/cuda/MultiTensorApply.cpp | 25 ++ .../src/ATen/native/cuda/MultiTensorApply.cuh | 235 +++++++++++++----- .../native/cuda/fused_adam_amsgrad_impl.cu | 66 +++-- aten/src/ATen/native/cuda/fused_adam_impl.cu | 66 +++-- .../src/ATen/native/cuda/fused_adam_utils.cuh | 10 +- .../native/cuda/fused_adamw_amsgrad_impl.cu | 66 +++-- aten/src/ATen/native/cuda/fused_adamw_impl.cu | 66 +++-- build_variables.bzl | 1 + 19 files changed, 852 insertions(+), 470 deletions(-) create mode 100644 aten/src/ATen/native/cuda/MultiTensorApply.cpp diff --git a/aten/src/ATen/native/cuda/AmpKernels.cu b/aten/src/ATen/native/cuda/AmpKernels.cu index 8c161ca627229d..07cb84f95924ec 100644 --- a/aten/src/ATen/native/cuda/AmpKernels.cu +++ b/aten/src/ATen/native/cuda/AmpKernels.cu @@ -157,21 +157,24 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads, using opmath_t = at::opmath_type; // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly. - multi_tensor_apply<1>(tensor_lists, - UnaryOpFunctor(), - [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { - // There is a slight asymmetry here with the TensorIterator kernel above. - // MTA Functors ensure val comes in as opmath_t rather than scalar_t. - if (!isfinite_ensure_cuda_math(val)) { - *found_inf_ptr = 1.f; - } - // Every thread accesses inv_scale, but it will hit in cache. - const auto inv_scale_val = *inv_scale_ptr; - return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); - }); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { + // There is a slight asymmetry here with the TensorIterator kernel above. + // MTA Functors ensure val comes in as opmath_t rather than scalar_t. + if (!isfinite_ensure_cuda_math(val)) { + *found_inf_ptr = 1.f; + } + // Every thread accesses inv_scale, but it will hit in cache. + const auto inv_scale_val = *inv_scale_ptr; + return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); + }); + }); }); } diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu index 533aa38c04cf5e..90c180cba3119d 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu @@ -41,15 +41,18 @@ std::vector foreach_tensor_list_op( tensor_lists.emplace_back(std::move(vec_res)); using opmath_t = at::opmath_type; - multi_tensor_apply<3>( - tensor_lists, - BinaryOpListAlphaFunctor< - T, - /* depth */ 3, - /* r_args_depth */ 2, - /* res_arg_index */ 2>(), - Op(), - alpha.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<3>( + tensor_lists, + BinaryOpListAlphaFunctor< + T, + /* depth */ 3, + /* r_args_depth */ 2, + /* res_arg_index */ 2, + large_kernel_arg>(), + Op(), + alpha.to()); + }); return tensor_lists[2]; } @@ -64,15 +67,18 @@ void foreach_tensor_list_op_( tensor_lists.emplace_back(tensors2.vec()); using opmath_t = at::opmath_type; - multi_tensor_apply<2>( - tensor_lists, - BinaryOpListAlphaFunctor< - T, - /* depth */ 2, - /* r_args_depth */ 2, - /* res_arg_index */ 0>(), - Op(), - alpha.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + BinaryOpListAlphaFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 2, + /* res_arg_index */ 0, + large_kernel_arg>(), + Op(), + alpha.to()); + }); increment_version(tensors1); } @@ -331,13 +337,15 @@ template < typename src_t, int depth, int r_args_depth, - int res_arg_index> + int res_arg_index, + bool large_kernel_arg> struct CopyFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1); template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -420,14 +428,17 @@ void foreach_tensor_copy_list_kernel_cuda_( using opmath_t = at::opmath_type; AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] { if constexpr (std::is_same_v) { - multi_tensor_apply<2>( - tensor_lists, - UnaryOpFunctor< - scalar_t, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1>(), - Copy()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + UnaryOpFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1, + large_kernel_arg>(), + Copy()); + }); } else { // Ref: // https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301 @@ -435,15 +446,18 @@ void foreach_tensor_copy_list_kernel_cuda_( TORCH_WARN_ONCE( "Casting complex values to real discards the imaginary part"); } - multi_tensor_apply<2>( - tensor_lists, - CopyFunctor< - scalar_t, - src_t, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1>(), - Copy()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + CopyFunctor< + scalar_t, + src_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1, + large_kernel_arg>(), + Copy()); + }); } }); }); diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu index 80d748dd3579b5..1b0045daec09c2 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu @@ -36,15 +36,18 @@ std::vector foreach_binary_op( tensor_lists.emplace_back(std::move(vec_res)); using opmath_t = at::opmath_type; - multi_tensor_apply<2>( - tensor_lists, - BinaryOpScalarFunctor< - T, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1>(), - Op(), - scalar.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + BinaryOpScalarFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1, + large_kernel_arg>(), + Op(), + scalar.to()); + }); return tensor_lists[1]; } @@ -54,15 +57,18 @@ void foreach_binary_op_(TensorList tensors, const Scalar& scalar) { tensor_lists.emplace_back(tensors.vec()); using opmath_t = at::opmath_type; - multi_tensor_apply<1>( - tensor_lists, - BinaryOpScalarFunctor< - T, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0>(), - Op(), - scalar.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>( + tensor_lists, + BinaryOpScalarFunctor< + T, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0, + large_kernel_arg>(), + Op(), + scalar.to()); + }); increment_version(tensors); } diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu index dcb93188b5e690..c423006719294d 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu @@ -36,16 +36,19 @@ std::vector foreach_binary_op( tensor_lists.emplace_back(vec_res); using opmath_t = at::opmath_type; - multi_tensor_apply<2, opmath_t>( - tensor_lists, - scalars, - BinaryOpScalarListFunctor< - T, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1>(), - - Op()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2, opmath_t>( + tensor_lists, + scalars, + BinaryOpScalarListFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1, + large_kernel_arg>(), + + Op()); + }); return tensor_lists[1]; } @@ -55,15 +58,18 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef scalars) { tensor_lists.emplace_back(tensors.vec()); using opmath_t = at::opmath_type; - multi_tensor_apply<1, opmath_t>( - tensor_lists, - scalars, - BinaryOpScalarListFunctor< - T, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0>(), - Op()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1, opmath_t>( + tensor_lists, + scalars, + BinaryOpScalarListFunctor< + T, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0, + large_kernel_arg>(), + Op()); + }); increment_version(tensors); } diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu index ad5eeee5ebec41..4163b30b2889eb 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu @@ -46,16 +46,19 @@ std::vector foreach_binary_op( tensor_lists.emplace_back(std::move(vec_res)); using opmath_t = at::opmath_type; - multi_tensor_apply<2>( - tensor_lists, - BinaryOpScalarTensorFunctor< - T, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1>(), - Op(), - scalar.data_ptr(), - alpha.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + BinaryOpScalarTensorFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1, + large_kernel_arg>(), + Op(), + scalar.data_ptr(), + alpha.to()); + }); return tensor_lists[1]; } @@ -81,16 +84,19 @@ void foreach_binary_op_( tensor_lists.emplace_back(tensors.vec()); using opmath_t = at::opmath_type; - multi_tensor_apply<1>( - tensor_lists, - BinaryOpScalarTensorFunctor< - T, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0>(), - Op(), - scalar.data_ptr(), - alpha.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>( + tensor_lists, + BinaryOpScalarTensorFunctor< + T, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0, + large_kernel_arg>(), + Op(), + scalar.data_ptr(), + alpha.to()); + }); increment_version(tensors); } diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index 55e4fd7a598907..fb8d8492b039e0 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -18,10 +18,10 @@ inline void increment_version(TensorList tensors) { } // Initializes args and checks if all args are aligned -template +template __device__ bool init_args( T** args, - TensorListMetadata& tl, + TensorListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { @@ -38,10 +38,10 @@ __device__ bool init_args( } // Initializes args and checks if all args are aligned -template +template __device__ bool init_args( T** args, - TensorListScalarListMetadata& tl, + TensorListScalarListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { @@ -57,10 +57,10 @@ __device__ bool init_args( return all_aligned; } -template +template __device__ bool init_args( T** args, - FusedOptimizerTensorListMetadata& tl, + FusedOptimizerTensorListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { @@ -203,13 +203,19 @@ __device__ __forceinline__ void pointwise_op_scalar( // // Binary Functors // -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct BinaryOpScalarFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, opmath_t scalar) { const int tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -227,13 +233,19 @@ struct BinaryOpScalarFunctor { } }; -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct BinaryOpScalarListFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListScalarListMetadata& tl, + TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -251,13 +263,19 @@ struct BinaryOpScalarListFunctor { } }; -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct BinaryOpListAlphaFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, opmath_t alpha) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -303,13 +321,19 @@ struct BinaryOpListAlphaFunctor { } }; -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct BinaryOpScalarTensorFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, T* scalar, opmath_t alpha) { @@ -361,11 +385,17 @@ struct BinaryOpScalarTensorFunctor { // Unary Functors // -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct ZeroFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata<1>& tl) { + TensorListMetadata<1, large_kernel_arg>& tl) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; @@ -401,13 +431,19 @@ struct ZeroFunctor { } }; -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct UnaryOpFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -453,13 +489,19 @@ struct UnaryOpFunctor { // Pointwise Functors // -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct PointwiseOpScalarFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, opmath_t scalar) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -477,13 +519,19 @@ struct PointwiseOpScalarFunctor { } }; -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct PointwiseOpScalarListFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListScalarListMetadata& tl, + TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -501,13 +549,14 @@ struct PointwiseOpScalarListFunctor { } }; -template +template struct PointwiseOpListFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -552,13 +601,19 @@ struct PointwiseOpListFunctor { } }; -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct TernaryOpListFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op) { static_assert(depth == 3 || depth == 4, ""); static_assert(depth >= r_args_depth, ""); @@ -606,13 +661,19 @@ struct TernaryOpListFunctor { } }; -template +template < + typename T, + int depth, + int r_args_depth, + int res_arg_index, + bool large_kernel_arg> struct TernaryOpScalarFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, opmath_t alpha) { static_assert(depth == 2 || depth == 3, ""); diff --git a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu index 7a3276c44750a6..bad152c322878b 100644 --- a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu +++ b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu @@ -46,15 +46,18 @@ std::vector foreach_pointwise_op( "foreach_pointwise_op_cuda", [&]() { using opmath_t = at::opmath_type; - multi_tensor_apply<4>( - tensor_lists, - PointwiseOpScalarFunctor< - scalar_t, - /* depth */ 4, - /* r_args_depth */ 3, - /* res_arg_index */ 3>(), - Op(), - scalar.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<4>( + tensor_lists, + PointwiseOpScalarFunctor< + scalar_t, + /* depth */ 4, + /* r_args_depth */ 3, + /* res_arg_index */ 3, + large_kernel_arg>(), + Op(), + scalar.to()); + }); }); return tensor_lists[3]; @@ -78,15 +81,18 @@ void foreach_pointwise_op_( "foreach_pointwise_op__cuda", [&]() { using opmath_t = at::opmath_type; - multi_tensor_apply<3>( - tensor_lists, - PointwiseOpScalarFunctor< - scalar_t, - /* depth */ 3, - /* r_args_depth */ 3, - /* res_arg_index */ 0>(), - Op(), - scalar.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<3>( + tensor_lists, + PointwiseOpScalarFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 3, + /* res_arg_index */ 0, + large_kernel_arg>(), + Op(), + scalar.to()); + }); }); increment_version(input); } @@ -110,15 +116,18 @@ void foreach_pointwise_op_( "foreach_pointwise_op__cuda", [&]() { using opmath_t = at::opmath_type; - multi_tensor_apply<3, opmath_t>( - tensor_lists, - scalars, - PointwiseOpScalarListFunctor< - scalar_t, - /* depth */ 3, - /* r_args_depth */ 3, - /* res_arg_index */ 0>(), - Op()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<3, opmath_t>( + tensor_lists, + scalars, + PointwiseOpScalarListFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 3, + /* res_arg_index */ 0, + large_kernel_arg>(), + Op()); + }); }); increment_version(input); } @@ -149,15 +158,18 @@ std::vector foreach_pointwise_op( "foreach_pointwise_op_cuda", [&]() { using opmath_t = at::opmath_type; - multi_tensor_apply<4, opmath_t>( - tensor_lists, - scalars, - PointwiseOpScalarListFunctor< - scalar_t, - /* depth */ 4, - /* r_args_depth */ 3, - /* res_arg_index */ 3>(), - Op()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<4, opmath_t>( + tensor_lists, + scalars, + PointwiseOpScalarListFunctor< + scalar_t, + /* depth */ 4, + /* r_args_depth */ 3, + /* res_arg_index */ 3, + large_kernel_arg>(), + Op()); + }); }); return tensor_lists[3]; diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 61793fd5f9e08c..1ce18b71aed3e4 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -50,11 +50,13 @@ template < typename T, int depth = 1, int r_args_depth = 1, - int res_arg_index = 0> + int res_arg_index = 0, + bool large_kernel_arg = false> struct LpMaxFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, T* output_per_tensor_ptr, const int max_chunks_per_tensor) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -178,11 +180,13 @@ std::vector foreach_tensor_max_cuda(TensorList tensors) { tensor_lists[0][0].scalar_type(), "foreach_tensor_max_cuda_scalar_type", [&]() { - multi_tensor_apply<1>( - tensor_lists, - LpMaxFunctor(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>( + tensor_lists, + LpMaxFunctor(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); const at::cuda::OptionalCUDAGuard device_guard( @@ -239,12 +243,14 @@ template < typename out_t, int depth = 1, int r_args_depth = 1, - int res_arg_index = 0> + int res_arg_index = 0, + bool large_kernel_arg = false> struct LpNormFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; using out_opmath_t = typename at::opmath_type; __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, out_opmath_t* output_per_tensor_ptr, const int max_chunks_per_tensor) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -476,23 +482,50 @@ std::vector foreach_tensor_norm_cuda( output_dtype, "foreach_tensor_norm_cuda_out_dtype", [&]() { using out_opmath_t = typename at::opmath_type; if (p == static_cast(1)) { - multi_tensor_apply<1>( - tensor_lists, - LpNormFunctor(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::L1, + out_t, + 1, + 1, + 0, + large_kernel_arg>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + }); } else if (p == static_cast(2)) { - multi_tensor_apply<1>( - tensor_lists, - LpNormFunctor(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::L2, + out_t, + 1, + 1, + 0, + large_kernel_arg>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + }); } else if (p == std::numeric_limits::infinity()) { - multi_tensor_apply<1>( - tensor_lists, - LpNormFunctor(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor< + scalar_t, + NormType::LInf, + out_t, + 1, + 1, + 0, + large_kernel_arg>(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + }); } C10_CUDA_KERNEL_LAUNCH_CHECK(); const at::cuda::OptionalCUDAGuard device_guard( diff --git a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu index e13f2015f1d819..e73862d3be6692 100644 --- a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu @@ -46,14 +46,17 @@ std::vector foreach_tensor_lerp_ternary_cuda( "foreach_tensor_lerp_ternary_cuda", [&]() { using opmath_t = typename at::opmath_type; - multi_tensor_apply<4>( - tensor_lists, - TernaryOpListFunctor< - scalar_t, - /* depth */ 4, - /* r_args_depth */ 3, - /* res_arg_index */ 3>(), - LerpFunctor()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<4>( + tensor_lists, + TernaryOpListFunctor< + scalar_t, + /* depth */ 4, + /* r_args_depth */ 3, + /* res_arg_index */ 3, + large_kernel_arg>(), + LerpFunctor()); + }); }); return tensor_lists[3]; @@ -77,14 +80,17 @@ void foreach_tensor_lerp_ternary_cuda_( "foreach_tensor_lerp_ternary_cuda_", [&]() { using opmath_t = typename at::opmath_type; - multi_tensor_apply<3>( - tensor_lists, - TernaryOpListFunctor< - scalar_t, - /* depth */ 3, - /* r_args_depth */ 3, - /* res_arg_index */ 0>(), - LerpFunctor()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<3>( + tensor_lists, + TernaryOpListFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 3, + /* res_arg_index */ 0, + large_kernel_arg>(), + LerpFunctor()); + }); }); increment_version(tensors1); } @@ -113,15 +119,18 @@ std::vector foreach_tensor_lerp_list_cuda( "foreach_tensor_lerp_scalar_cuda", [&]() { using opmath_t = typename at::opmath_type; - multi_tensor_apply<3>( - tensor_lists, - TernaryOpScalarFunctor< - scalar_t, - /* depth */ 3, - /* r_args_depth */ 2, - /* res_arg_index */ 2>(), - LerpFunctor(), - weight.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<3>( + tensor_lists, + TernaryOpScalarFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 2, + /* res_arg_index */ 2, + large_kernel_arg>(), + LerpFunctor(), + weight.to()); + }); }); return tensor_lists[2]; @@ -145,15 +154,18 @@ void foreach_tensor_lerp_list_cuda_( "foreach_tensor_lerp_scalar_cuda_", [&]() { using opmath_t = typename at::opmath_type; - multi_tensor_apply<2>( - tensor_lists, - TernaryOpScalarFunctor< - scalar_t, - /* depth */ 2, - /* r_args_depth */ 2, - /* res_arg_index */ 0>(), - LerpFunctor(), - weight.to()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + TernaryOpScalarFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 2, + /* res_arg_index */ 0, + large_kernel_arg>(), + LerpFunctor(), + weight.to()); + }); }); } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index 1a969cfdbdcc43..fa333b4327a943 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -56,14 +56,17 @@ std::vector foreach_unary_op(TensorList tensors) { tensor_lists.emplace_back(std::move(vec_res)); using opmath_t = typename at::opmath_type; - multi_tensor_apply<2>( - tensor_lists, - UnaryOpFunctor< - scalar_t, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1>(), - Op()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + UnaryOpFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1, + large_kernel_arg>(), + Op()); + }); return tensor_lists[1]; } @@ -73,14 +76,17 @@ void foreach_unary_op_(TensorList tensors) { std::vector> tensor_lists; tensor_lists.emplace_back(tensors.vec()); using opmath_t = typename at::opmath_type; - multi_tensor_apply<1>( - tensor_lists, - UnaryOpFunctor< - scalar_t, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0>(), - Op()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>( + tensor_lists, + UnaryOpFunctor< + scalar_t, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0, + large_kernel_arg>(), + Op()); + }); increment_version(tensors); } @@ -395,13 +401,16 @@ void foreach_tensor_zero_cuda_(TensorList tensors) { tensors[0].scalar_type(), "foreach_zero_cuda_", [&]() { - multi_tensor_apply<1>( - tensor_lists, - ZeroFunctor< - scalar_t, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0>()); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<1>( + tensor_lists, + ZeroFunctor< + scalar_t, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0, + large_kernel_arg>()); + }); }); } diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu index beca6f75563b7b..b8ab3af9bd7d9f 100644 --- a/aten/src/ATen/native/cuda/FusedSgdKernel.cu +++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu @@ -56,14 +56,15 @@ C10_DEVICE __forceinline__ void sgd_math( } } -template +template struct FusedSgdMathFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; static_assert( depth == 2 || depth == 3, "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0"); C10_DEVICE __forceinline__ void operator()( const int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, const double weight_decay, const double momentum, const float* lr_ptr, @@ -172,19 +173,21 @@ void _fused_sgd_with_momentum_kernel_cuda_( params[0].scalar_type(), "fused_sgd_with_momentum_kernel_cuda", [&]() { - multi_tensor_apply<3>( - tensor_lists, - FusedSgdMathFunctor(), - weight_decay, - momentum, - lr_ptr, - lr, - dampening, - nesterov, - maximize, - is_first_step, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<3>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr_ptr, + lr, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale_ptr, + found_inf_ptr); + }); }); } @@ -246,19 +249,21 @@ void _fused_sgd_with_momentum_kernel_cuda_( params[0].scalar_type(), "fused_sgd_with_momentum_kernel_cuda", [&]() { - multi_tensor_apply<3>( - tensor_lists, - FusedSgdMathFunctor(), - weight_decay, - momentum, - lr.data_ptr(), - 1.0, - dampening, - nesterov, - maximize, - is_first_step, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<3>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr.data_ptr(), + 1.0, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale_ptr, + found_inf_ptr); + }); }); } @@ -312,19 +317,21 @@ void _fused_sgd_kernel_cuda_( params[0].scalar_type(), "fused_sgd_kernel_cuda", [&]() { - multi_tensor_apply<2>( - tensor_lists, - FusedSgdMathFunctor(), - weight_decay, - momentum, - lr_ptr, - lr, - dampening, - nesterov, - maximize, - /* is_first_step */ false, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr_ptr, + lr, + dampening, + nesterov, + maximize, + /* is_first_step */ false, + grad_scale_ptr, + found_inf_ptr); + }); }); } @@ -404,19 +411,21 @@ void _fused_sgd_kernel_cuda_( params[0].scalar_type(), "fused_sgd_kernel_cuda", [&]() { - multi_tensor_apply<2>( - tensor_lists, - FusedSgdMathFunctor(), - weight_decay, - momentum, - lr.data_ptr(), - 1.0, - dampening, - nesterov, - maximize, - /* is_first_step */ false, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply<2>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr.data_ptr(), + 1.0, + dampening, + nesterov, + maximize, + /* is_first_step */ false, + grad_scale_ptr, + found_inf_ptr); + }); }); } diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cpp b/aten/src/ATen/native/cuda/MultiTensorApply.cpp new file mode 100644 index 00000000000000..208adb0d63fcfa --- /dev/null +++ b/aten/src/ATen/native/cuda/MultiTensorApply.cpp @@ -0,0 +1,25 @@ +#include +#include + +#include + +namespace at::native { + +bool supports_large_kernel_arg() { +#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && CUDART_VERSION >= 12010 + static std::optional supports_large_kernel_arg_ = std::nullopt; + if (!supports_large_kernel_arg_.has_value()) { + int driver_ver = 0; + AT_CUDA_CHECK(cudaDriverGetVersion(&driver_ver)); + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + supports_large_kernel_arg_ = (driver_ver >= 12010) && prop->major >= 7; + } + const bool is_capturing = at::cuda::currentStreamCaptureStatusMayInitCtx() != + at::cuda::CaptureStatus::None; + return !is_capturing && *supports_large_kernel_arg_; +#else + return false; +#endif +} + +} // namespace at::native diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh index 17f14444abd14a..9b9c3ee7cc93f0 100644 --- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh +++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh @@ -8,20 +8,105 @@ namespace at::native { +// NOTE: [32KB kernel argument size support] +// 32KB kernel argument size support has three requirements: +// - CUDART_VERSION >= 12010 +// - Driver version >= 530 +// - GPU arch >= VOLTA +// +// Due to minor version compatibility, it possible for binaries built with +// CUDART_VERSION >= 12010 to run with driver version < 530. Since driver +// version can only be checked at runtime, if CUDART_VERSION >= 12010, we have +// to build both 4KB and 32KB kernels and determine the appropriate kernel to +// dispatch at runtime. +// +// - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated. +// +// - If CUDART_VERSION >= 12010: +// - Host code: +// - We always instantiate the launching stub for both 4KB and 32KB kernels. +// - Device code: +// - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB +// kernels. +// - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty +// 32KB kernel (formal parameter space overflowed). Thus, we only +// instantiate a declaration for 32KB kernels. This is valid as long as the +// declaration-only kernel is not launched. +// +// - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and +// GPU arch >= VOLTA. +// +// - TODO(yifu): once there's a CUDART version that is not compatible with any +// driver version below 530, we can determine at compile time to not compile +// the kernels for 4KB kernel argument size. +// +// https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/ +bool supports_large_kernel_arg(); + namespace { static constexpr int64_t kILP = 4; static constexpr int64_t kChunkSize = 65536; static constexpr int64_t kBlockSize = 512; -// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` -// TensorListMetadata has to be < 4KB - the limit for kernel launch argument -static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; -static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; -static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; -static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { - 72, - 60}; +// MSVC has a problem with constexpr and can't handle passing them to templates +// as arguments. We need to replace it with const static. +// https://github.com/state-spaces/mamba/issues/12#issuecomment-1848835662 +#if !defined(_WIN32) +#define SWITCH_TYPE constexpr bool +#else +#define SWITCH_TYPE const static bool +#endif + +#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && \ + CUDART_VERSION >= 12010 +#define DISPATCH_MULTI_TENSOR_APPLY(...) \ + if (at::native::supports_large_kernel_arg()) { \ + SWITCH_TYPE large_kernel_arg C10_UNUSED = true; \ + __VA_ARGS__(); \ + } else { \ + SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \ + __VA_ARGS__(); \ + } +#else +#define DISPATCH_MULTI_TENSOR_APPLY(...) \ + do { \ + SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \ + __VA_ARGS__(); \ + } while (0); +#endif + +template +struct DepthToMaxConfig; + +template <> +struct DepthToMaxConfig { + static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; + static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + static constexpr int depth_to_max_tensors_scalarlist[5] = + {96, 64, 48, 36, 30}; + static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { + 72, + 60}; + using TensorIdxType = unsigned char; +}; + +template <> +struct DepthToMaxConfig { + // TODO(yifu): These values are not yet optimally tuned. I simply multiplied + // the values tuned for 4KB kernel argument size limit by 7 (the kernel + // argument size limit increased by 8x but we need to change the type of + // block_to_tensor from unsigned char to uint16_t to support larger number of + // tensors). + static constexpr int depth_to_max_tensors[5] = {770, 448, 336, 252, 210}; + static constexpr int depth_to_max_blocks[5] = {2240, 2240, 2240, 2240, 2240}; + static constexpr int depth_to_max_tensors_scalarlist[5] = + {672, 448, 336, 252, 210}; + static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { + 504, + 420}; + using TensorIdxType = uint16_t; +}; template __device__ __forceinline__ bool is_aligned(T* p) { @@ -38,73 +123,101 @@ __device__ __forceinline__ void load_store( ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template +template struct TensorListMetadata { - const void* addresses[n][depth_to_max_tensors[n - 1]]; - int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; + using Conf = DepthToMaxConfig; + const void* addresses[n][Conf::depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]]; + typename Conf::TensorIdxType + block_to_tensor[Conf::depth_to_max_blocks[n - 1]]; + int block_to_chunk[Conf::depth_to_max_blocks[n - 1]]; int start_tensor_this_launch; }; -template +template struct TensorListScalarListMetadata { - const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]]; - int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]]; - scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; + using Conf = DepthToMaxConfig; + const void* addresses[n][Conf::depth_to_max_tensors_scalarlist[n - 1]]; + int64_t numel_for_tensor[Conf::depth_to_max_tensors_scalarlist[n - 1]]; + scalar_vals_t scalar_vals[Conf::depth_to_max_tensors_scalarlist[n - 1]]; + typename Conf::TensorIdxType + block_to_tensor[Conf::depth_to_max_blocks[n - 1]]; + int block_to_chunk[Conf::depth_to_max_blocks[n - 1]]; }; // note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of // 4kb with `c10::complex` -template <> -struct TensorListScalarListMetadata, 1> { - const void* addresses[1] - [depth_to_max_tensors_scalarlist_of_complex_double[0]]; - int64_t - numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]]; +template +struct TensorListScalarListMetadata, 1, large_kernel_arg> { + using Conf = DepthToMaxConfig; + const void* + addresses[1][Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]]; + int64_t numel_for_tensor + [Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]]; c10::complex - scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]]; - unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]]; - int block_to_chunk[depth_to_max_blocks[1 - 1]]; + scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]]; + typename Conf::TensorIdxType + block_to_tensor[Conf::depth_to_max_blocks[1 - 1]]; + int block_to_chunk[Conf::depth_to_max_blocks[1 - 1]]; }; -template <> -struct TensorListScalarListMetadata, 2> { - const void* addresses[2] - [depth_to_max_tensors_scalarlist_of_complex_double[1]]; - int64_t - numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]]; +template +struct TensorListScalarListMetadata, 2, large_kernel_arg> { + using Conf = DepthToMaxConfig; + const void* + addresses[2][Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]]; + int64_t numel_for_tensor + [Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]]; c10::complex - scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]]; - unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]]; - int block_to_chunk[depth_to_max_blocks[2 - 1]]; + scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]]; + typename Conf::TensorIdxType + block_to_tensor[Conf::depth_to_max_blocks[2 - 1]]; + int block_to_chunk[Conf::depth_to_max_blocks[2 - 1]]; }; // NOTE(crcrpar): This is a conservative resolution to handle `state_steps` // whose each element is `at::Tensor` of 1 element representing the number of // `step`s called so far. -template +template struct FusedOptimizerTensorListMetadata { - const void* addresses[n][depth_to_max_tensors[n - 1]]; - int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; - const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; + using Conf = DepthToMaxConfig; + const void* addresses[n][Conf::depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]]; + const void* + state_steps_addresses[Conf::depth_to_max_tensors_scalarlist[n - 1]]; + typename Conf::TensorIdxType + block_to_tensor[Conf::depth_to_max_blocks[n - 1]]; + int block_to_chunk[Conf::depth_to_max_blocks[n - 1]]; int start_tensor_this_launch; }; +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 template C10_LAUNCH_BOUNDS_1(kBlockSize) -__global__ void multi_tensor_apply_kernel( - T tensorListMeta, - U callable, - ArgTypes... args) { +__global__ typename std::enable_if::type + multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) { // Hand the chunk information to the user-supplied functor to process however // it likes. callable(kChunkSize, tensorListMeta, args...); } +#else +// When compiling device code with __CUDA_ARCH__ < 700, we only instantiate a +// declaration for the 32KB kernels. +// For details see: [32KB kernel argument size support] +#pragma nv_diag_suppress 114 // Function was referenced but not defined +template +C10_LAUNCH_BOUNDS_1(kBlockSize) +__global__ typename std::enable_if::type + multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args); +#pragma nv_diag_default 114 // Function was referenced but not defined +#endif + +template +C10_LAUNCH_BOUNDS_1(kBlockSize) +__global__ typename std::enable_if::type + multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) { + callable(kChunkSize, tensorListMeta, args...); +} } // namespace @@ -133,7 +246,10 @@ void multi_tensor_apply( "Number of tensor lists has to match the depth."); const size_t n_tensors = tensor_lists[0].size(); using scalar_vals_t = typename T::opmath_t; - TensorListScalarListMetadata tensorListMeta; + TensorListScalarListMetadata + tensorListMeta; + + using Conf = DepthToMaxConfig; int loc_block_info = 0; int loc_tensor_info = 0; @@ -167,10 +283,11 @@ void multi_tensor_apply( // a tensor is not considered full unless all its chunks have been // processed const bool tensors_full = - (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] && + (loc_tensor_info == + Conf::depth_to_max_tensors_scalarlist[depth - 1] && chunk == chunks - 1); const bool blocks_full = - (loc_block_info == depth_to_max_blocks[depth - 1]); + (loc_block_info == Conf::depth_to_max_blocks[depth - 1]); if (tensors_full || blocks_full) { multi_tensor_apply_kernel<<< @@ -223,9 +340,11 @@ void multi_tensor_apply( tensor_lists.size() == depth, "Number of tensor lists has to match the depth."); const size_t n_tensors = tensor_lists[0].size(); - TensorListMetadata tensorListMeta; + TensorListMetadata tensorListMeta; tensorListMeta.start_tensor_this_launch = 0; + using Conf = DepthToMaxConfig; + int loc_block_info = 0; int loc_tensor_info = 0; for (size_t t = 0; t < n_tensors; t++) { @@ -250,10 +369,10 @@ void multi_tensor_apply( loc_block_info++; const bool tensors_full = - (loc_tensor_info == depth_to_max_tensors[depth - 1] && + (loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] && chunk == chunks - 1); const bool blocks_full = - (loc_block_info == depth_to_max_blocks[depth - 1]); + (loc_block_info == Conf::depth_to_max_blocks[depth - 1]); if (tensors_full || blocks_full) { multi_tensor_apply_kernel<<< @@ -304,7 +423,10 @@ void multi_tensor_apply_for_fused_optimizer( tensor_lists.size() == depth, "Number of tensor lists has to match the depth"); const auto num_tensors = tensor_lists[0].size(); - FusedOptimizerTensorListMetadata tensorListMeta; + FusedOptimizerTensorListMetadata + tensorListMeta; + + using Conf = DepthToMaxConfig; int loc_block_info = 0; int loc_tensor_info = 0; @@ -333,9 +455,10 @@ void multi_tensor_apply_for_fused_optimizer( loc_block_info++; const auto tensor_full = - (loc_tensor_info == depth_to_max_tensors[depth - 1] && + (loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] && chunk == chunks - 1); - const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1]; + const auto blocks_full = + loc_block_info == Conf::depth_to_max_blocks[depth - 1]; if (tensor_full || blocks_full) { multi_tensor_apply_kernel<<< diff --git a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu index cef07de1b41f99..0c6318d606b9aa 100644 --- a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu @@ -42,19 +42,26 @@ void _fused_adam_amsgrad_cuda_impl_( params[0].scalar_type(), "fused_adam_kernel_cuda", [&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor< + scalar_t, + 5, + ADAM_MODE::ORIGINAL, + true, + large_kernel_arg>(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); }); } @@ -93,19 +100,26 @@ void _fused_adam_amsgrad_cuda_impl_( params[0].scalar_type(), "fused_adam_kernel_cuda", [&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor< + scalar_t, + 5, + ADAM_MODE::ORIGINAL, + true, + large_kernel_arg>(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); }); } diff --git a/aten/src/ATen/native/cuda/fused_adam_impl.cu b/aten/src/ATen/native/cuda/fused_adam_impl.cu index 2c1f5ce0d6d572..43a4cbfa2f791d 100644 --- a/aten/src/ATen/native/cuda/fused_adam_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adam_impl.cu @@ -37,19 +37,26 @@ void _fused_adam_cuda_impl_( params[0].scalar_type(), "fused_adam_kernel_cuda", [&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor< + scalar_t, + 4, + ADAM_MODE::ORIGINAL, + false, + large_kernel_arg>(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); }); } @@ -83,19 +90,26 @@ void _fused_adam_cuda_impl_( params[0].scalar_type(), "fused_adam_kernel_cuda", [&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor< + scalar_t, + 4, + ADAM_MODE::ORIGINAL, + false, + large_kernel_arg>(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); }); } diff --git a/aten/src/ATen/native/cuda/fused_adam_utils.cuh b/aten/src/ATen/native/cuda/fused_adam_utils.cuh index 182195969ed9a3..43627fc32148b3 100644 --- a/aten/src/ATen/native/cuda/fused_adam_utils.cuh +++ b/aten/src/ATen/native/cuda/fused_adam_utils.cuh @@ -102,15 +102,21 @@ C10_DEVICE inline void adam_math( // parameter updates accordingly. To be functionally on par with `torch.optim` // optimizers and `_multi_tensor` ones, the kernel below writes out gradients // only when `grad_scale_ptr != nullptr. -template +template < + typename scalar_type, + int depth, + ADAM_MODE adam_mode, + bool amsgrad, + bool large_kernel_arg> struct FusedAdamMathFunctor { + static constexpr bool use_large_kernel_arg = large_kernel_arg; static_assert( depth == 4 || depth == 5, "depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); using opmath_t = at::opmath_type; C10_DEVICE __forceinline__ void operator()( int chunk_size, - FusedOptimizerTensorListMetadata& tl, + FusedOptimizerTensorListMetadata& tl, const float* lr_ptr, const double& lr, const double& beta1, diff --git a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu index 8a22b57a47e8b0..3b3b7322b769c4 100644 --- a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu @@ -43,19 +43,26 @@ void _fused_adamw_amsgrad_cuda_impl_( params[0].scalar_type(), "fused_adamw_kernel_cuda", [&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor< + scalar_t, + 5, + ADAM_MODE::ADAMW, + true, + large_kernel_arg>(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); }); } @@ -94,19 +101,26 @@ void _fused_adamw_amsgrad_cuda_impl_( params[0].scalar_type(), "fused_adamw_kernel_cuda", [&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor< + scalar_t, + 5, + ADAM_MODE::ADAMW, + true, + large_kernel_arg>(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); }); } diff --git a/aten/src/ATen/native/cuda/fused_adamw_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_impl.cu index b0f9dc6db6aff6..ff657687b5f920 100644 --- a/aten/src/ATen/native/cuda/fused_adamw_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adamw_impl.cu @@ -38,19 +38,26 @@ void _fused_adamw_cuda_impl_( params[0].scalar_type(), "fused_adamw_kernel_cuda", [&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor< + scalar_t, + 4, + ADAM_MODE::ADAMW, + false, + large_kernel_arg>(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); }); } @@ -84,19 +91,26 @@ void _fused_adamw_cuda_impl_( params[0].scalar_type(), "fused_adamw_kernel_cuda", [&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); + DISPATCH_MULTI_TENSOR_APPLY([&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor< + scalar_t, + 4, + ADAM_MODE::ADAMW, + false, + large_kernel_arg>(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); }); } diff --git a/build_variables.bzl b/build_variables.bzl index e19bb7afe84576..57c9caefbdbe1c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1460,6 +1460,7 @@ aten_cuda_cu_source_list = [ "aten/src/ATen/native/cuda/Equal.cpp", "aten/src/ATen/native/cuda/GridSampler.cpp", "aten/src/ATen/native/cuda/IndexKernel.cpp", + "aten/src/ATen/native/cuda/MultiTensorApply.cpp", "aten/src/ATen/native/cuda/ReduceOps.cpp", "aten/src/ATen/native/cuda/ScanKernels.cpp", "aten/src/ATen/native/cuda/Sort.cpp", From e7ac5dfc0135aa61759e9bda37f7fd6153f29a12 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 29 Aug 2024 20:37:32 -0700 Subject: [PATCH 0287/1018] [dynamo][dynamic][heuristic] Mark tuple getitem integers as static (#134734) Fixes issue seen in https://github.com/pytorch/pytorch/issues/132872#issuecomment-2314574656 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134734 Approved by: https://github.com/jansel ghstack dependencies: #134653, #134713 --- .../aot_eager_torchbench_inference.csv | 2 +- .../aot_eager_torchbench_training.csv | 2 +- ...ctor_amp_freezing_torchbench_inference.csv | 2 +- ...inductor_freezing_torchbench_inference.csv | 2 +- .../cpu_inductor_torchbench_inference.csv | 2 +- ...dynamic_aot_eager_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_training.csv | 2 +- ...amic_cpu_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../dynamo_eager_torchbench_inference.csv | 2 +- .../dynamo_eager_torchbench_training.csv | 2 +- .../inductor_torchbench_inference.csv | 2 +- .../inductor_torchbench_training.csv | 2 +- test/dynamo/test_modules.py | 34 +++++++++++++++++++ torch/_dynamo/variables/builder.py | 16 ++++++++- 16 files changed, 63 insertions(+), 15 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 904ba00fbdd759..c96684bc79462c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -370,7 +370,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 05d52bedb3e046..223aae9b97234b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -278,7 +278,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index f72a083f1013b4..c24f8e1b3fbbc3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -318,7 +318,7 @@ vgg16,pass,0 -vision_maskrcnn,fail_accuracy,28 +vision_maskrcnn,fail_accuracy,30 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv index 8197f6cb65c695..71f6ccce68c437 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv @@ -318,7 +318,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,27 +vision_maskrcnn,pass,29 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index dc589b9c49c10a..8b3868107df51e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -330,7 +330,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,27 +vision_maskrcnn,pass,29 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 7a14abcf8e2910..9075a4adfd3a1a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -366,7 +366,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 40011a6000d61e..47894a2b48c372 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -274,7 +274,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index cba443a815dd97..449d34d42f68a2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -314,7 +314,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,27 +vision_maskrcnn,pass,29 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 7a14abcf8e2910..9075a4adfd3a1a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -366,7 +366,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index 66b96ee134c9b7..3d3548d3ac5aee 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -274,7 +274,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 904ba00fbdd759..c96684bc79462c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -370,7 +370,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 05d52bedb3e046..223aae9b97234b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -278,7 +278,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index aea126aa270cd8..f21050a3d3d959 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -370,7 +370,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 05d52bedb3e046..223aae9b97234b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -278,7 +278,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 6e63389b0724b5..d390e0bb2d857b 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3013,6 +3013,40 @@ def fn(x): with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): helper() + def test_user_defined_nn_module_dynamic(self): + class Conv2d(torch.nn.Conv2d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + x = torch.nn.functional.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return x + + cnts = torch._dynamo.testing.CompileCounter() + mod1 = Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1)) + mod2 = Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2)) + mod3 = Conv2d(64, 64, kernel_size=(2, 2), stride=(3, 3)) + + opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True) + opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True) + opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True) + + x = torch.randn(1, 64, 64, 64) + opt_mod1(x) + opt_mod2(x) + opt_mod3(x) + + # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints. + self.assertEqual(cnts.frame_count, 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index ea71db1e5c0992..be45ab0a692ee8 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1157,8 +1157,22 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): if item is value: unimplemented("list elements are pointing to the list itself") + # Tuples are immutable objects, so we should mark its items static. This + # avoids wrapping of tuple items as symints. This helps for nn module + # attributes like conv2d strides, dilations. + if ( + istype(value, tuple) + and all(ConstantVariable.is_literal(item) for item in value) + and self.source.guard_source().is_unspecialized_nn_module() + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return TupleVariable([ConstantVariable.create(item) for item in value]) + output = [ - LazyVariableTracker.create(item, source=GetItemSource(self.get_source(), i)) + LazyVariableTracker.create( + item, + source=GetItemSource(self.get_source(), i), + ) for i, item in enumerate(value) ] From 3d3f9fa3fe5179ae66b76008006c99a2eb664d4e Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 30 Aug 2024 17:30:47 +0000 Subject: [PATCH 0288/1018] Update `ForeachfuncInfo.sample_inputs_func` to yield scalars & scalarlists that are more friendly to test_meta (#134552) for `test_meta.py` to see more "PASSED" instead of "XFAIL". `pytest test_meta.py -k "_foreach_"` ran 6400 test cases and: - This PR: 4702 passed, 260 skipped, 73732 deselected, 1698 xfailed - main (92c4771853892193d73d87bd60eca4dc7efc51d8): 3906 passed, 260 skipped, 73732 deselected, 2494 xfailed Pull Request resolved: https://github.com/pytorch/pytorch/pull/134552 Approved by: https://github.com/janeyx99 --- test/test_foreach.py | 88 +++++- .../_internal/common_methods_invocations.py | 279 ++++++++++-------- 2 files changed, 222 insertions(+), 145 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 6f664392289834..11c6273a1471fd 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -206,7 +206,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): _, _, func, ref = self._get_funcs(op) else: func, ref, _, _ = self._get_funcs(op) - for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous): + for sample in op.sample_inputs( + device, dtype, noncontiguous=noncontiguous, allow_higher_dtype_scalars=True + ): ref_kwargs = sample.kwargs # div promotes ints to floats, so we cannot go on the fastpath there div_slowpath = ( @@ -305,7 +307,12 @@ def clone(arg): scalar_self_arg_test_complete = False for i, sample in enumerate( - op.sample_inputs(device, dtype, noncontiguous=not is_fastpath) + op.sample_inputs( + device, + dtype, + noncontiguous=not is_fastpath, + allow_higher_dtype_scalars=True, + ) ): (rhs_arg,) = sample.args kwargs = {} or sample.kwargs @@ -347,7 +354,12 @@ def clone(arg): def test_pointwise_op_with_tensor_of_scalarlist_overload( self, device, dtype, op, is_fastpath ): - for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath): + for sample in op.sample_inputs( + device, + dtype, + noncontiguous=not is_fastpath, + allow_higher_dtype_scalars=True, + ): assert isinstance(sample.args, tuple) assert len(sample.args) == 2 inputs = [sample.input, *sample.args] @@ -816,7 +828,14 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op): method, ref, inplace_method, ref_inplace = self._get_funcs(op) # tensors: ['cuda', 'cpu] tensors = next( - iter(op.sample_inputs(device, dtype, num_input_tensors=[2])) + iter( + op.sample_inputs( + device, + dtype, + num_input_tensors=[2], + allow_higher_dtype_scalars=True, + ) + ) ).input tensors[1] = tensors[1].to("cpu") if not op.supports_out: @@ -844,10 +863,26 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op): @ops(filter(lambda op: op.supports_out, foreach_binary_op_db)) def test_binary_op_tensors_on_different_devices(self, device, dtype, op): _cuda_tensors = next( - iter(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True)) + iter( + op.sample_inputs( + device, + dtype, + num_input_tensors=[2], + same_size=True, + allow_higher_dtype_scalars=True, + ) + ) ).input _cpu_tensors = next( - iter(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True)) + iter( + op.sample_inputs( + "cpu", + dtype, + num_input_tensors=[2], + same_size=True, + allow_higher_dtype_scalars=True, + ) + ) ).input tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors)) @@ -877,10 +912,24 @@ def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op): # tensors3: ['cuda', 'cpu] # first tensorlist is zero-size when float32 _cuda_tensors = list( - op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True) + op.sample_inputs( + device, + dtype, + num_input_tensors=[3], + same_size=True, + allow_higher_dtype_scalars=True, + ) )[int(dtype == torch.float32)].input _cpu_tensors = next( - iter(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True)) + iter( + op.sample_inputs( + "cpu", + dtype, + num_input_tensors=[3], + same_size=True, + allow_higher_dtype_scalars=True, + ) + ) ).input tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors)) @@ -1220,7 +1269,9 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): foreach_copy_ = op.inplace_variant copy_ = op.ref_inplace for non_blocking in (False, True): - for sample in op.sample_inputs(device, dtype, noncontiguous=False): + for sample in op.sample_inputs( + device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True + ): with torch.no_grad(): ref_input = [t.clone().detach() for t in sample.input] foreach_copy_(sample.input, sample.args[0], non_blocking) @@ -1240,7 +1291,9 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_ foreach_copy_ = ForeachFuncWrapper(op.inplace_variant) - for sample in op.sample_inputs(device, dtype, noncontiguous=False): + for sample in op.sample_inputs( + device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True + ): for src_dtype in floating_types_and(torch.half, torch.bfloat16): if src_dtype == dtype: continue @@ -1307,6 +1360,7 @@ def test_autodiff(self, device, dtype, op, inplace): dtype, requires_grad=True, num_input_tensors=[5], + allow_higher_dtype_scalars=True, **value_range, ): # Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])` @@ -1420,10 +1474,10 @@ def check_autodiff_sample(op, sample, dtype, is_inplace): return False, "In-place abs is not supported for complex tensors." if op.name == "_foreach_sub" and ( ( - isinstance(sample.args[0], list) - and any(isinstance(a, bool) for a in sample.args[0]) + isinstance(sample.args[-1], list) + and any(isinstance(a, bool) for a in sample.args[-1]) ) - or isinstance(sample.args[0], bool) + or isinstance(sample.args[-1], bool) ): return False, _BOOL_SUB_ERR_MSG if op.name == "_foreach_norm" and (not is_inplace): @@ -1434,10 +1488,10 @@ def check_autodiff_sample(op, sample, dtype, is_inplace): ) rhs_arg_has_complex_number = sample.args and ( ( - isinstance(sample.args[0], list) - and any(isinstance(a, complex) for a in sample.args[0]) + isinstance(sample.args[-1], list) + and any(isinstance(a, complex) for a in sample.args[-1]) ) - or (isinstance(sample.args[0], complex)) + or (isinstance(sample.args[-1], complex)) ) if rhs_arg_has_complex_number and dtype == torch.float64: if op.name in ( @@ -1447,6 +1501,8 @@ def check_autodiff_sample(op, sample, dtype, is_inplace): "_foreach_minimum", ): return False, "clamp is not supported for complex types" + if op.name == "_foreach_lerp" and is_inplace: + return False, "value cannot be converted to type double without overflow" if not is_inplace: return False, "" else: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3049cd8d8f5753..a64cb62cedeb32 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9337,7 +9337,16 @@ def _set_rightmost_arg_types( if rightmost_supports_tensor: self._rightmost_arg_types.append(ForeachRightmostArgType.Tensor) - def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs): + def _sample_rightmost_arg( + self, + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ): if rightmost_arg_type == ForeachRightmostArgType.TensorList: return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)] if rightmost_arg_type == ForeachRightmostArgType.Tensor: @@ -9357,21 +9366,25 @@ def sample_float(): high = 2 if should_use_simpler_scalars else 9 if rightmost_arg_type == ForeachRightmostArgType.ScalarList: - return [ - [random.randint(0, high) + 1 for _ in range(num_tensors)], - [sample_float() for _ in range(num_tensors)], - [complex(sample_float(), sample_float()) for _ in range(num_tensors)], - [True for _ in range(num_tensors)], - [1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)], - [True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)], - ] + scalarlist_list = [] + scalarlist_list.append([random.randint(0, high) + 1 for _ in range(num_tensors)]) + + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalarlist_list.append([sample_float() for _ in range(num_tensors)]) + if allow_higher_dtype_scalars or dtype.is_complex: + scalarlist_list.append([complex(sample_float(), sample_float()) for _ in range(num_tensors)]) + scalarlist_list.append([1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)]) + scalarlist_list.append([True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)]) + return scalarlist_list if rightmost_arg_type == ForeachRightmostArgType.Scalar: - return ( - random.randint(1, high + 1), - sample_float(), - True, - complex(sample_float(), sample_float()), - ) + scalars = [] + scalars.append(random.randint(1, high + 1)) + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalars.append(sample_float()) + if allow_higher_dtype_scalars or dtype.is_complex: + scalars.append(complex(sample_float(), sample_float())) + scalars.append(True) + return scalars raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): @@ -9451,6 +9464,7 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * assert "num_input_tensors" not in kwargs _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) for rightmost_arg_type in self._rightmost_arg_types: zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs) zero_size_foreach_inputs_kwargs["zero_size"] = True @@ -9462,8 +9476,14 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * ] args.append( self._sample_rightmost_arg( - opinfo, ForeachRightmostArgType.TensorList, device, dtype, NUM_SIZE0_TENSORS, - **zero_size_foreach_inputs_kwargs)[0]) + opinfo, + ForeachRightmostArgType.TensorList, + device, + dtype, + NUM_SIZE0_TENSORS, + allow_higher_dtype_scalars=allow_higher_dtype_scalars, + **zero_size_foreach_inputs_kwargs, + )[0]) kwargs = self._sample_kwargs( opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype) else: @@ -9482,6 +9502,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad _foreach_inputs_kwargs["zero_size"] = False + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) # add empty tensor interspersion to test fully fixing #100701 for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( @@ -9500,7 +9521,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): for _ in range(self.arity - 2) ] rightmost_arg_list = self._sample_rightmost_arg( - opinfo, rightmost_arg_type, device, dtype, num_tensors, + opinfo, rightmost_arg_type, device, dtype, num_tensors, allow_higher_dtype_scalars, **_foreach_inputs_kwargs) for rightmost_arg in rightmost_arg_list: args.append(rightmost_arg) @@ -9554,6 +9575,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): assert isinstance(num_input_tensors, list) _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad + _allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) for num_tensors, ord, out_dtype in product( num_input_tensors, @@ -9584,24 +9606,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath) -class foreach_lerp_sample_func(foreach_inputs_sample_func): - def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs): - if rightmost_arg_type == ForeachRightmostArgType.TensorList: - return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)] - if rightmost_arg_type == ForeachRightmostArgType.ScalarList: - return [ - [random.randint(0, 9) + 1 for _ in range(num_tensors)], - [1.0 - random.random() for _ in range(num_tensors)], - [complex(1.0 - random.random(), 1.0 - random.random()) for _ in range(num_tensors)], - [True for _ in range(num_tensors)], - [1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)], - [True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)], - ] - if rightmost_arg_type == ForeachRightmostArgType.Scalar: - return [random.random()] - raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") - - class foreach_pointwise_sample_func(foreach_inputs_sample_func): def __init__( @@ -9636,6 +9640,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): assert isinstance(num_input_tensors, list) _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) for num_tensors, rightmost_arg_type in itertools.product(num_input_tensors, self._rightmost_arg_types): input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) @@ -9644,7 +9649,15 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList)) ] rightmost_arg_list = self._sample_rightmost_arg( - opinfo, rightmost_arg_type, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + zero_size=False, + allow_higher_dtype_scalars=allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ) for rightmost_arg in rightmost_arg_list: kwargs = {} if rightmost_arg_type == ForeachRightmostArgType.TensorList: @@ -10786,14 +10799,15 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): decorators=( # These tests fail with aten._local_scalar_dense not being implemented. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types() + complex_types_and(torch.bool, torch.bfloat16, torch.float16, torch.float64)), ), ), ForeachFuncInfo( @@ -10823,13 +10837,12 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): decorators=( # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), ), ), ForeachFuncInfo( @@ -10841,22 +10854,13 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): decorators=( # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - # fails with div_cpu is not implemented with ComplexHalf - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=(torch.float16,), device_type='cpu'), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=(torch.float16,), device_type='cpu'), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", - dtypes=(torch.float16,), device_type='cpu'), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.float16,), device_type='cpu'), + dtypes=integral_types_and(torch.bool)), ), ), ForeachFuncInfo( @@ -10866,14 +10870,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10896,14 +10908,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10927,14 +10947,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=False, supports_forward_ad=False, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10958,14 +10986,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_forward_ad=False, supports_inplace_autograd=False, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10990,30 +11026,19 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", dtypes=(torch.bool,)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.bool,)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", - dtypes=(torch.half,), device_type="cpu"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=(torch.half,), device_type="cpu"), + dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=(torch.half,), device_type="cpu"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.half,), device_type="cpu"), + dtypes=(torch.bool,),), DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)), DecorateInfo( unittest.expectedFailure, @@ -11044,23 +11069,19 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - # Samples have complex types and inplace only works if the dtype is complex. - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), + # # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types() + complex_types_and(torch.bool)), ), ), ForeachFuncInfo( @@ -11072,22 +11093,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): decorators=( # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types() + complex_types_and(torch.bool)), # fails with div_cpu is not implemented with ComplexHalf DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types() + complex_types_and(torch.bool)), ), ), ] @@ -11176,7 +11197,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): foreach_other_op_db: List[ForeachFuncInfo] = [ ForeachFuncInfo( "lerp", - sample_inputs_func=foreach_lerp_sample_func(3, True, False), + sample_inputs_func=foreach_inputs_sample_func(3, True, False), supports_autograd=True, supports_inplace_autograd=True, supports_forward_ad=True, From c41e494016ef6b86de5d581290ebfbda393d2321 Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 29 Aug 2024 14:10:39 -0700 Subject: [PATCH 0289/1018] [inductor] don't allow triton config pre_hook (#134633) The caching autotuner caches triton configs, and it doesn't try to hash or save the pre_hook from the config if it exists. If we had a config that had a pre_hook, then we might autotune -> save the config (without the pre_config) -> later load the saved config and try to run it, but this time without the pre_hook. So this PR adds an assert and deletes the pre_hook handling. We can be confident that we didn't have functional pre_hooks, because the pre_hook handling tries to use `self.arg_name`, which doesn't exist. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134633 Approved by: https://github.com/shunting314, https://github.com/jansel --- test/inductor/test_triton_heuristics.py | 63 +++++++++++++++++++- torch/_inductor/runtime/runtime_utils.py | 11 ++++ torch/_inductor/runtime/triton_heuristics.py | 15 ++--- 3 files changed, 77 insertions(+), 12 deletions(-) diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index bcb0584b18b0fe..c4add081aa4e26 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -10,14 +10,20 @@ try: import triton # noqa: F401 + import triton.language as tl except ImportError: if __name__ == "__main__": sys.exit(0) raise unittest.SkipTest("requires triton") # noqa: B904 from torch._inductor import config -from torch._inductor.runtime.hints import TRITON_MAX_BLOCK -from torch._inductor.runtime.triton_heuristics import triton_config +from torch._inductor.runtime.hints import ( + DeviceProperties, + HeuristicType, + TRITON_MAX_BLOCK, +) +from torch._inductor.runtime.triton_helpers import math as tl_math +from torch._inductor.runtime.triton_heuristics import CachingAutotuner, triton_config from torch._inductor.test_case import run_tests, TestCase @@ -81,6 +87,59 @@ def test_artificial_zgrid(self): def test_artificial_grid_cpp_wrapper(self): self._test_artificial_zgrid() + def _get_cos_kernel_caching_autotuner_args(self): + from triton.compiler.compiler import AttrsDescriptor + + @triton.jit + def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 16 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl_math.cos(tmp0) + tl.store(out_ptr0 + (x0), tmp1, xmask) + + triton_meta = { + "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, + "device": DeviceProperties.create(torch.device("cuda")), + "constants": {}, + "configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], + } + + configs = [ + triton_config([16], 64), + triton_config([256], 64), + ] + + inductor_meta = {} + + return { + "fn": triton_, + "triton_meta": triton_meta, + "configs": configs, + "save_cache_hook": False, + "mutated_arg_names": [], + "heuristic_type": HeuristicType.POINTWISE, + "inductor_meta": inductor_meta, + } + + @skipIfXpu + def test_pre_hook_assert(self): + # assert if any of the configs passed to the CachingAutotuner have pre-hooks + args = self._get_cos_kernel_caching_autotuner_args() + + def pre_hook(kwargs): + if "in_ptr0" in kwargs: + kwargs["in_ptr0"].zero_() + + for cfg in args["configs"]: + cfg.pre_hook = pre_hook + + with self.assertRaisesRegex(AssertionError, "pre_hook"): + autotuner = CachingAutotuner(**args) + if __name__ == "__main__": if IS_LINUX and HAS_GPU: diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 40e2678e089b54..446dbc71c61d1d 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -65,6 +65,17 @@ def triton_config_to_hashable(cfg): return tuple(items) +def validate_triton_config(cfg): + # [Note: Triton pre_hook in inductor] + # pre-hook is a lambda function, which we don't attempt to serialize. + # right now, if a pre-hook is attached to the config, it will not be saved; + # and then it won't be used when the config is loaded from cache. + # So we assert - if we do get a pre_hook, it might get ignored after caching. + assert ( + getattr(cfg, "pre_hook", None) is None + ), "triton configs with pre_hooks not supported" + + def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True): info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" slow = ms > 0.012 and gb_per_s < 650 diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 824ecf241d5dd4..81887a9fda88a6 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -40,6 +40,7 @@ get_num_bytes, next_power_of_2, triton_config_to_hashable, + validate_triton_config, ) @@ -182,6 +183,10 @@ def __init__( super().__init__() assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" + # makes sure there are no pre-hooks on any of the triton configs + for cfg in configs: + validate_triton_config(cfg) + self.fn = fn self.device_props: DeviceProperties = triton_meta["device"] self.triton_meta = { @@ -655,11 +660,6 @@ def bench(self, launcher, *args, grid, with_profiler=False, **kwargs): stream = device_interface.get_raw_stream(device_interface.current_device()) def kernel_call(): - if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} - ) - cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) launcher( *cloned_args, @@ -848,11 +848,6 @@ def run(self, *args, grid, stream, **kwargs): if launcher.store_cubin: self.save_gpu_kernel(grid, stream, launcher) - if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**dict(zip(self.arg_names, args)), **launcher.config.kwargs, **kwargs} - ) - if os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", 0) == "1": _dump_launch_params(args, kwargs, launcher, self.fn.__name__) From 948d825656695701bb7fdd858d8de25b839d090c Mon Sep 17 00:00:00 2001 From: Xu Han Date: Fri, 30 Aug 2024 17:51:46 +0000 Subject: [PATCH 0290/1018] [inductor] enable Intel Compiler(icx-cl) for inductor windows (#134772) This PR is enable Intel Compiler (`icx-cl`) for Windows inductor, likes previous PR: https://github.com/pytorch/pytorch/pull/134444 which enable clang. Changes: 1. Fix icx-cl crash by wrong decode args, the right decode should be "utf-8". 2. Add intel compiler check, and intel compiler Windows drivers check(icx-cl). 3. Add Intel compiler openmp args config. 4. Add intel compiler openmp binary preload. For intel compiler openmp binary path: image For performance, Intel compiler(`icx-cl`) is much better performance than MSVC(`cl`): image Append `clang-cl` performance data: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134772 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/cpp_builder.py | 66 +++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index f78bc354cc90cc..693f9b06418285 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -58,7 +58,7 @@ def use_global_cache() -> bool: _IS_MACOS = sys.platform.startswith("darwin") _IS_WINDOWS = sys.platform == "win32" -SUBPROCESS_DECODE_ARGS = ("oem",) if _IS_WINDOWS else () +SUBPROCESS_DECODE_ARGS = ("utf-8",) if _IS_WINDOWS else () log = logging.getLogger(__name__) @@ -198,6 +198,33 @@ def _is_msvc_cl(cpp_compiler: str) -> bool: return False +@functools.lru_cache(None) +def _is_intel_compiler(cpp_compiler: str) -> bool: + try: + output_msg = ( + subprocess.check_output( + [cpp_compiler, "--version"], stderr=subprocess.DEVNULL + ) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + is_intel_compiler = "Intel" in output_msg.splitlines()[0] + if is_intel_compiler: + if _IS_WINDOWS: + if re.search(r"((icx$)|(icx-cc$))", cpp_compiler): + raise RuntimeError( + "Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + return is_intel_compiler + except FileNotFoundError as exc: + return False + except subprocess.SubprocessError: + # --version args not support. + return False + + return False + + @functools.lru_cache(None) def is_gcc() -> bool: return _is_gcc(get_cpp_compiler()) @@ -208,6 +235,11 @@ def is_clang() -> bool: return _is_clang(get_cpp_compiler()) +@functools.lru_cache(None) +def is_intel_compiler() -> bool: + return _is_intel_compiler(get_cpp_compiler()) + + @functools.lru_cache(None) def is_apple_clang() -> bool: return _is_apple_clang(get_cpp_compiler()) @@ -798,6 +830,20 @@ def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: pass +@functools.lru_cache(None) +def perload_icx_libomp_win(cpp_compiler: str) -> None: + try: + output = subprocess.check_output( + [cpp_compiler, "-print-file-name=libiomp5md.dll"], stderr=subprocess.DEVNULL + ).decode(*SUBPROCESS_DECODE_ARGS) + omp_path = output.rstrip() + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + except subprocess.SubprocessError: + pass + + def _get_openmp_args( cpp_compiler: str, ) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str]]: @@ -854,10 +900,28 @@ def _get_openmp_args( # if openmp is still not available, we let the compiler to have a try, # and raise error together with instructions at compilation error later elif _IS_WINDOWS: + """ + On Windows, `clang` and `icx` have their specific openmp implenmention. + And the openmp lib is in compiler's some sub-directory. + For dynamic library(DLL) load, the Windows native APIs are `LoadLibraryA` and `LoadLibraryExA`, and their search + dependencies have some rules: + https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryexa#searching-for-dlls-and-dependencies + In some case, the rules may not include compiler's sub-directories. + So, it can't search and load compiler's openmp library correctly. + And then, the whole application would be broken. + + To avoid the openmp load failed, we can automatic locate the openmp binary and preload it. + 1. For clang, the function is `perload_clang_libomp_win`. + 2. For icx, the function is `perload_icx_libomp_win`. + """ if _is_clang(cpp_compiler): cflags.append("openmp") libs.append("libomp") perload_clang_libomp_win(cpp_compiler, "libomp.dll") + elif _is_intel_compiler(cpp_compiler): + cflags.append("Qiopenmp") + libs.append("libiomp5md") + perload_icx_libomp_win(cpp_compiler) else: # /openmp, /openmp:llvm # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ From 1e9cf1a8cbf3725cdf50945cafb507409ad28733 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Thu, 29 Aug 2024 20:25:14 -0700 Subject: [PATCH 0291/1018] [FR] Make pg_name unique, show P2P collective status and fix bugs when running the script as command (#134780) Fixes a bunches of bugs in the script when running with the generated command and 3D parallel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134780 Approved by: https://github.com/c-p-i-o ghstack dependencies: #134528 --- .../flight_recorder/test_fr_analysis.py | 21 ++++--- tools/flight_recorder/components/builder.py | 59 ++++++++++++++----- .../components/config_manager.py | 7 ++- tools/flight_recorder/components/loader.py | 4 +- tools/flight_recorder/components/types.py | 20 +++---- tools/flight_recorder/components/utils.py | 57 +++++++++++++----- tools/flight_recorder/fr_trace.py | 9 +-- 7 files changed, 118 insertions(+), 59 deletions(-) diff --git a/test/distributed/flight_recorder/test_fr_analysis.py b/test/distributed/flight_recorder/test_fr_analysis.py index 43a296a3966c61..bc5f010e927bbd 100644 --- a/test/distributed/flight_recorder/test_fr_analysis.py +++ b/test/distributed/flight_recorder/test_fr_analysis.py @@ -46,13 +46,16 @@ def test_match_one_event(self): "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) membership = {"0": {0, 1}} - self.assertEqual(match_one_event(e1, e1, membership), MatchState.FULLY_MATCHED) + self.assertEqual( + match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED + ) e2 = create_one_event( "all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) self.assertEqual( - match_one_event(e1, e2, membership), MatchState.COLLECTIVE_TYPE_MISMATCH + match_one_event(e1, e2, membership, "0"), + MatchState.COLLECTIVE_TYPE_MISMATCH, ) e3 = create_one_event( @@ -61,34 +64,35 @@ def test_match_one_event(self): e4 = create_one_event( "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) - self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED) + self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED) e5 = create_one_event( "all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1 ) self.assertEqual( - match_one_event(e1, e5, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH + match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH ) e6 = create_one_event( "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2 ) self.assertEqual( - match_one_event(e1, e6, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH + match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH ) e7 = create_one_event( "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2 ) self.assertEqual( - match_one_event(e7, e7, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH + match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH ) e9 = create_one_event( "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1 ) self.assertEqual( - match_one_event(e1, e9, membership), MatchState.COLLECTIVE_STATE_MISMATCH + match_one_event(e1, e9, membership, "0"), + MatchState.COLLECTIVE_STATE_MISMATCH, ) e10 = create_one_event( @@ -101,7 +105,8 @@ def test_match_one_event(self): output_dtypes="float16", ) self.assertEqual( - match_one_event(e10, e9, membership), MatchState.COLLECTIVE_DTYPE_MISMATCH + match_one_event(e10, e9, membership, "0"), + MatchState.COLLECTIVE_DTYPE_MISMATCH, ) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index b2f7b13101800d..b797d405afe754 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -49,7 +49,13 @@ def tabulate(data: Any, headers: Any = None) -> Any: # type: ignore[misc] def build_groups_memberships( pg_config: Any, -) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Set[Any]]]: +) -> Tuple[ + List[Group], + Dict[Any, Group], + List[Membership], + Dict[str, Set[Any]], + Dict[Tuple[str, int], str], +]: """ pg_config: { global_rank: { @@ -71,6 +77,7 @@ def build_groups_memberships( `_groups`: a dict that is indexed by pg_guid with Group namedtuple as value. `memberships`: a membership table where each row is a Membership namedtuple. `_memberships`: a dict that is indexed by pg_guid with set of ranks (int) as value. + `_pg_guids`: a dict that is indexed by (pg_uid, global_rank) with pg_guid as value. """ # flat lists for return groups = [] @@ -79,10 +86,16 @@ def build_groups_memberships( # dicts for faster cross-rank validation _groups = {} _memberships = {} + _pg_guids = {} for global_rank in pg_config: - for pg_guid in pg_config[global_rank]: - desc = pg_config[global_rank][pg_guid]["desc"] - ranks = ast.literal_eval(pg_config[global_rank][pg_guid]["ranks"]) + for pg_uid in pg_config[global_rank]: + desc = pg_config[global_rank][pg_uid]["desc"] + ranks = ast.literal_eval(pg_config[global_rank][pg_uid]["ranks"]) + # With the adoption of the split_group API, we can have multiple PGs with the same pg_guid (PG Name) + # So we need to add the hash of all its ranks within the PG as well. + # Also guid must be a string because `_process_group_name` returns a string. + pg_guid = pg_uid + str(hash(frozenset(ranks))) + _pg_guids[(pg_uid, global_rank)] = pg_guid if isinstance(ranks, str): # TODO Bug in FR data format? ranks is '[0, 1,...]' ranks = eval(ranks) @@ -97,18 +110,18 @@ def build_groups_memberships( # validation across ranks assert ( _groups[pg_guid].desc == desc - ), f"mismatch in desc {_groups[pg_guid].desc} vs {desc}" + ), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}" assert _memberships[pg_guid] == set( ranks - ), f"mismatch in membership {_memberships[pg_guid]} vs {set(ranks)}" - return groups, _groups, memberships, _memberships + ), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}" + return groups, _groups, memberships, _memberships, _pg_guids def build_nccl_call( entry: Dict[Any, Any], id: int, collective_id: Any, - group_id: int, + group_id: str, global_rank: Any, ) -> NCCLCall: return NCCLCall( @@ -126,6 +139,7 @@ def build_collectives( all_entries: Dict[int, List[Dict[str, Any]]], _groups: Dict[str, Group], _memberships: Dict[str, Set[Any]], + _pg_guids: Dict[Tuple[str, int], str], ) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]: """ groups, memberships are the non-flat dicts that are indexable @@ -186,7 +200,14 @@ def build_collectives( entries = all_entries[first_rank] pg_name, desc = entries[0]["process_group"] profiling_name = entries[0]["profiling_name"] + pg_name = _pg_guids[(pg_name, first_rank)] collective_seq_id = entries[0]["collective_seq_id"] + print( + "collective_seq_id ", + collective_seq_id, + " p2p_seq_id ", + entries[0]["p2p_seq_id"], + ) record_id = entries[0]["record_id"] input_sizes = entries[0]["input_sizes"] output_sizes = entries[0]["output_sizes"] @@ -198,7 +219,7 @@ def build_collectives( found_ranks = set() found_idx = {} - if find_coalesced_group(pg_name, entries): + if find_coalesced_group(pg_name, entries, _pg_guids, first_rank): expected_ranks.add(first_rank) done_ranks = set() all_coalesced_entries = {} @@ -206,13 +227,13 @@ def build_collectives( curr = expected_ranks.pop() done_ranks.add(curr) grp = ( - find_coalesced_group(pg_name, all_entries[curr]) # type: ignore[index] + find_coalesced_group(pg_name, all_entries[curr], _pg_guids, curr) # type: ignore[index] if curr in all_entries # type: ignore[comparison-overlap] else [] ) all_coalesced_entries[curr] = grp for index, entry in grp: - op = Op(entry, _memberships) + op = Op(entry, _memberships, pg_name) peer = None if op.type == "send": assert op._src_g == curr, (op._src_g, curr) @@ -228,6 +249,7 @@ def build_collectives( group_size=_groups[pg_name].size, groups=_groups, memberships=_memberships, + _pg_guids=_pg_guids, ) if match and mismatch[pg_name] == 0: @@ -256,10 +278,13 @@ def build_collectives( # step over ops from other PGs # only check match state when seq_id matches if ( - e["process_group"] == (pg_name, desc) + _pg_guids[(e["process_group"][0], o)] == pg_name + and e["process_group"][1] == desc and e["collective_seq_id"] == collective_seq_id ): - match_state = match_one_event(entries[0], e, _memberships) + match_state = match_one_event( + entries[0], e, _memberships, pg_name + ) if ( match_state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] @@ -403,17 +428,19 @@ def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Da entries = sort_trace_from_beginning(entries) # flattened database - groups, _groups, memberships, _memberships = build_groups_memberships(pg_config) + groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( + pg_config + ) print("built groups, memberships") check_no_missing_dump_files(entries, memberships) if args.just_print_entries: - just_print_entries(entries, _groups, _memberships) + just_print_entries(entries, _groups, _memberships, _pg_guids) sys.exit(0) tracebacks, collectives, nccl_calls = build_collectives( - entries, _groups, _memberships + entries, _groups, _memberships, _pg_guids ) print("built collectives, nccl_calls") if args.verbose: diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index d296bd981b811c..41155fe7c777d0 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +from typing import Optional, Sequence class JobConfig: @@ -30,5 +31,7 @@ def __init__(self: "JobConfig"): self.parser.add_argument("-j", "--just_print_entries", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") - def parse_args(self: "JobConfig") -> argparse.Namespace: - return self.parser.parse_args() + def parse_args( + self: "JobConfig", args: Optional[Sequence[str]] + ) -> argparse.Namespace: + return self.parser.parse_args(args) diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index e8129cc5a215fc..854c6034906e20 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -41,9 +41,7 @@ def read_dir(prefix: str, folder: str) -> Dict[str, Dict[str, Any]]: t0 = time.time() for root, _, files in os.walk(folder): for f in files: - ta = time.time() details[f] = read_dump(prefix, os.path.join(root, f)) - tb = time.time() - # print(f"read file {f} in {tb - ta}s") + tb = time.time() print(f"loaded {len(files)} files in {tb - t0}s") return details diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index 2d02954577d330..1f2b75a05eb735 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -68,13 +68,13 @@ def from_type(cls, c: T) -> "TypeInfo": class Group(NamedTuple): - id: int + id: str desc: str size: int class Membership(NamedTuple): - group_id: Ref[Group] + group_id: str global_rank: int @@ -85,13 +85,13 @@ class Traceback(NamedTuple): class Collective(NamedTuple): id: int - group_id: Ref[Group] + group_id: str class NCCLCall(NamedTuple): id: int collective_id: Ref[Collective] - group_id: Ref[Group] + group_id: str global_rank: int # technically Ref[Process] once we have it traceback_id: Ref[Traceback] collective_type: str @@ -188,7 +188,9 @@ class Op: nccl:recv 3<-0 """ - def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): + def __init__( + self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]], pg_name: str + ): profiling_name = event["profiling_name"] nccl, name = profiling_name.split(":") assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'" @@ -209,7 +211,7 @@ def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): self._dst, self._src = int(d), int(s) else: self._src, self._dst = -1, -1 - pg_name, pg_desc = event["process_group"] + _, pg_desc = event["process_group"] self._init_global_src_dst(memberships[pg_name]) self.pg_size = len(memberships[pg_name]) if type in P2P | COLLECTIVES: @@ -239,10 +241,8 @@ def dst(self) -> int: def __repr__(self) -> str: if self.type in P2P: - return ( - f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes})" - ) - return f"{self.type}(input_sizes={self.input_sizes}, {self.state})" + return f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes}, state={self.state})" + return f"{self.type}(input_sizes={self.input_sizes}, state={self.state})" def match(self, other: "Op") -> MatchState: # TODO: I think this can validly not match, diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index efb1799260f055..5725359291e69e 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Any, Dict, List, Set, Tuple # type: ignore[attr-defined] +from typing import Any, Dict, List, Set, Tuple from tools.flight_recorder.components.types import ( Group, @@ -40,9 +40,10 @@ def match_one_event( event_a: Dict[Any, Any], event_b: Dict[Any, Any], memberships: Dict[str, Set[Any]], + pg_name: str, ) -> MatchState: - op_a = Op(event_a, memberships) - op_b = Op(event_b, memberships) + op_a = Op(event_a, memberships, pg_name) + op_b = Op(event_b, memberships, pg_name) return op_a.match(op_b) @@ -51,6 +52,7 @@ def match_coalesced_groups( group_size: int, groups: Dict[str, Group], memberships: Dict[str, Set[Any]], + _pg_guids: Dict[Tuple[str, int], str], ) -> bool: """ all_rank_events: { @@ -76,13 +78,22 @@ def match_coalesced_groups( rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay """ all_ops = { - rank: [Op(e, memberships) for i, e in all_rank_events[rank]] + rank: [ + Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) + for i, e in all_rank_events[rank] + ] for rank in all_rank_events } - def visualize_ops(match: bool) -> None: + def visualize_ops( + match: bool, + _pg_guids: Dict[Tuple[str, int], str], + ) -> None: all_ops = { - rank: [Op(e, memberships) for i, e in all_rank_events[rank]] + rank: [ + Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) + for i, e in all_rank_events[rank] + ] for rank in all_rank_events } @@ -94,8 +105,14 @@ def visualize_ops(match: bool) -> None: progress = False for r in all_ops: if len(all_ops[r]) > i: - _, event = all_rank_events[r][i] - row.append(Op(event, memberships)) + rank, event = all_rank_events[r][i] + row.append( + Op( + event, + memberships, + _pg_guids[(event["process_group"][0], rank)], + ) + ) progress = True else: row.append(None) # type: ignore[arg-type] @@ -144,10 +161,10 @@ def visualize_ops(match: bool) -> None: my_ops.pop(0) peer_ops.pop(match_idx) else: - visualize_ops(False) + visualize_ops(False, _pg_guids) return False - visualize_ops(True) + visualize_ops(True, _pg_guids) return True @@ -161,21 +178,27 @@ def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int def find_coalesced_group( - pg_name: str, entries: List[Dict[str, Any]] + pg_name: str, + entries: List[Dict[str, Any]], + _pg_guids: Dict[Tuple[str, int], str], + rank: int, ) -> List[Tuple[int, Dict[str, Any]]]: """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones, build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id - TODO: handle p2p_seq_id v/s collective_seq_id separately here. """ found = [] collective_seq_id = None for i, e in enumerate(entries): - if e["process_group"][0] != pg_name: + if _pg_guids[(e["process_group"][0], rank)] != pg_name: continue elif collective_seq_id is None: - collective_seq_id = e["collective_seq_id"] + collective_seq_id = ( + e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"] + ) + found.append((i, e)) + elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id: found.append((i, e)) - elif e["collective_seq_id"] == collective_seq_id: + elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id: found.append((i, e)) else: break @@ -190,6 +213,7 @@ def just_print_entries( all_entries: Dict[int, List[Dict[str, Any]]], _groups: Dict[str, Group], _memberships: Dict[str, Set[Any]], + _pg_guids: Dict[Tuple[str, int], str], ) -> None: rows = [] ranks = sorted(all_entries.keys()) @@ -203,7 +227,8 @@ def just_print_entries( row.append("") else: entry = all_entries[rank].pop(0) - row.append(str(Op(entry, _memberships))) + pg_name = _pg_guids[(entry["process_group"][0], rank)] + row.append(str(Op(entry, _memberships, pg_name))) progress = True if progress: rows.append(row) diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py index bf58aee2a0f103..aa5797c3bec870 100644 --- a/tools/flight_recorder/fr_trace.py +++ b/tools/flight_recorder/fr_trace.py @@ -28,8 +28,8 @@ - This script is versioned so that we can ensure our future changes to flight recorder are backwards compatible. """ -import argparse import pickle +from typing import Optional, Sequence from tools.flight_recorder.components.builder import build_db from tools.flight_recorder.components.config_manager import JobConfig @@ -37,7 +37,9 @@ from tools.flight_recorder.components.types import types -def main(args: argparse.Namespace) -> None: +def main(args: Optional[Sequence[str]] = None) -> None: + config = JobConfig() + args = config.parse_args(args) details = read_dir(args.prefix, args.dir) db = build_db(details, args) if args.output: @@ -46,5 +48,4 @@ def main(args: argparse.Namespace) -> None: if __name__ == "__main__": - config = JobConfig() - main(config.parse_args()) + main() From b135c33e73a1361376092c7a3bef18ee5ebda2a5 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Fri, 30 Aug 2024 08:39:18 -0700 Subject: [PATCH 0292/1018] Used `torch.equal` in `test_foreach_copy_with_multi_dtypes` (#134861) `self.assertEqual` allows some tolerance, but here, we want to show that `_foreach_copy_` gives bitwise equivalent results. Let us use `torch.equal` then. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134861 Approved by: https://github.com/Skylion007, https://github.com/janeyx99, https://github.com/crcrpar --- test/test_foreach.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 11c6273a1471fd..d668cb30f24c8e 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1302,13 +1302,12 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): out = foreach_copy_( (self_tensors, src_tensors), is_cuda=True, expect_fastpath=True ) - self.assertEqual( - out, - [ - torch.empty_like(t).copy_(s) - for t, s in zip(self_tensors, src_tensors) - ], - ) + ref_out = [ + torch.empty_like(t).copy_(s) + for t, s in zip(self_tensors, src_tensors) + ] + for t, ref_t in zip(out, ref_out): + self.assertTrue(torch.equal(t, ref_t)) # Test reverse-mode & forward-mode AD if supported. @onlyCUDA From 762c456109048235c795b14fee6541e0ffb0258d Mon Sep 17 00:00:00 2001 From: fduwjj Date: Thu, 29 Aug 2024 20:25:14 -0700 Subject: [PATCH 0293/1018] [RFC] Make fr trace script a console scripts (#134729) We want to make fr analyzer script a command after users `pip install torch`, that's why we want to mimic what torchrun is doing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134729 Approved by: https://github.com/c-p-i-o, https://github.com/malfet ghstack dependencies: #134528, #134780 --- setup.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/setup.py b/setup.py index ad48f4b0108633..5eaddb894df61e 100644 --- a/setup.py +++ b/setup.py @@ -1107,6 +1107,12 @@ def make_relative_rpath_args(path): "default = torch.distributed.elastic.multiprocessing:DefaultLogsSpecs", ], } + + if cmake_cache_vars["USE_DISTRIBUTED"]: + # Only enable fr_trace command if distributed is enabled + entry_points["console_scripts"].append( + "torchfrtrace = tools.flight_recorder.fr_trace:main", + ) return extensions, cmdclass, packages, entry_points, extra_install_requires From e12cc732e8cff27ddd0bce5386c5d455172c0bac Mon Sep 17 00:00:00 2001 From: Gabriel Ferns Date: Fri, 30 Aug 2024 18:48:37 +0000 Subject: [PATCH 0294/1018] Enable cudagraphs in cpp wrapper (#133885) Fixes https://github.com/pytorch/pytorch/issues/130878 Summary: Enables cudagraphs in cpp wrapper by clearing inputs. Generated, non-cpp wrapper code: ```python def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (10, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = empty_strided_cuda((10, ), (1, ), torch.float32) # Topologically Sorted Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, buf0, 10, grid=grid(10), stream=stream0) del arg0_1 return (buf0, ) ``` vs generated cpp wrapper code: ```python def _wrap_func(f): def g(args): input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args] input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) # new: args.clear() # end new output_handles = f(input_handles) output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) return output_tensors return g ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133885 Approved by: https://github.com/eellison, https://github.com/desertfire --- test/inductor/test_cudagraph_trees.py | 11 +++++++++++ torch/_inductor/codegen/cpp_wrapper_cpu.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index cfff3baa62765c..a2675ed3231dd8 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -2445,6 +2445,17 @@ def iter(batch_size: int, mod: torch.nn.Module): exactly=True, ).run("\n".join(captured_output)) + @torch._inductor.config.patch("cpp_wrapper", 1) + def test_cpp_wrapper(self): + def f(x): + return torch.sin(x) + + compiled = torch.compile(f, mode="reduce-overhead") + example_input = torch.randn(10, device="cuda") + compiled_result = self.run_twc(compiled, example_input) + eager_result = f(example_input) + self.assertEqual(compiled_result, eager_result) + instantiate_parametrized_tests(CudaGraphTreeTests) if __name__ == "__main__": diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index f8baacfacd9c0c..f25c2021d36b9e 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1155,6 +1155,10 @@ def generate_end(self, result): wrapper_body += """ input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) """ + # Release the inputs for memory reuse. + wrapper_body += """ + args.clear() + """ # unwrap output tensor back to python scalar if all(x for x in self.output_is_tensor.values()): From 0ef05979013a588e5ffecc6a3b61257b74f2a3fa Mon Sep 17 00:00:00 2001 From: Wouter Devriendt Date: Fri, 30 Aug 2024 19:53:20 +0000 Subject: [PATCH 0295/1018] regenerate ci workflows for binary builds with new g4dn runners (#133404) Fixes #103104 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133404 Approved by: https://github.com/ZainRizvi --- .../windows_binary_build_workflow.yml.j2 | 4 ++-- ...generated-windows-binary-conda-nightly.yml | 24 +++++++++---------- ...-windows-binary-libtorch-debug-nightly.yml | 6 ++--- ...indows-binary-libtorch-release-nightly.yml | 6 ++--- ...generated-windows-binary-wheel-nightly.yml | 24 +++++++++---------- 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/.github/templates/windows_binary_build_workflow.yml.j2 b/.github/templates/windows_binary_build_workflow.yml.j2 index 3b087fb4f6ead6..9ba9af06a2ef45 100644 --- a/.github/templates/windows_binary_build_workflow.yml.j2 +++ b/.github/templates/windows_binary_build_workflow.yml.j2 @@ -104,9 +104,9 @@ jobs: - get-label-type {%- if config["gpu_arch_type"] == "cuda" %} {%- if branches == "nightly" %} - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" {%- else %} - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge.nonephemeral" {%- endif %} {%- else %} runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" diff --git a/.github/workflows/generated-windows-binary-conda-nightly.yml b/.github/workflows/generated-windows-binary-conda-nightly.yml index 8cc1b90e8e8e53..85b66c8555034c 100644 --- a/.github/workflows/generated-windows-binary-conda-nightly.yml +++ b/.github/workflows/generated-windows-binary-conda-nightly.yml @@ -403,7 +403,7 @@ jobs: needs: - conda-py3_9-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -651,7 +651,7 @@ jobs: needs: - conda-py3_9-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -899,7 +899,7 @@ jobs: needs: - conda-py3_9-cuda12_4-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1392,7 +1392,7 @@ jobs: needs: - conda-py3_10-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1640,7 +1640,7 @@ jobs: needs: - conda-py3_10-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1888,7 +1888,7 @@ jobs: needs: - conda-py3_10-cuda12_4-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2381,7 +2381,7 @@ jobs: needs: - conda-py3_11-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2629,7 +2629,7 @@ jobs: needs: - conda-py3_11-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2877,7 +2877,7 @@ jobs: needs: - conda-py3_11-cuda12_4-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3370,7 +3370,7 @@ jobs: needs: - conda-py3_12-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3618,7 +3618,7 @@ jobs: needs: - conda-py3_12-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3866,7 +3866,7 @@ jobs: needs: - conda-py3_12-cuda12_4-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index e59a82e63cb24a..fef73a900b15ae 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -419,7 +419,7 @@ jobs: needs: - libtorch-cuda11_8-shared-with-deps-debug-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -679,7 +679,7 @@ jobs: needs: - libtorch-cuda12_1-shared-with-deps-debug-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -939,7 +939,7 @@ jobs: needs: - libtorch-cuda12_4-shared-with-deps-debug-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index b41d421322e286..c327839e5ec1ae 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -419,7 +419,7 @@ jobs: needs: - libtorch-cuda11_8-shared-with-deps-release-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -679,7 +679,7 @@ jobs: needs: - libtorch-cuda12_1-shared-with-deps-release-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -939,7 +939,7 @@ jobs: needs: - libtorch-cuda12_4-shared-with-deps-release-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 287558aa22ae23..a48c6b938da647 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -405,7 +405,7 @@ jobs: needs: - wheel-py3_9-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -654,7 +654,7 @@ jobs: needs: - wheel-py3_9-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -903,7 +903,7 @@ jobs: needs: - wheel-py3_9-cuda12_4-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1398,7 +1398,7 @@ jobs: needs: - wheel-py3_10-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1647,7 +1647,7 @@ jobs: needs: - wheel-py3_10-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1896,7 +1896,7 @@ jobs: needs: - wheel-py3_10-cuda12_4-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2391,7 +2391,7 @@ jobs: needs: - wheel-py3_11-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2640,7 +2640,7 @@ jobs: needs: - wheel-py3_11-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2889,7 +2889,7 @@ jobs: needs: - wheel-py3_11-cuda12_4-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3384,7 +3384,7 @@ jobs: needs: - wheel-py3_12-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3633,7 +3633,7 @@ jobs: needs: - wheel-py3_12-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3882,7 +3882,7 @@ jobs: needs: - wheel-py3_12-cuda12_4-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.8xlarge.nvidia.gpu" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch From 9fef98190996c2bb999f1c894b802668235d28ce Mon Sep 17 00:00:00 2001 From: qchip Date: Fri, 30 Aug 2024 19:58:18 +0000 Subject: [PATCH 0296/1018] dynamic shapes for combo_kenel/foreach_kernel (#134477) This PR add dynamic shapes support to foreach and combo kernels for horizontal fusion. A flag `combo_kernel_foreach_dynamic_shapes` (default False to avoid disturb production workflows) is added to _inductor/config.py. Setting it to True enables automatic dynamic shapes for foreach kernels. It is always enabled for combo kernels cases. Added unit cases. This PR also fixes a flaky test case for [T198833257](https://www.internalfb.com/intern/tasks/?t=198833257) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134477 Approved by: https://github.com/mlazos --- test/inductor/test_combo_kernels.py | 249 +++++++++++++++ test/inductor/test_cuda_cpp_wrapper.py | 12 + test/inductor/test_foreach.py | 37 +++ torch/_inductor/codegen/cpp_wrapper_cuda.py | 10 +- .../_inductor/codegen/triton_combo_kernel.py | 294 ++++++++++++++---- torch/_inductor/codegen/wrapper.py | 18 +- torch/_inductor/config.py | 2 + torch/_inductor/lowering.py | 4 +- torch/_inductor/runtime/triton_heuristics.py | 38 ++- torch/_inductor/scheduler.py | 6 +- 10 files changed, 602 insertions(+), 68 deletions(-) diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index d026f7a3ba0f45..ddf2052703818f 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -227,6 +227,30 @@ def test_mutated(a, b, c, d): out_eager = test_mutated(*inps) out_compiled = torch.compile(test_mutated)(*inps) + self.assertEqual(out_eager, out_compiled) + self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9]) + + @requires_cuda + def test_round_robin_dispatch(self): + # combo kernel dispatch strategy: round robin + def test_mutated(a, b, c, d): + a.add_(1) + b.sigmoid_() + c = torch.add(c, 5) + d.tanh_() + + return a, b, c, d + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 5, device="cuda"), + torch.rand(10, 10, device="cuda"), + torch.rand(5, 18, device="cuda"), + ] + + out_eager = test_mutated(*inps) + out_compiled = torch.compile(test_mutated)(*inps) + self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) @@ -252,6 +276,231 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) + @requires_cuda + def test_persistent_reduction_no_x_dim(self): + def fn(x, y): + return x.sum(1), y.sum(1) + + inps = ( + torch.rand(16, 256, device="cuda"), + torch.rand(32, 256, device="cuda"), + ) + torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) + out_eager = fn(*inps) + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) + + +@instantiate_parametrized_tests +class ComboKernelDynamicShapesTests(TestCase): + check_model_cuda = check_model_cuda + check_model_cpu = check_model + check_kernel_count = True + + def setUp(self): + super().setUp() + torch._inductor.metrics.reset() + torch._inductor.config.combo_kernels = True + torch._inductor.config.benchmark_combo_kernel = True + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.config.assume_static_by_default = False + + def tearDown(self): + super().tearDown() + torch._inductor.metrics.reset() + + @requires_cuda + def test_dynamic_shapes_activations(self): + def test_activations(a, b, c): + a1 = torch.nn.functional.relu(a) + b1 = torch.nn.functional.sigmoid(b) + c1 = torch.nn.functional.tanh(c) + return a1, b1, c1 + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 20, device="cuda"), + torch.rand(10, 10, device="cuda"), + ] + + out_eager = test_activations(*inps) + out_compiled = torch.compile(test_activations)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) + + @requires_cuda + def test_dynamic_shapes_2d_blocking(self): + def fn(a0, a1, a2, b0, b1, b2): + c0 = torch.add(a0, b0) + c1 = torch.add(a1, b1) + c2 = torch.add(a2, b2) + return c0, c1, c2 + + self.check_model_cuda( + fn, + ( + torch.rand(30, 20, device="cuda"), + torch.rand(40, 30, device="cuda"), + torch.rand(36, 40, device="cuda"), + torch.rand(30, 20, device="cuda"), + torch.rand(30, 40, device="cuda").t(), + torch.rand(40, 36, device="cuda").t(), + ), + ) + + self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) + + @requires_cuda + def test_dynamic_shapes_reduce(self): + def test_reduce(a, b, c, d): + a1 = torch.sum(a, dim=0) + b1 = torch.max(b, dim=0) + c1 = torch.min(c, dim=0) + d1 = torch.nn.functional.tanh(d) + + return a1, b1, c1, d1 + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 20, device="cuda"), + torch.rand(10, 10, device="cuda"), + torch.rand(30, 8, device="cuda"), + ] + + out_eager = test_reduce(*inps) + out_compiled = torch.compile(test_reduce)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) + + @requires_cuda + def test_dynamic_shapes_mutated(self): + # combo kernel dispatch strategy: round robin + def test_mutated(a, b, c, d): + a.add_(1) + b.sigmoid_() + c = torch.add(c, 5) + d.tanh_() + + return a, b, c, d + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 5, device="cuda"), + torch.rand(10, 10, device="cuda"), + torch.rand(5, 18, device="cuda"), + ] + + out_eager = test_mutated(*inps) + out_compiled = torch.compile(test_mutated)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) + + @requires_cuda + @torch._inductor.config.patch("combo_kernels_autotune", 0) + def test_dynamic_shapes_activations_no_autotune(self): + def test_activations(a, b, c): + a1 = torch.nn.functional.relu(a) + b1 = torch.nn.functional.sigmoid(b) + c1 = torch.nn.functional.tanh(c) + return a1, b1, c1 + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 20, device="cuda"), + torch.rand(10, 10, device="cuda"), + ] + + out_eager = test_activations(*inps) + out_compiled = torch.compile(test_activations)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) + + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", True) + @torch._dynamo.config.patch("assume_static_by_default", True) + def test_dynamic_shapes_persistent_reduction_no_x_dim(self): + def fn(x, y): + return x.sum(1), y.sum(1) + + inps = ( + torch.rand(16, 256, device="cuda"), + torch.rand(32, 256, device="cuda"), + ) + torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) + out_eager = fn(*inps) + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) + + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", True) + @torch._dynamo.config.patch("assume_static_by_default", True) + def test_dynamic_shapes_2d_blocking_round_robin(self): + def fn(a0, a1, a2, b0, b1, b2): + c0 = torch.add(a0, b0) + c1 = torch.add(a1, b1) + c2 = torch.add(a2, b2) + return c0, c1, c2 + + inps = ( + torch.rand(20, 30, device="cuda"), + torch.rand(30, 30, device="cuda"), + torch.rand(40, 32, device="cuda"), + torch.rand(30, 20, device="cuda").t(), + torch.rand(30, 30, device="cuda").t(), + torch.rand(32, 40, device="cuda").t(), + ) + + out_eager = fn(*inps) + compiled = torch.compile(fn) + out_compiled = compiled(*inps) + self.assertEqual(out_eager, out_compiled) + self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) + torch._inductor.metrics.reset() + + inps = ( + torch.rand(24, 30, device="cuda"), + torch.rand(32, 30, device="cuda"), + torch.rand(48, 32, device="cuda"), + torch.rand(30, 24, device="cuda").t(), + torch.rand(30, 32, device="cuda").t(), + torch.rand(32, 48, device="cuda").t(), + ) + out_compiled = compiled(*inps) + out_eager = fn(*inps) + self.assertEqual(out_eager, out_compiled) + self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) + + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", True) + @torch._dynamo.config.patch("assume_static_by_default", True) + @torch._inductor.config.patch("triton.autotune_at_compile_time", True) + def test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda(self): + def fn(x, y, z): + return x.sum(1), y.mean(1), z.max(1) + + inps = ( + torch.rand(16, 128, device="cuda"), + torch.rand(32, 128, device="cuda"), + torch.rand(32, 256, device="cuda"), + ) + torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[2], 0, min=1, max=256) + out_eager = fn(*inps) + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 1ccec000f1c7d2..136483d27f9ece 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -16,6 +16,7 @@ try: try: from . import ( + test_combo_kernels, test_foreach, test_pattern_matcher, test_select_algorithm, @@ -23,6 +24,8 @@ test_torchinductor_dynamic_shapes, ) except ImportError: + import test_combo_kernels + import test_foreach import test_pattern_matcher import test_select_algorithm @@ -191,6 +194,14 @@ class BaseTest(NamedTuple): "test_foreach_cpp_wrapper", tests=test_foreach.ForeachTests(), ), # test foreach + BaseTest( + "test_enable_dynamic_shapes_cpp_wrapper", + tests=test_foreach.ForeachTests(), + ), + BaseTest( + "test_dynamic_shapes_persistent_reduction_mixed_x_dim", + tests=test_combo_kernels.ComboKernelDynamicShapesTests(), + ), BaseTest( "test_cat_slice_cat", tests=test_pattern_matcher.TestPatternMatcher(), @@ -253,5 +264,6 @@ class BaseTest(NamedTuple): if __name__ == "__main__": from torch._inductor.test_case import run_tests + print(f"FS: run_cuda {RUN_CUDA}") if RUN_CUDA: run_tests(needs="filelock") diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index af15fdcfd38532..5d30af0f79da69 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -487,6 +487,43 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", False) + @torch._dynamo.config.patch("assume_static_by_default", False) + @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) + def test_enable_dynamic_shapes_python_wrapper(self, op=torch._foreach_add): + def fn(a0, a1, b0, b1): + return op([a0, a1], [b0, b1]) + + inputs = ( + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + ) + + self.check_model_cuda(fn, inputs) + + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", False) + @torch._dynamo.config.patch("assume_static_by_default", False) + @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_enable_dynamic_shapes_cpp_wrapper_cuda(self, op=torch._foreach_add): + def fn(a0, a1, b0, b1): + return op([a0, a1], [b0, b1]) + + inputs = ( + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + ) + + self.check_model_cuda(fn, inputs) + @unittest.skipIf(IS_FBCODE, "cpp compile not supported in fbcode") @bin_ops def test_cpu_cpp_fallback(self, op): diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index 5aabf6bd4c1d7d..05016c387b1fa7 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -2,7 +2,7 @@ import functools import os from itertools import chain, count -from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union import sympy @@ -78,10 +78,16 @@ def __init__( self.grid_callable = grid_callable self.grid_extra_kwargs = grid_extra_kwargs + def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]): + if isinstance(grid, (list, tuple)): + return [self._process_grid(e) for e in grid] + else: + return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid + def __call__(self): grid = self.grid assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" - grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid] + grid = self._process_grid(grid) grid_callable = self.grid_callable or default_grid if not self.grid_extra_kwargs: grid_fn = grid_callable(*grid) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 4d0f5165e1c56f..abfd1dc76c698e 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -3,9 +3,20 @@ import textwrap from collections import defaultdict from dataclasses import dataclass -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union - -from sympy import Integer +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) + +from sympy import Integer, Symbol from torch.utils._ordered_set import OrderedSet @@ -16,7 +27,14 @@ from ..scheduler import BaseSchedulerNode from ..utils import Placeholder from ..virtualized import V -from .common import DeferredLine, IndentedBuffer, Kernel, PythonPrinter, SizeArg +from .common import ( + DeferredLine, + IndentedBuffer, + Kernel, + PythonPrinter, + SizeArg, + WorkspaceArg, +) from .simd import SIMDScheduling from .triton import gen_common_triton_imports, TritonKernel from .triton_utils import config_of, signature_to_meta @@ -71,7 +89,9 @@ def _default_custom_combo_kernel_horizontal_partition( not_reduction = [n for n in group_per_dim if n not in reduction] # rnumel > 2048 usually has long execution time # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes - long_reduction = [n for n in reduction if cast(Integer, n.group[-1][-1]) > 2048] + long_reduction = [ + n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 + ] short_reduction = [n for n in reduction if n not in long_reduction] if long_reduction: log.warning( @@ -286,15 +306,18 @@ def _calculate_xblocks( ) -> None: x_numels_list = kernel.x_numels_list for i in range(len(x_numels_list)): - xnumels = ( - x_numels_list[i] - if x_numels_list[i] > 0 - else kernel.min_x_blocks_list[i] + xnumels, no_x_dim = ( + (x_numels_list[i], False) + if isinstance(x_numels_list[i], str) + and cast(str, x_numels_list[i])[0] != "-" + or ( + isinstance(x_numels_list[i], int) + and cast(int, x_numels_list[i]) > 0 + ) + else (kernel.min_x_blocks_list[i], True) ) xblock_str = ( - f"tl.cdiv({xnumels}, XBLOCK)" - if x_numels_list[i] > 0 - else f"{xnumels}" + f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}" ) if i == 0: code.splice(f"num_xblocks_{i} = {xblock_str}") @@ -303,19 +326,30 @@ def _calculate_xblocks( @classmethod def grid( - cls, sub_kernel_numels: List[List[int]], x_blocks_list: List[int] + cls, + sub_kernel_numels: List[List[int]], + x_blocks_list: List[Union[str, int]], + dynamic_shape: bool, ) -> Tuple[Any, ...]: - xnumel = x_blocks_list - ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] - znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] + xnumel = list(x_blocks_list) + ynumel: Any = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] + znumel: Any = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] - # TODO: improve 1d/2d mixed cases - ynumel = ( - None if any(e is None for e in ynumel) else max(cast(List[int], ynumel)) - ) - znumel = ( - None if any(e is None for e in znumel) else max(cast(List[int], znumel)) - ) + if dynamic_shape: + ynumel = None if None in ynumel else ynumel + znumel = None if None in znumel else znumel + else: + # TODO: improve 1d/2d mixed cases + ynumel = ( + None + if any(e is None for e in cast(List[Any], ynumel)) + else max(cast(Iterable[int], ynumel)) + ) + znumel = ( + None + if any(e is None for e in cast(List[Any], znumel)) + else max(cast(Iterable[int], znumel)) + ) numels = ( (xnumel,) @@ -352,21 +386,38 @@ def codegen_pid_range( @classmethod def grid( - cls, sub_kernel_numels: List[List[int]], x_blocks_list: List[int] + cls, + sub_kernel_numels: List[List[int]], + x_blocks_list: List[Union[str, int]], + dynamic_shape: bool, ) -> Tuple[Any, ...]: - xnumel = [e[-1] if len(e) > 0 else None for e in sub_kernel_numels] + xnumel = x_blocks_list + # set no_x_dim xnumels to 0 + xnumel_x_dim = [max(e, 0) for e in xnumel] ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] # TODO: support 1d/2d mixed cases xnumel = ( - None if any(e is None for e in xnumel) else max(cast(List[int], xnumel)) + None + if any(e is None for e in xnumel) + else xnumel + if dynamic_shape + else max(xnumel_x_dim) # type: ignore[type-var, arg-type] ) ynumel = ( - None if any(e is None for e in ynumel) else max(cast(List[int], ynumel)) + None + if any(e is None for e in ynumel) + else ynumel + if dynamic_shape + else max(ynumel) # type: ignore[type-var, arg-type] ) znumel = ( - None if any(e is None for e in znumel) else max(cast(List[int], znumel)) + None + if any(e is None for e in znumel) + else znumel + if dynamic_shape + else max(znumel) # type: ignore[type-var, arg-type] ) numels = ( @@ -385,8 +436,8 @@ def __init__( self.sub_kernels: List[TritonKernel] = [] self.iter_vars_count = itertools.count() self.grids: List[List[int]] = [] - self.min_x_blocks_list: List[int] = [] - self.x_numels_list: List[int] = [] + self.min_x_blocks_list: List[Union[int, str]] = [] + self.x_numels_list: List[Union[int, str]] = [] self.enable_autotune = enable_autotune self.mixed_sizes = mixed_sizes self.dispatch_class: Optional[ @@ -401,6 +452,7 @@ def __init__( self.block_size_2d = 32 self.num_warps = 8 self.block_size_reduce = 256 + self.dynamic_shape_args: List[str] = [] def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: sub_kernel = triton_kernel @@ -419,6 +471,10 @@ def create_triton_kernel( reduction_hint: ReductionHint, optimize_mask: bool, ) -> TritonKernel: + """ + Only allow optimize_mask=True when 1) sequential dispatch is used, + 2) numels except x dimension are the same for each sub kernel. + """ return TritonKernel( *groups, index_dtype=index_dtype, @@ -451,16 +507,25 @@ def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexp uniquify_block_sizes = [] for tree in sub_kernel.range_trees: simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) - code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + if isinstance(simplified_tree_numel, (Integer, int)): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + else: + assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args + uniquify_block_sizes.append(f"{tree.prefix}numel") if tree.prefix != "r": - grid.append(int(simplified_tree_numel)) + if isinstance(simplified_tree_numel, (Integer, int)): + grid.append(int(simplified_tree_numel)) + else: + grid.append(f"{tree.prefix}numel_{num}") if tree.prefix == "r" and sub_kernel.persistent_reduction: if isinstance(simplified_tree_numel, (Integer, int)): val = int(simplified_tree_numel) else: - continue + raise RuntimeError( + "Dynamic shape on reduction dimension is not supported" + ) val = next_power_of_2(val) code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}") uniquify_block_sizes.append("RBLOCK") @@ -476,16 +541,27 @@ def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None: Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks. Grid calculation needs to make sure that they are assigned with enough number of blocks. """ - min_x_blocks = 0 - x_numels = 0 + min_x_blocks: Union[int, str] = 0 + x_numels: Union[int, str] = 0 for tree in sub_kernel.range_trees: simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) if tree.prefix == "x": + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" if sub_kernel.no_x_dim: - min_x_blocks = int(simplified_tree_numel) - x_numels = -min_x_blocks + min_x_blocks = x_numels + x_numels = ( + -min_x_blocks + if isinstance(x_numels, int) + else "-" + cast(str, x_numels) + ) else: - x_numels = int(simplified_tree_numel) + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" self.min_x_blocks_list.append(min_x_blocks) self.x_numels_list.append(x_numels) @@ -562,12 +638,15 @@ def get_mutated_args_sub_kernels(self) -> List[str]: def select_dispatch_strategy(self) -> None: if self.dispatch_class is not None: return - if not self.mixed_sizes: + # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch + # Not mixed sizes on y dim technically is ok to use round robin as wells. + if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list): + # str in min_x_blocks_list means a dynamic shape self.dispatch_class = ComboKernel.SequentialDispatch return # A negative x_blocks_list element means the kernel is not tunable, # i.e., no_x_dim = True - x_numels_list = [abs(e) for e in self.x_numels_list] + x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list] total = max(x_numels_list) * len(x_numels_list) needed = sum(x_numels_list) if needed / total > BLOCK_UTILIZATION: @@ -582,10 +661,12 @@ def jit_line( size_hints: List[int], selected_kernel: TritonKernel, pointwise_with_reduce: bool = False, + signature: Optional[List[Any]] = None, ) -> str: can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) size_dtype = "tl.int32" if can_use_32bit else "tl.int64" - _, _, signature, _ = self.args.python_argdefs() + if signature is None: + _, _, signature, _ = self.args.python_argdefs() for i, sub in enumerate(self.sub_kernels): self.min_x_blocks_sub_kernel(sub, i) self.select_dispatch_strategy() @@ -622,7 +703,7 @@ def jit_line( reduction_hint={reduction_hint}, filename=__file__, triton_meta={triton_meta!r}, - inductor_meta={inductor_meta} + inductor_meta={inductor_meta!r} ) @triton.jit """ @@ -678,6 +759,66 @@ def add_blockd_to_args(self, argdefs: List[str]) -> List[str]: self.block_args = list(block_names.keys()) return argdefs + def add_numel_to_args(self, argdefs: List[str], signature: List[Any]) -> List[str]: + for num, sub_kernel in enumerate(self.sub_kernels): + for tree in sub_kernel.active_range_trees(): + if not isinstance(tree.numel, (Integer, int)): + # only if it is a dynamic shape + sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel) + signature.append(sizearg) + argdefs.append(f"{tree.prefix}numel_{num}") + self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}") + return argdefs + + def add_numel_to_call_args_and_grid( + self, name: str, call_args: List[Any], arg_types: List[Any], grid: List[Any] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + if isinstance(tree.numel, (Integer, Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr( + name, tree, suffix=str(num) + ) + if tree.prefix != "r": + assert isinstance( + grid[i][num], str + ), f"Grid {grid[i][num]} should be a dynamic shape." + numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" + assert ( + grid[i][num] == numel_sign + numel_name + ), f"numel args mismatch: {grid[i][num]} vs {numel_name}" + grid[i][num] = -expr if numel_sign == "-" else expr + + if tree.prefix != "r" or sub_kernel.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def add_numel_to_call_args_and_grid_benchmark( + self, extra_args: List[Any], grid: Union[List[Any], Tuple[Any, ...]] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + expr = V.graph.sizevars.size_hint(tree.numel) + if tree.prefix != "r": + assert isinstance( + grid[i][num], str + ), f"Grid {grid[i][num]} should be a dynamic shape." + numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" + assert ( + grid[i][num] == numel_sign + numel_name + ), f"grid mismatch: {grid[i][num]} vs {numel_name}" + grid[i][num] = -expr if numel_sign == "-" else expr + if tree.prefix != "r" or sub_kernel.inside_reduction: + extra_args.append(expr) + def codegen_kernel(self, name: Optional[str] = None) -> str: # TODO: is it correct to use the first sub kernel's heuristics? heuristics_list, size_hints_list = [], [] @@ -699,7 +840,8 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: if config.benchmark_combo_kernel: code.splice(self.imports_for_benchmark_kernel()) - argdefs, _, _, _ = self.args.python_argdefs() + argdefs, _, signature, _ = self.args.python_argdefs() + argdefs = self.add_numel_to_args(argdefs, signature) argdefs = self.add_blockd_to_args(argdefs) code.splice( self.jit_line( @@ -707,6 +849,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: size_hints, selected_kernel, pointwise_with_reduce=pointwise_with_reduction, + signature=signature, ) ) code.writeline( @@ -752,7 +895,7 @@ def codegen_kernel_benchmark( var_names = [] for arg_name, arg_sig in zip(call_args, signature): var_name = f"arg_{next(name_cnt)}" - buf = V.graph.get_buffer(arg_name) + buf = V.graph.try_get_buffer(arg_name) if buf: result.writeline( f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long @@ -772,6 +915,12 @@ def codegen_kernel_benchmark( if "seed_offset" in arg_sig.name: symval_hint = 0 result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.scheduler.get_current_device_or_throw() + nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes) + result.writeline( + f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)" + ) else: raise KeyError( f"Don't find the buffer or const tensor for {arg_name}" @@ -782,15 +931,31 @@ def codegen_kernel_benchmark( result.writelines(["\n", "\n", "def call(args):"]) if grid is None: assert self.dispatch_class is not None - grid_tuple = self.dispatch_class.grid(self.grids, self.x_numels_list) + dynamic_shape = self.dynamic_shape_args != [] + grid_tuple = self.dispatch_class.grid( + self.grids, self.x_numels_list, dynamic_shape + ) + extra_args_str = "" + extra_args: List[Any] = [] + if dynamic_shape: + self.add_numel_to_call_args_and_grid_benchmark(extra_args, grid_tuple) + # convert nested list to list of str + grid_tuple = tuple( + "[" + ", ".join(pexpr(item) for item in e) + ",]" + for e in grid_tuple + ) + extra_args_str = ", ".join(map(str, extra_args)) + ", " + min_blocks = None + else: + min_blocks = max(self.min_x_blocks_list) * len(self.sub_kernels) grid_str = ", ".join(pexpr(item) for item in grid_tuple) grid_extra_kwargs = ( f"num_kernels={len(self.sub_kernels)}, " - f"min_blocks={max(self.min_x_blocks_list) * len(self.sub_kernels)}, " + f"min_blocks={min_blocks}, " f"is_sequential={self.dispatch_class is self.SequentialDispatch}" ) grid_str = f"{grid_str}, {grid_extra_kwargs}" - grid_arg = f"grid=grid_combo_kernels({grid_str})" + grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})" else: grid_arg = f"grid={grid}" index = V.graph.scheduler.get_current_device_or_throw().index @@ -882,13 +1047,22 @@ def call_kernel(self, code: IndentedBuffer, name: str) -> None: wrapper = V.graph.wrapper_code assert self.dispatch_class is not None - grid = self.dispatch_class.grid(self.grids, self.x_numels_list) + dynamic_shape = self.dynamic_shape_args != [] + grid = list( + self.dispatch_class.grid(self.grids, self.x_numels_list, dynamic_shape) + ) num_kernels = len(self.sub_kernels) - min_blocks = max(self.min_x_blocks_list) * num_kernels + min_blocks = ( + max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None + ) is_sequential = self.dispatch_class is self.SequentialDispatch - if not self.enable_autotune: + if dynamic_shape: + self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) + # convert nested list to list of str + # grid = tuple("["+", ".join(pexpr(item) for item in e)+",]" for e in grid) + if not self.enable_autotune and not dynamic_shape: launch_grid = self.grid_no_autotune( - grid, num_kernels, min_blocks, is_sequential + grid, num_kernels, cast(int, min_blocks), is_sequential ) V.graph.wrapper_code.generate_kernel_call( name, @@ -906,6 +1080,7 @@ def call_kernel(self, code: IndentedBuffer, name: str) -> None: num_kernels=num_kernels, min_blocks=min_blocks, is_sequential=is_sequential, + default_meta=None if self.enable_autotune else self.get_default_meta(), ) wrapper.generate_kernel_call( name, @@ -919,21 +1094,19 @@ def call_kernel(self, code: IndentedBuffer, name: str) -> None: grid_extra_kwargs=( f"num_kernels={num_kernels}, " f"min_blocks={min_blocks}, " - f"is_sequential={is_sequential}" + f"is_sequential={is_sequential}, " + f"default_meta={None if self.enable_autotune else self.get_default_meta()}" ), ) def grid_no_autotune( self, - grid: Tuple[Any], + grid: Union[Tuple[Any], List[Any]], num_kernels: int, min_blocks: int, is_sequential: bool, ) -> List[int]: - if "YBLOCK" in self.block_args: - meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d} - else: - meta = {"XBLOCK": self.block_size_1d} + meta = self.get_default_meta() grid_func = grid_combo_kernels( *grid, num_kernels=num_kernels, @@ -941,3 +1114,10 @@ def grid_no_autotune( is_sequential=is_sequential, ) return grid_func(meta) + + def get_default_meta(self) -> Dict[str, int]: + if "YBLOCK" in self.block_args: + meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d} + else: + meta = {"XBLOCK": self.block_size_1d} + return meta diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 9ecc942e5d4031..30b073ca5f7015 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1467,8 +1467,10 @@ def traverse(cur_kernel): ) return name, triton_meta - def generate_numel_expr(self, kernel_name: str, tree): + def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): expr = f"{kernel_name}_{tree.prefix}numel" + if suffix is not None: + expr += f"_{suffix}" if (expr, V.graph) not in self.kernel_numel_expr: # declare expr once in each graph (scope) self.kernel_numel_expr.add((expr, V.graph)) @@ -1622,10 +1624,20 @@ def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): ) elif isinstance(arg, (str, int, float, bool)): return str(arg) + elif isinstance(arg, list): + return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]" else: breakpoint() raise NotImplementedError(f"Unsupported type {type(arg)}") + def _grid_dim_str(self, grid_per_dim): + if isinstance(grid_per_dim, list): + return ( + "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]" + ) + else: + return pexpr(grid_per_dim) + def generate_kernel_call( self, kernel_name, @@ -1661,7 +1673,7 @@ def generate_kernel_call( if grid is None: grid_str = grid_fn else: - grid_str = ", ".join(pexpr(item) for item in grid) + grid_str = ", ".join(self._grid_dim_str(item) for item in grid) if grid_extra_kwargs: grid_str = f"{grid_str}, {grid_extra_kwargs}" grid_str = f"{grid_fn}({grid_str})" @@ -1715,6 +1727,8 @@ def generate_kernel_call( grid_str = ", ".join( self.generate_example_arg_value(g, type(g)) for g in grid ) + if grid_extra_kwargs: + grid_str = f"{grid_str}, {grid_extra_kwargs}" grid_str = f"{grid_fn}({grid_str})" self.kernel_autotune_calls.writeline( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index ddf156df300a40..f6599989f31a4b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -468,6 +468,8 @@ def use_autoheuristic(name: str) -> bool: # Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable # for all except for foreach, 2 - enable for all combo_kernel_allow_mixed_sizes = 1 +# Enable dynamic shapes for foreach kernels +combo_kernel_foreach_dynamic_shapes = False # constant folding on the joint graph joint_graph_constant_folding = True diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 7654792cdd299f..437a5abac3f339 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -533,7 +533,9 @@ def inner(*inputs: List[List[TensorBox]], alpha=1): def group_args(arg_pairs): out = defaultdict(list) for i, args in enumerate(arg_pairs): - use_foreach = not is_dynamic(*args) + use_foreach = ( + not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes + ) device = None for t in args: if isinstance(t, TensorBox): diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 81887a9fda88a6..8421dba05b98d5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1869,16 +1869,38 @@ def grid_fn(meta): return grid_fn -def grid_combo_kernels(*numels, num_kernels, min_blocks, is_sequential): +def grid_combo_kernels( + *numels, num_kernels, min_blocks, is_sequential, default_meta=None +): """min_blocks is the minimal size of the grid x dimension""" if not is_sequential: # round robin dispatch - kernel_grid_fn = grid(*numels) + numels_agg = list(numels) + for i in range(len(numels_agg)): + if isinstance(numels_agg[i], (list, tuple)): + numels_agg[i] = max(max(numels_agg[i]), 0) # noqa: PLW3301 + kernel_grid_fn = grid(*numels_agg) + + if isinstance(numels[-1], (list, tuple)): + min_blocks_d = max(-min(numels[-1]), 0) * num_kernels + else: + min_blocks_d = None + if min_blocks is None: + assert min_blocks_d is not None + min_blocks = min_blocks_d + else: + assert ( + min_blocks_d is None or min_blocks == min_blocks_d + ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}" else: # sequential dispatch seq_numels = list(numels) # x numels are not used here, just a place holder seq_numels[-1] = 1024 + for i in range(len(seq_numels) - 1): + if isinstance(seq_numels[i], (list, tuple)): + seq_numels[i] = max(seq_numels[i]) + kernel_grid_fn = grid(*seq_numels) def get_grid_dim(numel, block): @@ -1889,6 +1911,7 @@ def get_grid_dim(numel, block): return ceildiv(numel, block) def grid_fn(meta): + assert min_blocks is not None, "min_blocks must be a number" cuda_grid = list(kernel_grid_fn(meta)) cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks) return tuple(cuda_grid) @@ -1905,4 +1928,13 @@ def seq_grid_fn(meta): cuda_grid[0] = x_grid return tuple(cuda_grid) - return grid_fn if not is_sequential else seq_grid_fn + def grid_fn_default_meta(meta): + return grid_fn(default_meta) + + def seq_grid_fn_default_meta(meta): + return seq_grid_fn(default_meta) + + if default_meta is None: + return grid_fn if not is_sequential else seq_grid_fn + else: + return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 7b2450f37366bc..efe93253de6800 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3523,7 +3523,7 @@ def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool: raise # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking. - small_kernel = ms2_clone / ms2 > 0.6 and ms2 - ms2_clone < 0.2 + small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3 if fusion_log.isEnabledFor(logging.DEBUG): if ms1 > ms2 or small_kernel: fusion_log.debug( @@ -3535,8 +3535,8 @@ def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool: "cannot fuse (benchmark): fusing causes %sx slowdown", red_text(f"{ms1 / ms2:.3f}"), ) - - return ms2 < ms1 or small_kernel + # ms1 returned by benchmark_fused_nodes discounted clone time + return ms2 - ms2_clone < ms1 or small_kernel def get_buffer_layout(self, buf_name: str) -> ir.Layout: buf = self.name_to_buf[buf_name] From c568fc2b4f3a4ca3e6f3b8cbc5fad3d61265489b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 30 Aug 2024 20:00:41 +0000 Subject: [PATCH 0297/1018] Revert "[ROCm] remove triton-rocm commit pin and merge pins with triton.txt (#133438)" This reverts commit 5e8bf29148a590318f678620f84be8f4d5ffff5c. Reverted https://github.com/pytorch/pytorch/pull/133438 on behalf of https://github.com/ZainRizvi due to This still breaks linux binary builds. Added the appropriate labels to ensure tests can pass. See [GH job link](https://github.com/pytorch/pytorch/actions/runs/10626427003/job/29460479554) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/5e8bf29148a590318f678620f84be8f4d5ffff5c) ([comment](https://github.com/pytorch/pytorch/pull/133438#issuecomment-2322246198)) --- .ci/docker/centos-rocm/Dockerfile | 4 ++-- .ci/docker/ci_commit_pins/triton-rocm.txt | 1 + .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/common/install_triton.sh | 5 ++++- .ci/docker/ubuntu-rocm/Dockerfile | 4 ++-- .circleci/scripts/binary_populate_env.sh | 2 +- .github/scripts/build_triton_wheel.py | 26 ++++++++++++++++++++++- .github/workflows/build-triton-wheel.yml | 2 ++ CODEOWNERS | 1 + 9 files changed, 39 insertions(+), 8 deletions(-) create mode 100644 .ci/docker/ci_commit_pins/triton-rocm.txt diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index 30ce1406e3f81c..bfac9ddd859084 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -108,10 +108,10 @@ ENV CMAKE_C_COMPILER cc ENV CMAKE_CXX_COMPILER c++ COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton.txt triton.txt +COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt +RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt # Install AOTriton (Early fail) COPY ./aotriton_version.txt aotriton_version.txt diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt new file mode 100644 index 00000000000000..0cb336acccb5b1 --- /dev/null +++ b/.ci/docker/ci_commit_pins/triton-rocm.txt @@ -0,0 +1 @@ +21eae954efa5bf584da70324b640288c3ee7aede diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 378d625784fbc5..41c8d1602b6bef 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -757b6a61e7df814ba806f498f8bb3160f84b120c +dedb7bdf339a3546896d4820366ca562c586bfa0 diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index 6e5fb8839c686e..d4a0dc80ee177e 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -12,7 +12,10 @@ conda_reinstall() { as_jenkins conda install -q -n py_$ANACONDA_PYTHON_VERSION -y --force-reinstall $* } -if [ -n "${XPU_VERSION}" ]; then +if [ -n "${ROCM_VERSION}" ]; then + TRITON_REPO="https://github.com/openai/triton" + TRITON_TEXT_FILE="triton-rocm" +elif [ -n "${XPU_VERSION}" ]; then TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton" TRITON_TEXT_FILE="triton-xpu" else diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 07e25f533a71f9..ee9ede8ba611b6 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -100,10 +100,10 @@ ARG TRITON # try to reach out to S3, which docker build runners don't have access COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton.txt triton.txt +COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt +RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt # Install AOTriton COPY ./aotriton_version.txt aotriton_version.txt diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 106d0917ca68c5..e918635922a45e 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -90,7 +90,7 @@ fi if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then - TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt) + TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt) TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}" fi if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 7e70c49ff8b1bb..7ee2cb4b8e61b5 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -15,7 +15,9 @@ def read_triton_pin(device: str = "cuda") -> str: triton_file = "triton.txt" - if device == "xpu": + if device == "rocm": + triton_file = "triton-rocm.txt" + elif device == "xpu": triton_file = "triton-xpu.txt" with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / triton_file) as f: return f.read().strip() @@ -48,6 +50,25 @@ def patch_init_py( f.write(orig) +# TODO: remove patch_setup_py() once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 +def patch_setup_py(path: Path) -> None: + with open(path) as f: + orig = f.read() + try: + orig = check_and_replace( + orig, + "https://tritonlang.blob.core.windows.net/llvm-builds/", + "https://oaitriton.blob.core.windows.net/public/llvm-builds/", + ) + with open(path, "w") as f: + f.write(orig) + except RuntimeError as e: + print( + f"Applying patch_setup_py() for llvm-build package failed: {e}.", + "If you are trying to build a newer version of Triton, you can ignore this.", + ) + + def build_triton( *, version: str, @@ -89,6 +110,9 @@ def build_triton( else: check_call(["git", "checkout", commit_hash], cwd=triton_basedir) + # TODO: remove this and patch_setup_py() once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 + patch_setup_py(triton_pythondir / "setup.py") + if build_conda: with open(triton_basedir / "meta.yaml", "w") as meta: print( diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 1aa777c1167c4f..01f06cdd286cf4 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -13,6 +13,7 @@ on: - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton.txt + - .ci/docker/ci_commit_pins/triton-rocm.txt - .ci/docker/ci_commit_pins/triton-xpu.txt pull_request: paths: @@ -20,6 +21,7 @@ on: - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton.txt + - .ci/docker/ci_commit_pins/triton-rocm.txt - .ci/docker/ci_commit_pins/triton-xpu.txt concurrency: diff --git a/CODEOWNERS b/CODEOWNERS index bafce8f6f53521..7b9db26104a9d5 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -57,6 +57,7 @@ nn/qat/ @jerryzh168 # Docker /.ci/docker/ @jeffdaily /.ci/docker/ci_commit_pins/triton.txt @desertfire @Chillee @eellison @shunting314 @bertmaher @jeffdaily @jataylo @jithunnair-amd @pruthvistony +/.ci/docker/ci_commit_pins/triton-rocm.txt @jeffdaily @jataylo @jithunnair-amd @pruthvistony /.ci/docker/ci_commit_pins/triton-xpu.txt @EikanWang @gujinghui # Github Actions From 494b82ca02565511cfc06945a9a928ae17b8acee Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Fri, 30 Aug 2024 20:34:09 +0000 Subject: [PATCH 0298/1018] [Inductor] Allow customizing the padding format (#133939) Based on https://github.com/pytorch/pytorch/pull/130956. Inductor already supports padding through the `config.comprehensive_padding` option, but the padding format involves a few heuristics that are specific to Nvidia GPUs: - When we pad, it is always aligned to the next multiple of 128 bytes. - Strides smaller than 1024 are not padded. - Only intermediate values are padded, not outputs. The last of these is not really GPU-specific, but there are certain cases where we may want to override it. For example, padding outputs is useful on hardware accelerators with specific memory alignment requirements, or for applications where performance is more important than conformity with eager mode. This PR surfaces padding parameters up to Inductor's config module, so the user can control them. - `config.pad_outputs`: choose whether to pad outputs (default: `False`) - `config.padding_alignment_bytes`: choose the alignment size for padding (default: `128`) - `config.padding_stride_threshold`: choose the smallest stride that we will pad. For example, setting this to 0 will pad all unaligned strides. (default: `1024`) **Test plan** Added a new test in `test_padding.py` which tries various combinations of these options, checking that the output strides match our expectations. These changes should not affect perf, because the defaults are identical to Inductor's current behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133939 Approved by: https://github.com/shunting314 Co-authored-by: Yueming Hao --- test/inductor/test_padding.py | 58 +++++++++++++++++++++++++++++++++-- torch/_inductor/config.py | 28 +++++++++++++++++ torch/_inductor/graph.py | 11 ++++--- torch/_inductor/ir.py | 26 ++-------------- 4 files changed, 93 insertions(+), 30 deletions(-) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index c8170c92327166..9ae3dd3a125df8 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -3,6 +3,7 @@ import functools import os import unittest +from typing import Tuple import torch from torch import nn, Tensor @@ -12,8 +13,13 @@ from torch._inductor import config, ir, metrics from torch._inductor.fx_passes import pad_mm as pad_mm_pass from torch._inductor.runtime.benchmarking import benchmarker -from torch._inductor.utils import run_and_get_code -from torch.testing._internal.common_utils import requires_cuda, serialTest +from torch._inductor.utils import ceildiv, run_and_get_code +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + requires_cuda, + serialTest, +) from torch.testing._internal.inductor_utils import HAS_CUDA @@ -362,6 +368,7 @@ def test_longformer_small_bs(self): self.test_longformer(bs=2) +@instantiate_parametrized_tests class PaddingTest(TestCaseBase): @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") def test_mm_padding_perf(self): @@ -654,6 +661,53 @@ def test_pad_channels_last(self): out_strides = ir.Layout._pad_strides(in_strides, t.shape, torch.float32) self.assertTrue(in_strides == out_strides) + @parametrize("alignment_bytes", (32, 128)) + @parametrize("shape", [(21, 19), (3, 5, 71)]) + @parametrize("dtype", (torch.float16, torch.float32)) + def test_pad_outputs( + self, dtype: torch.dtype, shape: Tuple[int], alignment_bytes: int + ): + """ + Tests padding output tensors to a specific alignment. + This is enabled by a config flag. + """ + func = torch.add + inputs = tuple(torch.randn(*shape, dtype=dtype) for input_idx in range(2)) + + # Compile and run + with config.patch( + { + "comprehensive_padding": True, + "padding_alignment_bytes": alignment_bytes, + "padding_stride_threshold": 0, + "pad_outputs": True, + } + ): + compiled_func = torch.compile(func) + compiled_out = compiled_func(*inputs) + + # Check numerics + eager_out = func(*inputs) + self.check_close(eager_out, compiled_out) + + # Compute the expected padding + element_size = torch.tensor([], dtype=dtype).element_size() + self.assertGreater(alignment_bytes, element_size) + self.assertEqual(alignment_bytes % element_size, 0) + alignment_elements = alignment_bytes // element_size + contiguous_stride = inputs[0].stride() + expected_stride = [1] + for dim in reversed(shape[1:]): + slice_size = dim * expected_stride[0] + new_stride = alignment_elements * ceildiv(slice_size, alignment_elements) + expected_stride.insert(0, new_stride) + expected_stride = tuple(expected_stride) + self.assertNotEqual(expected_stride, contiguous_stride) + + # Check strides + self.assertFalse(compiled_out.is_contiguous()) + self.assertEqual(compiled_out.stride(), expected_stride) + if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f6599989f31a4b..919cddb773d7aa 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -595,6 +595,34 @@ def decide_compile_threads() -> int: ) pad_channels_last = False +# The width of comprehensive padding, in bytes. +# CUDA max memory transaction size is 128 bytes for a warp. +padding_alignment_bytes = 128 + +# Threshold on the minimum stride that will be padded. +# +# Don't align a too small stride since that causes too much memory increase. +# Pad too small stride may also cause perf loss. We may result in many tiny data blocks +# with gaps in between. That causes less coalesced GPU memory access! +# +# Initially we pick 320 as the threshold since for alignement=16, +# that results in at most 5% memory cost. +# +# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. +# Let's say an inner reduction has a row size 513. Inductor will generate +# persistent reduction code. +# If we do padding, the strides are not contiguous any more. Inductor +# uses a much smaller threshold for persistent reduction in this case and +# generates potentially worse non-persistent reduction code. +# +# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. +# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) +padding_stride_threshold = 1024 + +# Enable padding outputs, even if they would not be padded in eager mode. +# By default, we use the same strides as eager mode. +pad_outputs = False + # Whether to treat output of the backward graph as user visible. # For user visible outputs, inductor will make sure the stride matches with eager. bw_outputs_user_visible = True diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 987e639657cb0c..c9da83e7584e62 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -247,7 +247,11 @@ def _get_overload_packet( if prior_op not in ops_like_padding: prior.meta["dislike_padding"] = True # We only want to mark output nodes. So, move it after the above prior nodes process. - if user_visible_outputs and cur.name in user_visible_outputs: + if ( + not config.pad_outputs + and user_visible_outputs + and cur.name in user_visible_outputs + ): cur.meta["dislike_padding"] = True @@ -1359,9 +1363,8 @@ def debug(msg: str) -> None: strides = n.meta["val"].stride() if len(strides): allow_padding = ( - n.name not in self.user_visible_outputs - and not is_input_for_as_strided - ) + config.pad_outputs or n.name not in self.user_visible_outputs + ) and not is_input_for_as_strided dense = torch._prims_common.is_non_overlapping_and_dense( n.meta["val"] ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 55d43674807e0d..18cdac11ac8077 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2872,12 +2872,7 @@ def is_contiguous_strides_for_shape( def get_align_for_dtype(dtype: torch.dtype) -> int: - """ - CUDA max memory transaction size is 128 bytes for a warp. - We pick `128 // dtype.itemsize` as alighment so GPU can do coalesced - memory access. - """ - return 128 // dtype.itemsize + return config.padding_alignment_bytes // dtype.itemsize @dataclasses.dataclass @@ -3016,29 +3011,12 @@ def _pad_strides(in_strides, size, dtype): # smallest stride to be 1. new_strides[fill_order[0]] = 1 - # Don't align a too small stride since that causes too much memory increase. - # Pad too small stride may also cause perf loss. We may result in many tiny data blocks - # with gaps in between. That causes less coalesced GPU memory access! - # - # Initially we pick 320 as the threshold since for alignement=16, - # that results in at most 5% memory cost. - # - # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. - # Let's say an inner reduction has a row size 513. Inductor will generate - # persistent reduction code. - # If we do padding, the strides are not contiguous any more. Inductor - # uses a much smaller threshold for persistent reduction in this case and - # generates potentially worse non-persistent reduction code. - # - # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. - # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) - align_stride_threshold = 1024 padded = False for rank, idx in enumerate(fill_order[1:], start=1): prev_idx = fill_order[rank - 1] stride = new_strides[prev_idx] * size[prev_idx] - if stride > align_stride_threshold and stride % align != 0: + if stride > config.padding_stride_threshold and stride % align != 0: stride = ceildiv(stride, align) * align padded = True new_strides[idx] = stride From 9cc6a36cf42771cc0788b71717ed04b54ceac77f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 30 Aug 2024 11:35:14 -0700 Subject: [PATCH 0299/1018] [Dynamo][autograd.Function] Trace fwd graph under no_grad mode (#134872) Fixes #134820 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134872 Approved by: https://github.com/zou3519 --- test/dynamo/test_autograd_function.py | 31 +++++++++++++++++++++ torch/_dynamo/variables/higher_order_ops.py | 1 + 2 files changed, 32 insertions(+) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index e9b913ed9559db..629e21e0daf948 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -439,6 +439,29 @@ def f(x): self.assertEqual(result, Foo.apply(x)) self.assertEqual(cnt.frame_count, 1) + def test_fwd_no_grad(self): + # autograd.Function.forward should be traced and called under no_grad mode. + # torch.exp with out=... arguments don't support automatic differentiation, + # so can't be traced/called under grad mode (throwing RuntimeError), + # therefore this unit test ensures fwd is under no_grad mode. + class Foo(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs): + torch.exp(inputs, out=inputs) + return inputs + + @staticmethod + def backward(ctx, grad_output): + return None + + @torch.compile(backend="eager", fullgraph=True) + def f(x): + return Foo.apply(x) + + x1 = torch.randn(2, 3, requires_grad=True) + x2 = x1.clone() + self.assertEqual(f(x1), Foo.apply(x2)) + def test_amp_custom_fwd_bwd(self): torch._dynamo.utils.counters.clear() cnt = torch._dynamo.testing.CompileCounter() @@ -547,9 +570,13 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: " class fwd_body_0(torch.nn.Module): def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None + mul: "f32[]" = l_weird_b * l_weird_c clone: "f32[]" = x.clone(); x = None mul_1: "f32[]" = mul * clone; mul = clone = None + + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return (mul_1, [l_weird_b, l_weird_c]) class bwd_body_0(torch.nn.Module): @@ -1113,9 +1140,13 @@ def forward(self, L_x_: "f32[]", L_y_: "f32[]"): class fwd_body_0(torch.nn.Module): def forward(self, ctx, x: "f32[]", y: "f32[]"): + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None + out1: "f32[]" = x.sin(); x = None out2: "f32[]" = y * 2; y = None + + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return ((out1, out2), []) class bwd_body_0(torch.nn.Module): diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 92579cf4fa16cd..35d78fd304dae8 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1996,6 +1996,7 @@ def bwd(ctx, grad, x): fwd_args, kwargs, "autograd.Function", + enable_grad=False, set_subgraph_inputs="semi_automatic", restore_side_effects=False, tracer=fwd_tracer, From 8b85494069e24eb3666f75a211792180ee403f83 Mon Sep 17 00:00:00 2001 From: Ratnam Parikh <114774508+ratnampa@users.noreply.github.com> Date: Fri, 30 Aug 2024 23:51:37 +0000 Subject: [PATCH 0300/1018] [XPU] Fix Windows XPU build (#134276) Linker flag check doesn't work correctly with MSVC and linking torch_xpu with torch_cpu_library for windows MSVC works without any errors Pull Request resolved: https://github.com/pytorch/pytorch/pull/134276 Approved by: https://github.com/EikanWang, https://github.com/atalman --- caffe2/CMakeLists.txt | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index d44a8da210462f..63c2fef1e195ef 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1563,20 +1563,24 @@ if(USE_XPU) target_link_libraries( torch_xpu PRIVATE ${Caffe2_XPU_DEPENDENCY_LIBS}) - include(CheckLinkerFlag) - - # Check whether the compiler supports '--no-as-needed' and '--as-needed' - check_linker_flag(CXX "-Wl,--no-as-needed" HAVE_NO_AS_NEEDED) - check_linker_flag(CXX "-Wl,--as-needed" HAVE_AS_NEEDED) - # Ensure that torch_cpu is ready before being linked by torch_xpu. add_dependencies(torch_xpu torch_cpu) - if(HAVE_NO_AS_NEEDED AND HAVE_AS_NEEDED) - target_link_libraries(torch_xpu PRIVATE - "-Wl,--no-as-needed,\"$\" -Wl,--as-needed") + if(MSVC) + target_link_libraries(torch_xpu PUBLIC torch_cpu_library) else() - target_link_libraries(torch_xpu PRIVATE "$") + include(CheckLinkerFlag) + + # Check whether the compiler supports '--no-as-needed' and '--as-needed' + check_linker_flag(CXX "-Wl,--no-as-needed" HAVE_NO_AS_NEEDED) + check_linker_flag(CXX "-Wl,--as-needed" HAVE_AS_NEEDED) + + if(HAVE_NO_AS_NEEDED AND HAVE_AS_NEEDED) + target_link_libraries(torch_xpu PRIVATE + "-Wl,--no-as-needed,\"$\" -Wl,--as-needed") + else() + target_link_libraries(torch_xpu PRIVATE "$") + endif() endif() endif() From 1963d1d79f0837518583ae6d5314faecef0efe6e Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Sat, 31 Aug 2024 00:06:28 +0000 Subject: [PATCH 0301/1018] [ONNX][DORT] Lazy-import `onnxruntime` (#134662) Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine): ``` /mnt/c.py(53)() -> from torch._inductor.utils import maybe_profile /usr/local/lib/python3.10/site-packages/torch/_inductor/utils.py(49)() -> import torch._export /usr/local/lib/python3.10/site-packages/torch/_export/__init__.py(25)() -> import torch._dynamo /usr/local/lib/python3.10/site-packages/torch/_dynamo/__init__.py(2)() -> from . import convert_frame, eval_frame, resume_execution /usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py(48)() -> from . import config, exc, trace_rules /usr/local/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py(52)() -> from .variables import ( /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py(38)() -> from .higher_order_ops import ( /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py(14)() -> import torch.onnx.operators /usr/local/lib/python3.10/site-packages/torch/onnx/__init__.py(62)() -> from ._internal.onnxruntime import ( /usr/local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py(37)() -> import onnxruntime # type: ignore[import] ``` This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well. I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else? Related issue: https://github.com/huggingface/accelerate/pull/3056 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134662 Approved by: https://github.com/justinchuby --- .../test_dynamo_with_onnxruntime_backend.py | 8 +- test/onnx/test_lazy_import.py | 37 ++++++++ torch/onnx/_internal/onnxruntime.py | 85 ++++++++++++++----- 3 files changed, 106 insertions(+), 24 deletions(-) create mode 100644 test/onnx/test_lazy_import.py diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py index 46980f2ce64aaf..1d280dfa03450b 100644 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py @@ -51,17 +51,19 @@ def tearDown(self): OrtBackend.clear_cached_instances() def test_get_ort_device_type(self): + from onnxruntime.capi import _pybind_state as ORTC + self.assertEqual( torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"), - torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cuda(), + ORTC.OrtDevice.cuda(), ) self.assertEqual( torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"), - torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cpu(), + ORTC.OrtDevice.cpu(), ) self.assertEqual( torch.onnx._internal.onnxruntime._get_ort_device_type("maia"), - torch.onnx._internal.onnxruntime.ORTC.OrtDevice.npu(), + ORTC.OrtDevice.npu(), ) def test_torch_compile_backend_registration(self): diff --git a/test/onnx/test_lazy_import.py b/test/onnx/test_lazy_import.py new file mode 100644 index 00000000000000..bce4b892095ec8 --- /dev/null +++ b/test/onnx/test_lazy_import.py @@ -0,0 +1,37 @@ +# Owner(s): ["module: onnx"] + +import subprocess +import sys +import tempfile + +import pytorch_test_common + +from torch.testing._internal import common_utils + + +class TestLazyONNXPackages(pytorch_test_common.ExportTestCase): + def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"): + with tempfile.TemporaryDirectory() as wd: + r = subprocess.run( + [sys.executable, "-Ximporttime", "-c", "import torch.onnx"], + capture_output=True, + text=True, + cwd=wd, + check=True, + ) + + # The extra space makes sure we're checking the package, not any package containing its name. + self.assertTrue( + f" {pkg}" not in r.stderr, + f"`{pkg}` should not be imported, full importtime: {r.stderr}", + ) + + def test_onnxruntime_is_lazily_imported(self): + self._test_package_is_lazily_imported("onnxruntime") + + def test_onnxscript_is_lazily_imported(self): + self._test_package_is_lazily_imported("onnxscript") + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py index 97626dcc70bb58..accc755dd8840c 100644 --- a/torch/onnx/_internal/onnxruntime.py +++ b/torch/onnx/_internal/onnxruntime.py @@ -34,32 +34,18 @@ if TYPE_CHECKING: import onnx - -try: - # Use try-except to initialize package-dependent global variables. - import onnxruntime # type: ignore[import] - from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import] - - # This is not use directly in DORT but needed by underlying exporter, - # so we still need to check if it exists. - importlib.import_module("onnxscript") + import onnxruntime + from onnxruntime.capi import _pybind_state as ORTC import torch.onnx import torch.onnx._internal import torch.onnx._internal._exporter_legacy import torch.onnx._internal.diagnostics import torch.onnx._internal.fx.decomposition_table - import torch.onnx._internal.fx.passes - from torch.onnx._internal.fx import fx_onnx_interpreter - from torch.onnx._internal.fx.type_utils import ( - _TORCH_DTYPE_TO_NUMPY_DTYPE, - _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE, - from_python_type_to_onnx_tensor_element_type, - ) + import torch.onnx._internal.fx.passes # noqa: TCH004 - _SUPPORT_ONNXRT = True -except ImportError: - _SUPPORT_ONNXRT = False + +_SUPPORT_ONNXRT: Optional[bool] = None __all__ = [ "is_onnxrt_backend_supported", @@ -87,6 +73,35 @@ def is_onnxrt_backend_supported() -> bool: ... print("pip install onnx onnxscript onnxruntime") ... """ + global _SUPPORT_ONNXRT + + if _SUPPORT_ONNXRT is None: + # `onnxruntime` might import a lot of other runtime packages, + # e.g. apex, deepspeed, transformers. + # So lazy-importing onnxruntime to avoid possible circular import. + try: + importlib.import_module("onnxruntime") + importlib.import_module("onnxruntime.capi._pybind_state") + + # This is not use directly in DORT but needed by underlying exporter, + # so we still need to check if it exists. + importlib.import_module("onnxscript") + + import torch.onnx # noqa: F401 + import torch.onnx._internal # noqa: F401 + import torch.onnx._internal._exporter_legacy # noqa: F401 + import torch.onnx._internal.diagnostics # noqa: F401 + from torch.onnx._internal.fx import ( # noqa: F401 + decomposition_table, + fx_onnx_interpreter, + passes, + type_utils, + ) + + _SUPPORT_ONNXRT = True + except ImportError: + _SUPPORT_ONNXRT = False + return _SUPPORT_ONNXRT @@ -143,6 +158,8 @@ def _nvtx_range_pop(): def _get_ort_device_type(device_type: str): + from onnxruntime.capi import _pybind_state as ORTC + if device_type == "cuda": return ORTC.OrtDevice.cuda() if device_type == "cpu": @@ -305,6 +322,8 @@ def _get_onnx_devices( ..., ], ) -> Tuple["ORTC.OrtDevice", ...]: + from onnxruntime.capi import _pybind_state as ORTC + def _device_id_or_zero(device_id: int) -> int: return device_id or 0 @@ -338,6 +357,10 @@ def _map_tensor_or_sym_to_device( def _get_ortvalues_from_torch_tensors( tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...] ) -> Tuple[torch.Tensor, ...]: + from onnxruntime.capi import _pybind_state as ORTC + + from torch.onnx._internal.fx.type_utils import _TORCH_DTYPE_TO_NUMPY_DTYPE + ortvalues = ORTC.OrtValueVector() ortvalues.reserve(len(tensors)) dtypes = [] @@ -436,6 +459,9 @@ def _run_onnx_session_with_ortvaluevector( ..., ], ) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: + import onnxruntime + from onnxruntime.capi import _pybind_state as ORTC + _nvtx_range_push("contiguous") inputs = tuple( _adjust_scalar_from_fx_to_onnx(arg, value_info) @@ -514,6 +540,8 @@ def _run_onnx_session_with_fetch( ..., ], ) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: + import onnxruntime + inputs = tuple( _adjust_scalar_from_fx_to_onnx(arg, value_info) for arg, value_info in zip(inputs, input_value_infos) @@ -570,6 +598,11 @@ def __init__( ) def is_supported(self, *args): + from torch.onnx._internal.fx.type_utils import ( + _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE, + from_python_type_to_onnx_tensor_element_type, + ) + # Compare the args and the input schema in ONNX model and # return the first match. if len(args) != len(self.input_value_infos): @@ -728,6 +761,12 @@ class OrtBackend: """ def __init__(self, options: Optional[OrtBackendOptions] = None): + from onnxruntime.capi import _pybind_state as ORTC + + import torch.onnx + import torch.onnx._internal._exporter_legacy + import torch.onnx._internal.fx.decomposition_table + self._options: Final = OrtBackendOptions() if options is None else options # options.export_options contains information shared between exporter and DORT. @@ -849,6 +888,10 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar it means we delegate the computation to _ort_acclerated_call and therefore onnxruntime.InferenceSession. """ + import onnxruntime + + from torch.onnx._internal.fx import fx_onnx_interpreter, passes + cached_execution_info_per_session = ( self._all_ort_execution_info.search_reusable_session_execution_info( graph_module, *args @@ -867,7 +910,7 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar # It's first time seeing such as graph. Let's make a new session # (type: onnxruntime.InferenceSession) for it. - graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront( + graph_module = passes.MovePlaceholderToFront( self._resolved_onnx_exporter_options.diagnostic_context, graph_module, ).run() @@ -915,7 +958,7 @@ def maybe_map_to_meta_val(value): # Cast FX variables if they will result schema-mismatch when searching # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, # but ONNX expects add(double_tensor, double_tensor). - graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion( + graph_module = passes.InsertTypePromotion( self._resolved_onnx_exporter_options.diagnostic_context, graph_module ).run() # Start the per-node exporting process. It's conceptually a for loop From 792efb7ed52780de71e1a83c3895a3378194e8a3 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 31 Aug 2024 04:39:47 +0800 Subject: [PATCH 0302/1018] [dynamo][itertools] refactor `itertools.chain` and `itertools.chain.from_iterable` to use polyfills (#133864) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133864 Approved by: https://github.com/jansel --- torch/_dynamo/decorators.py | 16 +++++++++++++++- torch/_dynamo/polyfills/__init__.py | 19 +++++++++++++++++-- torch/_dynamo/polyfills/itertools.py | 21 ++++++++++++++++++++- torch/_dynamo/polyfills/loader.py | 1 + torch/_dynamo/trace_rules.py | 8 ++------ torch/_dynamo/variables/builder.py | 28 ++++++++++++++++++++++++++-- torch/_dynamo/variables/builtin.py | 16 ---------------- torch/_dynamo/variables/functions.py | 22 ++++++++++++++++++++++ torch/_dynamo/variables/iter.py | 8 -------- 9 files changed, 103 insertions(+), 36 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index e73aecad34ce3a..4e97422fa605e9 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -170,6 +170,8 @@ def substitute_in_graph( *, can_constant_fold_through: bool = False, skip_signature_check: bool = False, + # type that is embedded in the Python interpreter + is_embedded_type: bool = False, # internal use only ) -> Callable[[_F], _F]: """ Register a polyfill handler for a function, usually a C function from the C extension, to be @@ -219,10 +221,22 @@ def substitute_in_graph( >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) 2 """ - if not is_function(original_fn): + if not is_function(original_fn) and not ( + is_embedded_type and inspect.isclass(original_fn) + ): raise TypeError( f"substitute_in_graph expects a function but got {type(original_fn)!r}" ) + if is_embedded_type: + if not inspect.isclass(original_fn): + raise TypeError( + f"substitute_in_graph expects a class but got {type(original_fn)!r}" + ) + + from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS + + if id(original_fn) in ITERTOOLS_TYPE_IDS: + ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) def wrapper(traceable_fn: _F) -> _F: if not is_function(traceable_fn): diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 9b465726754bbe..fd206c3847e6e1 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -4,15 +4,28 @@ # NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports. # 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py. +# Add it in the TYPE_CHECKING block below as well. # mypy: allow-untyped-defs -import math -from typing import Any, Callable, Sequence +from typing import Any, Callable, Sequence, TYPE_CHECKING import torch +if TYPE_CHECKING: + # Load by torch._dynamo.polyfills.loader + # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py + # Put the submodules here to avoid circular imports + from . import ( + builtins as builtins, + functools as functools, + itertools as itertools, + os as os, + sys as sys, + ) + + def index(iterator, item, start=0, end=None): for i, elem in islice(enumerate(iterator), start, end): if item == elem: @@ -50,6 +63,8 @@ def repeat(item, count): def radians(x): + import math + return math.pi / 180.0 * x diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 090df0f84b5281..802a62a82c8cdd 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -10,12 +10,31 @@ from ..decorators import substitute_in_graph -__all__ = ["tee"] +__all__ = [ + "chain", + "chain_from_iterable", + "tee", +] _T = TypeVar("_T") +# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain +@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type] +def chain(*iterables: Iterable[_T]) -> Iterator[_T]: + for iterable in iterables: + yield from iterable + + +@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] +def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: + return itertools.chain(*iterable) + + +chain.from_iterable = chain_from_iterable # type: ignore[method-assign] + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 255486b6e016dd..24478e1b5a0f96 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -11,6 +11,7 @@ from types import ModuleType +# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( "builtins", "functools", diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 12c035caf3cadb..d5c682afcdfd92 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2993,9 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update( - {id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)} - ) + rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) rv.update( { id(cast): "typing.cast", @@ -3474,9 +3472,7 @@ def check_verbose(obj, is_inlined_call=False): # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: Set[str] = set() - rule = torch._dynamo.trace_rules.lookup_inner( - fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons - ) + rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): return SkipResult( False, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index be45ab0a692ee8..16eab36028eca3 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -17,7 +17,17 @@ import types import warnings import weakref -from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + FrozenSet, + List, + MutableMapping, + NamedTuple, + Optional, + Set, + TYPE_CHECKING, + Union, +) import torch from torch import SymInt @@ -320,6 +330,17 @@ class FrameStateSizeEntry: stride: Optional[List[int]] +# All class-based iterators in itertools +# NOTE: use id() because some objects are not hashable, it will raise error during lookup +ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset( + id(member) + for name, member in vars(itertools).items() + if not name.startswith("_") and inspect.isclass(member) +) +# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py +ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set() + + class VariableBuilder: """Wrap a python value in a VariableTracker() instance""" @@ -875,7 +896,10 @@ def build_key_value(i, k, v): value, source=self.source, ) - elif istype(value, type) and value in itertools.__dict__.values(): + elif ( + id(value) in ITERTOOLS_TYPE_IDS + and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS + ): self.install_guards(GuardBuilder.FUNCTION_MATCH) return ItertoolsVariable(value, source=self.source) elif isinstance(value, torch.SymBool): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 86900dee30bc56..3053ee4bba6ce5 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -994,15 +994,6 @@ def call_method( ) if self.fn is dict and name == "fromkeys": return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) - if self.fn is itertools.chain and name == "from_iterable": - assert len(args) == 1 - assert len(kwargs) == 0 - obj = args[0] - items = [] - for item in obj.unpack_var_sequence(tx): - items.extend(item.unpack_var_sequence(tx)) - return variables.TupleVariable(items) - return super().call_method(tx, name, args, kwargs) def _call_int_float(self, tx: "InstructionTranslator", arg): @@ -1942,13 +1933,6 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg ) return variables.ListVariable(items) - def call_chain(self, tx: "InstructionTranslator", *args): - if all(obj.has_unpack_var_sequence(tx) for obj in args): - items = [] - for obj in args: - items.extend(obj.unpack_var_sequence(tx)) - return variables.TupleVariable(items) - def call_islice(self, tx: "InstructionTranslator", iterable, *args): if iterable.has_unpack_var_sequence(tx) and all( x.is_python_constant() for x in args diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index a9ec80cc3de589..239b720a5a7530 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -18,6 +18,7 @@ check_constant_args, check_unspec_or_constant_args, identity, + is_function, is_wrapper_or_member_descriptor, istype, make_cell, @@ -992,6 +993,27 @@ def call_function( handler, ).call_function(tx, args, kwargs) + return super().call_function(tx, args, kwargs) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__call__": + return self.call_function(tx, args, kwargs) + + method = getattr(self.fn, name, None) + assert method is not None, f"Member {name} not found in {self.fn}" + assert is_function(method), f"Member {name} is not callable in {self.fn}" + options = {} + if self.source: + options["source"] = AttrSource(self.source, name) + member_variable = PolyfilledFunctionVariable(method, **options) + return member_variable.call_function(tx, args, kwargs) + def as_python_constant(self): return self.fn diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 6687611cf0aabd..ed3ae786634e92 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -52,14 +52,6 @@ def call_function( for item in itertools.product(*seqs): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) - elif ( - self.value is itertools.chain - and not kwargs - and all(arg.has_unpack_var_sequence(tx) for arg in args) - ): - seqs = [arg.unpack_var_sequence(tx) for arg in args] - items = list(itertools.chain.from_iterable(seqs)) - return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif self.value is itertools.accumulate: from .builtin import BuiltinVariable From 48c6eea8386a1ea9f4f6b78599f63f8246548d9a Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 31 Aug 2024 04:39:48 +0800 Subject: [PATCH 0303/1018] [dynamo] refactor `builtins.enumerate` to use polyfill (#133894) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133894 Approved by: https://github.com/jansel ghstack dependencies: #133864 --- torch/_dynamo/polyfills/builtins.py | 20 +++++++++++++++++++- torch/_dynamo/variables/builtin.py | 16 ---------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index d2f563ca6b883f..737b7c13809183 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -2,8 +2,10 @@ Python polyfills for builtins """ +from __future__ import annotations + import builtins -from typing import Iterable +from typing import Iterable, TypeVar from ..decorators import substitute_in_graph @@ -11,9 +13,13 @@ __all__ = [ "all", "any", + "enumerate", ] +_T = TypeVar("_T") + + @substitute_in_graph(builtins.all, can_constant_fold_through=True) def all(iterable: Iterable[object], /) -> bool: for elem in iterable: @@ -28,3 +34,15 @@ def any(iterable: Iterable[object], /) -> bool: if elem: return True return False + + +@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type] +def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]: + if not isinstance(start, int): + raise TypeError( + f"{type(start).__name__!r} object cannot be interpreted as an integer" + ) + + for x in iterable: + yield start, x + start += 1 diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 3053ee4bba6ce5..928543b1fe8478 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1442,22 +1442,6 @@ def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] return variables.TupleVariable(items) - def call_enumerate(self, tx: "InstructionTranslator", *args): - if len(args) == 1: - start = 0 - else: - assert len(args) == 2 - assert isinstance(args[1], variables.ConstantVariable) - start = args[1].as_python_constant() - if args[0].has_unpack_var_sequence(tx): - items = [ - variables.TupleVariable( - [variables.ConstantVariable.create(idx), var], - ) - for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) - ] - return variables.TupleVariable(items) - def call_len(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) From 55455acaaea02b0a9bb318fad48c13933e01bf53 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Sat, 31 Aug 2024 00:28:20 +0000 Subject: [PATCH 0304/1018] [benchmark] Add to torchbench relative path search (#134871) Add to relative path search in benchmark. This enables user to run `torchbench.py` inside the `pytorch/benchmark/dynamo` folder when `torchbench` repo is cloned in the same level as `pytorch` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134871 Approved by: https://github.com/FindHao --- benchmarks/dynamo/torchbench.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index dcbde79139b7af..6233ecb4e9ae9b 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -56,6 +56,9 @@ def setup_torchbench_cwd(): "../../torchbenchmark", "../../torchbench", "../../benchmark", + "../../../torchbenchmark", + "../../../torchbench", + "../../../benchmark", ): if exists(torchbench_dir): break From ed38f6f4bda7c4aa1493062466b702dbf0386806 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sat, 31 Aug 2024 00:29:47 +0000 Subject: [PATCH 0305/1018] [BE][MPS] Prefer xfail to skip (#134858) This essentially undoes large skips on everything but MacOS Sequoia to nn.modules made by https://github.com/pytorch/pytorch/pull/128393 Instead it uses existing `xfail`, but guards it on `_macos15_or_newer` boolean Before the change if run on MacOS 14: ``` % python3 ../test/test_modules.py -v -k Hardswish 2>&1|tail -n3 Ran 57 tests in 0.053s OK (skipped=32) ``` After ``` % python3 ../test/test_modules.py -v -k Hardswish 2>&1|tail -n3 Ran 57 tests in 0.229s OK (skipped=10, expected failures=2) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134858 Approved by: https://github.com/janeyx99 --- test/test_nn.py | 15 +++---- torch/testing/_internal/common_device_type.py | 18 -------- torch/testing/_internal/common_modules.py | 45 ++++++++++++++----- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index c9171f5bd653ce..813197b4499b67 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -43,7 +43,7 @@ dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \ onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, \ - skipMPSVersionIfLessThan, skipMeta, get_all_device_types + skipMeta, get_all_device_types from hypothesis import given import torch.testing._internal.hypothesis_utils as hu @@ -8205,9 +8205,8 @@ def test_batchnorm_large_batch(self, device, dtype): data = torch.rand(880801, 1, 1, 1, device=device, dtype=dtype) out = bn(data).sum().backward() - @skipMPSVersionIfLessThan(14, 0) # macOS 13 does not support bfloat16 @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.complex128) - @dtypesIfMPS(torch.float, torch.bfloat16, torch.complex64) + @dtypesIfMPS(torch.float, torch.half, torch.complex64) @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) def test_conv_empty_input(self, device, dtype): def help(input, conv, memory_format): @@ -12085,9 +12084,9 @@ def __init__(self) -> None: for p, pe in zip(test_model.parameters(), ref_model.parameters()): self.assertEqual(p.grad.to(devices[0]), pe.grad) - @skipMPSVersionIfLessThan(14, 0) # macOS 13 does not support bfloat16 def test_elu_inplace_overlap(self, device): - x = torch.randn((1, 6), dtype=torch.bfloat16, device=device).expand((6, 6)) + dtype = torch.bfloat16 if device != 'mps:0' else torch.float16 + x = torch.randn((1, 6), dtype=dtype, device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.elu(x, inplace=True) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): @@ -12163,7 +12162,6 @@ def test_leaky_relu_inplace_with_neg_slope(self, device): b.backward(torch.ones(2, device=device)) # Merge into OpInfo? - @skipMPSVersionIfLessThan(14, 0) # macOS 13 does not support bfloat16 def test_leaky_relu_inplace_with_zero_slope(self, device): a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), 0.0) @@ -12171,10 +12169,11 @@ def test_leaky_relu_inplace_with_zero_slope(self, device): expected = torch.tensor([0., 0., 1.], device=device) self.assertEqual(a.grad, expected) - a_bf16 = torch.tensor([-2., 0., 2.], device=device, dtype=torch.bfloat16, requires_grad=True) + dtype = torch.bfloat16 if device != 'mps:0' else torch.float16 + a_bf16 = torch.tensor([-2., 0., 2.], device=device, dtype=dtype, requires_grad=True) b_bf16 = torch.nn.functional.leaky_relu_(a_bf16.clone(), 0.0) b_bf16.backward(torch.ones(3, device=device)) - expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=torch.bfloat16) + expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=dtype) self.assertEqual(a_bf16.grad, expected_bf16) @onlyCPU diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index eeaa31fad1bc25..d762156e61a91a 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1891,24 +1891,6 @@ def skipMPS(fn): return skipMPSIf(True, "test doesn't work on MPS backend")(fn) -def skipMPSVersionIfLessThan(major: int, minor: int): - def dec_fn(fn): - @wraps(fn) - def wrap_fn(self, *args, **kwargs): - if self.device_type == "mps": - if not torch.backends.mps.is_macos_or_newer(major, minor): - reason = ( - f"MPS test is skipped for MacOS versions < {major}.{minor} " - ) - raise unittest.SkipTest(reason) - - return fn(self, *args, **kwargs) - - return wrap_fn - - return dec_fn - - def skipHPU(fn): return skipHPUIf(True, "test doesn't work on HPU backend")(fn) diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index d3629abde959f5..02b621640a9a3f 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -16,7 +16,7 @@ floating_types, floating_and_complex_types_and, get_all_fp_dtypes) from torch.testing._internal.common_device_type import ( _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol, - skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, skipMPSVersionIfLessThan, + skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, skipCUDAVersionIn) from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_nn import ( @@ -3353,6 +3353,9 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ] +_macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_macos_or_newer(15, 0) + + # Database of ModuleInfo entries in alphabetical order. module_db: List[ModuleInfo] = [ ModuleInfo(torch.nn.AdaptiveAvgPool1d, @@ -4057,11 +4060,15 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.Hardswish, module_inputs_func=module_inputs_torch_nn_Hardswish, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),), supports_gradgrad=False), ModuleInfo(torch.nn.Hardtanh, @@ -4176,11 +4183,15 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.ReLU, module_inputs_func=module_inputs_torch_nn_ReLU, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),) ), ModuleInfo(torch.nn.LeakyReLU, @@ -4214,11 +4225,15 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.Sigmoid, module_inputs_func=module_inputs_torch_nn_Sigmoid, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),) ), ModuleInfo(torch.nn.LogSigmoid, @@ -4273,20 +4288,28 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.Tanh, module_inputs_func=module_inputs_torch_nn_Tanh, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),) ), ModuleInfo(torch.nn.Tanhshrink, module_inputs_func=module_inputs_torch_nn_Tanhshrink, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),) ), ModuleInfo(torch.nn.Threshold, From cf07663f87477ff06b4013382c58888b37dc513c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 31 Aug 2024 00:38:30 +0000 Subject: [PATCH 0306/1018] Revert "[Inductor] Allow customizing the padding format (#133939)" This reverts commit 8b258b3b14408986a1d4142cff5a153c798ceecc. Reverted https://github.com/pytorch/pytorch/pull/133939 on behalf of https://github.com/ZainRizvi due to sorry but this PR is causing issues with diff train imports reverting it for now but it can be merged back in as-is ([comment](https://github.com/pytorch/pytorch/pull/133939#issuecomment-2322635388)) --- test/inductor/test_padding.py | 58 ++--------------------------------- torch/_inductor/config.py | 28 ----------------- torch/_inductor/graph.py | 11 +++---- torch/_inductor/ir.py | 26 ++++++++++++++-- 4 files changed, 30 insertions(+), 93 deletions(-) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 9ae3dd3a125df8..c8170c92327166 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -3,7 +3,6 @@ import functools import os import unittest -from typing import Tuple import torch from torch import nn, Tensor @@ -13,13 +12,8 @@ from torch._inductor import config, ir, metrics from torch._inductor.fx_passes import pad_mm as pad_mm_pass from torch._inductor.runtime.benchmarking import benchmarker -from torch._inductor.utils import ceildiv, run_and_get_code -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, - requires_cuda, - serialTest, -) +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import requires_cuda, serialTest from torch.testing._internal.inductor_utils import HAS_CUDA @@ -368,7 +362,6 @@ def test_longformer_small_bs(self): self.test_longformer(bs=2) -@instantiate_parametrized_tests class PaddingTest(TestCaseBase): @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") def test_mm_padding_perf(self): @@ -661,53 +654,6 @@ def test_pad_channels_last(self): out_strides = ir.Layout._pad_strides(in_strides, t.shape, torch.float32) self.assertTrue(in_strides == out_strides) - @parametrize("alignment_bytes", (32, 128)) - @parametrize("shape", [(21, 19), (3, 5, 71)]) - @parametrize("dtype", (torch.float16, torch.float32)) - def test_pad_outputs( - self, dtype: torch.dtype, shape: Tuple[int], alignment_bytes: int - ): - """ - Tests padding output tensors to a specific alignment. - This is enabled by a config flag. - """ - func = torch.add - inputs = tuple(torch.randn(*shape, dtype=dtype) for input_idx in range(2)) - - # Compile and run - with config.patch( - { - "comprehensive_padding": True, - "padding_alignment_bytes": alignment_bytes, - "padding_stride_threshold": 0, - "pad_outputs": True, - } - ): - compiled_func = torch.compile(func) - compiled_out = compiled_func(*inputs) - - # Check numerics - eager_out = func(*inputs) - self.check_close(eager_out, compiled_out) - - # Compute the expected padding - element_size = torch.tensor([], dtype=dtype).element_size() - self.assertGreater(alignment_bytes, element_size) - self.assertEqual(alignment_bytes % element_size, 0) - alignment_elements = alignment_bytes // element_size - contiguous_stride = inputs[0].stride() - expected_stride = [1] - for dim in reversed(shape[1:]): - slice_size = dim * expected_stride[0] - new_stride = alignment_elements * ceildiv(slice_size, alignment_elements) - expected_stride.insert(0, new_stride) - expected_stride = tuple(expected_stride) - self.assertNotEqual(expected_stride, contiguous_stride) - - # Check strides - self.assertFalse(compiled_out.is_contiguous()) - self.assertEqual(compiled_out.stride(), expected_stride) - if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 919cddb773d7aa..f6599989f31a4b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -595,34 +595,6 @@ def decide_compile_threads() -> int: ) pad_channels_last = False -# The width of comprehensive padding, in bytes. -# CUDA max memory transaction size is 128 bytes for a warp. -padding_alignment_bytes = 128 - -# Threshold on the minimum stride that will be padded. -# -# Don't align a too small stride since that causes too much memory increase. -# Pad too small stride may also cause perf loss. We may result in many tiny data blocks -# with gaps in between. That causes less coalesced GPU memory access! -# -# Initially we pick 320 as the threshold since for alignement=16, -# that results in at most 5% memory cost. -# -# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. -# Let's say an inner reduction has a row size 513. Inductor will generate -# persistent reduction code. -# If we do padding, the strides are not contiguous any more. Inductor -# uses a much smaller threshold for persistent reduction in this case and -# generates potentially worse non-persistent reduction code. -# -# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. -# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) -padding_stride_threshold = 1024 - -# Enable padding outputs, even if they would not be padded in eager mode. -# By default, we use the same strides as eager mode. -pad_outputs = False - # Whether to treat output of the backward graph as user visible. # For user visible outputs, inductor will make sure the stride matches with eager. bw_outputs_user_visible = True diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index c9da83e7584e62..987e639657cb0c 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -247,11 +247,7 @@ def _get_overload_packet( if prior_op not in ops_like_padding: prior.meta["dislike_padding"] = True # We only want to mark output nodes. So, move it after the above prior nodes process. - if ( - not config.pad_outputs - and user_visible_outputs - and cur.name in user_visible_outputs - ): + if user_visible_outputs and cur.name in user_visible_outputs: cur.meta["dislike_padding"] = True @@ -1363,8 +1359,9 @@ def debug(msg: str) -> None: strides = n.meta["val"].stride() if len(strides): allow_padding = ( - config.pad_outputs or n.name not in self.user_visible_outputs - ) and not is_input_for_as_strided + n.name not in self.user_visible_outputs + and not is_input_for_as_strided + ) dense = torch._prims_common.is_non_overlapping_and_dense( n.meta["val"] ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 18cdac11ac8077..55d43674807e0d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2872,7 +2872,12 @@ def is_contiguous_strides_for_shape( def get_align_for_dtype(dtype: torch.dtype) -> int: - return config.padding_alignment_bytes // dtype.itemsize + """ + CUDA max memory transaction size is 128 bytes for a warp. + We pick `128 // dtype.itemsize` as alighment so GPU can do coalesced + memory access. + """ + return 128 // dtype.itemsize @dataclasses.dataclass @@ -3011,12 +3016,29 @@ def _pad_strides(in_strides, size, dtype): # smallest stride to be 1. new_strides[fill_order[0]] = 1 + # Don't align a too small stride since that causes too much memory increase. + # Pad too small stride may also cause perf loss. We may result in many tiny data blocks + # with gaps in between. That causes less coalesced GPU memory access! + # + # Initially we pick 320 as the threshold since for alignement=16, + # that results in at most 5% memory cost. + # + # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. + # Let's say an inner reduction has a row size 513. Inductor will generate + # persistent reduction code. + # If we do padding, the strides are not contiguous any more. Inductor + # uses a much smaller threshold for persistent reduction in this case and + # generates potentially worse non-persistent reduction code. + # + # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. + # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) + align_stride_threshold = 1024 padded = False for rank, idx in enumerate(fill_order[1:], start=1): prev_idx = fill_order[rank - 1] stride = new_strides[prev_idx] * size[prev_idx] - if stride > config.padding_stride_threshold and stride % align != 0: + if stride > align_stride_threshold and stride % align != 0: stride = ceildiv(stride, align) * align padded = True new_strides[idx] = stride From e805a56f26637a9d8fa18020d3c64150fb95b0ff Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Sat, 31 Aug 2024 02:50:02 +0000 Subject: [PATCH 0307/1018] Cleanup unused `pytorch.version` (#134381) This file doesn't seem to be used anywhere? checking CI... Pull Request resolved: https://github.com/pytorch/pytorch/pull/134381 Approved by: https://github.com/zou3519 --- tools/pytorch.version | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 tools/pytorch.version diff --git a/tools/pytorch.version b/tools/pytorch.version deleted file mode 100644 index 3488ccfc56b1df..00000000000000 --- a/tools/pytorch.version +++ /dev/null @@ -1,31 +0,0 @@ -{ - global: - _TH*; - __TH*; - TH*; - *THP*; - *THCP*; - PyInit*; - init*; - state; - _ZGVZN2at*; - _ZN2at*; - _ZNK2at*Type*; - _ZNK2at*Tensor*; - _ZNK2at*Storage*; - _ZNK2at*Scalar*; - _ZNK2at*CUDA*; - *2at7Context*; - _ZTIN2at*; - _ZTIZN2at*; - _ZTSN2at*; - _ZTSPN2at*; - _ZTSZN2at*; - _ZTVN2at*; - _ZZN2at*; - _Z*torch*; - _Z*Tensor*; - _Z*tensor*; - local: - *; - }; From 38e10c60ebe2e2e4922b2e1cb6fe1ca290c6d0b2 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Sat, 31 Aug 2024 03:38:07 +0000 Subject: [PATCH 0308/1018] [Inductor] Support passing module map parameter to Triton make_ir API (#134774) In https://github.com/triton-lang/triton/pull/4539 the `make_ir` API was modified to accept a new `module_map` parameter. Update the Inductor callsite accordingly, preserving backwards compatibility following the existing code. Fixes #134674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134774 Approved by: https://github.com/EikanWang, https://github.com/zou3519, https://github.com/jansel --- torch/_higher_order_ops/triton_kernel_wrap.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 874dcdcb50fa28..37f7f05c641486 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -179,13 +179,18 @@ def generate_ttir(kernel, kwargs): src = ASTSource(kernel, signature, constants, specialization) - # Triton changes ASTSource.make_ir to take 3 arguments. Handle + # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle # backward compatibility here. - if len(inspect.signature(src.make_ir).parameters) == 2: + make_ir_sig_params = len(inspect.signature(src.make_ir).parameters) + if make_ir_sig_params == 2: ttir_module = src.make_ir(options, context) - else: + elif make_ir_sig_params == 3: codegen_fns = backend.get_codegen_implementation() ttir_module = src.make_ir(options, codegen_fns, context) + else: + codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() + ttir_module = src.make_ir(options, codegen_fns, module_map, context) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") From 8bffb3b465f3e3837cfba1370e7b03f104a23b36 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Sat, 31 Aug 2024 04:16:38 +0000 Subject: [PATCH 0309/1018] [executorch hash update] update the pinned executorch hash (#134894) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134894 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index ddcee95fd3ea17..cf11cdd5db311b 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -7608ab8fa7c37f334fdaa9a946730fdbb1c97d34 +e7e86478b4807d18ec0ac68c91f6d7e5c6d0f14e From ead5cee2be1a7f9d9140473aabca074c375ac541 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 30 Aug 2024 15:53:22 -0700 Subject: [PATCH 0310/1018] [Dynamo][autograd.Function][vmap] support torch._C._are_functorch_transforms_active (#134889) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134889 Approved by: https://github.com/jansel --- test/dynamo/test_functions.py | 7 +++++++ torch/_dynamo/variables/builder.py | 1 + 2 files changed, 8 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 84b3a923c1484e..8ec61ade010fa7 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2088,6 +2088,13 @@ def test_in_not_in(x): assert 6 not in myotherlist return sum(mylist) + @make_test + def test_are_functorch_transforms_active(x): + if torch._C._are_functorch_transforms_active(): + return x + 1 + else: + return x - 1 + @make_test def test_partials_udf_kwarg(x): par_mul = functools.partial(udf_mul, y=torch.ones(10, 10)) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 16eab36028eca3..9bef4e40ca22ab 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2305,6 +2305,7 @@ def _clone_input(value): set_example_value(proxy.node, example_value) return SDPAParamsVariable(proxy, **options) elif isinstance(example_value, bool) and proxy.node.target in [ + torch._C._are_functorch_transforms_active, torch.backends.cuda.is_flash_attention_available, torch.backends.cuda.can_use_flash_attention, torch.backends.cuda.can_use_efficient_attention, From df92a5244bd91032c041eaf7b62c289b38ed36b4 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sat, 31 Aug 2024 04:50:31 +0000 Subject: [PATCH 0311/1018] [inductor] preload icx built in math libs (#134870) Intel Compiler implenmented more math libraries than clang, for performance proposal. We need preload them like openmp library. reproduce UT: ```cmd pytest test/inductor/test_cpu_cpp_wrapper.py -v -k test_silu_cpu_dynamic_shapes_cpp_wrapper ``` Depends of module: Image Local test pass: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134870 Approved by: https://github.com/jansel --- torch/_inductor/cpp_builder.py | 37 +++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 693f9b06418285..45aaa2f399daef 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -832,16 +832,33 @@ def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: @functools.lru_cache(None) def perload_icx_libomp_win(cpp_compiler: str) -> None: - try: - output = subprocess.check_output( - [cpp_compiler, "-print-file-name=libiomp5md.dll"], stderr=subprocess.DEVNULL - ).decode(*SUBPROCESS_DECODE_ARGS) - omp_path = output.rstrip() - if os.path.isfile(omp_path): - os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" - omp_module = cdll.LoadLibrary(omp_path) - except subprocess.SubprocessError: - pass + def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: + try: + output = subprocess.check_output( + [cpp_compiler, f"-print-file-name={lib_name}"], + stderr=subprocess.DEVNULL, + ).decode(*SUBPROCESS_DECODE_ARGS) + omp_path = output.rstrip() + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + return True + except subprocess.SubprocessError: + pass + return False + + """ + Intel Compiler implenmented more math libraries than clang, for performance proposal. + We need preload them like openmp library. + """ + preload_list = [ + "libiomp5md.dll", # openmp + "svml_dispmd.dll", # svml library + "libmmd.dll", # libm + ] + + for lib_name in preload_list: + _load_icx_built_in_lib_by_name(cpp_compiler, lib_name) def _get_openmp_args( From 4cb5503436bce7f53af976edb725ecf333bdf8c8 Mon Sep 17 00:00:00 2001 From: cyyever Date: Sat, 31 Aug 2024 08:42:30 +0000 Subject: [PATCH 0312/1018] [jit] Change argument names (#134828) It seems like a bug. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134828 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> --- torch/csrc/jit/tensorexpr/external_functions.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index decfe0bceb3215..5fe00d7959c408 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1365,10 +1365,10 @@ void nnc_aten_addmm( const at::Tensor& y = tensors[2]; const at::Tensor& z = tensors[3]; // TODO: handle other alpha and beta dtypes, e.g. alpha=0.6, beta=0.2 - int64_t alpha = extra_args[0], beta = extra_args[1]; + int64_t beta = extra_args[0], alpha = extra_args[1]; try { - at::addmm_out(r, x, y, z, alpha, beta); + at::addmm_out(r, x, y, z, beta, alpha); } catch (...) { } } From 2185dcad0a3f3a07e3ed643b2256816a3c5d077d Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 30 Aug 2024 16:03:08 +0800 Subject: [PATCH 0313/1018] [dynamo] reduce overhead for `PolyfilledFunctionVariable.call_function` (#134842) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134842 Approved by: https://github.com/jansel --- torch/_dynamo/decorators.py | 15 ++-- torch/_dynamo/trace_rules.py | 6 +- torch/_dynamo/variables/builder.py | 6 +- torch/_dynamo/variables/builtin.py | 2 +- torch/_dynamo/variables/functions.py | 103 +++++++++++++++------------ torch/_dynamo/variables/torch.py | 5 +- 6 files changed, 76 insertions(+), 61 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 4e97422fa605e9..235db547dc7b1c 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -3,7 +3,7 @@ import functools import inspect from dataclasses import dataclass -from typing import Any, Callable, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar import torch from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -17,6 +17,8 @@ if TYPE_CHECKING: + from types import FunctionType + from torch._C._dynamo.eval_frame import ( # noqa: F401 reset_code, set_eval_frame, @@ -24,6 +26,8 @@ skip_code, unsupported, ) + + from .variables import VariableTracker else: for name in dir(torch._C._dynamo.eval_frame): if name.startswith("__"): @@ -305,13 +309,14 @@ def sig_ident(sig): "already registered in VariableBuilder's id dispatch map" ) - rule_map = get_torch_obj_rule_map() + rule_map: Dict[Any, Type[VariableTracker]] = get_torch_obj_rule_map() if original_fn in rule_map: raise ValueError( f"Duplicate object {original_fn} with different rules: " f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}" ) + polyfill_handlers: Dict[Callable[..., Any], FunctionType] polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers() if original_fn in polyfill_handlers: raise ValueError( @@ -325,16 +330,16 @@ def sig_ident(sig): def wrapped(*args, **kwargs): return original_fn(*args, **kwargs) - def dispatch_fn(self, value): + def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable: return PolyfilledFunctionVariable( value, source=self.source, - **self.install_guards(GuardBuilder.CLOSURE_MATCH), + **self.install_guards(GuardBuilder.FUNCTION_MATCH), ) id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable - polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = traceable_fn + polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = wrapped # type: ignore[assignment] wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined] wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined] diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d5c682afcdfd92..89715d6356eefb 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -34,7 +34,7 @@ import weakref from collections import defaultdict from pathlib import Path -from typing import Any, Callable, cast, Dict, List, Optional, Set, Union +from typing import Any, Callable, cast, Dict, List, Optional, Set, Type, Union import torch import torch._inductor.test_operators @@ -2847,8 +2847,8 @@ @functools.lru_cache(None) -def get_torch_obj_rule_map(): - d: Dict[Any, VariableTracker] = {} +def get_torch_obj_rule_map() -> Dict[Any, Type["VariableTracker"]]: + d: Dict[Any, Type[VariableTracker]] = {} for m in torch_name_rule_map: for k, v in m.items(): # type: ignore[attr-defined] if ".py#" not in k: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 9bef4e40ca22ab..fbe6e90d3eea48 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -19,6 +19,8 @@ import weakref from typing import ( Any, + Callable, + Dict, FrozenSet, List, MutableMapping, @@ -484,7 +486,9 @@ def wrap_jit_function(self, value): @classmethod @functools.lru_cache(None) - def _id_dispatch(cls): + def _id_dispatch( + cls, + ) -> Dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: from ..comptime import comptime entries = [ diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 928543b1fe8478..f27beae33cfe59 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -104,7 +104,7 @@ class BuiltinVariable(VariableTracker): @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) - return BuiltinVariable(value, source=source) + return cls(value, source=source) @staticmethod @functools.lru_cache(None) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 239b720a5a7530..3f881693c8be62 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -5,7 +5,7 @@ import inspect import itertools import types -from typing import Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar, Union import torch @@ -38,6 +38,9 @@ from torch._guards import Source +_F = TypeVar("_F", bound=Callable) + + def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): # Source propagation is best effort since not every object we encounter has a source to begin with. if isinstance(val, VariableTracker): @@ -136,10 +139,7 @@ class UserFunctionVariable(BaseUserFunctionVariable): @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) - return cls( - value, - source=source, - ) + return cls(value, source=source) def __init__(self, fn, is_constant=False, **kwargs) -> None: super().__init__(**kwargs) @@ -639,10 +639,7 @@ def create_with_source(cls, value, source): # attribute lookup. They are unlikely to be changed, so we can skip # guarding them. install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) - return cls( - value, - source=source, - ) + return cls(value, source=source) @staticmethod @functools.lru_cache(None) @@ -934,24 +931,53 @@ def guard_as_python_constant(self): class PolyfilledFunctionVariable(VariableTracker): _nonvar_fields = { "fn", - *BaseUserFunctionVariable._nonvar_fields, + "wrapped_fn", + "traceable_fn", + *VariableTracker._nonvar_fields, } @classmethod @functools.lru_cache(None) - def _get_polyfill_handlers(cls): + def _get_polyfill_handlers(cls) -> Dict[Callable[..., Any], types.FunctionType]: return {} @classmethod def create_with_source(cls, value, source): - return cls( - value, - source=source, - ) + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + + return cls(value, source=source) - def __init__(self, fn: VariableTracker, **kwargs) -> None: + def __init__(self, fn: _F, **kwargs) -> None: super().__init__(**kwargs) - self.fn = fn + self.fn: _F = fn + + handler = self._get_polyfill_handlers().get(fn, fn) + assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" + for candidate_attr in ( + "__torch_dynamo_polyfill__", # registered polyfill + "__python_implementation__", # self handler from third-party libraries + ): + candidate = getattr(handler, candidate_attr, None) + if candidate: + assert callable(candidate) + traceable_fn = candidate + break + else: + raise RuntimeError( + f"Polyfill handler {handler} does not have a traceable function" + ) + + self.wrapped_fn: _F = handler + self.traceable_fn: _F = traceable_fn + + @property + def polyfill_fn(self) -> _F: + return self.traceable_fn + + def can_constant_fold_through(self): + return getattr( + self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False + ) def get_function(self): return self.as_python_constant() @@ -964,36 +990,19 @@ def call_function( ) -> "VariableTracker": from torch._dynamo.variables.builder import SourcelessBuilder - handler = self._get_polyfill_handlers().get(self.fn) - if handler: - assert callable(handler) - if getattr( - handler, "__torch_dynamo_can_constant_fold_through__", False - ) and check_unspec_or_constant_args(args, kwargs): - return ConstantVariable.create( - self.fn( # use the original function which is faster than the polyfill - *[x.as_python_constant() for x in args], - **{k: v.as_python_constant() for k, v in kwargs.items()}, - ) + if self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ): + result = ( + self.fn( # use the original function which is faster than the polyfill + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, ) - return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs) - - for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"): - handler = getattr(self.fn, candidate, None) - if handler: - assert callable(handler) - if self.source: - source = AttrSource(self.source, candidate) - return UserFunctionVariable.create_with_source( - handler, - source=source, - ).call_function(tx, args, kwargs) - return SourcelessBuilder.create( - tx, - handler, - ).call_function(tx, args, kwargs) + ) + return SourcelessBuilder.create(tx, result) - return super().call_function(tx, args, kwargs) + traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn) + return traceable_function_variable.call_function(tx, args, kwargs) def call_method( self, @@ -1011,8 +1020,8 @@ def call_method( options = {} if self.source: options["source"] = AttrSource(self.source, name) - member_variable = PolyfilledFunctionVariable(method, **options) - return member_variable.call_function(tx, args, kwargs) + polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) + return polyfilled_method_variable.call_function(tx, args, kwargs) def as_python_constant(self): return self.fn diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1d86c16a6e09ec..7fbd6865026e3a 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -154,10 +154,7 @@ class BaseTorchVariable(VariableTracker): @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) - return cls( - value, - source=source, - ) + return cls(value, source=source) def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) From f9f4436f721be3960b901f2357e041bcc4384a28 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 31 Aug 2024 07:04:53 +0800 Subject: [PATCH 0314/1018] [dynamo][itertools] refactor `itertools.islice` to use polyfill (#133876) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133876 Approved by: https://github.com/jansel ghstack dependencies: #133864, #133894 --- torch/_dynamo/polyfills/__init__.py | 25 ++--------------------- torch/_dynamo/polyfills/itertools.py | 30 ++++++++++++++++++++++++++++ torch/_dynamo/trace_rules.py | 2 -- torch/_dynamo/variables/builtin.py | 9 --------- torch/_dynamo/variables/iter.py | 6 ------ 5 files changed, 32 insertions(+), 40 deletions(-) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index fd206c3847e6e1..e1ce07413f486f 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -27,6 +27,8 @@ def index(iterator, item, start=0, end=None): + from itertools import islice + for i, elem in islice(enumerate(iterator), start, end): if item == elem: return i @@ -34,29 +36,6 @@ def index(iterator, item, start=0, end=None): raise ValueError(f"{item} is not in {type(iterator)}") -def islice(iterator, start=0, end=None, step=1): - if start < 0 or (end is not None and end < 0) or step < 0: - raise ValueError("Indices must be non-negative") - if step == 0: - raise ValueError("Step cannot be 0") - - it = iter(iterator) - - for _ in range(start): - next(it) - - if end is None: - for i, element in enumerate(it): - if i % step == 0: - yield element - else: - for i, element in enumerate(it): - if i % step == 0 and i + start < end - start: - yield element - elif i + start >= end - start: - break - - def repeat(item, count): for i in range(count): yield item diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 802a62a82c8cdd..a829d2e9b88081 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -13,6 +13,7 @@ __all__ = [ "chain", "chain_from_iterable", + "islice", "tee", ] @@ -35,6 +36,35 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: chain.from_iterable = chain_from_iterable # type: ignore[method-assign] +# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice +@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] +def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: + s = slice(*args) + start = 0 if s.start is None else s.start + stop = s.stop + step = 1 if s.step is None else s.step + if start < 0 or (stop is not None and stop < 0) or step <= 0: + raise ValueError( + "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", + ) + + if stop is None: + # TODO: use indices = itertools.count() and merge implementation with the else branch + # when we support infinite iterators + next_i = start + for i, element in enumerate(iterable): + if i == next_i: + yield element + next_i += step + else: + indices = range(max(start, stop)) + next_i = start + for i, element in zip(indices, iterable): + if i == next_i: + yield element + next_i += step + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 89715d6356eefb..0740970457de7b 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -12,7 +12,6 @@ import functools import importlib import inspect -import itertools import linecache import logging import multiprocessing @@ -2993,7 +2992,6 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) rv.update( { id(cast): "typing.cast", diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index f27beae33cfe59..b19bee7e1b2f81 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1917,15 +1917,6 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg ) return variables.ListVariable(items) - def call_islice(self, tx: "InstructionTranslator", iterable, *args): - if iterable.has_unpack_var_sequence(tx) and all( - x.is_python_constant() for x in args - ): - const_args = [x.as_python_constant() for x in args] - items = iterable.unpack_var_sequence(tx) - items = list(itertools.islice(items, *const_args)) - return variables.TupleVariable(items) - # neg is a constant fold function, so we only get here if constant fold is not valid def call_neg(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index ed3ae786634e92..34c2354026a7bd 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -166,12 +166,6 @@ def retrieve_const_key(key): from_exc=e, ) return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) - elif self.value is itertools.islice: - from .builder import SourcelessBuilder - - return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.islice), args, kwargs - ) elif self.value is itertools.repeat: if len(args) < 2: return variables.RepeatIteratorVariable( From 117b42b8e7639ee7e01c675432e5edc7c78bd4b2 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 30 Aug 2024 07:33:52 -0700 Subject: [PATCH 0315/1018] [dynamo] Rewrite foreach_lerp to avoid aten item call (#134166) Context: Adding support for the beta parameters to be tensors Details: In order to add support for the beta params to be tensors without graph breaks in the Adam family of optimizers it is necessary to support foreach_lerp(x, y, s) where s is a scalar tensor. Today, this isn't possible because when `s` is a scalar, internally the aten op calls item() on it to extract the value and distribute it to each of the ops on the individual list indices. To support this in dynamo without graph breaks, I decompose the lerp into its constituent ops which support a scalar tensor in the list argument positions which do not result in an item() call. To be clear the item() call is more performant for eager I think and for BC I don't think we can modify that API, so this allows us to have performance in eager and no graph breaks in compile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134166 Approved by: https://github.com/anijain2305 --- test/dynamo/test_functions.py | 20 ++++++++++++++++++++ torch/_dynamo/polyfills/__init__.py | 8 ++++++++ torch/_dynamo/variables/torch.py | 11 +++++++++++ 3 files changed, 39 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 8ec61ade010fa7..ccd8029e738d2c 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -156,6 +156,26 @@ def test_is_not_null(a, b): if a is not None and b is not None: return a + b + def test_foreach_lerp_(self): + def fn(x, y, s): + return torch._foreach_lerp_(x, y, s) + + cnt = torch._dynamo.testing.CompileCounter() + + fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) + expected = fn( + [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14], + [torch.ones(2, 2), torch.ones(2, 2)], + torch.tensor(0.5), + ) + + actual = fn_opt( + [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14], + [torch.ones(2, 2), torch.ones(2, 2)], + torch.tensor(0.5), + ) + self.assertTrue(same(expected, actual)) + @make_test def test_functools_partial(a, b): return clip01(a + b) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index e1ce07413f486f..8132f73e712f69 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -148,3 +148,11 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): if isinstance(obj, cls): obj.__init__(*args, **kwargs) return obj + + +def foreach_lerp_inplace(self, end, weight): + # decompose foreach lerp into constituent ops, prevents a graph break due to + # converting a value to a scalar when arg[2] is a single tensor + result = torch._foreach_sub(end, self) + result = torch._foreach_mul(result, weight) + return torch._foreach_add_(self, result) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 7fbd6865026e3a..43c5dc22d35011 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -578,6 +578,17 @@ def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs): tx, [args[0], result], {} ) + @register(torch._foreach_lerp_) + def handle_inplace_foreach_lerp_scalar( + self, tx: "InstructionTranslator", *args, **kwargs + ): + if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs: + return tx.inline_user_function_return( + SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace), + args, + kwargs, + ) + @register(torch._assert) def handle_assert(self, tx: "InstructionTranslator", condition, message): if (condition.is_python_constant() and condition.as_python_constant()) or ( From 51d406398ab1c216828d1eda8886bfef5ee23f22 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 30 Aug 2024 07:33:53 -0700 Subject: [PATCH 0316/1018] [dynamo] Rewrite foreach pow to broadcast scalar argument (#134167) Context: Adding support for the beta parameters to be tensors Details: In this PR similarly to the previous, foreach_pow calls item() on the first argument when it is a scalar tensor. In this case, we broadcast that scalar tensor into a list of aliases of that tensor to avoid the item() call, and this results in a device copy of the scalar tensor. Once again, I dont think we can change the foreach_pow API due to BC concerns, so this op rewrite allows us to avoid a graph break, generate semantically the same code, and not affect eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134167 Approved by: https://github.com/anijain2305 ghstack dependencies: #134166 --- test/dynamo/test_functions.py | 16 ++++++++++++++++ torch/_dynamo/polyfills/__init__.py | 4 ++++ torch/_dynamo/variables/torch.py | 13 +++++++++++++ 3 files changed, 33 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index ccd8029e738d2c..6a5b782f878302 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -176,6 +176,22 @@ def fn(x, y, s): ) self.assertTrue(same(expected, actual)) + def test_broadcast_foreach_pow(self): + from torch._dynamo.utils import same + + def fn(x, y): + return torch._foreach_pow(x, y) + + cnt = torch._dynamo.testing.CompileCounter() + + fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) + inps = (torch.tensor(0.80), [torch.tensor(3.4), torch.tensor(7.8)]) + + actual = fn_opt(*inps) + expected = fn(*inps) + self.assertTrue(same(actual, expected)) + self.assertTrue(cnt.frame_count, 1) + @make_test def test_functools_partial(a, b): return clip01(a + b) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 8132f73e712f69..04ba7013657e84 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -156,3 +156,7 @@ def foreach_lerp_inplace(self, end, weight): result = torch._foreach_sub(end, self) result = torch._foreach_mul(result, weight) return torch._foreach_add_(self, result) + + +def foreach_pow_scalar(scalar, exps): + return torch._foreach_pow([scalar for _ in exps], exps) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 43c5dc22d35011..30dcdb84afffc0 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -589,6 +589,19 @@ def handle_inplace_foreach_lerp_scalar( kwargs, ) + @register(torch._foreach_pow) + def handle_foreach_pow_scalar( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # In eager it's more performant to call item() from within the C op implementation + # in compile, it's more performant to not graph break. + if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs: + return tx.inline_user_function_return( + SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar), + args, + kwargs, + ) + @register(torch._assert) def handle_assert(self, tx: "InstructionTranslator", condition, message): if (condition.is_python_constant() and condition.as_python_constant()) or ( From 71f4c55c073979971cc386747d0927f7fed118e5 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 30 Aug 2024 07:33:53 -0700 Subject: [PATCH 0317/1018] [dynamo] rewrite addcmul_ to remove graph break (#134168) Context: Adding support for the beta parameters to be tensors Details: Similarly to the previous two PRs addcmul_ is used with the tensor betas as the value argument. When this occurs, an item() call is invoked in the aten op. To avoid this graph break, addcmul_ is decomposed into its constrituent ops to avoid this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134168 Approved by: https://github.com/anijain2305 ghstack dependencies: #134166, #134167 --- test/dynamo/test_functions.py | 22 ++++++++++++++++++++++ torch/_dynamo/polyfills/__init__.py | 4 ++++ torch/_dynamo/variables/tensor.py | 14 ++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 6a5b782f878302..3359b6128f4133 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -192,6 +192,28 @@ def fn(x, y): self.assertTrue(same(actual, expected)) self.assertTrue(cnt.frame_count, 1) + def test_addcmul_(self): + from copy import deepcopy + + from torch._dynamo.utils import same + + def fn(x, y, z, s): + return x.addcmul_(y, z, value=s) + + cnt = torch._dynamo.testing.CompileCounter() + fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) + inps = ( + torch.ones(2, 2), + torch.ones(2, 2) + 1, + torch.rand(2, 2), + torch.tensor(0.3), + ) + inps_2 = deepcopy(inps) + actual = fn_opt(*inps) + expected = fn(*inps_2) + self.assertTrue(same(actual, expected)) + self.assertEqual(cnt.frame_count, 1) + @make_test def test_functools_partial(a, b): return clip01(a + b) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 04ba7013657e84..2b3f38920a0c2a 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -160,3 +160,7 @@ def foreach_lerp_inplace(self, end, weight): def foreach_pow_scalar(scalar, exps): return torch._foreach_pow([scalar for _ in exps], exps) + + +def addcmul_inplace(self, tensor1, tensor2, value): + return self.add_(tensor1 * tensor2 * value) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index ea941890b5d129..5d1a8a6bda799b 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -795,6 +795,20 @@ def method___len__(self): tx = InstructionTranslator.current_tx() return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + def method_addcmul_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + from .. import polyfills + from .builder import SourcelessBuilder + + return tx.inline_user_function_return( + SourcelessBuilder.create(tx, polyfills.addcmul_inplace), + [self, tensor1, tensor2, value], + {}, + ) + def method___setitem__(self, key, value): def has_bool_key(v): if isinstance(v, TensorVariable): From 7d89c16b165d201a796ef5725df3ce1d49a09de6 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 30 Aug 2024 07:33:53 -0700 Subject: [PATCH 0318/1018] Update compiled optimizer tests for tensor betas (#134169) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134169 Approved by: https://github.com/anijain2305, https://github.com/eellison ghstack dependencies: #134166, #134167, #134168 --- test/inductor/test_compiled_optimizers.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 2f54dff94cee8a..abd62fbd584b03 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -217,7 +217,7 @@ def build_opt_kwarg_db(): has_tensor_lr = False for key, val in kwargs.items(): - if not key == "lr" and ( + if (not key == "lr" and not key == "betas") and ( not isinstance(val, bool) or (isinstance(val, bool) and val) ): name += "_" + key @@ -226,6 +226,9 @@ def build_opt_kwarg_db(): has_tensor_lr = True name += "_tensor_lr" + if key == "betas" and isinstance(kwargs["betas"][0], torch.Tensor): + name += "_tensor_betas" + name += f"_{device}" kwargs["device"] = device @@ -367,6 +370,16 @@ def test_fn(self): kwargs["lr"] = kwargs["lr"].to(device) kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device) + if "betas" in kwargs and isinstance(kwargs["betas"][0], torch.Tensor): + kwargs["betas"] = ( + kwargs["betas"][0].to(device), + kwargs["betas"][1].to(device), + ) + kwargs_compiled["betas"] = ( + kwargs_compiled["betas"][0].to(device), + kwargs_compiled["betas"][1].to(device), + ) + torch._dynamo.reset() torch._inductor.metrics.reset() input = torch.ones([10, 10], device=device) From 56f3e531e01ab168e8466379dac0bc900afb25b8 Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Thu, 29 Aug 2024 23:32:53 -0400 Subject: [PATCH 0319/1018] [Inductor] Remove VecChecker and fallback non-supported Vec op to Scalar impl with a for loop (#134569) Fall back non-vectorized op by scalar impl + for loop. Example code: ``` cpp_fused_igammac_0 = async_compile.cpp_pybinding(['const double*', 'const double*', 'double*'], ''' #include "/tmp/torchinductor_root/z4/cz4j2mmotlx3z2b7u4fbjtdt4x6plhd67ljwzg5bk7ekv4xz6y7q.h" extern "C" void kernel(const double* in_ptr0, const double* in_ptr1, double* out_ptr0) { { for(int64_t x0=static_cast(0L); x0(48L); x0+=static_cast(8L)) { auto tmp0 = at::vec::VectorizedN::loadu(in_ptr0 + static_cast(x0), 8); auto tmp1 = in_ptr1[static_cast(0L)]; auto tmp2 = at::vec::VectorizedN(tmp1); auto tmp3 = [&]() { __at_align__ std::array tmpbuf0; tmp0.store(tmpbuf0.data(), 8); __at_align__ std::array tmpbuf1; tmp2.store(tmpbuf1.data(), 8); __at_align__ std::array tmpbuf_out; for (int i = 0; i < 8; i++) { tmpbuf_out[i] = calc_igammac(tmpbuf0[i], tmpbuf1[i]); } return at::vec::VectorizedN::loadu(tmpbuf_out.data(), 8); } () ; tmp3.store(out_ptr0 + static_cast(x0), 8); } #pragma omp simd simdlen(4) for(int64_t x0=static_cast(48L); x0(50L); x0+=static_cast(1L)) { auto tmp0 = in_ptr0[static_cast(x0)]; auto tmp1 = in_ptr1[static_cast(0L)]; auto tmp2 = calc_igammac(tmp0, tmp1); out_ptr0[static_cast(x0)] = tmp2; } } } ''') ``` `frexp` are difficult to be handled by common `fallback` since it returns two `cse_var` https://github.com/pytorch/pytorch/blob/2ba60a1618813c91b26ca0860242626b22832295/torch/_inductor/codegen/cpp.py#L752-L766 So we added a special function to do that. ``` cpp_fused_frexp_0 = async_compile.cpp_pybinding(['const double*', 'double*', 'int32_t*'], ''' #include "/tmp/torchinductor_root/z4/cz4j2mmotlx3z2b7u4fbjtdt4x6plhd67ljwzg5bk7ekv4xz6y7q.h" extern "C" void kernel(const double* in_ptr0, double* out_ptr0, int32_t* out_ptr1) { { for(int64_t x0=static_cast(0L); x0(16L); x0+=static_cast(8L)) { auto tmp0 = at::vec::VectorizedN::loadu(in_ptr0 + static_cast(x0), 8); at::vec::Vectorized tmp1; at::vec::VectorizedN tmp2; [&]() { __at_align__ std::array tmpbuf; tmp0.store(tmpbuf.data(), 8); __at_align__ std::array tmpbuf_exponent; __at_align__ std::array tmpbuf_mantissa; for (int i = 0; i < 8; i++) { tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]); } tmp1 = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), 8); tmp2 = at::vec::VectorizedN::loadu(tmpbuf_mantissa.data(), 8); } (); tmp2.store(out_ptr0 + static_cast(x0), 8); tmp1.store(out_ptr1 + static_cast(x0), 8); } #pragma omp simd simdlen(4) for(int64_t x0=static_cast(16L); x0(20L); x0+=static_cast(1L)) { auto tmp0 = in_ptr0[static_cast(x0)]; int32_t tmp1; auto tmp2 = std::frexp(tmp0, &tmp1); out_ptr0[static_cast(x0)] = tmp2; out_ptr1[static_cast(x0)] = tmp1; } } } ''') ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134569 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_cpu_repro.py | 236 +-------------- torch/_inductor/codegen/common.py | 4 +- torch/_inductor/codegen/cpp.py | 429 ++++++++++----------------- torch/_inductor/codegen/cpp_utils.py | 34 +++ 4 files changed, 191 insertions(+), 512 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 1ce3a731c788d3..7dbe35f766f9a7 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -10,30 +10,20 @@ from typing import Callable from unittest.mock import patch -import numpy as np -import sympy - import torch from torch import nn from torch._C import FileCheck from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same from torch._inductor import config, cpu_vec_isa, metrics, test_operators -from torch._inductor.codegen.common import OptimizationContext -from torch._inductor.codegen.cpp import ( - CppOverrides, - CppVecKernelChecker, - CppVecOverrides, -) +from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides from torch._inductor.compile_fx import ( compile_fx, compile_fx_inner, complex_memory_overlap, ) from torch._inductor.graph import GraphLowering -from torch._inductor.ir import InterpreterShim from torch._inductor.utils import timed -from torch._inductor.virtualized import V from torch._prims_common import is_float_dtype from torch.fx.experimental.proxy_tensor import make_fx from torch.nn import functional as F @@ -2375,230 +2365,6 @@ def fn(x): assert metrics.cpp_to_dtype_count == 2 check_metrics_vec_kernel_count(1) - @requires_vectorization - @patch("torch.cuda.is_available", lambda: False) - def test_cpp_vec_constant_checker(self): - _graph: torch.fx.Graph = torch.fx.Graph() - a: torch.fx.Node = _graph.create_node("placeholder", "ops") - iv: torch.fx.Node = _graph.create_node("placeholder", "iv") - fv: torch.fx.Node = _graph.create_node("placeholder", "fv") - b: torch.fx.Node = _graph.create_node( - "call_method", - "constant", - args=( - a, - iv, - torch.int64, - ), - ) - c: torch.fx.Node = _graph.create_node( - "call_method", - "constant", - args=( - a, - fv, - torch.double, - ), - ) - d: torch.fx.Node = _graph.create_node( - "call_method", - "ge", - args=( - a, - b, - b, - ), - ) - _graph.output((d, c)) - - def get_index(): - return "" - - submodules = {"get_index": get_index} - - graph_lowering = GraphLowering( - torch.fx.GraphModule(submodules, _graph), - shape_env=None, - ) - - def set_opt_dtype(graph): - for node in graph.nodes: - if node.target == "constant": - if OptimizationContext.key in node.meta: - opt_ctx = node.meta[OptimizationContext.key] - else: - opt_ctx = OptimizationContext() - opt_ctx.dtype = node.args[-1] - node.meta[OptimizationContext.key] = opt_ctx - - with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( - graph_lowering - ): - # The moset inner loop variable is used in the index_expr - tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.float) - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - i32_iinfo = np.iinfo(np.int32) - f32_iinfo = np.finfo(np.float32) - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min, np.inf - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min, -np.inf - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5) - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5) - ) - self.assertTrue(vec_checker.simd_vec) - - @requires_vectorization - @patch("torch.cuda.is_available", lambda: False) - def test_cpp_vec_index_expr_checker(self): - _graph: torch.fx.Graph = torch.fx.Graph() - a: torch.fx.Node = _graph.create_node("placeholder", "ops") - b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=()) - c: torch.fx.Node = _graph.create_node( - "call_method", - "index_expr", - args=( - a, - b, - torch.int64, - ), - ) - d: torch.fx.Node = _graph.create_node( - "call_method", - "ge", - args=( - a, - c, - c, - ), - ) - _graph.output(d) - - def get_index(): - return "" - - submodules = {"get_index": get_index} - graph_lowering = GraphLowering( - torch.fx.GraphModule(submodules, _graph), - shape_env=None, - ) - with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( - graph_lowering - ): - itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")] - - tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.float) - # The most inner loop variable is used in the index_expr - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - - def get_index(): - return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] - - ranges = [0, 100, 200] - vec_checker.itervars = itervars[:2] - vec_checker.ranges = ranges[:2] - submodules = {"get_index": get_index} - InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertTrue(vec_checker.simd_vec) - - # Most inner loop variable irrevalant - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - - def get_index(): - return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] - - ranges = [0, 100, 200] - vec_checker.itervars = itervars - vec_checker.ranges = ranges - submodules = {"get_index": get_index} - InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertTrue(vec_checker.simd_vec) - - i32_iinfo = np.iinfo(np.int32) - _max_value = i32_iinfo.max + 1 - ranges = [_max_value, _max_value, _max_value] - # Most inner loop variable irrevalant but max value is greater than - # the max value of INT32 - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - - def get_index(): - return itervars[0] - - submodules = {"get_index": get_index} - vec_checker.itervars = itervars - vec_checker.ranges = ranges - InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertTrue(vec_checker.simd_vec) - - # Most inner loop variable irrevalant but min value is greater than - # the min value of INT32 - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - - def get_index(): - return -itervars[0] - 2 - - submodules = {"get_index": get_index} - vec_checker.itervars = itervars - vec_checker.ranges = ranges - InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertTrue(vec_checker.simd_vec) - @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_maxpool2d_cpu_only(self): diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 4f5e65f5a25b25..e27b847d0f4ca3 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -271,8 +271,8 @@ def get_device_op_overrides(device: str): @functools.lru_cache(None) def boolean_ops(): return ( - "is_inf", - "is_nan", + "isinf", + "isnan", "logical_not", "signbit", "le", diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ab480e066e68c3..6663de7639e8e1 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3,7 +3,6 @@ import dataclasses import functools import itertools -import logging import math import re import sys @@ -17,13 +16,11 @@ import torch.fx from torch._inductor import dependencies from torch._prims_common import is_float_dtype, is_integer_dtype -from torch.utils import _pytree as pytree from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from ..._dynamo.utils import counters from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics -from ..codegen.wrapper import WrapperCodeGen from ..scheduler import ( BaseSchedulerNode, BaseScheduling, @@ -62,6 +59,8 @@ OptimizationContext, ) from .cpp_utils import ( + _get_dtype_from_loopbodies, + _get_loop_body, cexpr, cexpr_index, codegen_rand, @@ -133,6 +132,26 @@ def get_export_declaration(): torch.float16, ] +VECTORIZABLE_DTYPES: List[torch.dtype] = [ + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, +] + +MASKED_VECTORIZABLE_DTYPES: List[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, +] + def reduction_init(reduction_type, dtype): if dtype in DTYPE_LOWP_FP: @@ -1012,6 +1031,7 @@ def wrapper(*args, **kwargs): "index_expr", ]: setattr(self, name, wrap(method.__func__)) + return self @staticmethod @@ -1540,8 +1560,118 @@ def index_expr(expr, dtype): csevar.update_on_args("index_expr", (expr, dtype), {}) return csevar + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): + return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) + + cdtype = DTYPE_TO_CPP[x.dtype] + size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor + code = BracesBuffer() + exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar() + exponent.update_on_args("frexp", (x,), kwargs={}) + mantissa.update_on_args("frexp", (x,), kwargs={}) + n_vec = V.kernel._get_num_vectors(x.dtype) + mantissa_t = ( + f"at::vec::Vectorized<{cdtype}>" + if n_vec == 1 + else f"at::vec::VectorizedN<{cdtype}, {n_vec}>" + ) + code.writeline(f"at::vec::Vectorized {exponent};") + code.writeline(f"{mantissa_t} {mantissa};") + code.writeline("[&]()") + with code.indent(): + code.writeline(f"__at_align__ std::array<{cdtype}, {size}> tmpbuf;") + code.writeline(f"{x}.store(tmpbuf.data(), {size});") + code.writeline(f"__at_align__ std::array tmpbuf_exponent;") + code.writeline( + f"__at_align__ std::array<{cdtype}, {size}> tmpbuf_mantissa;" + ) + code.writeline(f"for (int i = 0; i < {size}; i++)") + with code.indent(): + code.writeline( + "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" + ) + code.writeline( + f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {size});" + ) + code.writeline( + f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {size});" + ) + code.writeline("();") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.cache[cache_key] = cse_var + return mantissa, exponent + + @classmethod + def scalarize(cls, scalar_func): + def inner(*args, **kwargs): + assert not kwargs + kernel = V.kernel + assert isinstance(kernel, CppVecKernel) + code = BracesBuffer() + code.writeline("[&]()") + vec_dtype = args[0].dtype + n_vec = kernel._get_num_vectors(vec_dtype) + size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor + scalar_args = [] + cdtype = DTYPE_TO_CPP[vec_dtype] + output_mask = scalar_func.__name__ in ( + "isinf", + "isnan", + "signbit", + ) + octype = "bool" if output_mask else cdtype + with code.indent(): + for argidx, arg in enumerate(args): + if isinstance(arg, CppCSEVariable): + assert arg.is_vec + assert arg.dtype == vec_dtype + code.writeline( + f"__at_align__ std::array<{cdtype}, {size}> tmpbuf{argidx};" + ) + code.writeline(f"{arg}.store(tmpbuf{argidx}.data(), {size});") + scalar_args.append(f"tmpbuf{argidx}[i]") + else: + scalar_args.append(arg) + code.writeline(f"__at_align__ std::array<{octype}, {size}> tmpbuf_out;") + res = scalar_func(*scalar_args) + code.writeline(f"for (int i = 0; i < {size}; i++)") + with code.indent(): + code.writeline(f"tmpbuf_out[i] = {res};") + if output_mask: + assert not kernel.tail_size + load_args = "tmpbuf_out.data()" + load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" + else: + load_args = f"tmpbuf_out.data(), {size}" + if n_vec == 1: + load_fn = f"at::vec::Vectorized<{cdtype}>::loadu" + else: + load_fn = f" at::vec::VectorizedN<{cdtype}, {n_vec}>::loadu" + code.writeline(f"return {load_fn}({load_args});") + code.writeline("()") + return code + + return inner + + @classmethod + def _initialize_scalarize(cls): + for name, method in vars(CppOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in vars( + CppVecOverrides + ): + func = cls.scalarize(method.__func__) + func.__name__ = name + setattr(cls, name, staticmethod(func)) + CppVecOverrides._initialize_pointwise_overrides("cppvec") +CppVecOverrides._initialize_scalarize() class CppTile2DOverrides(CppVecOverrides): @@ -2554,9 +2684,11 @@ def store_reduction(self, name, index, value): # Vertical reduction if out_dtype != dtype: converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}" - code.writeline( - f"auto {converted_value} = at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value});" - ) + if out_dtype == torch.bool: + convert = f"{value}.template cast()" + else: + convert = f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + code.writeline(f"auto {converted_value} = {convert};") value = converted_value code.splice(self._get_store_line(value, var, index, out_dtype)) self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x))) @@ -2995,235 +3127,6 @@ def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: ) -class CppVecKernelChecker(CppVecKernel): - def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1): - super().__init__(args, num_threads, tiling_factor, tiling_idx) - - # Since this kernel is only for checker but does not generate any - # code, so we need to decrease the kernel count. - metrics.generated_kernel_count -= 1 - - # Used to record the graph wrapper code as the wrapper_code status could be - # changed during graph run. - self._orig_wrapper_code = None - - self.simd_vec = True - - self.simd_masked_vec = True - - self.fast_vec_list = [] - for k, v in CppVecOverrides.__dict__.items(): - if isinstance(v, staticmethod): - self.fast_vec_list.append(k) - self.exit_stack = contextlib.ExitStack() - - # Cache all the load result - self.supported_dtypes: List[torch.dtype] = [ - torch.float64, - torch.float, - torch.bfloat16, - torch.float16, - torch.bool, - torch.uint8, - torch.int8, - torch.int32, - torch.int64, - ] - - # TODO: remove it after all data types support masked vectorization. - self.supported_dtypes_for_masked_vec: List[torch.dtype] = [ - torch.float, - torch.bfloat16, - torch.float16, - torch.uint8, - torch.int8, - ] - - def disable_vec(self, msg=None): - if schedule_log.isEnabledFor(logging.DEBUG): - schedule_log.debug("Disabled vectorization: %s", msg) - self.simd_vec = False - self.simd_masked_vec = False - - def disable_masked_vec(self, msg=None): - if schedule_log.isEnabledFor(logging.DEBUG): - schedule_log.debug("Disabled masked vectorization: %s", msg) - self.simd_masked_vec = False - - def load(self, name: str, index: sympy.Expr): - with RecordOptimizationContext(__name__) as node_ctx: - load_dtype = V.graph.get_dtype(name) - opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() - assert opt_ctx - - opt_ctx.dtype = load_dtype - var = self.cse.newvar() - - if load_dtype not in self.supported_dtypes_for_masked_vec: - self.disable_masked_vec( - f"{load_dtype} not supported by masked vectorization" - ) - - if has_free_symbols(self.ranges): - self.disable_masked_vec("Symbolic ranges not supported by masked load") - - if len(self.itervars) == 0: - self.disable_vec("not a loop") - return var - - if load_dtype not in self.supported_dtypes and ( - index.has(self.itervars[self.tiling_idx]) - or free_symbol_is_type(index, SymT.TMP) - ): - self.disable_vec(f"{load_dtype} not supported by load") - return var - - return var - - def store(self, name, index, value, mode=None): - with RecordOptimizationContext(__name__) as node_ctx: - store_dtype = V.graph.get_dtype(name) - - if store_dtype not in self.supported_dtypes_for_masked_vec: - self.disable_masked_vec( - f"{store_dtype} not supported by masked vectorization" - ) - - if has_free_symbols(self.ranges): - self.disable_masked_vec("Symbolic ranges not supported by masked store") - - if len(self.itervars) == 0: - self.disable_vec("not a loop") - return self.simd_vec - - opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() - assert opt_ctx - opt_ctx.dtype = store_dtype - - if store_dtype not in self.supported_dtypes: - self.disable_vec(f"{store_dtype} not supported by store") - return self.simd_vec - - assert "buf" in name - index = self.rename_indexing(index) - - return self.simd_vec - - def reduction(self, dtype, src_dtype, reduction_type, value): - if has_free_symbols(self.ranges): - self.disable_masked_vec("Symbolic ranges not supported by masked reduction") - - if reduction_type not in VECTORIZABLE_RTYPES: - self.disable_vec( - f"reduction: dtype {dtype}, src_dtype {src_dtype}, reduction_type {reduction_type}" - ) - if is_welford_reduction(reduction_type): - return tuple([self.simd_vec] * 3) - return self.simd_vec - - def check_bounds( - self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool - ): - return self.simd_vec - - def store_reduction(self, name, index, value): - return self.simd_vec - - def __exit__(self, exc_type, exc_val, exc_tb): - # Restore the wrapper_code - V.graph.wrapper_code = self._orig_wrapper_code # type: ignore[assignment] - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - - def __enter__(self): - # Record the graph wrapper code. The wrapper_code status could be - # changed during graph run. Regarding this checker, we also need to - # run the graph but we don't expect to change any status that would - # impact the code generation. Hence, we record the graph wrapper code - # and replace it with a dummy wrapper_code and then restore to the - # original one as long as the checker is finished. - self._orig_wrapper_code = V.graph.wrapper_code - V.graph.wrapper_code = WrapperCodeGen() - - parent_handler = V.MockHandler() - - class VecCheckerProxy: - @staticmethod - def __getattr__(name): # type: ignore[misc] - def inner(*args, **kwargs): - if name not in self.fast_vec_list: - self.disable_vec(f"op: {name}") - - parent_val = getattr(parent_handler, name)(*args, **kwargs) - return pytree.tree_map(lambda _: self.simd_vec, parent_val) - - return inner - - @staticmethod - def load(name: str, index: sympy.Expr): - return self.load(name, index) - - @staticmethod - def store(name, index, value, mode=None): - return self.store(name, index, value, mode=mode) - - @staticmethod - def reduction(dtype, src_dtype, reduction_type, value): - return self.reduction(dtype, src_dtype, reduction_type, value) - - @staticmethod - def store_reduction(name, index, value): - return self.store_reduction(name, index, value) - - @staticmethod - def check_bounds( - expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool - ): - return self.check_bounds(expr, size, lower, upper) - - @staticmethod - def constant(val, dtype): - with RecordOptimizationContext(__name__) as node_ctx: - opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() - assert opt_ctx - - if opt_ctx.dtype not in self.supported_dtypes_for_masked_vec: - self.disable_masked_vec( - f"{opt_ctx.dtype} not supported by masked vectorization" - ) - - if opt_ctx.dtype not in self.supported_dtypes: - self.disable_vec(f"constant dtype: {opt_ctx.dtype}") - return val - - @staticmethod - def index_expr(expr, dtype): - return self.cse.newvar() - - @staticmethod - def indirect_indexing(index_var, size, check=True, wrap_neg=True): - return sympy_index_symbol(str(index_var)) - - @staticmethod - def masked(mask, body, other): - body() - return self.cse.newvar() - - @staticmethod - def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): - if dtype not in self.supported_dtypes_for_masked_vec: - self.disable_masked_vec( - f"{dtype} not supported by masked vectorization" - ) - - if dtype not in self.supported_dtypes: - self.disable_vec(f"to_dtype: {dtype}") - return x - - self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) - self.exit_stack.enter_context(V.set_kernel_handler(self)) - return self - - def get_loop_body_lowp_fp(_body: ir.LoopBody) -> Optional[torch.dtype]: """ Returns the low precision float data type (torch.float16/torch.bfloat16) if all the @@ -3283,24 +3186,11 @@ def select_tiling( var_sizes_list, ) -> Tuple[List[int], List[int]]: # TODO(jgong5): support alternative tiling factors and data types - loop_bodies = None - if all(isinstance(fn, ir.LoopBody) for fn in fn_list): - loop_bodies = fn_list - else: - if hasattr(fn_list[0], "original_fn"): - # For the case of local buffer, we wrap the fn with localize_function - assert all(hasattr(fn, "original_fn") for fn in fn_list) - assert all( - isinstance(fn.original_fn.args[0]._body, ir.LoopBody) - for fn in fn_list - ) - loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] - else: - assert all(isinstance(fn, functools.partial) for fn in fn_list) - assert all(isinstance(fn.args[0]._body, ir.LoopBody) for fn in fn_list) - loop_bodies = [fn.args[0]._body for fn in fn_list] - assert loop_bodies is not None - + loop_bodies = _get_loop_body(fn_list) + all_dtypes = _get_dtype_from_loopbodies(loop_bodies) + assert all_dtypes + if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): + return [], [] dtype = torch.float _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0]) if _lowp_fp_dtype and all( @@ -3730,6 +3620,10 @@ def run(kernel): if not self.picked_vec_isa: return + if not self.itervars: + # not a loop + return + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. # But the generated scalar kernel has updated these global contexts. Hence, the other kernels # should not do this again to avoid context conflict. By now, we only control the @@ -3741,29 +3635,14 @@ def run(kernel): ) assert len(tiling_factors) == len(tiling_indices) # This should be removed after full support for vectorization is implemented. - could_vec = True could_masked_vec = True - if tiling_factors and tiling_indices: - tiling_factor = tiling_factors[0] - assert all( - _tiling_factor == tiling_factor for _tiling_factor in tiling_factors - ) - for tiling_indice in tiling_indices: - with CppVecKernelChecker( - deepcopy(self.kernel_group.args), - parallel_num_threads(), - tiling_factor, - tiling_indice, - ) as vec_checker: - run(vec_checker) - could_vec = could_vec and vec_checker.simd_vec - could_masked_vec = ( - could_masked_vec and vec_checker.simd_masked_vec - ) - if not could_vec: - tiling_factors = [] - tiling_indices = [] - break + all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list)) + if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): + # can be removed after masked vectorizable dtype are same with vectorizable dtype + could_masked_vec = False + if has_free_symbols(self.ranges): + # can be removed after masked vec support dynamic ranges + could_masked_vec = False if len(tiling_indices) == 1: vec_kernel = codegen_kernel( diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 8dfa3489e7673d..46ec1ac9644891 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -880,3 +880,37 @@ def inner_fn(index): inner_fn=inner_fn, ranges=input_buffer.get_size(), ) + + +def _get_loop_body(fn_list): + loop_bodies = None + if all(isinstance(fn, ir.LoopBody) for fn in fn_list): + loop_bodies = fn_list + else: + if hasattr(fn_list[0], "original_fn"): + # For the case of local buffer, we wrap the fn with localize_function + assert all(hasattr(fn, "original_fn") for fn in fn_list) + assert all( + isinstance(fn.original_fn.args[0]._body, ir.LoopBody) for fn in fn_list + ) + loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] + else: + assert all(isinstance(fn, functools.partial) for fn in fn_list) + assert all(isinstance(fn.args[0]._body, ir.LoopBody) for fn in fn_list) + loop_bodies = [fn.args[0]._body for fn in fn_list] + assert loop_bodies is not None + return loop_bodies + + +def _get_dtype_from_loopbodies(loop_bodies): + dtypes = set() + for loop_body in loop_bodies: + graphs = [loop_body.root_block.graph] + [ + body.graph for body in list(loop_body.subblocks.values()) + ] + for graph in graphs: + for node in graph.nodes: + if node.op != "call_method": + continue + dtypes.add(node.meta[OptimizationContext.key].dtype) + return dtypes From 40b0c844502e7312bb82c7f5976a6d25b3a991ca Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sat, 31 Aug 2024 18:08:11 +0000 Subject: [PATCH 0320/1018] [inductor] move test_fuse_large_params to slow test. (#134900) Move `test_fuse_large_params` to slow test. This case spend about 1.5 minutes. image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134900 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b354ff2b1ac903..3127fae65976c8 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10874,6 +10874,7 @@ def fn(x, y): fn(a, b) # Skipped on ROCm until https://github.com/ROCm/triton/issues/443 resolved + @slowTest def test_fuse_large_params(self): def pt2_optimizer_step(optimizer): @torch.compile() From 8d3a8679b4caf62bdbb32639949436282c3e7d1c Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 31 Aug 2024 20:18:59 +0000 Subject: [PATCH 0321/1018] Clean up RemoteCache classes (#134032) Summary: The existing RemoteCacheBackend classes were a bit haphazard - some of them accepted bytes only, some accepted objects, some returned different types of objects than were passed in. Update them to be more consistent: 1. RemoteCacheBackend is an implementation of a backend: Redis, Memcache, Manifold, LocalFile 2. RemoteCacheSerde is an implementation of a serde protocol - to turn structured objects (dict, list, etc) into bytes: RemoteCacheJsonSerde (json encoding), RemoteCachePassthroughSerde (strictly bytes only) 3. RemoteCache is the cache implementation itself, mixing a RemoteCacheBackend along with an RemoteCacheSerde to provide structured caching. Other than simply reorganizing the existing cache code this also fixes the Redis autotune caching for OSS. Test Plan: unit tests Reviewed By: oulgen Differential Revision: D61178859 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134032 Approved by: https://github.com/oulgen, https://github.com/bhack --- test/inductor/mock_cache.py | 281 +++++++----------- test/inductor/test_codecache.py | 34 +-- test/inductor/test_max_autotune.py | 30 +- torch/_VF.py | 6 +- .../_aot_autograd/autograd_cache.py | 2 +- torch/_inductor/codecache.py | 70 +++-- torch/_inductor/remote_cache.py | 173 ++++++++++- torch/_inductor/runtime/triton_heuristics.py | 25 +- 8 files changed, 355 insertions(+), 266 deletions(-) diff --git a/test/inductor/mock_cache.py b/test/inductor/mock_cache.py index b5effefc1fa3af..8db9cc2ba733c0 100644 --- a/test/inductor/mock_cache.py +++ b/test/inductor/mock_cache.py @@ -1,127 +1,110 @@ # Owner(s): ["module: inductor"] +from __future__ import annotations + import contextlib import dataclasses import sys import threading -import unittest.mock -from types import TracebackType -from typing import Callable, Generator, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, Optional, Type, TYPE_CHECKING from typing_extensions import override, Self +from unittest.mock import patch import torch from torch._inductor import config from torch._inductor.remote_cache import RemoteCacheBackend -# The cache state is thread-local so if we're running multiple tests at once -# they won't cross contaminate. However - it needs to be "global" because we -# allow code to create new cache clients which refer to the same cache (because -# it's a remote cache). -class _MockCacheState(threading.local): - def __init__(self, name: str): - self.reset() - self._name = name - self._cache = {} - self._clients = {} # Used for Manifold +if TYPE_CHECKING: + from types import TracebackType + - def reset(self): - self.num_init = 0 +@dataclasses.dataclass +class Stats: + num_put: int = 0 + num_get_hit: int = 0 + num_get_miss: int = 0 + + def __iadd__(self, other: Stats) -> Self: + self.num_put += other.num_put + self.num_get_hit += other.num_get_hit + self.num_get_miss += other.num_get_miss + return self + + def reset(self) -> None: self.num_put = 0 self.num_get_hit = 0 self.num_get_miss = 0 - def report(self): - print( - "".join( - [ - f"{self._name} cache: ", - f"init: {self.num_init}, ", - f"puts: {self.num_put}, ", - f"misses: {self.num_get_miss}, ", - f"hits: {self.num_get_hit}, ", - ] - ), - file=sys.stderr, + def __str__(self) -> str: + return "".join( + ( + f"puts: {self.num_put}, ", + f"misses: {self.num_get_miss}, ", + f"hits: {self.num_get_hit}, ", + ) ) -class _MockLocalAutotuneCacheBackend(RemoteCacheBackend): - _state = _MockCacheState("Local") +# The cache states are thread-local so if we're running multiple tests at once +# they won't cross contaminate. However - it needs to be "global" because we +# allow code to create new cache clients which refer to the same cache (because +# it's a remote cache). - def __init__(self): - state = self._state - state.num_init += 1 - @override - def get(self, key: str) -> Optional[bytes]: - assert isinstance(key, str) +class _GlobalStats(Stats, threading.local): + def __init__(self) -> None: + self.autotune = Stats() + self.fx_graph = Stats() + self.triton = Stats() - state = self._state - if key in state._cache: - state.num_get_hit += 1 - return state._cache[key] - else: - state.num_get_miss += 1 + def reset(self) -> None: + self.autotune.reset() + self.fx_graph.reset() + self.triton.reset() - @override - def put(self, key: str, data: bytes) -> None: - assert isinstance(key, str) - assert isinstance(data, bytes) + def update(self, name: str, delta: Stats) -> None: + stat = getattr(self, name) + stat += delta - state = self._state - state.num_put += 1 - state._cache[key] = data + def report(self): + print("Cache Stats:", file=sys.stderr) + print(f" autotune: {self.autotune}", file=sys.stderr) + print(f" fx_graph: {self.fx_graph}", file=sys.stderr) + print(f" triton: {self.triton}", file=sys.stderr) -class _MockRedisRemoteCache: - _state = _MockCacheState("Redis") +global_stats = _GlobalStats() - def __init__(self, *args, **kwargs): - state = self._state - state.num_init += 1 - def get(self, key: Union[bytes, str]) -> Optional[Union[bytes, str, int, float]]: - assert isinstance(key, (bytes, str)) +class MockBackend(RemoteCacheBackend[Any]): + def __init__(self, name: str, cache: Dict[str, object]) -> None: + self._cache = cache + self._name = name - state = self._state + @staticmethod + def with_name(name: str) -> Callable[[], MockBackend]: + cache = {} - if key in state._cache: - state.num_get_hit += 1 - else: - state.num_get_miss += 1 - return state._cache.get(key) + def wrapper() -> MockBackend: + return MockBackend(name, cache) - def set(self, key: Union[bytes, str], data: Union[bytes, str, int, float]) -> None: - assert isinstance(key, (bytes, str)) - assert isinstance(data, (bytes, str, int, float)), type(data) + return wrapper - state = self._state + @override + def get(self, key: str) -> Optional[Any]: + if key in self._cache: + global_stats.update(self._name, Stats(num_get_hit=1)) + return self._cache.get(key) + else: + global_stats.update(self._name, Stats(num_get_miss=1)) + return None - # According to https://redis-py.readthedocs.io/en/stable/commands.html#redis.commands.core.CoreCommands.set - # redis accepts Union[bytes, memoryview, str, int, float] - state.num_put += 1 - state._cache[key] = data + @override + def put(self, key: str, data: Any) -> None: + global_stats.update(self._name, Stats(num_put=1)) + self._cache[key] = data -@dataclasses.dataclass -class CacheDecl: - qname: str - cls: Type[object] - f: Optional[Callable[..., object]] = None - - def patch(self) -> contextlib.AbstractContextManager: - return unittest.mock.patch(self.qname, self.f or self.cls) - - -_CACHES = ( - CacheDecl( - "torch._inductor.runtime.triton_heuristics.LocalAutotuneCache", - _MockLocalAutotuneCacheBackend, - ), - # This causes any mocking test to require 'redis'. - CacheDecl("redis.Redis", _MockRedisRemoteCache), -) - # List of configs for each cache _CACHE_CONFIG_EN = ( "fx_graph_cache", @@ -133,52 +116,6 @@ def patch(self) -> contextlib.AbstractContextManager: class PatchCaches(contextlib.AbstractContextManager): - num_init = 0 - num_put = 0 - num_get_miss = 0 - num_get_hit = 0 - _savedCacheState = {} - - @staticmethod - def get_caches() -> Tuple[CacheDecl, ...]: - if config.is_fbcode(): - from .fb.mock_cache import FB_CACHES - - return _CACHES + FB_CACHES - else: - return _CACHES - - def __init__(self): - self._contexts = [] - for decl in self.get_caches(): - self._contexts.append(decl.patch()) - - @classmethod - def reset(cls): - """ - Reset the patched cache states as well as the PatchCaches - aggregation. - """ - cls.num_init = 0 - cls.num_put = 0 - cls.num_get_miss = 0 - cls.num_get_hit = 0 - - for decl in cls.get_caches(): - decl.cls._state.reset() - - @classmethod - def update(cls): - """ - Update PatchCaches' state with the values from all the patched caches. - """ - cls.num_init = sum(decl.cls._state.num_init for decl in cls.get_caches()) - cls.num_put = sum(decl.cls._state.num_put for decl in cls.get_caches()) - cls.num_get_miss = sum( - decl.cls._state.num_get_miss for decl in cls.get_caches() - ) - cls.num_get_hit = sum(decl.cls._state.num_get_hit for decl in cls.get_caches()) - @classmethod def setUp(cls): # If this test is using PatchCaches then disable all the caches by @@ -190,50 +127,52 @@ def setUp(cls): cls._savedCacheState[name] = getattr(config, name) setattr(config, name, False) - for decl in cls.get_caches(): - if hasattr(decl.cls, "setUp"): - decl.cls.setUp() - @classmethod def tearDown(cls): - for decl in cls.get_caches()[::-1]: - if hasattr(decl.cls, "tearDown"): - decl.cls.tearDown() - # Restore cache defaults for name in _CACHE_CONFIG_EN: delattr(config, name) if name in cls._savedCacheState: setattr(config, name, cls._savedCacheState[name]) - @classmethod - def report(cls): - """ - Report cache state for all patched caches. - """ - for decl in cls.get_caches(): - decl.cls._state.report() - print( - "".join( - [ - "All caches: ", - f"init: {cls.num_init}, ", - f"puts: {cls.num_put}, ", - f"misses: {cls.num_get_miss}, ", - f"hits: {cls.num_get_hit}", - ] - ), - file=sys.stderr, - ) + def __init__(self) -> None: + self._stack = contextlib.ExitStack() def __enter__(self) -> Self: - """ - Start mocking the patched caches. - """ - self.reset() + global_stats.reset() + self._stack.__enter__() + + ctx = patch( + "torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls", + MockBackend.with_name("fx_graph"), + ) + self._stack.enter_context(ctx) + + if config.is_fbcode(): + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls", + MockBackend.with_name("fx_graph"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls", + MockBackend.with_name("triton"), + ) + self._stack.enter_context(ctx) - for ctx in self._contexts: - ctx.__enter__() return self def __exit__( @@ -242,13 +181,7 @@ def __exit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - """ - Stop mocking the patched caches. - """ - for ctx in self._contexts[::-1]: - ctx.__exit__(exc_type, exc_value, traceback) - - self.update() + self._stack.__exit__(exc_type, exc_value, traceback) @contextlib.contextmanager diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 4b7f1039611d36..516d16fba7869d 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -41,9 +41,9 @@ try: - from .mock_cache import patch_fbcode, PatchCaches + from .mock_cache import global_stats, patch_fbcode, PatchCaches except ImportError: - from mock_cache import PatchCaches # @manual + from mock_cache import global_stats, patch_fbcode, PatchCaches # @manual HAS_TRITON = has_triton() @@ -160,6 +160,7 @@ def fn(x, y): self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) @requires_triton() + @config.patch({"fx_graph_remote_cache": True}) @parametrize("device", (GPU_TYPE, "cpu")) @parametrize("dtype", (torch.float32, torch.bfloat16)) @parametrize("dynamic", (False, True)) @@ -179,7 +180,6 @@ def fn(x, y): with config.patch( { - "fx_graph_cache": False, "fx_graph_remote_cache": True, } ), patch.dict(os.environ), PatchCaches(): @@ -190,10 +190,10 @@ def fn(x, y): self.assertEqual(fn(a, b), compiled_fn(a, b)) reset() - PatchCaches.report() - self.assertEqual(PatchCaches.num_get_hit, 3) - self.assertEqual(PatchCaches.num_get_miss, 1) - self.assertEqual(PatchCaches.num_put, 1) + global_stats.report() + self.assertEqual(global_stats.fx_graph.num_get_hit, 3) + self.assertEqual(global_stats.fx_graph.num_get_miss, 1) + self.assertEqual(global_stats.fx_graph.num_put, 1) @requires_triton() @config.patch({"fx_graph_cache": True}) @@ -793,6 +793,8 @@ def reset(self): torch._dynamo.reset() clear_inductor_caches() + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @config.patch({"autotune_local_cache": False}) @@ -800,9 +802,6 @@ def reset(self): @config.patch({"max_autotune": True}) @parametrize("fbcode", (False,) + (True,) * config.is_fbcode()) def test_autotune_cache(self, fbcode: bool): - if not fbcode: - self.skipTest("Redis for autotune is currently broken") - class Model(torch.nn.Module): def forward(self, x, y, a, b): return x + y, a + b @@ -819,18 +818,17 @@ def f(x, y, a, b): with PatchCaches(), patch_fbcode(fbcode): f_compiled(x, y, a, b) - PatchCaches.update() - self.assertEqual(PatchCaches.num_get_hit, 0) - self.assertEqual(PatchCaches.num_get_miss, 2) - self.assertEqual(PatchCaches.num_put, 2) + self.assertEqual(global_stats.autotune.num_get_hit, 0) + self.assertEqual(global_stats.autotune.num_get_miss, 2) + self.assertEqual(global_stats.autotune.num_put, 2) self.reset() f_compiled(x, y, a, b) - PatchCaches.report() - self.assertEqual(PatchCaches.num_get_hit, 2) - self.assertEqual(PatchCaches.num_get_miss, 2) - self.assertEqual(PatchCaches.num_put, 2) + global_stats.report() + self.assertEqual(global_stats.autotune.num_get_hit, 2) + self.assertEqual(global_stats.autotune.num_get_miss, 2) + self.assertEqual(global_stats.autotune.num_put, 2) class TestUtils(TestCase): diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 61e21189a9b51a..5a74bb3048490c 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -35,9 +35,9 @@ try: - from .mock_cache import PatchCaches + from .mock_cache import global_stats, PatchCaches except ImportError: - from mock_cache import PatchCaches # @manual + from mock_cache import global_stats, PatchCaches # @manual torch.set_float32_matmul_precision("high") @@ -258,7 +258,7 @@ def no_lookup( op: str, inputs: str, benchmark: Callable[[Any], Dict[ChoiceCaller, float]], - ) -> Dict[ChoiceCaller, float]: + ) -> Optional[Dict[ChoiceCaller, float]]: if benchmark is not None: return benchmark(choices) @@ -723,9 +723,6 @@ def tearDown(self): def test_max_autotune_remote_caching(self, dynamic: bool): from unittest.mock import patch - if not config.is_fbcode(): - self.skipTest("Redis for autotune is currently broken") - def mm(a, b): a = torch.sin(a) return a @ b @@ -756,22 +753,20 @@ def f(x, y): torch.compile(mm, dynamic=dynamic)(a, b) reset() - PatchCaches.update() - PatchCaches.report() - self.assertEqual(PatchCaches.num_get_hit, 3) - self.assertEqual(PatchCaches.num_get_miss, 1) - self.assertEqual(PatchCaches.num_put, 1) + global_stats.report() + self.assertEqual(global_stats.autotune.num_get_hit, 3) + self.assertEqual(global_stats.autotune.num_get_miss, 1) + self.assertEqual(global_stats.autotune.num_put, 1) - PatchCaches.reset() + global_stats.reset() for _ in range(4): with fresh_inductor_cache(): torch.compile(f, dynamic=dynamic)(x, y) reset() - PatchCaches.update() - PatchCaches.report() - self.assertEqual(PatchCaches.num_get_hit, 3) - self.assertEqual(PatchCaches.num_get_miss, 1) - self.assertEqual(PatchCaches.num_put, 1) + global_stats.report() + self.assertEqual(global_stats.autotune.num_get_hit, 3) + self.assertEqual(global_stats.autotune.num_get_miss, 1) + self.assertEqual(global_stats.autotune.num_put, 1) class TestBenchmarkRequest(BenchmarkRequest): @@ -796,6 +791,7 @@ def benchmark( if not self.multi_device: assert visible_devices == self.parent_visible_devices else: + assert self.parent_visible_devices is not None valid_devices = self.parent_visible_devices.split(",") assert visible_devices in valid_devices diff --git a/torch/_VF.py b/torch/_VF.py index 13fe080e6f41e5..94166b51f17865 100644 --- a/torch/_VF.py +++ b/torch/_VF.py @@ -20,12 +20,12 @@ class VFModule(types.ModuleType): vf: types.ModuleType - def __init__(self, name): + def __init__(self, name: str): super().__init__(name) self.vf = torch._C._VariableFunctions - def __getattr__(self, attr): - return getattr(self.vf, attr) + def __getattr__(self, name: str) -> object: + return getattr(self.vf, name) sys.modules[__name__] = VFModule(__name__) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index d280f67d2fd516..2db43bc865f0bf 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -262,7 +262,7 @@ def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGra # That is, AOTAutograd and Inductor never create new guards based on symints with different sources # than those passed to it by inductor. result = FxGraphCache._lookup_graph( - self.fx_graph_cache_key, example_inputs, local=True, remote_cache=False + self.fx_graph_cache_key, example_inputs, local=True, remote_cache=None ) if result is None: log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_key) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 53748266b7a9c4..af86b71f5673ab 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -68,7 +68,7 @@ if TYPE_CHECKING: from collections.abc import KeysView - from .remote_cache import RemoteCacheBackend + from .remote_cache import JsonDataTy, RemoteCache """ @@ -142,16 +142,16 @@ ) else: - def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] pass - def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] pass - def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] pass - def use_global_cache() -> bool: + def use_global_cache() -> bool: # type: ignore[misc] return False @@ -443,7 +443,7 @@ def write_text(text: str) -> str: def write_atomic( - path: str, + path_: str, content: Union[str, bytes], make_dirs: bool = False, encode_utf_8: bool = False, @@ -453,7 +453,7 @@ def write_atomic( assert isinstance( content, (str, bytes) ), "Only strings and byte arrays can be saved in the cache" - path = Path(path) + path = Path(path_) if make_dirs: path.parent.mkdir(parents=True, exist_ok=True) tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" @@ -694,7 +694,7 @@ def get_code_hash(root: str) -> bytes: from libfb.py import parutil - return parutil.get_file_contents("torch/src_hash.txt").rstrip() + return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") def get_inductor_root() -> str: @@ -837,8 +837,10 @@ def cudagraph_post_compile( from .compile_fx import cudagraphify + current_callable = compiled_graph.current_callable + assert current_callable is not None compiled_graph.current_callable = cudagraphify( - compiled_graph.current_callable, + current_callable, static_input_idxs=static_input_idxs, device_index=next(iter(compiled_graph.device_idxs)), stack_traces=stack_traces, @@ -1000,7 +1002,7 @@ def _lookup_graph( key: str, example_inputs: List[torch.Tensor], local: bool, - remote_cache: Optional[Any], + remote_cache: Optional[RemoteCache[JsonDataTy]], ) -> Optional[CompiledFxGraph]: """ Lookup a compiled graph in the cache by key. On a hit, return the @@ -1028,8 +1030,12 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: if remote_cache: try: - if (data := remote_cache.get(key)) is not None: - yield pickle.loads(data) + if (cache_data := remote_cache.get(key)) is not None: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, (str, bytes)) + content = base64.b64decode(data) + yield pickle.loads(content) except Exception: log.warning( "fx graph cache unable to load compiled graph", exc_info=True @@ -1179,7 +1185,7 @@ def _save_graph( compiled_graph: CompiledFxGraph, example_inputs: List[torch.Tensor], local: bool, - remote_cache: Optional[RemoteCacheBackend], + remote_cache: Optional[RemoteCache[JsonDataTy]], ) -> None: """ Store a serialized CompiledFxGraph on disk. @@ -1227,15 +1233,11 @@ def _save_graph( if remote_cache: time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6) - cache_data = ( - { - "data": content, - "time_taken_ms": time_taken_ms, - } - if config.is_fbcode() - else content - ) - remote_cache.put(key, cache_data) # type: ignore[arg-type] + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) except Exception: log.warning("fx graph unable to write to cache", exc_info=True) counters["inductor"]["fxgraph_cache_write_error"] += 1 @@ -1296,20 +1298,22 @@ def load( # type: ignore[no-untyped-def] cache_info["key"] = key cache_info["components"] = debug_lines - remote_cache: Optional[RemoteCacheBackend] = None + remote_cache: Optional[RemoteCache[JsonDataTy]] = None if remote: cache_id = "fx-graph-v1" try: if config.is_fbcode(): - from torch._inductor.fb.remote_cache import ( - FbRemoteFxGraphCacheBackend, - ) + from torch._inductor.fb.remote_cache import FbRemoteFxGraphCache - remote_cache = FbRemoteFxGraphCacheBackend(cache_id) + remote_cache = FbRemoteFxGraphCache(cache_id) else: - from torch._inductor.remote_cache import RedisRemoteCacheBackend + from torch._inductor.remote_cache import RemoteFxGraphCache - remote_cache = RedisRemoteCacheBackend(cache_id) + remote_cache = RemoteFxGraphCache(cache_id) + except ModuleNotFoundError as e: + # No need for a stack trace on this error + remote_cache = None + log.warning("Unable to create a remote cache: %s", e) except Exception: remote_cache = None log.warning("Unable to create a remote cache", exc_info=True) @@ -1458,14 +1462,18 @@ def __init__( self.metrics_deltas = metrics_deltas self.counter_deltas = counter_deltas self.guards_expr = None + self.cudagraph_info = None + self.fx_kwargs = {} + self.inputs_to_check = () + self.boxed_forward_device_index = None def __call__(self, inputs: List[Any]) -> Any: assert self.current_callable is not None return self.current_callable(inputs) -def run_command_and_check(cmd: str) -> None: - cmd = shlex.split(cmd) +def run_command_and_check(cmd_: str) -> None: + cmd = shlex.split(cmd_) try: subprocess.check_call(cmd) except subprocess.CalledProcessError as e: diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 391192eaac6832..14963a1bf5d435 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -1,32 +1,140 @@ +from __future__ import annotations + +import json import os +import typing from abc import abstractmethod -from typing import Optional +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union +from typing_extensions import override, TypeAlias + +from torch._inductor import config + + +try: + import redis +except ImportError: + redis = None # type: ignore[assignment] + + +if config.is_fbcode(): + from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found] + Sample as Sample_, + ) + Sample: TypeAlias = Sample_ +else: + Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef] -class RemoteCacheBackend: + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class RemoteCacheBackend(Generic[_T]): """ - A backend implementation for accessing a remote/distributed cache. + A backend implementation for accessing a remote/distributed cache. Only + works with bytes in/out. For structured data use a RemoteCache. """ - def __init__(self, cache_id: str) -> None: + @abstractmethod + def get(self, key: str) -> Optional[_T]: pass @abstractmethod - def get(self, key: str) -> Optional[object]: + def put(self, key: str, data: _T) -> None: pass + +# Serde that encodes from _T to _U and decodes from _U to _T. +class RemoteCacheSerde(Generic[_T, _U]): @abstractmethod - def put(self, key: str, data: bytes) -> None: + def encode(self, data: _T) -> _U: + pass + + @abstractmethod + def decode(self, data: _U) -> _T: + pass + + +JsonDataTy = Optional[ + Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]] +] + + +class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]): + def encode(self, data: JsonDataTy) -> bytes: + return bytes(json.dumps(data), "ascii") + + def decode(self, data: bytes) -> JsonDataTy: + return json.loads(data) + + +class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]): + def encode(self, data: _T) -> _T: + return data + + def decode(self, data: _T) -> _T: + return data + + +class RemoteCache(Generic[_T]): + backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None + + def __init__( + self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U] + ) -> None: + # Support for testing. + if (override_cls := self.__class__.backend_override_cls) is not None: + self.backend = override_cls() + else: + self.backend = backend + self.serde = serde + + def get(self, key: str) -> Optional[_T]: + sample = self._create_sample() + result = self._get(key, sample) + self._log_sample(sample) + return result + + def put(self, key: str, value: _T) -> None: + sample = self._create_sample() + self._put(key, value, sample) + self._log_sample(sample) + + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: + return self.serde.decode(data) + + def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U + return self.serde.encode(value) + + def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]: + if data := self.backend.get(key): + return self._decode(data, sample) + return None + + def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None: + data = self._encode(value, sample) + self.backend.put(key, data) + + def _create_sample(self) -> Optional[Sample]: + return None + + def _log_sample(self, sample: Optional[Sample]) -> None: pass -class RedisRemoteCacheBackend(RemoteCacheBackend): +class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): """ A Redis implementation of a remote/distributed cache. """ + _key_fmt: str + _redis: Optional[redis.Redis] = None + def __init__(self, cache_id: str) -> None: - import redis + if not redis: + # We had trouble importing redis - just skip init. + return self._key_fmt = f"pt2:{cache_id}:{{key}}" self._redis = redis.Redis( @@ -34,14 +142,57 @@ def __init__(self, cache_id: str) -> None: port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), ) - def _get_key(self, key: str) -> str: + def __get_key(self, key: str) -> str: return self._key_fmt.format(key=key) + @override def get(self, key: str) -> Optional[bytes]: - value = self._redis.get(self._get_key(key)) + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return None + + try: + value = self._redis.get(self.__get_key(key)) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + return None + # In theory redis.get() can return an Awaitable as well... assert value is None or isinstance(value, bytes) return value + @override def put(self, key: str, data: bytes) -> None: - self._redis.set(self._get_key(key), data) + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return + + try: + self._redis.set(self.__get_key(key), data) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + + +class RedisRemoteCache(RemoteCache[JsonDataTy]): + def __init__(self, key: str) -> None: + # Special test handling: If we're just going to override the backend + # anyway don't require redis + if self.__class__.backend_override_cls: + # This is totally bogus but it works for now... + backend = typing.cast(RemoteCacheBackend[bytes], None) + else: + backend = RedisRemoteCacheBackend(key) + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + + +class RemoteAutotuneCache(RedisRemoteCache): + pass + + +class RemoteFxGraphCache(RedisRemoteCache): + pass diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 8421dba05b98d5..288a44c6296961 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from __future__ import annotations + import builtins import copy import functools @@ -14,7 +16,7 @@ import sys import threading import time -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union import torch @@ -45,7 +47,8 @@ if TYPE_CHECKING: - from ..remote_cache import RemoteCacheBackend + from ..remote_cache import JsonDataTy, RemoteCache + try: import triton @@ -1037,7 +1040,7 @@ def load_cached_autotuning( return matching_configs[0] -def should_use_remote_autotune_cache(inductor_meta): +def _should_use_remote_autotune_cache(inductor_meta): if inductor_meta.get("autotune_remote_cache") is not None: return inductor_meta.get("autotune_remote_cache") if not inductor_meta.get("is_fbcode"): @@ -1095,14 +1098,16 @@ def cached_autotune( local_cache = None cache_filename = None - remote_cache: Optional[RemoteCacheBackend] = None + remote_cache: Optional[ + Union[RemoteCache[JsonDataTy], RemoteCache[object]] + ] = None remote_cache_key = None best_config = None if not inductor_meta.get("force_disable_caches", False): if inductor_meta.get("autotune_local_cache", True): local_cache = LocalAutotuneCache() cache_filename = os.path.splitext(filename)[0] + ".best_config" - if should_use_remote_autotune_cache(inductor_meta): + if _should_use_remote_autotune_cache(inductor_meta): backend_hash = inductor_meta.get("backend_hash", None) if backend_hash is not None: key = backend_hash + configs_hash + "autotune-best-config-v2" @@ -1111,16 +1116,14 @@ def cached_autotune( try: if inductor_meta.get("is_fbcode"): from torch._inductor.fb.remote_cache import ( - FbRemoteAutotuneCacheBackend, + FbRemoteAutotuneCache, ) - remote_cache = FbRemoteAutotuneCacheBackend(key) + remote_cache = FbRemoteAutotuneCache(key) else: - from torch._inductor.remote_cache import ( - RedisRemoteCacheBackend, - ) + from torch._inductor.remote_cache import RemoteAutotuneCache - remote_cache = RedisRemoteCacheBackend(key) + remote_cache = RemoteAutotuneCache(key) except Exception: remote_cache = None log.warning("Unable to create a remote cache", exc_info=True) From b7ed18b639fdc489962ac769e84a583ec78258af Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Sun, 1 Sep 2024 00:05:53 +0000 Subject: [PATCH 0322/1018] drop gil in couple places (leads to deadlocks) (#134910) Per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/134910 Approved by: https://github.com/eqy --- torch/csrc/cuda/Event.cpp | 10 ++++++---- torch/csrc/cuda/Module.cpp | 6 ++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/torch/csrc/cuda/Event.cpp b/torch/csrc/cuda/Event.cpp index b73316cfc1778f..0bb76907ee0f7c 100644 --- a/torch/csrc/cuda/Event.cpp +++ b/torch/csrc/cuda/Event.cpp @@ -123,10 +123,12 @@ static PyObject* THCPEvent_get_device(THCPEvent* self, void* unused) { } static PyObject* THCPEvent_record(PyObject* _self, PyObject* _stream) { - HANDLE_TH_ERRORS - auto self = (THCPEvent*)_self; - auto stream = (THCPStream*)_stream; - self->cuda_event.record(stream->cuda_stream); + HANDLE_TH_ERRORS { + auto self = (THCPEvent*)_self; + auto stream = (THCPStream*)_stream; + pybind11::gil_scoped_release no_gil{}; + self->cuda_event.record(stream->cuda_stream); + } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 4bfcc17500b964..451717dbeccdce 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -543,8 +543,10 @@ PyObject* THCPModule_setMemoryFraction(PyObject* _unused, PyObject* args) { } PyObject* THCPModule_emptyCache(PyObject* _unused, PyObject* noargs) { - HANDLE_TH_ERRORS - c10::cuda::CUDACachingAllocator::emptyCache(); + HANDLE_TH_ERRORS { + pybind11::gil_scoped_release no_gil; + c10::cuda::CUDACachingAllocator::emptyCache(); + } END_HANDLE_TH_ERRORS Py_RETURN_NONE; } From 7e4bcecff5145a900f3f84985464709e187fbc4f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Sun, 1 Sep 2024 01:33:43 +0000 Subject: [PATCH 0323/1018] [ET] codegen: bool array as array ref (#134886) Test Plan: CI Differential Revision: D62046959 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134886 Approved by: https://github.com/larryliu0820 --- tools/test/test_executorch_types.py | 5 +++++ torchgen/executorch/api/et_cpp.py | 3 +-- torchgen/executorch/api/unboxing.py | 11 ++++++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tools/test/test_executorch_types.py b/tools/test/test_executorch_types.py index c00a02cd500e1b..dedb19e21f3e6d 100644 --- a/tools/test/test_executorch_types.py +++ b/tools/test/test_executorch_types.py @@ -3,6 +3,7 @@ from torchgen import local from torchgen.api.types import ( BaseCType, + boolT, ConstRefCType, CType, longT, @@ -64,6 +65,10 @@ def test_argumenttype_type(self) -> None: "int[]? dims", NamedCType("dims", OptionalCType(ArrayRefCType(BaseCType(longT)))), ), + ( + "bool[3] output_mask", + NamedCType("output_mask", ArrayRefCType(BaseCType(boolT))), + ), ] for d in data: self._test_argumenttype_type(*d) diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py index 0bdf28acfa6418..76cebcd0f0f1dc 100644 --- a/torchgen/executorch/api/et_cpp.py +++ b/torchgen/executorch/api/et_cpp.py @@ -5,7 +5,6 @@ from torchgen import local from torchgen.api.types import ( ArgName, - ArrayCType, BaseCType, Binding, ConstRefCType, @@ -88,7 +87,7 @@ def valuetype_type( if str(t.elem) == "bool": assert t.size is not None return NamedCType( - binds, ArrayCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]), t.size) + binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool])) ) else: return None diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py index f206980af44fd7..6845e72a22a5d8 100644 --- a/torchgen/executorch/api/unboxing.py +++ b/torchgen/executorch/api/unboxing.py @@ -173,7 +173,16 @@ def _gen_code_list_type( # handle list type with size, e.g., bool[4] code.extend( f""" - auto {out_name} = {arg_name}.toBoolList(); +#ifdef USE_ATEN_LIB +std::array {out_name}; +auto {in_name} = {arg_name}.toBoolList(); +size_t _i = 0; +for (auto {elem_name}: {in_name}) {{ + {out_name}[_i++] = {elem_name}; +}} +#else +auto {out_name} = {arg_name}.toBoolList(); +#endif """.split( "\n" ) From 6c1bc7287bd2aa0d665dfdd1d108422791f691e1 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Sun, 1 Sep 2024 09:07:25 +0000 Subject: [PATCH 0324/1018] =?UTF-8?q?expose=20host=5FemptyCache=20to=20pyt?= =?UTF-8?q?hon,=20fix=20a=20bug=20in=20freeing=20cudaHostRegist=E2=80=A6?= =?UTF-8?q?=20(#134919)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ered memory Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134919 Approved by: https://github.com/eqy --- aten/src/ATen/cuda/CachingHostAllocator.cpp | 4 ++-- test/test_cuda.py | 14 ++++++++++++++ torch/csrc/cuda/Module.cpp | 10 ++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index c88aaa04826ac7..b430fb984a426d 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -98,7 +98,7 @@ struct CUDACachingHostAllocatorImpl pinned_use_cuda_host_register()) { void* ptr = block->ptr_; AT_CUDA_CHECK(cudaHostUnregister(ptr)); - free(ptr); + std::free(ptr); } else { AT_CUDA_CHECK(cudaFreeHost(block->ptr_)); } @@ -175,7 +175,7 @@ struct CUDACachingHostAllocatorImpl // Here we do regular allocation, pre-fault/map the pages, and then do // cudaHostRegister with GPU mapping flags to lock the pages, so we // can minimize the cost for the cuda global lock. - *ptr = malloc(roundSize); + *ptr = std::malloc(roundSize); // Parallelize the mapping/registering of pages to reduce wall time size_t pageSize = (1 << 12); // 4kB pages diff --git a/test/test_cuda.py b/test/test_cuda.py index f53ae5f61da150..c6a61e1218ecd2 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -157,6 +157,20 @@ def test_pinned_memory_with_cudaregister_multithread(self): for thread in threads: thread.join() + def test_pinned_memory_empty_cache(self): + for alloc_settings in (True, False): + torch.cuda.memory._set_allocator_settings( + f"pinned_use_cuda_host_register:{alloc_settings}" + ) + try: + t = torch.ones(1024 * 1024, pin_memory=True) + self.assertTrue(t.is_pinned()) + del t + torch._C._host_emptyCache() + except RuntimeError as e: + # Some GPUs don't support same address space on host and device side + pass + def test_cudart_register(self): t = torch.ones(20) self.assertFalse(t.is_pinned()) diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 451717dbeccdce..134758b2c06814 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -542,6 +542,15 @@ PyObject* THCPModule_setMemoryFraction(PyObject* _unused, PyObject* args) { Py_RETURN_NONE; } +PyObject* THCPModule_hostEmptyCache(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS { + pybind11::gil_scoped_release no_gil; + at::cuda::CachingHostAllocator_emptyCache(); + } + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + PyObject* THCPModule_emptyCache(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS { pybind11::gil_scoped_release no_gil; @@ -1839,6 +1848,7 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_cudaHostAllocator, METH_NOARGS, nullptr}, + {"_host_emptyCache", THCPModule_hostEmptyCache, METH_NOARGS, nullptr}, {"_cuda_cudaCachingAllocator_raw_alloc", THCPModule_cudaCachingAllocator_raw_alloc, METH_VARARGS, From 063ec360b1836b7e170174174a48a4feac1584ca Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 1 Sep 2024 09:38:35 +0000 Subject: [PATCH 0325/1018] Check function declarations of Vulkan code (#134550) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134550 Approved by: https://github.com/ezyang --- aten/src/ATen/native/vulkan/api/Pipeline.cpp | 2 +- aten/src/ATen/native/vulkan/api/QueryPool.cpp | 8 +- aten/src/ATen/native/vulkan/api/Resource.cpp | 2 +- aten/src/ATen/native/vulkan/api/Tensor.cpp | 4 +- aten/src/ATen/native/vulkan/impl/Packing.cpp | 2 +- aten/src/ATen/native/vulkan/ops/Batchnorm.cpp | 2 +- aten/src/ATen/native/vulkan/ops/BinaryOp.cpp | 75 +++++++++-------- .../ATen/native/vulkan/ops/Convolution.cpp | 83 ++----------------- aten/src/ATen/native/vulkan/ops/Copy.cpp | 2 +- aten/src/ATen/native/vulkan/ops/Factory.cpp | 4 +- aten/src/ATen/native/vulkan/ops/Gru.cpp | 3 +- aten/src/ATen/native/vulkan/ops/Layernorm.cpp | 2 +- aten/src/ATen/native/vulkan/ops/Lstm.cpp | 2 +- .../native/vulkan/ops/QuantizedTensor.cpp | 10 +-- aten/src/ATen/native/vulkan/ops/Random.cpp | 12 +-- aten/src/ATen/native/vulkan/ops/Shape.cpp | 4 +- aten/src/ATen/native/vulkan/ops/Upsample.cpp | 4 +- aten/src/ATen/native/vulkan/ops/Utils.cpp | 1 + caffe2/CMakeLists.txt | 2 +- 19 files changed, 76 insertions(+), 148 deletions(-) diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp index 114fad05afaa51..93a631c31e3acd 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp +++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp @@ -230,7 +230,7 @@ void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept { rhs.handle_ = tmp_handle; } -bool operator==( +static bool operator==( const ComputePipeline::Descriptor& _1, const ComputePipeline::Descriptor& _2) { return ( diff --git a/aten/src/ATen/native/vulkan/api/QueryPool.cpp b/aten/src/ATen/native/vulkan/api/QueryPool.cpp index e74b7431b0f795..9c0c7fb2ea860c 100644 --- a/aten/src/ATen/native/vulkan/api/QueryPool.cpp +++ b/aten/src/ATen/native/vulkan/api/QueryPool.cpp @@ -170,13 +170,7 @@ void QueryPool::extract_results() { results_pending_ = false; } -std::ostream& operator<<(std::ostream& os, const VkExtent3D& extents) { - os << "{" << extents.width << ", " << extents.height << ", " << extents.depth - << "}"; - return os; -} - -std::string stringize(const VkExtent3D& extents) { +static std::string stringize(const VkExtent3D& extents) { std::stringstream ss; ss << "{" << extents.width << ", " << extents.height << ", " << extents.depth << "}"; diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index ebed09d47e7d4e..4981f54accfd3e 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -253,7 +253,7 @@ BufferMemoryBarrier::BufferMemoryBarrier( // ImageSampler // -bool operator==( +static bool operator==( const ImageSampler::Properties& _1, const ImageSampler::Properties& _2) { return ( diff --git a/aten/src/ATen/native/vulkan/api/Tensor.cpp b/aten/src/ATen/native/vulkan/api/Tensor.cpp index f89cd0fb9aaca0..bbaaef1d1ddb8f 100644 --- a/aten/src/ATen/native/vulkan/api/Tensor.cpp +++ b/aten/src/ATen/native/vulkan/api/Tensor.cpp @@ -460,7 +460,7 @@ void vTensor::virtual_resize(const std::vector& new_sizes) { // vTensorStorage // -api::VulkanImage allocate_image( +static api::VulkanImage allocate_image( api::Context* const context_ptr, api::utils::uvec3& extents, const api::StorageType storage_type, @@ -505,7 +505,7 @@ api::VulkanImage allocate_image( /*allocate_memory = */ allocate_memory); } -api::VulkanBuffer allocate_buffer( +static api::VulkanBuffer allocate_buffer( api::Context* const context_ptr, const int64_t numel, const api::StorageType storage_type, diff --git a/aten/src/ATen/native/vulkan/impl/Packing.cpp b/aten/src/ATen/native/vulkan/impl/Packing.cpp index 3c47208c4805fc..10d918ffbbfd20 100644 --- a/aten/src/ATen/native/vulkan/impl/Packing.cpp +++ b/aten/src/ATen/native/vulkan/impl/Packing.cpp @@ -296,7 +296,7 @@ bool record_buffer_to_nchw_op( v_src.buffer_metadata()); } -vTensor channel_image_repacking( +static vTensor channel_image_repacking( const vTensor& v_input, api::GPUMemoryLayout target_layout, const api::ShaderInfo& shader_descriptor) { diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp index e12e69c4ebec2a..8fbcc54bd47400 100644 --- a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp @@ -15,7 +15,7 @@ struct Params final { float eps; }; -void record_op( +static void record_op( api::Context* const context, vTensor& v_output, const vTensor& v_input, diff --git a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp index e1445f40ac5f87..bb4e5187a4c8e5 100644 --- a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp +++ b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp @@ -1,9 +1,9 @@ +#ifdef USE_VULKAN_API #include #include #include #include #include -#include namespace at { namespace native { @@ -12,7 +12,7 @@ namespace ops { using namespace api::utils; -Tensor binary_op_scalar( +static Tensor binary_op_scalar( const Tensor& self_arg, const Scalar& other, const std::optional& alpha_arg, @@ -66,7 +66,7 @@ Tensor binary_op_scalar( return convert(v_output); } -Tensor binary_op_preprocess_other_arg(const Tensor& other_arg) { +static Tensor binary_op_preprocess_other_arg(const Tensor& other_arg) { // Similar to binary_op_scalar where tensors is mapped to float, we // also map known integer types (but not quant types) tensor to float. @@ -99,7 +99,7 @@ Tensor binary_op_preprocess_other_arg(const Tensor& other_arg) { return other; } -Tensor& binary_op_scalar_( +static Tensor& binary_op_scalar_( Tensor& self_arg, const Scalar& other, const std::optional& alpha_arg, @@ -149,7 +149,7 @@ Tensor& binary_op_scalar_( return self_arg; } -Tensor binary_op_tensor( +static Tensor binary_op_tensor( const Tensor& self_arg, const Tensor& other_arg, const std::optional& alpha_arg, @@ -222,7 +222,7 @@ Tensor binary_op_tensor( return convert(v_output); } -Tensor quantized_binary_op_tensor( +static Tensor quantized_binary_op_tensor( const Tensor& self_arg, const Tensor& other_arg, const double scale, @@ -310,7 +310,7 @@ Tensor quantized_binary_op_tensor( return convert_quantized(v_output); } -Tensor& binary_op_tensor_( +static Tensor& binary_op_tensor_( Tensor& self_arg, const Tensor& other_arg, const std::optional& alpha_arg, @@ -384,7 +384,7 @@ Tensor& binary_op_tensor_( return self_arg; } -Tensor add_scalar( +static Tensor add_scalar( const Tensor& self_arg, const Scalar& other, const Scalar& alpha) { @@ -392,7 +392,10 @@ Tensor add_scalar( self_arg, other, std::optional(alpha), VK_KERNEL(add_scalar)); } -Tensor& add_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { +static Tensor& add_scalar_( + Tensor& self, + const Scalar& other, + const Scalar& alpha) { return binary_op_scalar_( self, other, std::optional(alpha), VK_KERNEL(add_scalar_inplace)); } @@ -433,7 +436,7 @@ Tensor quantized_div( self_arg, other_arg, scale, zero_point, VK_KERNEL(quantized_div)); } -Tensor add_tensor( +static Tensor add_tensor( const Tensor& self_arg, const Tensor& other_arg, const Scalar& alpha) { @@ -441,7 +444,7 @@ Tensor add_tensor( self_arg, other_arg, std::optional(alpha), VK_KERNEL(add)); } -Tensor& add_tensor_( +static Tensor& add_tensor_( Tensor& self, const Tensor& other_arg, const Scalar& alpha) { @@ -449,7 +452,7 @@ Tensor& add_tensor_( self, other_arg, std::optional(alpha), VK_KERNEL(add_inplace)); } -Tensor sub_scalar( +static Tensor sub_scalar( const Tensor& self_arg, const Scalar& other, const Scalar& alpha) { @@ -460,7 +463,10 @@ Tensor sub_scalar( VK_KERNEL(add_scalar)); } -Tensor& sub_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { +static Tensor& sub_scalar_( + Tensor& self, + const Scalar& other, + const Scalar& alpha) { return binary_op_scalar_( self, other, @@ -468,7 +474,7 @@ Tensor& sub_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { VK_KERNEL(add_scalar_inplace)); } -Tensor sub_tensor( +static Tensor sub_tensor( const Tensor& self_arg, const Tensor& other_arg, const Scalar& alpha) { @@ -476,7 +482,7 @@ Tensor sub_tensor( self_arg, other_arg, std::optional(alpha), VK_KERNEL(sub)); } -Tensor& sub_tensor_( +static Tensor& sub_tensor_( Tensor& self, const Tensor& other_arg, const Scalar& alpha) { @@ -484,27 +490,27 @@ Tensor& sub_tensor_( self, other_arg, std::optional(alpha), VK_KERNEL(sub_inplace)); } -Tensor mul_scalar(const Tensor& self_arg, const Scalar& other) { +static Tensor mul_scalar(const Tensor& self_arg, const Scalar& other) { return binary_op_scalar( self_arg, other, std::optional(), VK_KERNEL(mul_scalar)); } -Tensor& mul_scalar_(Tensor& self, const Scalar& other) { +static Tensor& mul_scalar_(Tensor& self, const Scalar& other) { return binary_op_scalar_( self, other, std::optional(), VK_KERNEL(mul_scalar_inplace)); } -Tensor mul_tensor(const Tensor& self_arg, const Tensor& other_arg) { +static Tensor mul_tensor(const Tensor& self_arg, const Tensor& other_arg) { return binary_op_tensor( self_arg, other_arg, std::optional(), VK_KERNEL(mul)); } -Tensor& mul_tensor_(Tensor& self, const Tensor& other_arg) { +static Tensor& mul_tensor_(Tensor& self, const Tensor& other_arg) { return binary_op_tensor_( self, other_arg, std::optional(), VK_KERNEL(mul_inplace)); } -Tensor div_scalar(const Tensor& self_arg, const Scalar& other) { +static Tensor div_scalar(const Tensor& self_arg, const Scalar& other) { return binary_op_scalar( self_arg, 1.0 / other.to(), @@ -512,7 +518,7 @@ Tensor div_scalar(const Tensor& self_arg, const Scalar& other) { VK_KERNEL(mul_scalar)); } -Tensor& div_scalar_(Tensor& self, const Scalar& other) { +static Tensor& div_scalar_(Tensor& self, const Scalar& other) { return binary_op_scalar_( self, 1.0 / other.to(), @@ -520,31 +526,31 @@ Tensor& div_scalar_(Tensor& self, const Scalar& other) { VK_KERNEL(mul_scalar_inplace)); } -Tensor div_tensor(const Tensor& self_arg, const Tensor& other_arg) { +static Tensor div_tensor(const Tensor& self_arg, const Tensor& other_arg) { return binary_op_tensor( self_arg, other_arg, std::optional(), VK_KERNEL(div)); } -Tensor& div_tensor_(Tensor& self, const Tensor& other_arg) { +static Tensor& div_tensor_(Tensor& self, const Tensor& other_arg) { return binary_op_tensor_( self, other_arg, std::optional(), VK_KERNEL(div_inplace)); } -Tensor pow(const Tensor& self, const Tensor& other) { +static Tensor pow(const Tensor& self, const Tensor& other) { return binary_op_tensor(self, other, std::optional(), VK_KERNEL(pow)); } -Tensor& pow_(Tensor& self, const Tensor& other) { +static Tensor& pow_(Tensor& self, const Tensor& other) { return binary_op_tensor_( self, other, std::optional(), VK_KERNEL(pow_inplace)); } -Tensor pow_tensor_scalar(const Tensor& self, const Scalar& other) { +static Tensor pow_tensor_scalar(const Tensor& self, const Scalar& other) { return binary_op_scalar( self, other, std::optional(), VK_KERNEL(pow_tensor_scalar)); } -Tensor& pow_tensor_scalar_(Tensor& self, const Scalar& other) { +static Tensor& pow_tensor_scalar_(Tensor& self, const Scalar& other) { return binary_op_scalar_( self, other, @@ -552,12 +558,12 @@ Tensor& pow_tensor_scalar_(Tensor& self, const Scalar& other) { VK_KERNEL(pow_tensor_scalar_inplace)); } -Tensor pow_scalar_tensor(const Scalar& self, const Tensor& other) { +static Tensor pow_scalar_tensor(const Scalar& self, const Tensor& other) { return binary_op_scalar( other, self, std::optional(), VK_KERNEL(pow_scalar_tensor)); } -Tensor floor_divide_scalar(const Tensor& self, const Scalar& other) { +static Tensor floor_divide_scalar(const Tensor& self, const Scalar& other) { TORCH_CHECK( other.to() != 0.0f, "floor_divide_scalar: can't divide by zero"); return binary_op_scalar( @@ -567,7 +573,7 @@ Tensor floor_divide_scalar(const Tensor& self, const Scalar& other) { VK_KERNEL(floor_mul_scalar)); } -Tensor& floor_divide_scalar_(Tensor& self, const Scalar& other) { +static Tensor& floor_divide_scalar_(Tensor& self, const Scalar& other) { TORCH_CHECK( other.to() != 0.0f, "floor_divide_scalar_: can't divide by zero"); return binary_op_scalar_( @@ -577,12 +583,12 @@ Tensor& floor_divide_scalar_(Tensor& self, const Scalar& other) { VK_KERNEL(floor_mul_scalar_inplace)); } -Tensor floor_divide_tensor(const Tensor& self, const Tensor& other) { +static Tensor floor_divide_tensor(const Tensor& self, const Tensor& other) { return binary_op_tensor( self, other, std::optional(), VK_KERNEL(floor_divide)); } -Tensor& floor_divide_tensor_(Tensor& self, const Tensor& other_arg) { +static Tensor& floor_divide_tensor_(Tensor& self, const Tensor& other_arg) { return binary_op_tensor_( self, other_arg, @@ -590,8 +596,6 @@ Tensor& floor_divide_tensor_(Tensor& self, const Tensor& other_arg) { VK_KERNEL(floor_divide_inplace)); } -#ifdef USE_VULKAN_API - TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::add.Scalar"), TORCH_FN(add_scalar)); m.impl(TORCH_SELECTIVE_NAME("aten::add_.Scalar"), TORCH_FN(add_scalar_)); @@ -631,9 +635,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { TORCH_FN(floor_divide_tensor_)); } -#endif /* USE_VULKAN_API */ - } // namespace ops } // namespace vulkan } // namespace native } // namespace at +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index f210c253800b19..3831cce8937367 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -53,7 +53,7 @@ inline bool is_pointwise(const IntArrayRef weight_size) { return true; } -Conv2dMethod determine_method( +static Conv2dMethod determine_method( const IntArrayRef weight_size, const IntArrayRef stride, const IntArrayRef padding, @@ -359,7 +359,7 @@ struct Params final { api::utils::vec2 clamp; }; -void record_op( +static void record_op( api::Context* const context, api::ShaderInfo& compute_shader, vTensor& v_output, @@ -432,7 +432,7 @@ struct QParams final { api::utils::vec2 clamp; }; -void record_quantized_op( +static void record_quantized_op( api::Context* const context, api::ShaderInfo& compute_shader, vTensor& v_output, @@ -787,56 +787,11 @@ Tensor convolution( input, c10::make_intrusive(conv_context)); } -Tensor quantized_convolution( - const Tensor& input, - const Tensor& weight, - const std::optional& bias, - const IntArrayRef stride, - const IntArrayRef padding, - const IntArrayRef dilation, - const bool transposed, - const IntArrayRef output_padding, - const int64_t groups, - const double out_scale, - const int64_t out_zero_point) { - if (transposed) { - return run_tconv2d_context( - input, - c10::make_intrusive(Conv2dPackedContext( - weight, - bias, - stride, - padding, - dilation, - transposed, - false, - output_padding, - groups))); - } - - Conv2dPackedContext conv_context = Conv2dPackedContext( - weight, - bias, - stride, - padding, - dilation, - transposed, - true, - output_padding, - groups); - - return run_qconv2d_context( - input, - out_scale, - out_zero_point, - c10::make_intrusive(conv_context)); -} - } // namespace namespace conv1d { -vTensor pack_weights_using_width_packing(const Tensor& weight_arg) { +static vTensor pack_weights_using_width_packing(const Tensor& weight_arg) { Tensor weight = weight_arg; if (weight.is_cpu()) { @@ -862,7 +817,7 @@ vTensor pack_weights_using_width_packing(const Tensor& weight_arg) { * This is a full implementation. For algorithm details, refer to the shader * kernel code. */ -Tensor run_conv1d_context_impl( +static Tensor run_conv1d_context_impl( const Tensor& input_arg, const Tensor& weight_arg, const std::optional& bias_arg_opt, @@ -1150,7 +1105,7 @@ c10::intrusive_ptr create_qtconv2d_context( output_max)); } -Tensor run_conv2d_context_impl( +static Tensor run_conv2d_context_impl( const Tensor& input_arg, const c10::intrusive_ptr& conv_context, double scale, @@ -1291,30 +1246,6 @@ Tensor run_qconv2d_context( return run_conv2d_context_impl(input_arg, conv_context, scale, zero_point); } -Tensor quantized_conv2d( - const Tensor& input, - const Tensor& weight, - const std::optional& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - double out_scale, - int64_t out_zero_point) { - return quantized_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - false, - {{0, 0}}, - groups, - out_scale, - out_zero_point); -} - /* Backwards compatibility */ Conv2dOpContext::Conv2dOpContext(Conv2dPackedContext conv_context) : conv_context_{std::move(conv_context)} {} @@ -1444,7 +1375,7 @@ c10::intrusive_ptr create_conv1d_context( Conv1dPackedContext(weight, bias, stride, padding, dilation, groups)); } -Tensor convolution1d( +static Tensor convolution1d( const Tensor& input, const Tensor& weight, const std::optional& bias, diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp index 60bc3a341ba0df..d8dcee9391fbe2 100644 --- a/aten/src/ATen/native/vulkan/ops/Copy.cpp +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -123,7 +123,7 @@ void transfer_vulkan_to_cpu(vTensor& v_src, Tensor& dst) { .to(convert_dtype(v_src.dtype())); } -void transfer_vulkan_to_vulkan(vTensor& src, vTensor& dst) { +static void transfer_vulkan_to_vulkan(vTensor& src, vTensor& dst) { api::Context* const context = api::context(); api::PipelineBarrier pipeline_barrier{}; diff --git a/aten/src/ATen/native/vulkan/ops/Factory.cpp b/aten/src/ATen/native/vulkan/ops/Factory.cpp index 153a6448eaf4ab..d8cd21eb659f3e 100644 --- a/aten/src/ATen/native/vulkan/ops/Factory.cpp +++ b/aten/src/ATen/native/vulkan/ops/Factory.cpp @@ -28,7 +28,7 @@ Tensor _empty_affine_quantized( }); } -Tensor empty_memory_format( +static Tensor empty_memory_format( const IntArrayRef sizes, const std::optional dtype, const std::optional layout, @@ -46,7 +46,7 @@ Tensor empty_memory_format( }); } -Tensor empty_strided( +static Tensor empty_strided( const IntArrayRef sizes, const IntArrayRef /* strides */, const std::optional dtype, diff --git a/aten/src/ATen/native/vulkan/ops/Gru.cpp b/aten/src/ATen/native/vulkan/ops/Gru.cpp index a66c69b134cee1..e803e9b9686f92 100644 --- a/aten/src/ATen/native/vulkan/ops/Gru.cpp +++ b/aten/src/ATen/native/vulkan/ops/Gru.cpp @@ -135,7 +135,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { } // namespace -std::vector> pack_linear_op_contexts( +static std::vector> +pack_linear_op_contexts( const std::vector& params_cpu, int64_t num_layers) { TORCH_CHECK( diff --git a/aten/src/ATen/native/vulkan/ops/Layernorm.cpp b/aten/src/ATen/native/vulkan/ops/Layernorm.cpp index 6b6a4b866c700c..f2d285e736f4a5 100644 --- a/aten/src/ATen/native/vulkan/ops/Layernorm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Layernorm.cpp @@ -78,7 +78,7 @@ Tensor run_layernorm_context( return std::get<0>(native_layer_norm_output); } -Tensor layer_norm( +static Tensor layer_norm( const at::Tensor& input_arg, IntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, diff --git a/aten/src/ATen/native/vulkan/ops/Lstm.cpp b/aten/src/ATen/native/vulkan/ops/Lstm.cpp index 7e8000370346aa..63f17f15eb2c14 100644 --- a/aten/src/ATen/native/vulkan/ops/Lstm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Lstm.cpp @@ -171,7 +171,7 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { } // namespace -std::vector> +static std::vector> pack_lstm_linear_op_contexts( const std::vector& params_cpu, int64_t num_layers) { diff --git a/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp b/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp index 81f1a9c0197ab2..228a1ed5262a67 100644 --- a/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp @@ -1,3 +1,4 @@ +#ifdef USE_VULKAN_API #include #include #include @@ -161,13 +162,13 @@ Tensor dequantize_helper( return convert(v_output); } -double q_scale(const Tensor& self) { +static double q_scale(const Tensor& self) { TORCH_CHECK(self.is_vulkan(), "Expecting a vulkan tensor for q_scale"); const vTensor& v_input = convert(self); return v_input.get_scale(); } -int64_t q_zero_point(const Tensor& self) { +static int64_t q_zero_point(const Tensor& self) { TORCH_CHECK(self.is_vulkan(), "Expecting a vulkan tensor for q_zero_point"); const vTensor& v_input = convert(self); return v_input.get_zero_point(); @@ -179,8 +180,6 @@ Tensor dequantize(const Tensor& self) { return dequantize_helper(self, q_scale, zero_point, kFloat); } -#ifdef USE_VULKAN_API - TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl( TORCH_SELECTIVE_NAME("aten::quantize_per_tensor"), quantize_per_tensor); @@ -192,9 +191,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::dequantize.self"), dequantize); } -#endif /* USE_VULKAN_API */ - } // namespace ops } // namespace vulkan } // namespace native } // namespace at +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Random.cpp b/aten/src/ATen/native/vulkan/ops/Random.cpp index 3103f7fe6f58d1..49199b48cb9709 100644 --- a/aten/src/ATen/native/vulkan/ops/Random.cpp +++ b/aten/src/ATen/native/vulkan/ops/Random.cpp @@ -12,7 +12,9 @@ namespace ops { using namespace api::utils; -Tensor& uniform_( +#ifdef USE_VULKAN_API + +static Tensor& uniform_( Tensor& self, const double from, const double to, @@ -57,7 +59,7 @@ Tensor& uniform_( return self; } -Tensor rand_like( +static Tensor rand_like( const at::Tensor& input_arg, const std::optional /* not implemented */, const std::optional /* not implemented */, @@ -71,7 +73,7 @@ Tensor rand_like( return input_arg.clone().detach().uniform_(0.0, 1.0); } -Tensor& normal_( +static Tensor& normal_( Tensor& self, const double mean, const double std, @@ -118,7 +120,7 @@ Tensor& normal_( return self; } -Tensor randn_like( +static Tensor randn_like( const at::Tensor& input_arg, const std::optional /* not implemented */, const std::optional /* not implemented */, @@ -130,8 +132,6 @@ Tensor randn_like( return input_arg.clone().detach().normal_(0.0, 1.0); } -#ifdef USE_VULKAN_API - TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::uniform_"), TORCH_FN(uniform_)); m.impl(TORCH_SELECTIVE_NAME("aten::rand_like"), TORCH_FN(rand_like)); diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp index 2a13979523f638..fd10a05a7c87f2 100644 --- a/aten/src/ATen/native/vulkan/ops/Shape.cpp +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -8,7 +8,7 @@ namespace native { namespace vulkan { namespace ops { -Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { +static Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { api::Context* const context = api::context(); Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); @@ -52,7 +52,7 @@ inline Tensor view(const Tensor& self_arg, IntArrayRef shape) { return view_internal(self_arg, shape); } -Tensor _reshape_alias( +static Tensor _reshape_alias( const Tensor& self_arg, const IntArrayRef shape, const IntArrayRef strides) { diff --git a/aten/src/ATen/native/vulkan/ops/Upsample.cpp b/aten/src/ATen/native/vulkan/ops/Upsample.cpp index 7e3a2ead2d632f..fc426e2da73838 100644 --- a/aten/src/ATen/native/vulkan/ops/Upsample.cpp +++ b/aten/src/ATen/native/vulkan/ops/Upsample.cpp @@ -9,7 +9,7 @@ namespace vulkan { namespace ops { using namespace api::utils; -Tensor upsample_nearest2d( +static Tensor upsample_nearest2d( const Tensor& input_arg, const IntArrayRef output_sizes, const std::optional scales_h, @@ -94,7 +94,7 @@ Tensor upsample_nearest2d( return convert(v_output); } -Tensor upsample_bilinear2d( +static Tensor upsample_bilinear2d( const Tensor& input_arg, const IntArrayRef output_sizes, bool align_corners, diff --git a/aten/src/ATen/native/vulkan/ops/Utils.cpp b/aten/src/ATen/native/vulkan/ops/Utils.cpp index 1e6c18cfa43b1e..36d4221c666bf6 100644 --- a/aten/src/ATen/native/vulkan/ops/Utils.cpp +++ b/aten/src/ATen/native/vulkan/ops/Utils.cpp @@ -1,5 +1,6 @@ #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 63c2fef1e195ef..b3e9e3ca721027 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -767,7 +767,7 @@ if(NOT MSVC) set_source_files_properties(${PROJECT_SOURCE_DIR}/torch/csrc/distributed/c10d/socket.cpp PROPERTIES COMPILE_OPTIONS "-Wno-error=deprecated") endif() -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_VULKAN AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) target_compile_options_if_supported(torch_cpu "-Wmissing-prototypes") target_compile_options_if_supported(torch_cpu "-Werror=missing-prototypes") get_target_property(TORCH_CPU_SOURCES torch_cpu SOURCES) From 3c5919fa6dec503c465ba6a8f68906bf04576656 Mon Sep 17 00:00:00 2001 From: eqy Date: Sun, 1 Sep 2024 10:48:04 +0000 Subject: [PATCH 0326/1018] Abate `-Wsign-compare` warning spam in `Indexing.cu` (#134805) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix for warning spam like ``` warning: comparison of integer expressions of different signedness: ‘long int’ and ‘uint64_t’ {aka ‘long unsigned int’} [-Wsign-compare] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134805 Approved by: https://github.com/janeyx99 --- aten/src/ATen/native/cuda/Indexing.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 613ecf9b9d3f98..ee83ee5c6d3b8a 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -1350,7 +1350,7 @@ template void index_select_out_cuda_impl( Tensor& out, const Tensor& self, - long dim, + uint64_t dim, const Tensor& index) { uint64_t numIndices = index.numel(); uint64_t selfDims = self.dim() == 0 ? 1 : self.dim(); @@ -1511,13 +1511,13 @@ Tensor& index_select_out_cuda( self.qscheme() == kPerTensorAffine, "Only per_tensor quantized quantized tensors are supported by index_select.") AT_DISPATCH_QINT_TYPES(out.scalar_type(), "index_select_quant_cuda", [&] { - index_select_out_cuda_impl(out, self, dim, index); + index_select_out_cuda_impl(out, self, (uint64_t) dim, index); }); } else { AT_DISPATCH_V2( out.scalar_type(), "index_select_cuda", - AT_WRAP([&] { index_select_out_cuda_impl(out, self, dim, index); }), + AT_WRAP([&] { index_select_out_cuda_impl(out, self, (uint64_t) dim, index); }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kComplexHalf, kHalf, From 7207ff88e221a0044751554449250a2096e5fb3b Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 1 Sep 2024 15:15:38 +0000 Subject: [PATCH 0327/1018] [Reland] [Torchgen] Pass mutable to cpp.valuetype_type (#134549) Reland of #121415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134549 Approved by: https://github.com/ezyang --- torchgen/api/cpp.py | 6 ++++-- torchgen/api/structured.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 19e5a11b8cdf2a..c657570ee3e249 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -94,6 +94,7 @@ def valuetype_type( t: Type, *, binds: ArgName, + mutable: bool = True, remove_non_owning_ref_types: bool = False, symint: bool = False, ) -> NamedCType | None: @@ -113,7 +114,7 @@ def valuetype_type( # All other BaseType currently map directly to BaseCppTypes. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) elif isinstance(t, OptionalType): - elem = valuetype_type(t.elem, binds=binds, symint=symint) + elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint) if elem is None: return None return NamedCType(binds, OptionalCType(elem.type)) @@ -143,6 +144,7 @@ def argumenttype_type( r = valuetype_type( t, binds=binds, + mutable=mutable, symint=symint, remove_non_owning_ref_types=remove_non_owning_ref_types, ) @@ -231,7 +233,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: # placeholder is ignored # NB: symint is ALWAYS respected for return types. So symint argument # here is IGNORED - r = valuetype_type(t, binds="__placeholder__", symint=True) + r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True) if r is not None: return r.type diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index a93d666114de64..93a72eb2b4a5c1 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -48,7 +48,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # CompositeExplicitAutograd and the meta function (which could # hypothetically be SymInt), but for simplicity we plan for these to just # be handled in Python - r = cpp.valuetype_type(t, symint=False, binds=binds) + r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) if r is not None: return r From 301e0180c0bcf393a76e8c3ed8e54fe129c36eaf Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Sun, 1 Sep 2024 19:02:09 +0000 Subject: [PATCH 0328/1018] Log fx graph cache bypass reasons (#134792) Summary: Lets track when we bypass and why Test Plan: unit tests Differential Revision: D61994739 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134792 Approved by: https://github.com/jamesjwu --- torch/_inductor/codecache.py | 3 +++ torch/_utils_internal.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index af86b71f5673ab..e3c4faedcb9f73 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -60,6 +60,7 @@ rocm_compile_command, rocm_compiler, ) +from torch._utils_internal import log_cache_bypass T = TypeVar("T") @@ -1359,6 +1360,8 @@ def load( # type: ignore[no-untyped-def] cache_state = "bypass" log.info("Bypassing FX Graph Cache because '%s'", e) cache_info["cache_bypass_reason"] = str(e) + if remote: + log_cache_bypass("bypass_fx_graph", str(e)) cache_event_time = time_ns() if not compiled_graph: compiled_graph = compile_fx_fn( diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 877856243545a0..f254217452061f 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -132,6 +132,10 @@ def log_trace_structured_event(*args, **kwargs) -> None: pass +def log_cache_bypass(*args, **kwargs) -> None: + pass + + def log_torchscript_usage(api: str, **kwargs): _ = api return From e94a2788aa9162204e3dc978629a89a45e500fe5 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 26 Aug 2024 13:16:52 -0400 Subject: [PATCH 0329/1018] Move FunctionSchema implementations to cpp file (#133856) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/133856 Approved by: https://github.com/bdhirsh, https://github.com/albanD --- aten/src/ATen/core/function_schema.cpp | 407 ++++++++++++++++++++++ aten/src/ATen/core/function_schema.h | 4 +- aten/src/ATen/core/function_schema_inl.h | 408 ----------------------- 3 files changed, 409 insertions(+), 410 deletions(-) diff --git a/aten/src/ATen/core/function_schema.cpp b/aten/src/ATen/core/function_schema.cpp index 83d22242e8629d..cebab59d0664eb 100644 --- a/aten/src/ATen/core/function_schema.cpp +++ b/aten/src/ATen/core/function_schema.cpp @@ -197,4 +197,411 @@ bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs, const SchemaAr return rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes); } } + +std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { + // eventually this should look almost identical to python arg parser, but + // it is simpler for now to work directly on this schema + + out << schema.name(); + if (!schema.overload_name().empty()) { + out << "." << schema.overload_name(); + } + out << "("; + + bool seen_kwarg_only = false; + for (const auto i : c10::irange(schema.arguments().size())) { + if (i > 0) out << ", "; + if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) { + out << "*, "; + seen_kwarg_only = true; + } + out << schema.arguments()[i]; + } + + if(schema.is_vararg()) { + if(!schema.arguments().empty()) + out << ", "; + out << "..."; + } + + out << ") -> "; + + const auto& returns = schema.returns(); + + /* + * We should skip parenthesis if we return a single item and it's not varret, + * or we return nothing but varret. + * + * Need special handling for schema + * aten::items.str(Dict(str, t) self) -> (str,t)[] + * Even though this schema returns a single item, we need add parenthesis. + * The is necessary so the printed schema can be parsed by the C++ SchemaParser + * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly + * treat the return type as a tuple. An alternative is to enhance the Lexer + * to lookahead multiple tokens to accurately decide if the return type is + * a tuple. + */ + bool need_paren = !( + (returns.size() == 1 && !schema.is_varret()) || + (returns.empty() && schema.is_varret())); + + if (returns.size() == 1 && !schema.is_varret()) { + std::stringstream return_ss; + return_ss << returns.at(0); + auto return_str = return_ss.str(); + + // enclosing the single return item with parenthesis if the return type + // starts with a left parenthesis. + // + // There are 2 cases + // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'. + // without the extra parenthesis, the c++ schem parser can not parse it. + // 2. something like '-> ((str, str))'. Need extra parenthesis so the return + // type is a single tuple rather than two strings. + // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about + // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15) + // also covers this case. + if (!return_str.empty() && return_str.front() == '(') { + need_paren = true; + } + } + + if (need_paren) { + out << "("; + } + for (const auto i : c10::irange(returns.size())) { + if (i > 0) { + out << ", "; + } + out << returns.at(i); + } + if (schema.is_varret()) { + if (!returns.empty()) { + out << ", "; + } + out << "..."; + } + if (need_paren) { + out << ")"; + } + return out; +} + +static size_t findFirstOutArg(const std::vector& args) { + // find the start of out args in the schema + for (const auto out_start_idx : c10::irange(args.size())) { + if (args.at(out_start_idx).is_out()) { + return out_start_idx; + } + } + return args.size(); +} + +bool Argument::isBackwardCompatibleWith( + const Argument& old, + std::ostream* why_not) const { + const Argument* lhs = this; + const Argument* rhs = &old; + if (!(lhs->name() == rhs->name() + && lhs->N() == rhs->N() + && (lhs->alias_info() == rhs->alias_info() + || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr + && *lhs->alias_info() == *rhs->alias_info())))) { + return false; + } + if (lhs->kwarg_only() && !rhs->kwarg_only()) { + return false; + } + if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) { + return false; + } + if (rhs->default_value().has_value() && + lhs->default_value() != rhs->default_value()) { + return false; + } + return true; +} + +bool Argument::isForwardCompatibleWith( + const Argument& old, + std::ostream* why_not) const { + const Argument* lhs = this; + const Argument* rhs = &old; + if (!(lhs->name() == rhs->name() + && lhs->N() == rhs->N() + && (lhs->alias_info() == rhs->alias_info() + || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr + && *lhs->alias_info() == *rhs->alias_info())))) { + return false; + } + if (lhs->kwarg_only() && !rhs->kwarg_only()) { + return false; + } + if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) { + return false; + } + if (rhs->default_value().has_value() && + lhs->default_value() != rhs->default_value()) { + return false; + } + if (lhs->default_value().has_value() && !rhs->default_value().has_value()) { + return false; + } + return true; +} + +std::string FunctionSchema::formatTypeMismatchMsg( + const Argument& expected, + const std::string& actual_type, + std::optional position, + std::optional value) const { + std::string position_str; + if (position) { + position_str = c10::str("Position: ", *position, "\n"); + } + std::string value_str; + if (value) { + value_str = c10::str("Value: ", *value, "\n"); + } + return c10::str( + name(), + "() ", + expected.formatTypeMismatchMsg(actual_type), + position_str, + value_str, + "Declaration: ", + *this); +} + +bool FunctionSchema::isBackwardCompatibleWith( + const FunctionSchema& old, + std::ostream* why_not) const { + if (!(name() == old.name() + && overload_name() == old.overload_name() + // we are conservative on is_vararg and is_varret, + // since they are only used by internal operators + && is_vararg() == old.is_vararg() + && is_varret() == old.is_varret() + && returns().size() == old.returns().size() + && arguments().size() >= old.arguments().size())) { + return false; + } + for (const auto i : c10::irange(returns().size())) { + // Backwards compatibility requires covariance on argument types + // (i.e. more generic), and contravariance on return types (i.e. + // more specific). + if (!old.returns().at(i).isBackwardCompatibleWith( + returns().at(i), + why_not)) { + return false; + } + } + + // we want to test both out and default args separately + size_t old_out_start_idx = findFirstOutArg(old.arguments()); + size_t new_out_start_idx = findFirstOutArg(arguments()); + + // make sure among the default args, they are backward compatible + for (const auto i : c10::irange(old_out_start_idx)) { + if (!arguments().at(i).isBackwardCompatibleWith( + old.arguments().at(i), why_not)) { + return false; + } + } + + // Validate that all new arguments provided has a default value + for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) { + if (!arguments().at(i).default_value()) { + if (why_not) { + *why_not + << "Function schema not backward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() + << " did not provide a default value."; + } + return false; + } + } + + // now compare the out args + for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) { + if (!arguments() + .at(i - old_out_start_idx + new_out_start_idx) + .isBackwardCompatibleWith(old.arguments().at(i), why_not)) { + return false; + } + } + + return true; +} + +bool FunctionSchema::isForwardCompatibleWith( + const FunctionSchema& old, + std::ostringstream& why_not) const { + if (!(name() == old.name() && + overload_name() == old.overload_name() + // we are conservative on is_vararg and is_varret, + // since they are only used by internal operators + && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() && + returns().size() == old.returns().size())) { + return false; + } + + // we want to test both out and default args separately + size_t old_out_start_idx = findFirstOutArg(old.arguments()); + size_t new_out_start_idx = findFirstOutArg(arguments()); + + if (old.arguments().size() - old_out_start_idx != + arguments().size() - new_out_start_idx) { + if (why_not) { + why_not << "Function schema should have the " + << "same number of out arguments"; + } + return false; + } + + // make sure among the default args, they are forward compatible + for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) { + if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) { + if (why_not) { + why_not + << "'" << arguments().at(i).name() << "'" + << " is not forward compatible with the older version of the schema"; + } + return false; + } + } + + // Validate that all new arguments provided has a default value + for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) { + if (!arguments().at(i).default_value()) { + if (why_not) { + why_not + << "Function schema is not forward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() + << " did not provide a default value."; + } + return false; + } + + auto default_val = arguments().at(i).default_value().value(); + if (default_val.isList() || default_val.isGenericDict()) { + if (why_not) { + why_not + << "Function schema is not forward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() << " has a container type " + << "as its default value."; + } + return false; + } + } + + // now compare the out args + for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) { + if (!arguments() + .at(i - old_out_start_idx + new_out_start_idx) + .isForwardCompatibleWith(old.arguments().at(i))) { + if (why_not) { + why_not << "Out argument '" + << "'" << arguments().at(i).name() + << " is not FC with the older version of the schema"; + } + return false; + } + } + + return true; +} + +std::string FunctionSchema::findErrorInKwargs(const std::vector& kwargs) const { + // First check if any of the kwargs are unknown, i.e. don't match the name of + // any argument in the schema. + for (const auto& kwarg : kwargs) { + if (!std::count_if( + arguments().begin(), + arguments().end(), + [&kwarg](const Argument& argument) { + return argument.name() == kwarg; + })) { + return c10::str( + "Unknown keyword argument '", + kwarg, + "' for operator '", + name(), + "'. Schema: ", + *this); + } + } + // If there are unconsumed kwargs but none of them were unknown, the first + // positional argument present in the kwargs is duplicated. + for (const auto& argument : arguments()) { + if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) { + AT_ASSERT(!argument.default_value()); + return c10::str( + "Argument '", + argument.name(), + "' specified both as positional and ", + "keyword argument. Schema: ", + *this); + } + } + return ""; +} + + +FunctionSchema FunctionSchema::cloneWithRemappedTypes( + const std::function type_map) const { + auto update_args = [&](const std::vector& args) { + std::vector new_args; + new_args.reserve(args.size()); + for(const Argument& arg : args) { + new_args.emplace_back(arg.cloneWithType(type_map(arg.type()))); + } + return new_args; + }; + return FunctionSchema( + name(), + overload_name(), + update_args(arguments()), + update_args(returns()), + is_vararg(), + is_varret()); +} + +// covariant subtyping of list of Arguments +static bool isSubtypeOfList( + ArrayRef child, + ArrayRef parent, + std::ostream* why_not) { + if (child.size() != parent.size()) { + return false; + } + for (const auto i : c10::irange(child.size())) { + const Argument& c = child[i]; + const Argument& p = parent[i]; + if (c.name() != p.name()) { + return false; + } + if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) { + return false; + } + } + return true; +} + +bool FunctionSchema::isSubtypeOf( + const FunctionSchema& rhs, + bool as_method, + std::ostream* why_not) const { + size_t start = as_method ? 1 : 0; + // functions are contravariant in arguments but covariant in returns + return isSubtypeOfList( + ArrayRef(rhs.arguments()).slice(start), + ArrayRef(arguments()).slice(start), + why_not) && + isSubtypeOfList(returns(), rhs.returns(), why_not); +} + } // namespace c10 diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index de5247e0eaabca..8dab896b1411d7 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -25,7 +25,7 @@ using AliasTypeSet = std::vector; bool operator==(const Argument& lhs, const Argument& rhs); -struct Argument { +struct TORCH_API Argument { Argument( std::string name = "", const TypePtr& type = nullptr, @@ -622,7 +622,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { return out; } -inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema); +TORCH_API std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema); inline std::string toString(const FunctionSchema& schema) { std::ostringstream str; diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index e81103ea928a17..a2fff1c130cb57 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -2,328 +2,8 @@ #include #include -// note: windows build doesn't find symbols in operator files unless -// this is a header file - namespace c10 { -inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { - // eventually this should look almost identical to python arg parser, but - // it is simpler for now to work directly on this schema - - out << schema.name(); - if (!schema.overload_name().empty()) { - out << "." << schema.overload_name(); - } - out << "("; - - bool seen_kwarg_only = false; - for (const auto i : c10::irange(schema.arguments().size())) { - if (i > 0) out << ", "; - if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) { - out << "*, "; - seen_kwarg_only = true; - } - out << schema.arguments()[i]; - } - - if(schema.is_vararg()) { - if(!schema.arguments().empty()) - out << ", "; - out << "..."; - } - - out << ") -> "; - - const auto& returns = schema.returns(); - - /* - * We should skip parenthesis if we return a single item and it's not varret, - * or we return nothing but varret. - * - * Need special handling for schema - * aten::items.str(Dict(str, t) self) -> (str,t)[] - * Even though this schema returns a single item, we need add parenthesis. - * The is necessary so the printed schema can be parsed by the C++ SchemaParser - * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly - * treat the return type as a tuple. An alternative is to enhance the Lexer - * to lookahead multiple tokens to accurately decide if the return type is - * a tuple. - */ - bool need_paren = !( - (returns.size() == 1 && !schema.is_varret()) || - (returns.empty() && schema.is_varret())); - - if (returns.size() == 1 && !schema.is_varret()) { - std::stringstream return_ss; - return_ss << returns.at(0); - auto return_str = return_ss.str(); - - // enclosing the single return item with parenthesis if the return type - // starts with a left parenthesis. - // - // There are 2 cases - // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'. - // without the extra parenthesis, the c++ schem parser can not parse it. - // 2. something like '-> ((str, str))'. Need extra parenthesis so the return - // type is a single tuple rather than two strings. - // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about - // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15) - // also covers this case. - if (!return_str.empty() && return_str.front() == '(') { - need_paren = true; - } - } - - if (need_paren) { - out << "("; - } - for (const auto i : c10::irange(returns.size())) { - if (i > 0) { - out << ", "; - } - out << returns.at(i); - } - if (schema.is_varret()) { - if (!returns.empty()) { - out << ", "; - } - out << "..."; - } - if (need_paren) { - out << ")"; - } - return out; -} - -inline size_t findFirstOutArg(const std::vector& args) { - // find the start of out args in the schema - for (const auto out_start_idx : c10::irange(args.size())) { - if (args.at(out_start_idx).is_out()) { - return out_start_idx; - } - } - return args.size(); -} - -inline bool Argument::isBackwardCompatibleWith( - const Argument& old, - std::ostream* why_not) const { - const Argument* lhs = this; - const Argument* rhs = &old; - if (!(lhs->name() == rhs->name() - && lhs->N() == rhs->N() - && (lhs->alias_info() == rhs->alias_info() - || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr - && *lhs->alias_info() == *rhs->alias_info())))) { - return false; - } - if (lhs->kwarg_only() && !rhs->kwarg_only()) { - return false; - } - if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) { - return false; - } - if (rhs->default_value().has_value() && - lhs->default_value() != rhs->default_value()) { - return false; - } - return true; -} - -inline bool Argument::isForwardCompatibleWith( - const Argument& old, - std::ostream* why_not) const { - const Argument* lhs = this; - const Argument* rhs = &old; - if (!(lhs->name() == rhs->name() - && lhs->N() == rhs->N() - && (lhs->alias_info() == rhs->alias_info() - || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr - && *lhs->alias_info() == *rhs->alias_info())))) { - return false; - } - if (lhs->kwarg_only() && !rhs->kwarg_only()) { - return false; - } - if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) { - return false; - } - if (rhs->default_value().has_value() && - lhs->default_value() != rhs->default_value()) { - return false; - } - if (lhs->default_value().has_value() && !rhs->default_value().has_value()) { - return false; - } - return true; -} - -inline std::string FunctionSchema::formatTypeMismatchMsg( - const Argument& expected, - const std::string& actual_type, - std::optional position, - std::optional value) const { - std::string position_str; - if (position) { - position_str = c10::str("Position: ", *position, "\n"); - } - std::string value_str; - if (value) { - value_str = c10::str("Value: ", *value, "\n"); - } - return c10::str( - name(), - "() ", - expected.formatTypeMismatchMsg(actual_type), - position_str, - value_str, - "Declaration: ", - *this); -} - -inline bool FunctionSchema::isBackwardCompatibleWith( - const FunctionSchema& old, - std::ostream* why_not) const { - if (!(name() == old.name() - && overload_name() == old.overload_name() - // we are conservative on is_vararg and is_varret, - // since they are only used by internal operators - && is_vararg() == old.is_vararg() - && is_varret() == old.is_varret() - && returns().size() == old.returns().size() - && arguments().size() >= old.arguments().size())) { - return false; - } - for (const auto i : c10::irange(returns().size())) { - // Backwards compatibility requires covariance on argument types - // (i.e. more generic), and contravariance on return types (i.e. - // more specific). - if (!old.returns().at(i).isBackwardCompatibleWith( - returns().at(i), - why_not)) { - return false; - } - } - - // we want to test both out and default args separately - size_t old_out_start_idx = findFirstOutArg(old.arguments()); - size_t new_out_start_idx = findFirstOutArg(arguments()); - - // make sure among the default args, they are backward compatible - for (const auto i : c10::irange(old_out_start_idx)) { - if (!arguments().at(i).isBackwardCompatibleWith( - old.arguments().at(i), why_not)) { - return false; - } - } - - // Validate that all new arguments provided has a default value - for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) { - if (!arguments().at(i).default_value()) { - if (why_not) { - *why_not - << "Function schema not backward compatible since the new argument '" - << arguments().at(i).name() << "' of type " - << arguments().at(i).type()->str() - << " did not provide a default value."; - } - return false; - } - } - - // now compare the out args - for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) { - if (!arguments() - .at(i - old_out_start_idx + new_out_start_idx) - .isBackwardCompatibleWith(old.arguments().at(i), why_not)) { - return false; - } - } - - return true; -} - -inline bool FunctionSchema::isForwardCompatibleWith( - const FunctionSchema& old, - std::ostringstream& why_not) const { - if (!(name() == old.name() && - overload_name() == old.overload_name() - // we are conservative on is_vararg and is_varret, - // since they are only used by internal operators - && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() && - returns().size() == old.returns().size())) { - return false; - } - - // we want to test both out and default args separately - size_t old_out_start_idx = findFirstOutArg(old.arguments()); - size_t new_out_start_idx = findFirstOutArg(arguments()); - - if (old.arguments().size() - old_out_start_idx != - arguments().size() - new_out_start_idx) { - if (why_not) { - why_not << "Function schema should have the " - << "same number of out arguments"; - } - return false; - } - - // make sure among the default args, they are forward compatible - for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) { - if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) { - if (why_not) { - why_not - << "'" << arguments().at(i).name() << "'" - << " is not forward compatible with the older version of the schema"; - } - return false; - } - } - - // Validate that all new arguments provided has a default value - for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) { - if (!arguments().at(i).default_value()) { - if (why_not) { - why_not - << "Function schema is not forward compatible since the new argument '" - << arguments().at(i).name() << "' of type " - << arguments().at(i).type()->str() - << " did not provide a default value."; - } - return false; - } - - auto default_val = arguments().at(i).default_value().value(); - if (default_val.isList() || default_val.isGenericDict()) { - if (why_not) { - why_not - << "Function schema is not forward compatible since the new argument '" - << arguments().at(i).name() << "' of type " - << arguments().at(i).type()->str() << " has a container type " - << "as its default value."; - } - return false; - } - } - - // now compare the out args - for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) { - if (!arguments() - .at(i - old_out_start_idx + new_out_start_idx) - .isForwardCompatibleWith(old.arguments().at(i))) { - if (why_not) { - why_not << "Out argument '" - << "'" << arguments().at(i).name() - << " is not FC with the older version of the schema"; - } - return false; - } - } - - return true; -} - template inline void FunctionSchema::checkArg( const IValue& value, @@ -341,41 +21,6 @@ inline void FunctionSchema::checkArg( } } -inline std::string FunctionSchema::findErrorInKwargs(const std::vector& kwargs) const { - // First check if any of the kwargs are unknown, i.e. don't match the name of - // any argument in the schema. - for (const auto& kwarg : kwargs) { - if (!std::count_if( - arguments().begin(), - arguments().end(), - [&kwarg](const Argument& argument) { - return argument.name() == kwarg; - })) { - return c10::str( - "Unknown keyword argument '", - kwarg, - "' for operator '", - name(), - "'. Schema: ", - *this); - } - } - // If there are unconsumed kwargs but none of them were unknown, the first - // positional argument present in the kwargs is duplicated. - for (const auto& argument : arguments()) { - if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) { - AT_ASSERT(!argument.default_value()); - return c10::str( - "Argument '", - argument.name(), - "' specified both as positional and ", - "keyword argument. Schema: ", - *this); - } - } - return ""; -} - template inline void FunctionSchema::checkAndNormalizeInputs( std::vector& inputs, @@ -427,57 +72,4 @@ inline void FunctionSchema::checkAndNormalizeInputs( } } -inline FunctionSchema FunctionSchema::cloneWithRemappedTypes( - const std::function type_map) const { - auto update_args = [&](const std::vector& args) { - std::vector new_args; - new_args.reserve(args.size()); - for(const Argument& arg : args) { - new_args.emplace_back(arg.cloneWithType(type_map(arg.type()))); - } - return new_args; - }; - return FunctionSchema( - name(), - overload_name(), - update_args(arguments()), - update_args(returns()), - is_vararg(), - is_varret()); -} - -// covariant subtyping of list of Arguments -inline bool isSubtypeOfList( - ArrayRef child, - ArrayRef parent, - std::ostream* why_not) { - if (child.size() != parent.size()) { - return false; - } - for (const auto i : c10::irange(child.size())) { - const Argument& c = child[i]; - const Argument& p = parent[i]; - if (c.name() != p.name()) { - return false; - } - if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) { - return false; - } - } - return true; -} - -inline bool FunctionSchema::isSubtypeOf( - const FunctionSchema& rhs, - bool as_method, - std::ostream* why_not) const { - size_t start = as_method ? 1 : 0; - // functions are contravariant in arguments but covariant in returns - return isSubtypeOfList( - ArrayRef(rhs.arguments()).slice(start), - ArrayRef(arguments()).slice(start), - why_not) && - isSubtypeOfList(returns(), rhs.returns(), why_not); -} - } // namespace c10 From 847157adb67dff7da69183bc13dcf92d58728edc Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 1 Sep 2024 12:36:16 -0700 Subject: [PATCH 0330/1018] Don't setup try-except handler when Dynamo compiling (#133239) The reraise is not supported and so this just gunks up our actual exception handling. You can trigger this by hitting an exception inside of an NN module that has hooks on it. You end up graph breaking on the reraise here, and losing the inner stack trace from the actual exception that was raised. This might be kind of controversial. An alternate strategy is to support reraises in Dynamo or something but IDK this doesn't feel like the right place to apply force. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/133239 Approved by: https://github.com/anijain2305 --- test/dynamo/test_exceptions.py | 16 ++++++++++++++++ torch/nn/modules/module.py | 22 +++++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 5efa344c7ea2ed..d6613d84560af0 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -4,6 +4,7 @@ import torch._dynamo.config import torch._dynamo.test_case import torch._functorch.config +import torch.nn import torch.utils.checkpoint @@ -317,6 +318,21 @@ def fn(x, y): res = opt_fn(x, y) self.assertEqual(ref, res) + def test_nn_reraise(self): + class M(torch.nn.Module): + def forward(self, x): + raise ValueError("woof") + return x + 2 + + m = M() + m.register_forward_pre_hook(lambda m, go: None) + + torch._dynamo.utils.clear_compilation_metrics() + opt_call = torch.compile(lambda x: m(x), backend="eager") + self.assertRaises(ValueError, lambda: opt_call(torch.randn(3))) + metrics = torch._dynamo.utils.get_compilation_metrics() + self.assertEqual(metrics[0].fail_reason, "Observed exception") + def test_key_error(self): def fn(x, d): try: diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index a15850553f1f56..f4796e50e415a1 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1746,9 +1746,11 @@ def _call_impl(self, *args, **kwargs): or _global_forward_hooks or _global_forward_pre_hooks): return forward_call(*args, **kwargs) - try: - result = None - called_always_called_hooks = set() + result = None + called_always_called_hooks = set() + + def inner(): + nonlocal result, args, kwargs full_backward_hooks, non_full_backward_hooks = [], [] backward_pre_hooks = [] @@ -1826,6 +1828,20 @@ def _call_impl(self, *args, **kwargs): return result + from torch.compiler import is_compiling + + # This is technically not behavior equivalent when compiling, but it's + # incredibly unlikely we will ever support throwing an exception in NN + # module, and then catching it here, and then reraising it, and then + # catching it again, and expecting the resulting frame to be compiled. + # The reraise here just gunks up our exception handling for no good + # reason. Don't try to run the always called hooks in event of + # exception. + if is_compiling(): + return inner() + + try: + return inner() except Exception: # run always called hooks if they have not already been run # For now only forward hooks have the always_call option but perhaps From 07e045eb663173a92e3320957017c510cdb9c8df Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 1 Sep 2024 12:41:13 -0700 Subject: [PATCH 0331/1018] Stop adding useless prefix to error message here, you're pushing the important info off the screen. (#133108) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/133108 Approved by: https://github.com/Skylion007 --- test/export/test_export.py | 5 +---- torch/_dynamo/utils.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index a7a9b2c5a31c8e..e93a721eb97505 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -4362,12 +4362,9 @@ def forward(self, x): f = Module() if is_non_strict_test(self._testMethodName): error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode - error_msg = r"Could not guard on data-dependent expression" else: error = torch._dynamo.exc.UserError - error_msg = ( - r"Tried to use data-dependent value in the subsequent computation" - ) + error_msg = r"Could not guard on data-dependent expression" with self.assertRaisesRegex(error, error_msg): _ = export(f, (torch.tensor(6),)) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index fa4103d9c3211c..2b9d4b96081471 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2163,10 +2163,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): ): raise UserError( # noqa: B904 UserErrorType.CONSTRAINT_VIOLATION, - "Tried to use data-dependent value in the subsequent computation. " - "This can happen when we encounter unbounded dynamic value that is unknown during tracing time. " - "You will need to explicitly give hint to the compiler. Please take a look at " - f"torch._check OR torch._check_is_size APIs. {cause}", + str(cause), case_name="constrain_as_size_example", ) elif isinstance(cause, ValueRangeError): From ce10395b5552eddef0d4b57ea47f04acfe145310 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Mon, 2 Sep 2024 00:27:40 +0000 Subject: [PATCH 0332/1018] Reorg cache code to make it simpler (#134911) Summary: Pull the big nested function out of the middle of cached_autotune() into its own class. Also refactor creating the autotune cache itself out - which gets shared in the next diff. Test Plan: unit tests Differential Revision: D60677501 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134911 Approved by: https://github.com/oulgen --- torch/_inductor/runtime/autotune_cache.py | 237 +++++++++++++++++++ torch/_inductor/runtime/triton_heuristics.py | 169 ++----------- 2 files changed, 253 insertions(+), 153 deletions(-) create mode 100644 torch/_inductor/runtime/autotune_cache.py diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py new file mode 100644 index 00000000000000..65dfc73d63d72d --- /dev/null +++ b/torch/_inductor/runtime/autotune_cache.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import dataclasses +import hashlib +import logging +import os +import os.path +from typing import Dict, List, Optional, Tuple +from typing_extensions import override + +import torch +from torch.utils._triton import has_triton_package + +from ..remote_cache import ( + JsonDataTy, + RemoteCache, + RemoteCacheBackend, + RemoteCacheJsonSerde, +) + + +if has_triton_package(): + from triton import Config + +log = logging.getLogger(__name__) + + +_InductorMetaTy = Dict[str, object] + + +@dataclasses.dataclass +class AutotuneCache: + configs_hash: str + filename: str + local_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None + remote_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None + + # Create a AutotuneCache. Returns None if none of the caches can be used. + @staticmethod + def create( + inductor_meta: _InductorMetaTy, filename: str, configs_hash: str + ) -> Optional[AutotuneCache]: + cache = AutotuneCache(configs_hash, filename) + cache._setup_local_cache(inductor_meta, filename) + cache._setup_remote_autotune_cache(inductor_meta, filename) + if cache.local_cache or cache.remote_cache: + return cache + else: + return None + + # Read the best config options from the most local cache and return it. + def _read(self, inductor_meta: _InductorMetaTy) -> Optional[Dict[str, JsonDataTy]]: + if local_cache := self.local_cache: + cache, key = local_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + if remote_cache := self.remote_cache: + cache, key = remote_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + return None + + # Read the best config options from the most local cache and figure out + # which `configs` represents that option. + def read_best( + self, inductor_meta: _InductorMetaTy, configs: List[Config] + ) -> Optional[Config]: + if best := self._read(inductor_meta): + return _load_cached_autotuning( + best, self.configs_hash, configs, inductor_meta + ) + return None + + # Set up local filesystem caching information + def _setup_local_cache(self, inductor_meta: _InductorMetaTy, filename: str) -> None: + if not inductor_meta.get("autotune_local_cache", True): + return + + cache_filename = os.path.splitext(filename)[0] + ".best_config" + local_cache = RemoteCache(_LocalAutotuneCacheBackend(), RemoteCacheJsonSerde()) + self.local_cache = (local_cache, cache_filename) + + # Set up remote caching information + def _setup_remote_autotune_cache( + self, inductor_meta: _InductorMetaTy, filename: str + ) -> None: + if not _should_use_remote_autotune_cache(inductor_meta): + return + + remote_cache = _create_cache( + inductor_meta, + self.configs_hash, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + "autotune-best-config-v2", + ) + if not remote_cache: + return + + # we already sha256 hash the source contents + remote_cache_key = os.path.basename(filename) + self.remote_cache = (remote_cache, remote_cache_key) + + # Save the config in the caches + def save( + self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False + ) -> None: + data = { + **config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + "configs_hash": self.configs_hash, + "found_by_coordesc": found_by_coordesc, + "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS + } + + if local_cache := self.local_cache: + cache, key = local_cache + cache.put(key, data) + + if log.isEnabledFor(logging.DEBUG): + type_str = "coordesc" if found_by_coordesc else "heuristic" + log.debug("Save %s tuning result to %s", type_str, key) + + if remote_cache := self.remote_cache: + cache, key = remote_cache + cache.put(key, data) + + +def _should_use_remote_autotune_cache(inductor_meta: Dict[str, object]) -> bool: + if (config := inductor_meta.get("autotune_remote_cache")) is not None: + return bool(config) + if not inductor_meta.get("is_fbcode"): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:autotune_memcache_version" + ) + + +def _load_cached_autotuning( + best_config: Dict[str, JsonDataTy], + configs_hash: str, + configs: List[Config], + inductor_meta: Dict[str, object], +) -> Optional[Config]: + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + + if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( + "found_by_coordesc", False + ): + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) + triton_config.found_by_coordesc = True + return triton_config + + matching_configs = [ + cfg + for cfg in configs + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + and cfg.num_warps == best_config.get("num_warps") + and cfg.num_stages == best_config.get("num_stages") + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +def _create_cache( + inductor_meta: Dict[str, object], + configs_hash: str, + fb_cache_cls: str, + oss_cache_cls: str, + salt: str, +) -> Optional[RemoteCache[JsonDataTy]]: + backend_hash = inductor_meta.get("backend_hash", None) + if backend_hash is None: + log.debug( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + return None + + assert isinstance(backend_hash, str) + + key = backend_hash + configs_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + try: + if inductor_meta.get("is_fbcode"): + import torch._inductor.fb.remote_cache + + cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls) + return cache_cls(key) + else: + import torch._inductor.remote_cache + + cache_cls = getattr(torch._inductor.remote_cache, oss_cache_cls) + return cache_cls(key) + except Exception: + log.warning("Unable to create a remote cache", exc_info=True) + return None + + +class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]): + @override + def get(self, key: str) -> Optional[bytes]: + try: + with open(key, "rb") as fd: + return fd.read() + except FileNotFoundError: + return None + + @override + def put(self, key: str, data: bytes) -> None: + with open(key, "wb") as fd: + fd.write(data) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 288a44c6296961..c5a6bd69fd7263 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -6,7 +6,6 @@ import functools import hashlib import inspect -import json import logging import math import operator @@ -16,10 +15,11 @@ import sys import threading import time -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Set, Tuple import torch +from .autotune_cache import AutotuneCache from .benchmarking import benchmarker from .coordinate_descent_tuner import CoordescTuner from .hints import ( @@ -46,10 +46,6 @@ ) -if TYPE_CHECKING: - from ..remote_cache import JsonDataTy, RemoteCache - - try: import triton except ImportError: @@ -1004,74 +1000,6 @@ def hash_configs(configs: List[Config]): return hasher.hexdigest() -def load_cached_autotuning( - best_config, - configs_hash: str, - configs: List[Config], - inductor_meta: Dict[str, Any], -): - if best_config is None: - return None - if best_config.pop("configs_hash", None) != configs_hash: - return None - - # Remove time taken for comparison - best_config.pop("time_taken_ms", None) - - if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( - "found_by_coordesc", False - ): - num_warps = best_config.pop("num_warps") - num_stages = best_config.pop("num_stages") - triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) - triton_config.found_by_coordesc = True - return triton_config - - matching_configs = [ - cfg - for cfg in configs - if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) - and cfg.num_warps == best_config.get("num_warps") - and cfg.num_stages == best_config.get("num_stages") - ] - if len(matching_configs) != 1: - return None - - return matching_configs[0] - - -def _should_use_remote_autotune_cache(inductor_meta): - if inductor_meta.get("autotune_remote_cache") is not None: - return inductor_meta.get("autotune_remote_cache") - if not inductor_meta.get("is_fbcode"): - return False - if torch._utils_internal.is_fb_unit_test(): - return False - if inductor_meta.get("is_hip"): - return False - - try: - from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION - except ModuleNotFoundError: - return False - - return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( - "pytorch/remote_cache:autotune_memcache_version" - ) - - -class LocalAutotuneCache: - def get(self, filename): - if os.path.exists(filename): - with open(filename) as fd: - return json.loads(fd.read()) - return None - - def put(self, filename, data): - with open(filename, "w") as fd: - fd.write(json.dumps(data)) - - def cached_autotune( size_hints: Optional[List[int]], configs: List[Config], @@ -1087,92 +1015,27 @@ def cached_autotune( """ configs = unique_configs(configs) assert len(configs) == 1 or filename - save_cache_hook: Optional[Callable[[Any, Any, Any], Any]] inductor_meta = {} if inductor_meta is None else inductor_meta + disabled = inductor_meta.get("force_disable_caches", False) + # on disk caching logic and/or remote caching - if filename is not None and ( - len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning") + autotune_cache = None + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) ): configs_hash = hash_configs(configs) - local_cache = None - cache_filename = None - remote_cache: Optional[ - Union[RemoteCache[JsonDataTy], RemoteCache[object]] - ] = None - remote_cache_key = None - best_config = None - if not inductor_meta.get("force_disable_caches", False): - if inductor_meta.get("autotune_local_cache", True): - local_cache = LocalAutotuneCache() - cache_filename = os.path.splitext(filename)[0] + ".best_config" - if _should_use_remote_autotune_cache(inductor_meta): - backend_hash = inductor_meta.get("backend_hash", None) - if backend_hash is not None: - key = backend_hash + configs_hash + "autotune-best-config-v2" - key = hashlib.sha256(key.encode("utf-8")).hexdigest() - - try: - if inductor_meta.get("is_fbcode"): - from torch._inductor.fb.remote_cache import ( - FbRemoteAutotuneCache, - ) - - remote_cache = FbRemoteAutotuneCache(key) - else: - from torch._inductor.remote_cache import RemoteAutotuneCache - - remote_cache = RemoteAutotuneCache(key) - except Exception: - remote_cache = None - log.warning("Unable to create a remote cache", exc_info=True) - # we already sha256 hash the source contents - remote_cache_key = os.path.basename(filename) - else: - log.debug( - "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" - ) - - best_config = None - if local_cache is not None and cache_filename is not None: - best_config = local_cache.get(cache_filename) - if ( - remote_cache is not None - and remote_cache_key is not None - and best_config is None - ): - best_config = remote_cache.get(remote_cache_key) - - best_config = load_cached_autotuning( - best_config, configs_hash, configs, inductor_meta - ) - if best_config: + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + if best_config := autotune_cache.read_best(inductor_meta, configs): configs = [best_config] - else: - log.debug("autotune caching is disabled by config.force_disable_caches") - - def save_cache_hook(cfg, time_taken_ns, found_by_coordesc=False): - data = { - **cfg.kwargs, - "num_warps": cfg.num_warps, - "num_stages": cfg.num_stages, - "configs_hash": configs_hash, - "found_by_coordesc": found_by_coordesc, - "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS - } - if local_cache is not None and cache_filename is not None: - local_cache.put(cache_filename, data) - if remote_cache is not None and remote_cache_key is not None: - remote_cache.put(remote_cache_key, data) # type: ignore[arg-type] - - if log.isEnabledFor(logging.DEBUG): - type_str = "coordesc" if found_by_coordesc else "heuristic" - log.debug("Save %s tuning result to %s", type_str, cache_filename) - else: - save_cache_hook = None + if disabled: + log.debug("autotune caching is disabled by config.force_disable_caches") mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) @@ -1199,7 +1062,7 @@ def decorator(fn): "profile_bandwidth_with_do_bench_using_profiling" ], configs=configs, - save_cache_hook=save_cache_hook, + save_cache_hook=autotune_cache and autotune_cache.save, mutated_arg_names=mutated_arg_names, heuristic_type=heuristic_type, size_hints=size_hints, @@ -1211,7 +1074,7 @@ def decorator(fn): triton_meta=triton_meta, inductor_meta=inductor_meta, configs=configs, - save_cache_hook=save_cache_hook, + save_cache_hook=autotune_cache and autotune_cache.save, mutated_arg_names=mutated_arg_names, heuristic_type=heuristic_type, size_hints=size_hints, From 1308f3738dbb15cda68314d325584b10c10f81ba Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Fri, 30 Aug 2024 19:46:46 -0700 Subject: [PATCH 0333/1018] [Inductor] Apply loop split optimization in codegen_node (#132389) This PR applies loop split optimization in codegen_node to avoid non-contiguous load. When the vector is loaded in a non-contiguous manner due to a division in the index, we eliminate the division by splitting the loop to avoid non-contiguous load. Example: ``` import torch import torch.nn as nn class GNReLU(torch.nn.Module): def __init__(self, num_groups, num_channels): super(GNReLU, self).__init__() self.gn = nn.GroupNorm(num_groups, num_channels) def forward(self, x): return torch.nn.functional.relu(self.gn(x)) input = torch.randn(2, 960, 96, 96).to(memory_format=torch.channels_last) m = GNReLU(32, 960).eval() compiled_m = torch.compile(m) with torch.no_grad(): compiled_m(input) ``` Generated code: - Before: ``` cpp_fused_native_group_norm_relu_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], ''' #include "/tmp/torchinductor_jiayisun/vu/cvuckxaygqfovv2zu2byqhcmiejbke7mdhf2rpgpr5mlscdev2hg.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(56) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast(0L); x0(2L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(32L); x1+=static_cast(1L)) { { Welford tmp_acc0 = Welford(); Welford> tmp_acc0_vec = Welford>(); Welford> masked_tmp_acc0_vec = Welford>(); static WeightRecp> wrecps0(static_cast(17280L)); for(long x2=static_cast(0L); x2(9216L); x2+=static_cast(1L)) { for(long x3=static_cast(0L); x3(16L); x3+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0); } for(long x3=static_cast(16L); x3(30L); x3+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14); masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0); } } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec)); tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.mean); out_ptr1[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.m2); } } } } { #pragma omp for collapse(2) for(long x0=static_cast(0L); x0(2L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(9216L); x1+=static_cast(1L)) { for(long x2=static_cast(0L); x2(960L); x2+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x2 + (960L*x1) + (8847360L*x0)), 16); auto tmp1 = [&] { __at_align__ std::array tmpbuf; #pragma GCC unroll 16 for (long x2_inner = 0; x2_inner < 16; x2_inner++) { tmpbuf[x2_inner] = out_ptr0[static_cast((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))]; } return at::vec::Vectorized::loadu(tmpbuf.data(), 16); } () ; auto tmp3 = [&] { __at_align__ std::array tmpbuf; #pragma GCC unroll 16 for (long x2_inner = 0; x2_inner < 16; x2_inner++) { tmpbuf[x2_inner] = out_ptr1[static_cast((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))]; } return at::vec::Vectorized::loadu(tmpbuf.data(), 16); } () ; auto tmp12 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x2), 16); auto tmp14 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x2), 16); auto tmp2 = tmp0 - tmp1; auto tmp4 = static_cast(276480.0); auto tmp5 = at::vec::Vectorized(tmp4); auto tmp6 = tmp3 / tmp5; auto tmp7 = static_cast(1e-05); auto tmp8 = at::vec::Vectorized(tmp7); auto tmp9 = tmp6 + tmp8; auto tmp10 = tmp9.rsqrt(); auto tmp11 = tmp2 * tmp10; auto tmp13 = tmp11 * tmp12; auto tmp15 = tmp13 + tmp14; auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0)); tmp16.store(out_ptr2 + static_cast(x2 + (960L*x1) + (8847360L*x0))); } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg2_1, = args args.clear() assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960)) buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32) buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32) buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32) cpp_fused_native_group_norm_relu_0(arg2_1, _frozen_param3, _frozen_param2, buf0, buf1, buf3) del arg2_1 return (buf3, ) ``` - After: ``` cpp_fused_native_group_norm_relu_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], ''' #include "/tmp/torchinductor_jiayisun/vu/cvuckxaygqfovv2zu2byqhcmiejbke7mdhf2rpgpr5mlscdev2hg.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(56) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast(0L); x0(2L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(32L); x1+=static_cast(1L)) { { Welford tmp_acc0 = Welford(); Welford> tmp_acc0_vec = Welford>(); Welford> masked_tmp_acc0_vec = Welford>(); static WeightRecp> wrecps0(static_cast(17280L)); for(long x2=static_cast(0L); x2(9216L); x2+=static_cast(1L)) { for(long x3=static_cast(0L); x3(16L); x3+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0); } for(long x3=static_cast(16L); x3(30L); x3+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14); masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0); } } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec)); tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.mean); out_ptr1[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.m2); } } } } { #pragma omp for collapse(2) for(long x0=static_cast(0L); x0(2L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(9216L); x1+=static_cast(1L)) { #pragma GCC ivdep for(long x2=static_cast(0L); x2(32L); x2+=static_cast(1L)) { for(long x3=static_cast(0L); x3(16L); x3+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 16); auto tmp1 = out_ptr0[static_cast(x2 + (32L*x0))]; auto tmp4 = out_ptr1[static_cast(x2 + (32L*x0))]; auto tmp12 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x3 + (30L*x2)), 16); auto tmp14 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x3 + (30L*x2)), 16); auto tmp2 = at::vec::Vectorized(tmp1); auto tmp3 = tmp0 - tmp2; auto tmp5 = static_cast(276480.0); auto tmp6 = tmp4 / tmp5; auto tmp7 = static_cast(1e-05); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); auto tmp9 = 1 / std::sqrt(tmp8); auto tmp10 = at::vec::Vectorized(tmp9); auto tmp11 = tmp3 * tmp10; auto tmp13 = tmp11 * tmp12; auto tmp15 = tmp13 + tmp14; auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0)); tmp16.store(out_ptr2 + static_cast(x3 + (30L*x2) + (960L*x1) + (8847360L*x0))); } for(long x3=static_cast(16L); x3(30L); x3+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 14); auto tmp1 = out_ptr0[static_cast(x2 + (32L*x0))]; auto tmp4 = out_ptr1[static_cast(x2 + (32L*x0))]; auto tmp12 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x3 + (30L*x2)), 14); auto tmp14 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x3 + (30L*x2)), 14); auto tmp2 = at::vec::Vectorized(tmp1); auto tmp3 = tmp0 - tmp2; auto tmp5 = static_cast(276480.0); auto tmp6 = tmp4 / tmp5; auto tmp7 = static_cast(1e-05); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); auto tmp9 = 1 / std::sqrt(tmp8); auto tmp10 = at::vec::Vectorized(tmp9); auto tmp11 = tmp3 * tmp10; auto tmp13 = tmp11 * tmp12; auto tmp15 = tmp13 + tmp14; auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0)); tmp16.store(out_ptr2 + static_cast(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 14); } } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg2_1, = args args.clear() assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960)) buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32) buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32) buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32) cpp_fused_native_group_norm_relu_0(arg2_1, _frozen_param3, _frozen_param2, buf0, buf1, buf3) del arg2_1 return (buf3, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/132389 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel Co-authored-by: Jiong Gong --- test/inductor/test_cpu_repro.py | 12 ++++ torch/_inductor/codegen/cpp.py | 116 +++++++++++++++++++++++++++++++- torch/_inductor/ir.py | 13 ++++ torch/_inductor/scheduler.py | 13 +++- 4 files changed, 148 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 7dbe35f766f9a7..d1f5a82cdf6c74 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3435,6 +3435,18 @@ def forward(self, x): # 2 generated kernels (one for var_mean, the other for result) check_metrics_vec_kernel_count(2) + # check loop split optimization + if fmt == torch.channels_last: + torch._dynamo.reset() + metrics.reset() + with torch.no_grad(): + opt_mod = torch.compile(mod) + _, code = run_and_get_cpp_code(opt_mod, x) + # check that there are no non_contiguous loads + FileCheck().check_count("__at_align__ std::array", 0, exactly=True).run( + code + ) + def test_int_div_vec(self): def fn(x, y, mode): return torch.div(x, y, rounding_mode=mode) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 6663de7639e8e1..999df248fc6110 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -313,9 +313,12 @@ def visit_modular_indexing(divisor, modulus): @functools.lru_cache -def stride_at_vec_range(index: sympy.Expr, var: sympy.Symbol, vec_length: int): - index_vec_simplified = simplify_index_in_vec_range(index, var, vec_length) - return stride_at(index_vec_simplified, var) +def stride_at_vec_range( + index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None +): + if vec_length: + index = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index, var) class OuterLoopFusedSchedulerNode(FusedSchedulerNode): @@ -4085,6 +4088,112 @@ def can_fuse_vertical(self, node1, node2): self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) + def try_loop_split(self, nodes: List[SchedulerNode]): + """ + Apply loop split optimization. + When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop + to avoid non-contiguous loads, subject to the following conditions: + 1. No reduction and no mudular index for all nodes. + 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, + we can get the dimension that needs to be split, and the split dimension is contiguous + in all other indexing_exprs. + + For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: + {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, + we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to + {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to + {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. + """ + + # No reduction and no mudular + if any( + len(node.group[1][1]) != 0 + or any( + expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() + ) + for node in nodes + ): + return nodes + + split_var = None + split_number = None + divide_index_name = None + num_div = 0 + match_div = False + matched_node = None + + for node in nodes: + assert isinstance(node.node, ir.ComputedBuffer) + _, original_body, _ = node.node.get_default_sizes_body() + for name, expr in original_body.indexing_exprs.items(): + num_div += expr.count(FloorDiv) + if num_div > 1: + return nodes + if expr.count(FloorDiv) == 1: + div_expr = expr.find(FloorDiv).pop() + split_var = div_expr.args[0] + split_number = div_expr.args[1] + divide_index_name = name + if ( + isinstance(split_number, sympy.core.numbers.Integer) + and isinstance(split_var, sympy.core.symbol.Symbol) + and divide_index_name is not None + and all( + stride_at_vec_range(expr, split_var) == 1 + for name, expr in original_body.indexing_exprs.items() + if name != divide_index_name + ) + ): + match_div = True + matched_node = node + + # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. + if not match_div: + return nodes + + extra_indexing_constraints = None + + def loop_split(sizes, body, vars): + index_size, reduce_size = sizes + index_vars, reduce_vars = vars + split_idx = index_vars.index(split_var) + new_index_size = index_size.copy() + new_index_size[split_idx] = index_size[split_idx] // split_number + new_index_size.insert(split_idx + 1, split_number) + (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( + new_index_size, reduce_size, prefix="y" + ) + iter_vars = new_index_vars.copy() + divisor_var = iter_vars.pop(split_idx + 1) + iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var + body = ir.LoopBody( + body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars + ) + nonlocal extra_indexing_constraints + if not extra_indexing_constraints: + extra_indexing_constraints = ( + body.var_ranges, + list(body.indexing_exprs.values()), + ) + return ( + (new_index_size, reduce_size), + body, + (new_index_vars, reduce_vars), + ) + + # Here decide the final loop order + for node in nodes: + if node == matched_node: + node.recompute_size_and_body(recompute_sizes_body_func=loop_split) + for node in nodes: + if node != matched_node: + node.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=loop_split, + ) + + return nodes + def codegen_outer_loop_node( self, node: OuterLoopFusedSchedulerNode, @@ -4287,6 +4396,7 @@ def codegen_node( self.codegen_outer_loop_node(node) else: nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + nodes = self.try_loop_split(nodes) cpp_kernel_proxy = CppKernelProxy(kernel_group) cpp_kernel_proxy.codegen_nodes(nodes) kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 55d43674807e0d..121a52d7405fb3 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3739,6 +3739,7 @@ def get_default_sizes_body(self): def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ): """ This is a main place where we do loop transformations in a @@ -3754,6 +3755,8 @@ def simplify_and_reorder( to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) on CPU by preventing indexing simplifications and obtaining index/reduce ranges for the scheduler node compatible with other nodes. + Optional argument recompute_sizes_body_func can be used to recompute sizes and body + on the default body. This can be useful to append additional loop transformations. """ ( (index_size, reduce_size), @@ -3761,6 +3764,15 @@ def simplify_and_reorder( (index_vars, reduce_vars), ) = self.get_default_sizes_body() + if recompute_sizes_body_func: + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = recompute_sizes_body_func( + (index_size, reduce_size), body, (index_vars, reduce_vars) + ) + index_formulas = [*body.indexing_exprs.values()] if extra_indexing_constraints is not None: assert ( @@ -3937,6 +3949,7 @@ def should_allocate(self): def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ): return ( ( diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index efe93253de6800..d30231d63c9cc1 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -827,10 +827,12 @@ def __init__( def _compute_attrs( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) self._sizes, self._body = self.node.simplify_and_reorder( - extra_indexing_constraints=extra_indexing_constraints + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, ) group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn @@ -855,9 +857,14 @@ def _compute_attrs( ) def recompute_size_and_body( - self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: - self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) + self._compute_attrs( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) def refresh_dependencies(self, normalize: bool) -> None: # Fake dependencies are added manually. They can not be analyzed from From d8c72cbba4699692e356473c2a2aa1764a065323 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Mon, 2 Sep 2024 03:52:39 +0000 Subject: [PATCH 0334/1018] [executorch hash update] update the pinned executorch hash (#134914) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134914 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index cf11cdd5db311b..6d1222c950e8de 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -e7e86478b4807d18ec0ac68c91f6d7e5c6d0f14e +61ddee54058bf3fc7e17eb539d61d39d12cc86f8 From 7e871eb0a5303313cf90b7f93a866598f7932fe4 Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Mon, 2 Sep 2024 05:56:33 +0000 Subject: [PATCH 0335/1018] [Inductor] Allow customizing the padding format (#133939) Based on https://github.com/pytorch/pytorch/pull/130956. Inductor already supports padding through the `config.comprehensive_padding` option, but the padding format involves a few heuristics that are specific to Nvidia GPUs: - When we pad, it is always aligned to the next multiple of 128 bytes. - Strides smaller than 1024 are not padded. - Only intermediate values are padded, not outputs. The last of these is not really GPU-specific, but there are certain cases where we may want to override it. For example, padding outputs is useful on hardware accelerators with specific memory alignment requirements, or for applications where performance is more important than conformity with eager mode. This PR surfaces padding parameters up to Inductor's config module, so the user can control them. - `config.pad_outputs`: choose whether to pad outputs (default: `False`) - `config.padding_alignment_bytes`: choose the alignment size for padding (default: `128`) - `config.padding_stride_threshold`: choose the smallest stride that we will pad. For example, setting this to 0 will pad all unaligned strides. (default: `1024`) **Test plan** Added a new test in `test_padding.py` which tries various combinations of these options, checking that the output strides match our expectations. These changes should not affect perf, because the defaults are identical to Inductor's current behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133939 Approved by: https://github.com/shunting314 Co-authored-by: Yueming Hao --- test/inductor/test_padding.py | 58 +++++++++++++++++++++++++++++++++-- torch/_inductor/config.py | 28 +++++++++++++++++ torch/_inductor/graph.py | 11 ++++--- torch/_inductor/ir.py | 26 ++-------------- 4 files changed, 93 insertions(+), 30 deletions(-) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index c8170c92327166..9ae3dd3a125df8 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -3,6 +3,7 @@ import functools import os import unittest +from typing import Tuple import torch from torch import nn, Tensor @@ -12,8 +13,13 @@ from torch._inductor import config, ir, metrics from torch._inductor.fx_passes import pad_mm as pad_mm_pass from torch._inductor.runtime.benchmarking import benchmarker -from torch._inductor.utils import run_and_get_code -from torch.testing._internal.common_utils import requires_cuda, serialTest +from torch._inductor.utils import ceildiv, run_and_get_code +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + requires_cuda, + serialTest, +) from torch.testing._internal.inductor_utils import HAS_CUDA @@ -362,6 +368,7 @@ def test_longformer_small_bs(self): self.test_longformer(bs=2) +@instantiate_parametrized_tests class PaddingTest(TestCaseBase): @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") def test_mm_padding_perf(self): @@ -654,6 +661,53 @@ def test_pad_channels_last(self): out_strides = ir.Layout._pad_strides(in_strides, t.shape, torch.float32) self.assertTrue(in_strides == out_strides) + @parametrize("alignment_bytes", (32, 128)) + @parametrize("shape", [(21, 19), (3, 5, 71)]) + @parametrize("dtype", (torch.float16, torch.float32)) + def test_pad_outputs( + self, dtype: torch.dtype, shape: Tuple[int], alignment_bytes: int + ): + """ + Tests padding output tensors to a specific alignment. + This is enabled by a config flag. + """ + func = torch.add + inputs = tuple(torch.randn(*shape, dtype=dtype) for input_idx in range(2)) + + # Compile and run + with config.patch( + { + "comprehensive_padding": True, + "padding_alignment_bytes": alignment_bytes, + "padding_stride_threshold": 0, + "pad_outputs": True, + } + ): + compiled_func = torch.compile(func) + compiled_out = compiled_func(*inputs) + + # Check numerics + eager_out = func(*inputs) + self.check_close(eager_out, compiled_out) + + # Compute the expected padding + element_size = torch.tensor([], dtype=dtype).element_size() + self.assertGreater(alignment_bytes, element_size) + self.assertEqual(alignment_bytes % element_size, 0) + alignment_elements = alignment_bytes // element_size + contiguous_stride = inputs[0].stride() + expected_stride = [1] + for dim in reversed(shape[1:]): + slice_size = dim * expected_stride[0] + new_stride = alignment_elements * ceildiv(slice_size, alignment_elements) + expected_stride.insert(0, new_stride) + expected_stride = tuple(expected_stride) + self.assertNotEqual(expected_stride, contiguous_stride) + + # Check strides + self.assertFalse(compiled_out.is_contiguous()) + self.assertEqual(compiled_out.stride(), expected_stride) + if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f6599989f31a4b..919cddb773d7aa 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -595,6 +595,34 @@ def decide_compile_threads() -> int: ) pad_channels_last = False +# The width of comprehensive padding, in bytes. +# CUDA max memory transaction size is 128 bytes for a warp. +padding_alignment_bytes = 128 + +# Threshold on the minimum stride that will be padded. +# +# Don't align a too small stride since that causes too much memory increase. +# Pad too small stride may also cause perf loss. We may result in many tiny data blocks +# with gaps in between. That causes less coalesced GPU memory access! +# +# Initially we pick 320 as the threshold since for alignement=16, +# that results in at most 5% memory cost. +# +# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. +# Let's say an inner reduction has a row size 513. Inductor will generate +# persistent reduction code. +# If we do padding, the strides are not contiguous any more. Inductor +# uses a much smaller threshold for persistent reduction in this case and +# generates potentially worse non-persistent reduction code. +# +# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. +# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) +padding_stride_threshold = 1024 + +# Enable padding outputs, even if they would not be padded in eager mode. +# By default, we use the same strides as eager mode. +pad_outputs = False + # Whether to treat output of the backward graph as user visible. # For user visible outputs, inductor will make sure the stride matches with eager. bw_outputs_user_visible = True diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 987e639657cb0c..c9da83e7584e62 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -247,7 +247,11 @@ def _get_overload_packet( if prior_op not in ops_like_padding: prior.meta["dislike_padding"] = True # We only want to mark output nodes. So, move it after the above prior nodes process. - if user_visible_outputs and cur.name in user_visible_outputs: + if ( + not config.pad_outputs + and user_visible_outputs + and cur.name in user_visible_outputs + ): cur.meta["dislike_padding"] = True @@ -1359,9 +1363,8 @@ def debug(msg: str) -> None: strides = n.meta["val"].stride() if len(strides): allow_padding = ( - n.name not in self.user_visible_outputs - and not is_input_for_as_strided - ) + config.pad_outputs or n.name not in self.user_visible_outputs + ) and not is_input_for_as_strided dense = torch._prims_common.is_non_overlapping_and_dense( n.meta["val"] ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 121a52d7405fb3..b9f08fedfe847c 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2872,12 +2872,7 @@ def is_contiguous_strides_for_shape( def get_align_for_dtype(dtype: torch.dtype) -> int: - """ - CUDA max memory transaction size is 128 bytes for a warp. - We pick `128 // dtype.itemsize` as alighment so GPU can do coalesced - memory access. - """ - return 128 // dtype.itemsize + return config.padding_alignment_bytes // dtype.itemsize @dataclasses.dataclass @@ -3016,29 +3011,12 @@ def _pad_strides(in_strides, size, dtype): # smallest stride to be 1. new_strides[fill_order[0]] = 1 - # Don't align a too small stride since that causes too much memory increase. - # Pad too small stride may also cause perf loss. We may result in many tiny data blocks - # with gaps in between. That causes less coalesced GPU memory access! - # - # Initially we pick 320 as the threshold since for alignement=16, - # that results in at most 5% memory cost. - # - # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. - # Let's say an inner reduction has a row size 513. Inductor will generate - # persistent reduction code. - # If we do padding, the strides are not contiguous any more. Inductor - # uses a much smaller threshold for persistent reduction in this case and - # generates potentially worse non-persistent reduction code. - # - # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. - # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) - align_stride_threshold = 1024 padded = False for rank, idx in enumerate(fill_order[1:], start=1): prev_idx = fill_order[rank - 1] stride = new_strides[prev_idx] * size[prev_idx] - if stride > align_stride_threshold and stride % align != 0: + if stride > config.padding_stride_threshold and stride % align != 0: stride = ceildiv(stride, align) * align padded = True new_strides[idx] = stride From 081f0cd39cdabdb0e9d0bbad3dd698864a4bac3e Mon Sep 17 00:00:00 2001 From: FEI Date: Mon, 2 Sep 2024 06:05:10 +0000 Subject: [PATCH 0336/1018] fix tensor_repr(at::Tensor) (#134762) (#134764) Fixes #134762 @ezyang @antocuni Pull Request resolved: https://github.com/pytorch/pytorch/pull/134764 Approved by: https://github.com/ezyang Co-authored-by: Edward Z. Yang --- torch/csrc/utils.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index fda2c45a9c88b3..cbac2367500b3f 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -267,7 +267,14 @@ char* tensor_repr(at::Tensor tensor) { const char* buf = nullptr; char* result = nullptr; - pytensor = THPVariable_Wrap(std::move(tensor)); + // NB: It's important not to move the tensor into THPVariable_Wrap, + // because this function is only called from our gdb macros, and + // we want to avoid accidentally moving out the tensor. In principle, + // the Tensor signature above should induce a copy, but we've + // observed that sometimes gdb passes the outer Tensor address exactly as is + // into this function. + // See https://github.com/pytorch/pytorch/issues/134762 + pytensor = THPVariable_Wrap(tensor); if (!pytensor) // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) goto error; From 83d12386d60b92a0100953e22ccff1f782540889 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 1 Sep 2024 17:30:02 +0000 Subject: [PATCH 0337/1018] Add a test to avoid decorator based regression for cprofile traces (#133086) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/133086 Approved by: https://github.com/aorenste --- test/inductor/test_torchinductor.py | 11 +++++++++++ torch/_inductor/config.py | 6 ++++++ torch/_inductor/scheduler.py | 21 +++++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 3127fae65976c8..698f3bfde5d9ae 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8134,6 +8134,17 @@ def fn(x): self.common(fn, [torch.zeros([20, 20])]) + @config.patch(check_stack_no_cycles_TESTING_ONLY=True) + def test_check_stack_no_cycles(self): + @torch.compile() + def fn(x): + return x * 3 + + r = fn(torch.randn(2, device=self.device, requires_grad=True)) + # Backward compilation isn't hooked into cprofile, it probably + # should... + # r.sum().backward() + def test_like_rands2(self): # rand_like with kwargs `device` of str type d = self.device diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 919cddb773d7aa..c5429761abf452 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -704,6 +704,12 @@ def decide_compile_threads() -> int: # set unless you know what you're doing. unsafe_ignore_unsupported_triton_autotune_args: bool = False +# When True, we will check in scheduler.py _codegen that there are no "loops" +# in the call stack; that is to say, the same frame multiple times. This +# ensures that a cProfile trace to this frame will be a straight line without +# any cycles. +check_stack_no_cycles_TESTING_ONLY: bool = False + # config specific to codegen/cpp.py class cpp: diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index d30231d63c9cc1..07d8cf4f421daf 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -11,6 +11,7 @@ import os import pprint import textwrap +import traceback import typing from typing import ( Any, @@ -3379,6 +3380,26 @@ def codegen(self) -> None: return self._codegen() def _codegen(self) -> None: + if config.check_stack_no_cycles_TESTING_ONLY: + import torch._dynamo.convert_frame + + stack = traceback.extract_stack() + seen = set() + for frame in reversed(stack): + # This is where maybe_cprofile is + if ( + frame.name == "_compile_inner" + and frame.filename == torch._dynamo.convert_frame.__file__ + ): + break + key = (frame.filename, frame.lineno) + assert key not in seen, ( + f"Duplicate stack frame {frame.filename}:{frame.lineno}; " + "did you add a decorator to one of the functions in this stack " + "trace? If so, try using a context manager instead." + ) + seen.add(key) + for node in self.nodes: try: log.debug( From 6ce77ab5d03d2cd493bcdffa9dbd59a015d7695c Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Mon, 2 Sep 2024 13:30:30 +0000 Subject: [PATCH 0338/1018] Corrected docstring of ``solve_triangular`` (#129766) **Description** The arguments docstring of [torch.linalg.solve_triangular](https://pytorch.org/docs/stable/generated/torch.linalg.solve_triangular.html#torch.linalg.solve_triangular) incorrectly describes the shape of the ``A`` argument if the option ``left=True``. The argument ``A`` should have shape $k \times k$ if ``left=False`` in line with the rest of the docstring and the implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129766 Approved by: https://github.com/lezcano --- torch/linalg/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index f62569f7a9557e..cef76fec1107d5 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -2232,7 +2232,7 @@ equations with a unique solution. Args: - A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`) + A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= False`) where `*` is zero or more batch dimensions. B (Tensor): right-hand side tensor of shape `(*, n, k)`. From 4983076f461a3fb5345b278ae9b07753ec75ce77 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 2 Sep 2024 00:43:16 +0800 Subject: [PATCH 0339/1018] [BE] better type annotation for `torch.types` (#129559) Closes #129525 - #129525 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129559 Approved by: https://github.com/ezyang --- torch/types.py | 96 ++++++++++++++++++++++++++++---------------------- 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/torch/types.py b/torch/types.py index 7725716909196e..536af3eef8c9d1 100644 --- a/torch/types.py +++ b/torch/types.py @@ -1,7 +1,5 @@ # mypy: allow-untyped-defs -import builtins - # In some cases, these basic types are shadowed by corresponding # top-level values. The underscore variants let us refer to these # types. See https://github.com/python/mypy/issues/4146 for why these @@ -14,100 +12,114 @@ int as _int, str as _str, ) -from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union - -import torch -from torch import SymBool, SymFloat, SymInt +from typing import Any, Dict, List, Sequence, Tuple, TYPE_CHECKING, Union +from typing_extensions import TypeAlias + +# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType` +from torch import ( # noqa: F401 + device as _device, + DispatchKey as DispatchKey, + dtype as _dtype, + layout as _layout, + qscheme as _qscheme, + Size as Size, + SymBool as SymBool, + SymFloat as SymFloat, + SymInt as SymInt, + Tensor as Tensor, +) if TYPE_CHECKING: from torch.autograd.graph import GradientEdge + __all__ = ["Number", "Device", "Storage"] # Convenience aliases for common composite types that we need # to talk about in PyTorch - -_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]] -_TensorOrTensorsOrGradEdge = Union[ - torch.Tensor, - Sequence[torch.Tensor], +_TensorOrTensors: TypeAlias = Union[Tensor, Sequence[Tensor]] # noqa: PYI047 +_TensorOrTensorsOrGradEdge: TypeAlias = Union[ # noqa: PYI047 + Tensor, + Sequence[Tensor], "GradientEdge", Sequence["GradientEdge"], ] -_dtype = torch.dtype -_device = torch.device -_qscheme = torch.qscheme -_layout = torch.layout -_size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]] -_symsize = Union[torch.Size, Sequence[Union[_int, SymInt]]] -_dispatchkey = Union[builtins.str, torch._C.DispatchKey] +_size: TypeAlias = Union[Size, List[int], Tuple[int, ...]] # noqa: PYI042,PYI047 +_symsize: TypeAlias = Union[Size, Sequence[Union[int, SymInt]]] # noqa: PYI042,PYI047 +_dispatchkey: TypeAlias = Union[str, DispatchKey] # noqa: PYI042,PYI047 # int or SymInt -IntLikeType = Union[_int, torch.SymInt] +IntLikeType: TypeAlias = Union[int, SymInt] +# float or SymFloat +FloatLikeType: TypeAlias = Union[float, SymFloat] +# bool or SymBool +BoolLikeType: TypeAlias = Union[bool, SymBool] py_sym_types = (SymInt, SymFloat, SymBool) -PySymType = Union[SymInt, SymFloat, SymBool] +PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool] # Meta-type for "numeric" things; matches our docs -Number = Union[builtins.int, builtins.float, builtins.bool] +Number: TypeAlias = Union[int, float, bool] # Meta-type for "device-like" things. Not to be confused with 'device' (a # literal device object). This nomenclature is consistent with PythonArgParser. # None means use the default device (typically CPU) -Device = Optional[Union[_device, builtins.str, builtins.int]] -del Optional - -# Storage protocol implemented by ${Type}StorageBase classes +Device: TypeAlias = Union[_device, str, int, None] +# Storage protocol implemented by ${Type}StorageBase classes class Storage: - _cdata: _int - device: torch.device - dtype: torch.dtype - _torch_load_uninitialized: _bool + _cdata: int + device: _device + dtype: _dtype + _torch_load_uninitialized: bool - def __deepcopy__(self, memo: dict) -> "Storage": + def __deepcopy__(self, memo: Dict[int, Any]) -> "Storage": raise NotImplementedError - def _new_shared(self, size: _int) -> "Storage": + def _new_shared(self, size: int) -> "Storage": raise NotImplementedError def _write_file( self, f: Any, - is_real_file: _bool, - save_size: _bool, - element_size: _int, + is_real_file: bool, + save_size: bool, + element_size: int, ) -> None: raise NotImplementedError - def element_size(self) -> _int: + def element_size(self) -> int: raise NotImplementedError - def is_shared(self) -> _bool: + def is_shared(self) -> bool: raise NotImplementedError def share_memory_(self) -> "Storage": raise NotImplementedError - def nbytes(self) -> _int: + def nbytes(self) -> int: raise NotImplementedError def cpu(self) -> "Storage": raise NotImplementedError - def data_ptr(self) -> _int: + def data_ptr(self) -> int: raise NotImplementedError def from_file( self, - filename: _str, - shared: _bool = False, - nbytes: _int = 0, + filename: str, + shared: bool = False, + nbytes: int = 0, ) -> "Storage": raise NotImplementedError - def _new_with_file(self, f: Any, element_size: _int) -> "Storage": + def _new_with_file( + self, + f: Any, + element_size: int, + ) -> "Storage": raise NotImplementedError From 5149b575cbe43a731d2ad4cf94e942c25d0d092e Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Mon, 2 Sep 2024 23:38:39 +0000 Subject: [PATCH 0340/1018] [ONNX] Delete `op_level_debug` from `torch.onnx.ExportOptions` (#134961) op_level_debug helped to identify missing operators, and wrongly implemented operators at the time that dynamo exporter relied on nearest matching and torchlib was just created. However, right now, with dispatcher logic improved and torchlib becomes mature, we no longer need it. PS: op-level-debug diagnostics rule is not deleted in this PR, as it auto generates lint error code, and need more time to fix. We can delete it when we retire sarif. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134961 Approved by: https://github.com/justinchuby --- test/onnx/test_fx_to_onnx.py | 4 +- torch/onnx/_internal/_exporter_legacy.py | 10 - .../onnx/_internal/fx/fx_onnx_interpreter.py | 32 +- torch/onnx/_internal/fx/op_validation.py | 381 ------------------ torch/onnx/_internal/onnxruntime.py | 3 - 5 files changed, 2 insertions(+), 428 deletions(-) delete mode 100644 torch/onnx/_internal/fx/op_validation.py diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 747c922a549542..04f132abc9c9aa 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -192,9 +192,7 @@ def forward(self, input): return self.conv2(input) x = torch.randn(20, 16, 50, 50) - onnx_program = dynamo_export( - TraceModel(), x, export_options=ExportOptions(op_level_debug=False) - ) + onnx_program = dynamo_export(TraceModel(), x) assert_has_diagnostics( onnx_program.diagnostic_context, diagnostics.rules.find_opschema_matched_symbolic_function, diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index df1ae10c9c26e2..52f37460592cd4 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -263,7 +263,6 @@ class ExportOptions: When ``None``, the exporter determines the most compatible setting. When ``True``, all input shapes are considered dynamic. When ``False``, all input shapes are considered static. - op_level_debug: Whether to export the model with op-level debug information diagnostic_options: The diagnostic options for the exporter. fake_context: The fake context used for symbolic tracing. onnx_registry: The ONNX registry used to register ATen operators to ONNX functions. @@ -277,9 +276,6 @@ class ExportOptions: - ``False``: all input shapes are considered static. """ - op_level_debug: bool | None = None - """When True export the model with op-level debug running ops through ONNX Runtime.""" - diagnostic_options: DiagnosticOptions """The diagnostic options for the exporter.""" @@ -293,13 +289,11 @@ def __init__( self, *, dynamic_shapes: bool | None = None, - op_level_debug: bool | None = None, fake_context: ONNXFakeContext | None = None, onnx_registry: OnnxRegistry | None = None, diagnostic_options: DiagnosticOptions | None = None, ): self.dynamic_shapes = dynamic_shapes - self.op_level_debug = op_level_debug self.fake_context = fake_context self.onnx_registry = onnx_registry self.diagnostic_options = diagnostic_options or DiagnosticOptions() @@ -313,7 +307,6 @@ class ResolvedExportOptions(ExportOptions): # Public attributes MUST be redefined below without ``Optional[]`` from ``ExportOptions`` dynamic_shapes: bool - op_level_debug: bool diagnostic_options: DiagnosticOptions fake_context: ONNXFakeContext onnx_registry: OnnxRegistry @@ -347,7 +340,6 @@ def __init__( if isinstance(options, ResolvedExportOptions): self.dynamic_shapes = options.dynamic_shapes - self.op_level_debug = options.op_level_debug self.diagnostic_options = options.diagnostic_options self.fake_context = options.fake_context # private @@ -400,7 +392,6 @@ def resolve(value: T | None, fallback: T | Callable[[], T]) -> T: from torch.onnx._internal.fx import onnxfunction_dispatcher - self.op_level_debug = resolve(options.op_level_debug, False) self.onnxfunction_dispatcher = ( onnxfunction_dispatcher.OnnxFunctionDispatcher( self.onnx_registry, @@ -1213,7 +1204,6 @@ def export(self) -> ONNXProgram: onnxscript_graph = fx_interpreter.run( fx_graph_module=graph_module, onnxfunction_dispatcher=self.options.onnxfunction_dispatcher, - op_level_debug=self.options.op_level_debug, ) # NOTE: Filter out the initializers with fake tensors when it's fake_mode exporting. diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index 01b851a49421d3..b81c7254751ba8 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -5,7 +5,6 @@ import logging import operator import re -import types from typing import Callable, Sequence import onnxscript # type: ignore[import] @@ -20,7 +19,6 @@ _pass, diagnostics, onnxfunction_dispatcher, - op_validation, type_utils as fx_type_utils, ) from torch.utils import _pytree @@ -396,7 +394,6 @@ def run_node( node, fx_graph_module: torch.fx.GraphModule, onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - op_level_debug: bool, onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, fx_name_to_onnxscript_value: dict[ @@ -411,7 +408,6 @@ def run_node( node: The FX node to be translated. fx_graph_module: The FX graph module containing the node. onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op. - op_level_debug (bool): Whether to enable op level debug. onnxscript_graph: The ONNX graph to be populated. onnxscript_tracer: The tracer to trace the ONNX graph. fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value. @@ -446,7 +442,6 @@ def run_node( onnxscript_tracer, fx_name_to_onnxscript_value, onnxfunction_dispatcher, - op_level_debug, fx_graph_module, ) elif node.op == "call_method": @@ -459,7 +454,6 @@ def run_node( onnxscript_tracer, fx_graph_module, onnxfunction_dispatcher, - op_level_debug, ) elif node.op == "output": self.output(node, onnxscript_graph, fx_name_to_onnxscript_value) @@ -474,7 +468,6 @@ def run( self, fx_graph_module: torch.fx.GraphModule, onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - op_level_debug: bool, parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph | None = None, ) -> onnxscript_graph_building.TorchScriptGraph: @@ -483,7 +476,6 @@ def run( Args: fx_graph_module: FX graph module to be translated. onnxfunction_dispatcher: ONNX function dispatcher. - op_level_debug: Whether to enable op-level debug. parent_onnxscript_graph: The parent TorchScript graph. Must be provided if `fx_graph_module` is a submodule. If not provided, `fx_graph_module` is assumed to be the root module. @@ -541,13 +533,11 @@ def run( # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible # with FakeTensorMode. with torch.utils._mode_utils.no_dispatch(): - # node_fixed_shape is only used on op_level_debug purpose. for node in fx_graph_module.graph.nodes: self.run_node( node, fx_graph_module, onnxfunction_dispatcher, - op_level_debug, onnxscript_graph, onnxscript_tracer, fx_name_to_onnxscript_value, @@ -621,7 +611,6 @@ def call_function( | tuple[onnxscript_graph_building.TorchScriptTensor, ...], ], onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - op_level_debug: bool, fx_graph_module: torch.fx.GraphModule, ): # aten ops and other stateless functions. @@ -675,23 +664,6 @@ def call_function( assert isinstance( output, (onnxscript_graph_building.TorchScriptTensor, tuple) ), type(output) - # NOTE(titaiwang): We bypass two kinds of ops as it's not meaningful to - # validate them with op level debug. - # 1. aten::sym_size: The op is simply get item from a list of tensors. - # 2. BuiltinFunction: It doesn't supported tensor - if ( - op_level_debug - and node.target != torch.ops.aten.sym_size - and not isinstance(node.target, types.BuiltinFunctionType) - ): - op_validation.validate_op_between_ort_torch( - self.diagnostic_context, - node, - symbolic_fn, - fx_args, - fx_kwargs, - fx_graph_module, - ) fx_name_to_onnxscript_value[node.name] = output def output( @@ -734,7 +706,6 @@ def call_module( tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, root_fx_graph_module: torch.fx.GraphModule, onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - op_level_debug: bool, ) -> None: """Export a fx.GraphModule submodule to ONNXScript graph. @@ -753,7 +724,6 @@ def call_module( tracer: The tracer used to trace the ONNXScript graph. root_fx_graph_module: The root FX module. onnxfunction_dispatcher: The dispatcher. - op_level_debug: Whether to enable op-level debug. """ assert isinstance( node.target, str @@ -766,7 +736,7 @@ def call_module( ), f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." sub_onnxscript_graph = self.run( - sub_module, onnxfunction_dispatcher, op_level_debug, parent_onnxscript_graph + sub_module, onnxfunction_dispatcher, parent_onnxscript_graph ) onnx_args, _ = _wrap_fx_args_as_onnxscript_args( diff --git a/torch/onnx/_internal/fx/op_validation.py b/torch/onnx/_internal/fx/op_validation.py deleted file mode 100644 index 834e1157b20d4d..00000000000000 --- a/torch/onnx/_internal/fx/op_validation.py +++ /dev/null @@ -1,381 +0,0 @@ -# mypy: allow-untyped-defs -"""Module for handling op-level validation during exporting.""" - -from __future__ import annotations - -import logging -from typing import Any, Callable, Sequence - -import onnxscript # type: ignore[import] -from onnxscript import evaluator # type: ignore[import] - -import torch -import torch.fx -from torch.fx.experimental import symbolic_shapes -from torch.onnx import _constants, _type_utils as jit_type_utils -from torch.onnx._internal.fx import ( - diagnostics, - fx_onnx_interpreter, - type_utils as fx_type_utils, -) -from torch.utils import _pytree - - -def _op_level_debug_message_formatter( - fn: Callable, - self, - node: torch.fx.Node, - symbolic_fn: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - *args, - **kwargs, -) -> str: - return ( - f"FX Node: {node.op}::{node.target}[name={node.name}]. \n" - f"ONNX Node: {symbolic_fn.name}[opset={symbolic_fn.opset}]." - ) - - -@diagnostics.diagnose_call( - diagnostics.rules.op_level_debugging, - diagnostic_message_formatter=_op_level_debug_message_formatter, -) -def validate_op_between_ort_torch( - diagnostic_context: diagnostics.DiagnosticContext, - node: torch.fx.Node, - symbolic_fn: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - fx_args: list[fx_type_utils.Argument], - fx_kwargs: dict[str, fx_type_utils.Argument], - fx_graph_module: torch.fx.GraphModule, -): - """Validate the op between ONNX Runtime and PyTorch. - - The function will run the op in ONNX Runtime and PyTorch and compare the - results. It doesn't break the exporting process, but saves each op validated - result into SARIF, under the section of `fx_onnx_interpreter`. - - There are three signs can be found: - 1. Blue: Pass - 2. Yellow: Bypass - - Args: - node (torch.fx.Node): The validated fx.node - symbolic_fn (Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]): The corresponded ONNX node - torch_args (list): torch argument inputs - torch_kwargs (dict): torch keyword argument inputs - fx_graph_module (torch.fx.GraphModule): The fx.GraphModule that contains the nodes - """ - # op-level validation - # Symbolic_fn should have the same output as node.target (torch ops) - - try: - torch_args, torch_kwargs = _wrap_fx_args_as_torch_args( - fx_args, fx_kwargs, fx_graph_module - ) - except ValueError as value_error: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section( - logging.WARNING, "Op level debug fails due to unsupported input types" - ): - diagnostic.log_source_exception(logging.WARNING, value_error) - diagnostic.level = diagnostics.levels.WARNING - return - - with evaluator.default_as(evaluator.ort_evaluator): - try: - expected_outputs = node.target(*torch_args, **torch_kwargs) # type: ignore[operator] - # NOTE: randomly generating indices/dim: INT64 could go out of bounds - except IndexError as index_error: - # TODO(titaiwang): How to bound indices/dim: INT64 - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section(logging.WARNING, "Op level debug is bypassed"): - diagnostic.log_source_exception(logging.WARNING, index_error) - diagnostic.level = diagnostics.levels.WARNING - return - # NOTE: Error in torch ops with random inputs generated from FakTensors - except RuntimeError as runtime_error: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section( - logging.WARNING, "Op level debug fails on PyTorch" - ): - diagnostic.log_source_exception(logging.WARNING, runtime_error) - diagnostic.level = diagnostics.levels.WARNING - return - - try: - ( - function_eager_inputs, - function_eager_attributes, - ) = _convert_torch_args_to_onnxfunction_args( - symbolic_fn.param_schemas(), - torch_args, - torch_kwargs, - allow_extra_kwargs=True, - ) - # NOTE: Apply kwargs preprocessing AFTER they are split - function_eager_attributes = ( - fx_onnx_interpreter.filter_incompatible_and_dtype_convert_kwargs( - function_eager_attributes - ) - ) - # NOTE: Incompatible kwargs or missing required args - except TypeError as type_error: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section(logging.WARNING, "Op level debug is bypassed"): - diagnostic.log_source_exception(logging.WARNING, type_error) - diagnostic.level = diagnostics.levels.WARNING - return - try: - ort_outputs = symbolic_fn( - *function_eager_inputs, **function_eager_attributes - ) - # NOTE: Error in ONNX Runtime with random inputs generated from FakTensors - except RuntimeError as runtime_error: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section( - logging.WARNING, "Op level debug fails on ONNXRUNTIME" - ): - diagnostic.log_source_exception(logging.WARNING, runtime_error) - diagnostic.level = diagnostics.levels.WARNING - return - - flattened_torch_outputs, _ = _pytree.tree_flatten(expected_outputs) - flattened_function_outputs, _ = _pytree.tree_flatten(ort_outputs) - - assert flattened_torch_outputs - assert len(flattened_torch_outputs) == len(flattened_function_outputs) - - for torch_output, function_output in zip( - flattened_torch_outputs, flattened_function_outputs - ): - if isinstance( - torch_output, torch.Tensor - ) and fx_type_utils.is_torch_complex_dtype(torch_output.dtype): - torch_output = torch.view_as_real(torch_output.resolve_conj()) - try: - if isinstance(function_output, onnxscript.tensor.Tensor): - function_output = function_output.value - - # Use torch.testing as opposed to np.testing to ensure dtypes and shapes match - torch.testing.assert_close( - torch.tensor(function_output).cpu(), - torch_output.cpu() - if isinstance(torch_output, torch.Tensor) - else torch.tensor(torch_output).cpu(), - rtol=1e-4, - atol=1e-3, - ) - except AssertionError as e: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section(logging.WARNING, "Validation failed"): - diagnostic.log_source_exception(logging.WARNING, e) - diagnostic.level = diagnostics.levels.WARNING - - -def _convert_symint_to_int_in_shape(shape: torch.Size) -> torch.Size: - """Convert SymInt to int in shape - - Args: - shape (torch.Size): The shape of a tensor - Raises: - ValueError: When SymInt is found in shape - Returns: - torch.Size: The shape of a tensor with SymInt converted to int - - """ - list_int_shape = [] - for dim in shape: - if isinstance(dim, torch.SymInt): - if symbolic_shapes.has_hint(dim): - list_int_shape.append(symbolic_shapes.hint_int(dim)) - else: - raise ValueError( - f"An unbacked SymInt found in shape. SymInt: {dim}; " - f"torch.Size: {shape}. There is no hint for SymInt." - ) - else: - list_int_shape.append(dim) - return torch.Size(list_int_shape) - - -def generate_random_tensors(shape: torch.Size, dtype: torch.dtype): - shape = _convert_symint_to_int_in_shape(shape) - - if dtype == torch.uint8: - return torch.randint( - low=_constants.UINT8_MIN, high=_constants.UINT8_MAX, size=shape, dtype=dtype - ) - if dtype == torch.int8: - return torch.randint( - low=_constants.INT8_MIN, high=_constants.INT8_MAX, size=shape, dtype=dtype - ) - if dtype == torch.int16: - return torch.randint( - low=_constants.INT16_MIN, high=_constants.INT16_MAX, size=shape, dtype=dtype - ) - if dtype == torch.int32: - return torch.randint( - low=_constants.INT32_MIN, high=_constants.INT32_MAX, size=shape, dtype=dtype - ) - if dtype == torch.int64: - return torch.randint( - low=_constants.INT64_MIN, high=_constants.INT64_MAX, size=shape, dtype=dtype - ) - if dtype == torch.bool: - random_numbers = torch.rand(shape) - return torch.where( - random_numbers > 0.5, torch.tensor(True), torch.tensor(False) - ) - if fx_type_utils.is_torch_complex_dtype(dtype): - # ONNX does not support complex values, but supports their real representation - return torch.view_as_complex( - torch.randn((*shape, 2), dtype=fx_type_utils.from_complex_to_float(dtype)) - ) - return torch.randn(shape, dtype=dtype) - - -def _fx_args_to_torch_args( - fx_args: list[fx_type_utils.Argument], fx_graph_module: torch.fx.GraphModule -) -> list[fx_type_utils.Argument]: - """Recursively convert fx args to torch args""" - wrapped_args: list[fx_type_utils.Argument] = [] - for arg in fx_args: - if isinstance(arg, torch.fx.Node): - fake_tensor = arg.meta.get("val") - if fake_tensor is None and arg.op == "get_attr": - fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator, arg-type] - # NOTE: Currently, we are aware of - # FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in - # arg.meta["val"]/get_attr. - if isinstance(fake_tensor, torch.Tensor): - real_tensor = generate_random_tensors( - fake_tensor.shape, fake_tensor.dtype - ) - wrapped_args.append(real_tensor) - elif isinstance(fake_tensor, (int, float, bool)): - wrapped_args.append(fake_tensor) - elif symbolic_shapes.has_hint(fake_tensor): # type: ignore[arg-type] - wrapped_args.append(symbolic_shapes.hint_int(fake_tensor)) # type: ignore[arg-type] - else: - raise ValueError( - f"Unexpected input argument type found inside fx.Node. arg: {arg}; " - f"arg.meta['val']/get_attr: {fake_tensor}; type(arg.meta['val']/get_attr): " - f"{type(fake_tensor)}." - ) - elif isinstance(arg, Sequence): - wrapped_args.append(_fx_args_to_torch_args(arg, fx_graph_module)) # type: ignore[arg-type] - elif isinstance(arg, (int, float, torch.dtype)) or arg is None: - wrapped_args.append(arg) - elif isinstance(arg, torch.device): - wrapped_args.append(str(arg)) - else: - raise ValueError( - f"Unexpected input argument type is found in node arguments. arg: {arg}; " - ) - - return wrapped_args - - -def _wrap_fx_args_as_torch_args( - fx_args: list[fx_type_utils.Argument], - fx_kwargs: dict[str, fx_type_utils.Argument], - fx_graph_module: torch.fx.GraphModule, -) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]: - """Prepare torch format args and kwargs for op-level validation by using fake tensor to create real tensor to feed in ops""" - - # NOTE: This function only supports FakeTensor with concrete shapes - torch_args: list[fx_type_utils.Argument] = _fx_args_to_torch_args( - fx_args, fx_graph_module - ) - return torch_args, fx_kwargs - - -# NOTE: Referenced from onnxscript internal function: _tag_arguments_with_param_schemas. -def _convert_torch_args_to_onnxfunction_args( - param_schemas: Sequence[onnxscript.values.ParamSchema], - args: list[fx_type_utils.Argument], - kwargs: dict[str, fx_type_utils.Argument], - allow_extra_kwargs: bool = False, -) -> tuple[list[Any], dict[str, Any]]: - """Convert Python args and kwargs to OnnxFunction acceptable with matching ONNX ParamSchema. - - NOTE: This is different from the param_schema separating in dispatcher, since at this point - we are already sure that the args and kwargs are in order and matched. - - Args: - param_schemas: The parameter schemas of an Op or a OnnxFunction. - args: The Python positional arguments supplied by the caller. - kwargs: The Python keyword arguments supplied by the caller. - allow_extra_kwargs: Whether to allow extra keyword arguments. - When set to True, extra/unknown arguments will be ignored. - - Returns: - A tuple of two elements: - - A list of Python positional argument. - - An ordered dictionary of Python keyword argument names and its values. - - Raises: - TypeError: When allow_extra_kwargs is False and there are unknown kwargs. - TypeError: When a required input is not provided. - """ - # args, kwargs and param_schemas should be all in order - # user may not specify all inputs or attributes - - all_param_names = {param.name for param in param_schemas} - extra_kwargs = set(kwargs).difference(all_param_names) - if extra_kwargs and not allow_extra_kwargs: - raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'") - - tagged_args: list[Any] = [] - tagged_kwargs: dict[str, Any] = {} - - for i, param in enumerate(param_schemas): - if param.is_variadic_input: - # Exhaust all remaining args - tagged_args.extend(arg for arg in args[i:]) - args = [] - continue - if i < len(args): - if param.is_input or isinstance(args[i], torch.dtype): - tagged_args.append(_convert_tensor_to_numpy(args[i])) - else: - tagged_args.append(args[i]) - elif param.name in kwargs: - if param.is_input: - tagged_kwargs[param.name] = _convert_tensor_to_numpy(kwargs[param.name]) - else: - tagged_kwargs[param.name] = kwargs[param.name] - elif param.required: - raise TypeError(f"Required input/attribute '{param}' was not provided") - - return tagged_args, tagged_kwargs - - -def _convert_tensor_to_numpy(input: fx_type_utils.Argument) -> Any: - try: - import numpy as np - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - f"{__name__} needs numpy, but it's not installed." - ) from exc - - if isinstance(input, torch.Tensor): - if torch.is_complex(input): - # from complex to real representation - input = torch.view_as_real(input.resolve_conj()) - return input.detach().cpu().numpy() - if isinstance(input, torch.dtype): - return int(jit_type_utils.JitScalarType.from_dtype(input).onnx_type()) # type: ignore[union-attr,call-overload] - if isinstance(input, (tuple, list)): - if len(input) == 0: - return np.array((), dtype=np.int64) - if isinstance(input[0], torch.Tensor): - return [_convert_tensor_to_numpy(x) for x in input] - if isinstance(input[0], bool): - return np.array(input, dtype=np.bool_) - - # Just a sequence of numbers - if isinstance(input[0], int): - return np.array(input, dtype=np.int64) - if isinstance(input[0], float): - return np.array(input) - return input diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py index accc755dd8840c..6a8fd4d4fefaf3 100644 --- a/torch/onnx/_internal/onnxruntime.py +++ b/torch/onnx/_internal/onnxruntime.py @@ -966,7 +966,6 @@ def maybe_map_to_meta_val(value): exported = fx_interpreter.run( fx_graph_module=graph_module, onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, - op_level_debug=self._resolved_onnx_exporter_options.op_level_debug, ) # Convert the exported result to ONNX ModelProto. onnx_model = exported.to_model_proto( @@ -1211,8 +1210,6 @@ def reusable(a: OrtBackendOptions, b: OrtBackendOptions): if a.export_options is not None and b.export_options is not None: return ( a.export_options.dynamic_shapes == b.export_options.dynamic_shapes - and a.export_options.op_level_debug - == b.export_options.op_level_debug and a.export_options.diagnostic_options == b.export_options.diagnostic_options and a.export_options.onnx_registry is b.export_options.onnx_registry From e780b48cb61f2aca6f2384ab925e9c7b62d80e2a Mon Sep 17 00:00:00 2001 From: CaoE Date: Sun, 1 Sep 2024 03:57:22 -0700 Subject: [PATCH 0341/1018] Add specializations for vectorized conversion between float and BF16/FP16 (#126500) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126500 Approved by: https://github.com/jgong5, https://github.com/jansel --- aten/src/ATen/cpu/vec/vec256/vec256_convert.h | 38 +++++++++++++++++++ aten/src/ATen/cpu/vec/vec512/vec512_convert.h | 38 +++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h index 7242e6facacaef..b0f109fc875026 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h @@ -43,6 +43,26 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + template <> struct VecConvert { static inline VectorizedN apply(const VectorizedN& src) { @@ -52,6 +72,24 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + template <> inline Vectorized convert_to_fp_of_same_size( const Vectorized& src); diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h index fcdfa3d5934c97..78c7045fb30e3f 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h @@ -43,6 +43,26 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + template <> struct VecConvert { static inline VectorizedN apply(const VectorizedN& src) { @@ -52,6 +72,24 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + template <> struct VecConvert { static inline VectorizedN apply( From 473e73149eb8f7f0ee11626a6a0e8c960789f029 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Mon, 2 Sep 2024 17:59:05 -0700 Subject: [PATCH 0342/1018] [Inductor][CPP] Turns on inline_inbuilt_nn_modules for CPP GEMM template testing (#132487) **Summary** The CPP GEMM template testing has been skipped with turning on `inline_inbuilt_nn_modules ` as in https://github.com/pytorch/pytorch/issues/131929. Since https://github.com/pytorch/pytorch/pull/132334 has landed to fix the issues. Turn on this flag back since it's default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132487 Approved by: https://github.com/anijain2305, https://github.com/jgong5 --- test/inductor/test_cpu_select_algorithm.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 9a696d8735918d..0797e63d9aa974 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -57,8 +57,7 @@ def skip_cache(self, choices, name, key, benchmark): for patcher in [ dynamo_config.patch(verbose=True), - # Fails due to https://github.com/pytorch/pytorch/issues/131929 - dynamo_config.patch(inline_inbuilt_nn_modules=False), + dynamo_config.patch(inline_inbuilt_nn_modules=True), inductor_config.patch( debug=True, max_autotune=True, @@ -223,6 +222,7 @@ def forward(self, x): ), ) @dtypes(torch.float, torch.bfloat16, torch.half) + @torch.fx.experimental._config.patch(use_duck_shape=False) def test_linear_with_pointwise( self, batch_size, in_features, out_features, bias, epilogue, dtype ): @@ -253,6 +253,13 @@ def forward(self, x): and epilogue != "mul" and epilogue != "div" or (dtype == torch.half and epilogue == "add" and not bias) + or ( + dtype == torch.float32 + and epilogue == "add" + and not bias + and dynamo_config.dynamic_shapes + and not dynamo_config.assume_static_by_default + ) ): # Several scenarios where epilogue fusion is not counted in: # 1. For bfloat16, the epilogue fusion is part of the template, @@ -261,6 +268,8 @@ def forward(self, x): # div fusion which is not supported for oneDNN linear. # 2. For float16, since oneDNN linear is not applied, linear w/o bias # plus epilogue add is treated as linear w/ bias. + # 3. For float32, when dynamic shapes is enabled, mkl linear is not applied. + # and linear w/o bias plus epilogue add is treated as addmm. self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0) else: self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) From 117e5ed47ba46cf75bb18c90ecdab4e367c05a2f Mon Sep 17 00:00:00 2001 From: chilli Date: Mon, 2 Sep 2024 19:17:26 -0700 Subject: [PATCH 0343/1018] add msg to _assert_async (#134813) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134813 Approved by: https://github.com/ezyang, https://github.com/eqy, https://github.com/albanD --- aten/src/ATen/native/cuda/TensorCompare.cu | 30 ++++++++++++++-------- c10/macros/Macros.h | 17 ++++++++++++ 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index f6956405174261..1a3ee09ac931c3 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -101,33 +101,41 @@ REGISTER_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl); REGISTER_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl); REGISTER_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl); +struct Msg { + static constexpr size_t MAX_MSG_LENGTH = 256; + char msg[MAX_MSG_LENGTH]; +}; template -__global__ void _assert_async_cuda_kernel(const scalar_t* input) { - CUDA_KERNEL_ASSERT(input[0] != 0); +__global__ void _assert_async_cuda_kernel(const scalar_t* input, Msg msg) { + CUDA_KERNEL_ASSERT_MSG(input[0] != 0, msg.msg); } -__global__ void _assert_async_cuda_kernel(const c10::complex* input) { - CUDA_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +__global__ void _assert_async_cuda_kernel(const c10::complex* input, Msg msg) { + CUDA_KERNEL_ASSERT_MSG(input[0] != c10::complex(0, 0), msg.msg); } -__global__ void _assert_async_cuda_kernel(const c10::complex* input) { - CUDA_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +__global__ void _assert_async_cuda_kernel(const c10::complex* input, Msg msg) { + CUDA_KERNEL_ASSERT_MSG(input[0] != c10::complex(0, 0), msg.msg); } -void _assert_async_cuda(const Tensor& self_tensor) { +void _assert_async_msg_cuda(const Tensor& self_tensor, c10::string_view assert_msg) { const TensorBase &self = get_tensor_base(self_tensor); auto n = self.numel(); TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous"); TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous"); auto stream = at::cuda::getCurrentCUDAStream(); + Msg msg; + size_t copy_length = assert_msg.length(); + TORCH_CHECK(copy_length < Msg::MAX_MSG_LENGTH - 1, "Message length must be smaller than " + std::to_string(Msg::MAX_MSG_LENGTH - 1)); + std::copy_n(assert_msg.data(), copy_length, msg.msg); + msg.msg[copy_length] = '\0'; // Ensure null-termination AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_assert_async_cuda", [&] { - _assert_async_cuda_kernel<<<1, 1, 0, stream>>>(self.const_data_ptr()); + _assert_async_cuda_kernel<<<1, 1, 0, stream>>>(self.const_data_ptr(), msg); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } -// TODO (tmanlaibaatar) Ignore assert msg for now -void _assert_async_msg_cuda(const Tensor& self_tensor, c10::string_view assert_msg) { - _assert_async_cuda(self_tensor); +void _assert_async_cuda(const Tensor& self_tensor) { + _assert_async_msg_cuda(self_tensor, ""); } } // namespace at::native diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 51b54105297082..ab6f2b38cf6be7 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -345,6 +345,7 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; #if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) // Those platforms do not support assert() #define CUDA_KERNEL_ASSERT(cond) +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) #define SYCL_KERNEL_ASSERT(cond) #elif defined(_MSC_VER) #if defined(NDEBUG) @@ -372,6 +373,16 @@ __host__ __device__ static_cast(__LINE__)), \ 0); \ } +// TODO: This doesn't assert the message because I (chilli) couldn't figure out +// a nice way to convert a char* to a wchar_t* +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } #define SYCL_KERNEL_ASSERT(cond) \ if (C10_UNLIKELY(!(cond))) { \ (void)(_wassert( \ @@ -413,6 +424,7 @@ __host__ __device__ // ROCm disable kernel assert by default #if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) #define CUDA_KERNEL_ASSERT(cond) +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) #define SYCL_KERNEL_ASSERT(cond) #else #define CUDA_KERNEL_ASSERT(cond) \ @@ -420,6 +432,11 @@ __host__ __device__ __assert_fail( \ #cond, __FILE__, static_cast(__LINE__), __func__); \ } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + msg, __FILE__, static_cast(__LINE__), __func__); \ + } #define SYCL_KERNEL_ASSERT(cond) \ if (C10_UNLIKELY(!(cond))) { \ __assert_fail( \ From 1692f942c0c35903004dab16e4d8ca477e2db8bc Mon Sep 17 00:00:00 2001 From: chilli Date: Mon, 2 Sep 2024 19:17:26 -0700 Subject: [PATCH 0344/1018] change multinomial to use async asserts instead of a synchronization (#134818) Fixes https://github.com/pytorch/pytorch/issues/134442 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134818 Approved by: https://github.com/ezyang ghstack dependencies: #134813 --- aten/src/ATen/native/Distributions.cpp | 17 +++---- .../_internal/common_methods_invocations.py | 45 ++++++++++--------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 0de9632ddb1756..bd1d974e9bee50 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -585,19 +586,15 @@ Tensor& multinomial_out(const Tensor& self, // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 if (!with_replacement || n_sample == 1) { // Sanity checks on `self`. - auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item(); - TORCH_CHECK( - is_valid.to(), - "probability tensor contains either `inf`, `nan` or element < 0"); - bool zero_prob_condition = false; + auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)); + at::_assert_async(is_valid, "probability tensor contains either `inf`, `nan` or element < 0"); + at::Tensor zero_prob_condition; if (self.dim() == 1){ - zero_prob_condition = (self.sum() == 0).item().to(); + zero_prob_condition = (self.sum() == 0); } else { - zero_prob_condition = (self.sum(1) == 0).sum().item().to(); + zero_prob_condition = (self.sum(1) == 0).any(); } - TORCH_CHECK( - !zero_prob_condition, - "invalid multinomial distribution (sum of probabilities <= 0)"); + at::_assert_async(~zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)"); // The algorithm is from gumbel softmax. // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a64cb62cedeb32..2120f4f019e863 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2849,28 +2849,29 @@ def error_inputs_multinomial(op_info, device, **kwargs): rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,) - for rep in rep_arg: - kwargs = {'num_samples': 2, 'replacement': rep} - - for shape in inputs: - # error case when input tensor contains `inf`, `nan` or negative element - yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), - error_regex=err_msg1 if rep is False else err_msg2) - - # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input - x = torch.zeros(3, device=device) - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) - - # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input - x = torch.zeros(3, 3, device=device) - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) - - # error case for the invalid multinomial distribution - x[1, :] = 1 - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) + if torch.device(device).type == 'cpu': + for rep in rep_arg: + kwargs = {'num_samples': 2, 'replacement': rep} + + for shape in inputs: + # error case when input tensor contains `inf`, `nan` or negative element + yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), + error_regex=err_msg1 if rep is False else err_msg2) + + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input + x = torch.zeros(3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input + x = torch.zeros(3, 3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + + # error case for the invalid multinomial distribution + x[1, :] = 1 + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) def error_inputs_gradient(op_info, device, **kwargs): for dtype in [torch.long, torch.float32, torch.complex64]: From ea198c6c4a14c261b6c8d424273db461f4afa1ea Mon Sep 17 00:00:00 2001 From: chilli Date: Mon, 2 Sep 2024 19:17:27 -0700 Subject: [PATCH 0345/1018] Switch nanmedian to not cuda synchronize (#134819) Generally, this seems to be faster. ![image](https://github.com/user-attachments/assets/43a86c6f-236d-4ba1-aae0-14e3d88ae401) And as an added benefit, it works great with cudagraphs and such :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134819 Approved by: https://github.com/Skylion007, https://github.com/eqy ghstack dependencies: #134813, #134818 --- aten/src/ATen/native/cuda/Sorting.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cuda/Sorting.cpp b/aten/src/ATen/native/cuda/Sorting.cpp index 9381c0e4f02a4c..310db22aa0571f 100644 --- a/aten/src/ATen/native/cuda/Sorting.cpp +++ b/aten/src/ATen/native/cuda/Sorting.cpp @@ -21,6 +21,9 @@ #include #include #include +#include +#include +#include #endif namespace at::native { @@ -146,8 +149,8 @@ Tensor median_impl(const Tensor& self, bool ignore_nan) { return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]); } else { // For torch.nanmedian return the middle element among the non-nan values - int64_t k = ((size - 1) - sorted.isnan().sum().item()) / 2; - return sorted[k].clone(); // Clone so we aren't keeping `sorted` alive + Tensor k = at::div(at::rsub(sorted.isnan().sum(), (size - 1)), 2).to(kLong); + return at::index(sorted, {k}); } } From b44dd60efcbd39da95d11bbbe62646c1062bdb26 Mon Sep 17 00:00:00 2001 From: chilli Date: Mon, 2 Sep 2024 19:17:27 -0700 Subject: [PATCH 0346/1018] Changed addmv to be a decomposition and not a fallback (#134823) Overall seems to be faster ![image](https://github.com/user-attachments/assets/0cbea76e-fb78-4634-9265-047de0291549) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134823 Approved by: https://github.com/jansel ghstack dependencies: #134813, #134818, #134819 --- torch/_inductor/decomposition.py | 1 + torch/_inductor/lowering.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index d369c2a664459a..105a1e96f3da59 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -48,6 +48,7 @@ inductor_decompositions = get_decompositions( [ aten._adaptive_avg_pool2d_backward, + aten.addmv, aten.arange, aten.bitwise_and_, aten.bitwise_or_, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 437a5abac3f339..3219e3a427ee66 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2212,7 +2212,6 @@ def is_aligned(x): # Need templated kernel make_fallback(aten.addbmm) -make_fallback(aten.addmv, warn=False) make_fallback(aten._addmm_activation, warn=False) # Need templated kernel. Probably impossible to write efficiently From fa9cb1f95ebb30f13880858328d115d382b63c50 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 3 Sep 2024 09:29:59 +0000 Subject: [PATCH 0347/1018] Update cpuinfo submodule (#134891) Last time it was done in June by https://github.com/pytorch/pytorch/pull/127505 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134891 Approved by: https://github.com/Skylion007 --- third_party/cpuinfo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cpuinfo b/third_party/cpuinfo index 3c8b1533ac03dd..fa1c679da8d19e 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit 3c8b1533ac03dd6531ab6e7b9245d488f13a82a5 +Subproject commit fa1c679da8d19e1d87f20175ae1ec10995cd3dd3 From 51ae6e316f343d7671dd6702be8f28ef4ccabb38 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Tue, 3 Sep 2024 12:14:37 +0000 Subject: [PATCH 0348/1018] Update torch-xpu-ops pin (ATen XPU implementation) (#134983) Release cycle for PyTorch 2.5 1. Enable Windows build in latest torch-xpu-ops. Resolved large bin issue. 2. Refine test infrastructure for compatibility on different HW platforms. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134983 Approved by: https://github.com/EikanWang --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index c34e261f0e73c4..e71ba25bc5f08f 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -7eb52196954e5eaf32f791252506b0640f16f200 +78e33d0afc3376a45e6184132f1e9fc154db3cae From e48d2ce02985458c243c45d807bf6a95bb84805d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 1 Sep 2024 15:49:37 -0400 Subject: [PATCH 0349/1018] Fix set_unbacked_bindings when list of Tensors is returned (#133585) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/133585 Approved by: https://github.com/albanD --- test/export/test_export.py | 34 +++++++++++++++++++++++++++ torch/fx/experimental/proxy_tensor.py | 8 ++++--- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index e93a721eb97505..426a136aaf3ec2 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -652,6 +652,40 @@ def forward(self, x, c): foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False ) + def test_unbacked_to_cond(self): + class M(torch.nn.Module): + def forward(self, a): + az = a.nonzero() + + def true_fn(x): + return (x + 1).sum() + + def false_fn(x): + return (x + 3).sum() + + r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) + return r * 2 + + M()(torch.randn(7)) + torch.export.export(M(), (torch.randn(7),)) + + def test_unbacked_to_cond_passthrough(self): + class M(torch.nn.Module): + def forward(self, a): + az = a.nonzero() + + def true_fn(x): + return x + 1 + + def false_fn(x): + return x + 3 + + r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) + return r * 2 + + M()(torch.randn(7)) + torch.export.export(M(), (torch.randn(7),)) + def test_state_tensors(self): class M(torch.nn.Module): # simple with register buffer def __init__(self) -> None: diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 7a3ece886506c3..12c4f632118570 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -596,8 +596,6 @@ def track_tensor_tree( constant: Optional[_NestedTensors], tracer: _ProxyTracer, ) -> T: - _set_unbacked_bindings(inner_res, proxy_res) - def wrap_with_proxy( e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors] ) -> None: @@ -606,11 +604,13 @@ def wrap_with_proxy( assert constant is None or isinstance(constant, Tensor) track_tensor(e, proxy, tracer=tracer, constant=constant) set_meta(proxy, e) + _set_unbacked_bindings(e, proxy) elif isinstance(e, py_sym_types): assert isinstance(proxy, Proxy) # NB: eagerly set meta here, so that the numbering is in order set_meta(proxy, e) set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) + _set_unbacked_bindings(e, proxy) elif isinstance(e, _AnyScriptObject): assert isinstance(proxy, Proxy) set_proxy_slot(e, tracer, proxy) @@ -2188,6 +2188,8 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: """A helper function for setting up unbacked_bindings on the destination FX graph.""" from .symbolic_shapes import compute_unbacked_bindings + log.debug("_set_unbacked_bindings %s", out_proxy) + # Can't use detect_fake_mode here, # # python test/distributed/_tensor/test_dtensor_compile.py -k @@ -2198,5 +2200,5 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) if fake_mode and fake_mode.shape_env: if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): - assert isinstance(out_proxy, Proxy) + assert isinstance(out_proxy, Proxy), out_proxy out_proxy.node.meta["unbacked_bindings"] = symbol_to_path From b553a9fd5dd121deffb4c53b14f5d48f4d58d535 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Fri, 30 Aug 2024 07:49:06 -0700 Subject: [PATCH 0350/1018] [subclasses] Do not fakeTensor const prop subclass args (#134855) The issue: Const propagation checks only if arguments do not have FakeTensor. If argument is Subclass, it will pass this condition. As a result Const Propogation execution happens without FakeTensorMode and having tensor factories inside Subclass.__torch_dispatch__ results that this Tensor is not Fakified. Solution: If we have subclasses arguments, do not count that const propagation is doable Pull Request resolved: https://github.com/pytorch/pytorch/pull/134855 Approved by: https://github.com/zou3519 --- test/functorch/test_aotdispatch.py | 12 ++++++ torch/_subclasses/fake_tensor.py | 2 + torch/testing/_internal/common_subclass.py | 48 ++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 55383073c4d3d6..f62d46df470bac 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -6010,6 +6010,18 @@ def forward(self, x): with self.assertRaisesRegex(AssertionError, "Unexpected fake"): aot_module_simplified(MockModule(), (fake_x,), nop) + def test_aot_test_subclasses_with_tensor_factories(self): + from torch.testing._internal.common_subclass import SubclassWithTensorFactory + + inp = SubclassWithTensorFactory(torch.zeros(3, 5)) + + def fn(x): + return 2 * x + + ref_out = fn(inp) + out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp) + self.assertEqual(ref_out, out) + # entries in here don't work and need to be fixed. # Each one of these is a bug (or needs to be investigated) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 0450d4d19b5fb0..51f726f53bacc3 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1719,6 +1719,7 @@ def _dispatch_impl( has_symbolic_sizes = any( i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors ) or any(isinstance(a, SymInt) for a in flat_args) + has_subclasses = any(is_traceable_wrapper_subclass(a) for a in flat_args) converter = self.fake_tensor_converter @@ -1736,6 +1737,7 @@ def _dispatch_impl( should_allow_numbers_as_tensors(func) and not has_symbolic_sizes and not flat_arg_fake_tensors + and not has_subclasses ): assert all( t.constant is not None for t in flat_arg_fake_tensors diff --git a/torch/testing/_internal/common_subclass.py b/torch/testing/_internal/common_subclass.py index f6a8ed065cb819..3c76e19fab4eb6 100644 --- a/torch/testing/_internal/common_subclass.py +++ b/torch/testing/_internal/common_subclass.py @@ -3,6 +3,8 @@ import torch from copy import deepcopy from torch.utils._pytree import tree_map +import torch.utils._pytree as pytree + # TODO: Move LoggingTensor here. from torch.testing._internal.logging_tensor import LoggingTensor @@ -216,3 +218,49 @@ def __init__(self, name, create_fn, closed_under_ops=True): closed_under_ops=False # sparse semantics ), } + +class SubclassWithTensorFactory(torch.Tensor): + @staticmethod + def __new__(cls, src): + shape = src.shape + kwargs = {} + kwargs["strides"] = src.stride() + kwargs["storage_offset"] = src.storage_offset() + kwargs["device"] = src.device + kwargs["layout"] = src.layout + kwargs["requires_grad"] = src.requires_grad + kwargs["dtype"] = src.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return out + + def __init__(self, src): + self.src = src + + def __repr__(self): + return f"{self.__class__.__name__}" + + def __tensor_flatten__(self): + return ["src"], None + + @classmethod + def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride): + src = inner_tensors["src"] + return cls(src) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + + def _fn(x): + return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src + + _args = pytree.tree_map_only(cls, _fn, args) + _kwargs = pytree.tree_map_only(cls, _fn, kwargs) + + _out = func(*_args, **_kwargs) + + _out_flat, _out_spec = pytree.tree_flatten(_out) + + out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat] + return pytree.tree_unflatten(out_flat, _out_spec) From 0f48edb9debedcbd5dee178bc69c429ed08ef543 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 15:33:09 +0000 Subject: [PATCH 0351/1018] [ONNX] Bump onnxscript version in CI; temporarily remove op test (#133748) Bump onnxscript version in CI to 0.1.0.dev20240831, and temporarily remove the fx consistency test. We will add a better version back later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133748 Approved by: https://github.com/titaiwangms --- .ci/docker/common/install_onnx.sh | 7 +- test/onnx/dynamo/test_registry_dispatcher.py | 223 +- test/onnx/test_fx_op_consistency.py | 2076 ----------------- test/onnx/test_fx_to_onnx.py | 28 - test/onnx/test_fx_to_onnx_with_onnxruntime.py | 3 + torch/onnx/_internal/_exporter_legacy.py | 50 +- 6 files changed, 24 insertions(+), 2363 deletions(-) delete mode 100644 test/onnx/test_fx_op_consistency.py diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 65907a99f2572a..b99ecac283d646 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -30,10 +30,9 @@ pip_install \ pip_install coloredlogs packaging -pip_install onnxruntime==1.18 -pip_install onnx==1.16.0 -# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps -pip_install onnxscript==0.1.0.dev20240613 --no-deps +pip_install onnxruntime==1.18.1 +pip_install onnx==1.16.2 +pip_install onnxscript==0.1.0.dev20240831 --no-deps # required by onnxscript pip_install ml_dtypes diff --git a/test/onnx/dynamo/test_registry_dispatcher.py b/test/onnx/dynamo/test_registry_dispatcher.py index d782f94dd311d9..8605b098d30055 100644 --- a/test/onnx/dynamo/test_registry_dispatcher.py +++ b/test/onnx/dynamo/test_registry_dispatcher.py @@ -8,18 +8,11 @@ import onnxscript # type: ignore[import] from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import] -from onnxscript.function_libs.torch_lib import ops # type: ignore[import] from onnxscript.onnx_opset import opset15 as op # type: ignore[import] import torch import torch.fx -from torch.onnx._internal.diagnostics import infra -from torch.onnx._internal.fx import ( - analysis, - diagnostics, - onnxfunction_dispatcher, - registration, -) +from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration from torch.testing._internal import common_utils @@ -84,60 +77,6 @@ def test_custom(x, y): [test_original, test_custom], ) - def test_unsupported_nodes_analysis_with_missing_aten_op(self): - # NOTE: simulate unsupported nodes - aten_mul_tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="mul", overload="Tensor" - ) - aten_mul_default = registration.OpName.from_name_parts( - namespace="aten", op_name="mul" - ) - aten_add_tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="add", overload="Tensor" - ) - aten_add_default = registration.OpName.from_name_parts( - namespace="aten", op_name="add" - ) - - self.registry._registry.pop(aten_mul_tensor) - self.registry._registry.pop(aten_mul_default) - self.registry._registry.pop(aten_add_tensor) - self.registry._registry.pop(aten_add_default) - - diagnostic_context = diagnostics.DiagnosticContext( - "torch.onnx.dynamo_export", torch.__version__ - ) - dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( - self.registry, diagnostic_context - ) - - graph: torch.fx.Graph = torch.fx.Graph() - x: torch.fx.Node = graph.create_node("placeholder", "x") - x.meta["val"] = torch.tensor(3.0) - b: torch.fx.Node = graph.create_node( - "call_function", target=torch.ops.aten.mul.Tensor, args=(x, x) - ) - c: torch.fx.Node = graph.create_node( - "call_function", target=torch.ops.aten.add.Tensor, args=(b, b) - ) - output: torch.fx.Node = graph.output(c) - module = torch.fx.GraphModule(torch.nn.Module(), graph) - - with self.assertRaises(infra.RuntimeErrorWithDiagnostic): - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, module, dispatcher - ).analyze(infra.levels.ERROR) - - try: - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, module, dispatcher - ).analyze(infra.levels.ERROR) - except infra.RuntimeErrorWithDiagnostic as e: - self.assertIn( - "Unsupported FX nodes: {'call_function': ['aten.mul.Tensor', 'aten.add.Tensor']}.", - e.diagnostic.message, - ) - @common_utils.instantiate_parametrized_tests class TestDispatcher(common_utils.TestCase): @@ -488,165 +427,5 @@ def test_first_custom_op( self.assertEqual(symbolic_fn, test_third_custom_op) -@common_utils.instantiate_parametrized_tests -class TestOpSchemaWrapper(common_utils.TestCase): - def setUp(self): - # overload type: optional dtype - self.onnx_function_new_full = ops.core.aten_new_full - self.onnx_function_new_full_dtype = ops.core.aten_new_full_dtype - - @common_utils.parametrize( - "inputs, attributes, assertion", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {"alpha": 2.0}, True), - name="perfect_match_with_kwargs", - ), - common_utils.subtest( - (["A", "B"], {}, False), - name="non_perfect_match_due_to_non_tensor_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4)], {}, False), - name="non_perfect_match_due_to_too_many_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {"wrong_kwargs": 2.0}, False), - name="non_perfect_match_due_to_wrong_kwargs", - ), - ], - ) - def test_perfect_match_inputs(self, inputs, attributes, assertion): - # OnnxFunction with default attributes - dummy_diagnostic = diagnostics.Diagnostic( - rule=diagnostics.rules.find_opschema_matched_symbolic_function, - level=diagnostics.levels.WARNING, - ) - op_schema_wrapper_add = onnxfunction_dispatcher._OnnxSchemaChecker( - ops.core.aten_add - ) - self.assertEqual( - op_schema_wrapper_add.perfect_match_inputs( - dummy_diagnostic, inputs, attributes - ), - assertion, - ) - - @common_utils.parametrize( - "inputs, kwargs, op, score", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul, 2), - name="match_2_inputs", - ), - common_utils.subtest( - ( - [ - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - ], - {}, - ops.core.aten_mul, - 0, - ), - name="match_0_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul_bool, 0), - name="match_0_inputs_bool", - ), - common_utils.subtest( - ( - [ - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - ], - {}, - ops.core.aten_mul_bool, - 2, - ), - name="match_2_inputs_bool", - ), - ], - ) - def test_matching_score_system_on_overload_dtypes(self, inputs, kwargs, op, score): - op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op) - op_schema_wrapper._record_matching_score(inputs, kwargs) - self.assertEqual(op_schema_wrapper.match_score, score) - - @common_utils.parametrize( - "inputs, kwargs, op, score", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.tensor(3)], {}, ops.core.aten_new_full, 2), - name="match_2_inputs", - ), - common_utils.subtest( - ( - [torch.randn(3, 4), torch.tensor(3)], - {"dtype": 2}, # at this point, dtype should be converted to int - ops.core.aten_new_full_dtype, - 2, - ), - name="match_2_input_and_match_1_kwargs_optional", - ), - ], - ) - def test_matching_score_system_on_optional_dtypes(self, inputs, kwargs, op, score): - op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op) - op_schema_wrapper._record_matching_score(inputs, kwargs) - self.assertEqual(op_schema_wrapper.match_score, score) - - @common_utils.parametrize( - "value, expected_onnx_str_dtype", - [ - common_utils.subtest( - (1, {"tensor(int64)", "tensor(int16)", "tensor(int32)"}), - name="all_ints", - ), - common_utils.subtest( - (1.0, {"tensor(float)", "tensor(double)", "tensor(float16)"}), - name="all_floats", - ), - common_utils.subtest( - (torch.tensor([True]), {"tensor(bool)"}), - name="bool", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int64), {"tensor(int64)"}), - name="int64", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int32), {"tensor(int32)"}), - name="int32", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int16), {"tensor(int16)"}), - name="int16", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.float), {"tensor(float)"}), - name="float", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.float16), {"tensor(float16)"}), - name="float16", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.double), {"tensor(double)"}), - name="double", - ), - common_utils.subtest((None, set()), name="None"), # None allows no dtype - common_utils.subtest( - ([], set()), name="empaty_list" - ), # Empty list allows no dtype - ], - ) - def test_find_onnx_data_type(self, value, expected_onnx_str_dtype): - self.assertEqual( - onnxfunction_dispatcher._find_onnx_data_type(value), expected_onnx_str_dtype - ) - - if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py deleted file mode 100644 index 9972ea901722c7..00000000000000 --- a/test/onnx/test_fx_op_consistency.py +++ /dev/null @@ -1,2076 +0,0 @@ -# Owner(s): ["module: onnx"] - -"""Test consistency between the output values of torch.onnx FX exported operators -and torch operators given the same inputs. - -Usage: - - 1. Test all operators: - - pytest test/onnx/test_fx_op_consistency.py - - 2. To run tests on a specific operator (e.g. torch.ceil): - - pytest test/onnx/test_fx_op_consistency.py -k ceil - pytest test/onnx/test_fx_op_consistency.py -k nn_functional_scaled_dot_product_attention - - 3. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g. - - CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/test_fx_op_consistency.py -k div_mode_int - - NOTE: Read more on Running and writing tests: - https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests - -Note: - - 1. Please make sure pytest-subtests is installed. Otherwise, the sub-tests will be ignored. - - 2. Install pytest-xdist to run tests in parallel if runng all tests is the goal. - - 3. When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES and - TESTED_OPS lists. See "Modify this section" - -""" - -from __future__ import annotations - -import copy -import itertools -import os -from typing import ( - Any, - Callable, - Collection, - List, - Mapping, - Optional, - Tuple, - Type, - TYPE_CHECKING, - Union, -) - -import error_reproduction -import onnx_test_common -import parameterized -import pytest -import pytorch_test_common -from onnx_test_common import skip, skip_slow, xfail - -import torch -from torch.onnx._internal.diagnostics import _rules -from torch.testing._internal import ( - common_device_type, - common_methods_invocations, - common_utils, -) - - -if TYPE_CHECKING: - from torch.testing._internal.opinfo import core as opinfo_core - - -# NOTE: For ATen signature modifications that will break ONNX export, -# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip -# to make the signal apparent for maintainers. -def xfail_torchlib_forward_compatibility( - op_name: str, - variant_name: str = "", - *, - reason: str, - github_issue: str, - opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, - dtypes: Optional[Collection[torch.dtype]] = None, - matcher: Optional[Callable[[Any], bool]] = None, - enabled_if: bool = True, -): - """Prefer using this (xfail) over skip when possible. - - Only skip when the test is not failing consistently. - """ - return xfail( - op_name, - variant_name=variant_name, - reason=f"{reason}. GitHub Issue: {github_issue}", - opsets=opsets, - dtypes=dtypes, - matcher=matcher, - enabled_if=enabled_if, - ) - - -def skip_torchlib_forward_compatibility( - op_name: str, - variant_name: str = "", - *, - reason: str, - github_issue: str, - opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, - dtypes: Optional[Collection[torch.dtype]] = None, - matcher: Optional[Callable[[Any], Any]] = None, - enabled_if: bool = True, -): - """Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible. - - Only skip when the test is not failing consistently. - """ - return skip( - op_name, - variant_name=variant_name, - reason=f"{reason}. GitHub Issue: {github_issue}", - opsets=opsets, - dtypes=dtypes, - matcher=matcher, - enabled_if=enabled_if, - ) - - -# fmt: off -# Turn off black formatting to keep the list compact - -# Expected failures for onnx export. -# The list should be sorted alphabetically by op name. -# Q: When should I use fixme vs vs skip vs xfail? -# A: Prefer xfail over skip when possible. -# 2a. If a test is now failing because of xpass, because some previous errors -# are now fixed, removed the corresponding xfail. -# 2b. If a test is not failing consistently, use skip. -# NOTE: EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES only supports dtypes. If a matcher or model_type -# is needed, use the SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE list further down below. -EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] = ( - xfail( - "__getitem__", - reason="io_adaper doesn't support __getitem__ input slice(0, 3, None)", - ), - xfail( - "__radd__", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "bool"), - ), - xfail( - "__rmatmul__", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "__rpow__", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Pow", "int"), - ), - skip( - "_native_batch_norm_legit", - reason=onnx_test_common.reason_onnx_script_does_not_support("cpu is not supported: \ - https://github.com/microsoft/onnxscript/pull/1289") - ), - skip( - "_batch_norm_with_update", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch and type error", - ), - xfail( - "_softmax_backward_data", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "_unsafe_masked_index", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "_unsafe_masked_index", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("_unsafe_masked_index", "complex64"), - ), - xfail( - "_unsafe_masked_index_put_accumulate", - reason="fixme: Status Message: updates tensor should have shape equal to " - "indices.shape[:-1] + data.shape[indices.shape[-1]:]", - ), - xfail( - "_unsafe_masked_index_put_accumulate", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "add", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Add") - ), - xfail( - "add", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_script_does_not_support( - "Add", "int8, int16, uint8 have type issue." - ), - ), - xfail( - "addbmm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addbmm", "complex64") - ), - xfail( - "addmm", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Addmm") - ), - xfail( - "addmm", - variant_name="decomposed", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Addmm") - ), - skip( - "addmm", dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") - ), - skip( - "addmm", - variant_name="decomposed", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") - ), - xfail( - "addr", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support( - "Addr", "bool" - ), - ), - xfail( - "addr", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") - ), - xfail( - "alias_copy", - dtypes=(torch.int8, torch.uint8, torch.int16, torch.float64), - reason="OnnxExporterError: Failed to export model", - ), - xfail( - "allclose", - reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") - ), - xfail( - "amax", - dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), - ), - xfail( - "amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16") - ), - xfail( - "aminmax", - dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), - ), - xfail( - "arange", - dtypes=(torch.uint8,), - reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"), - ), - xfail( - "arange", - dtypes=(torch.int16, torch.int32), - reason="AssertionError: The values for attribute 'shape' do not match", - ), - xfail( - "argmax", - dtypes=( - torch.int16, - torch.int64, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ArgMax", "int16, int64" - ), - ), - xfail( - "argmin", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - torch.int64, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ArgMin", "uint8, int8, int16, int64" - ), - ), - xfail( - "argwhere", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "as_strided", - variant_name="partial_views", - reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults", - ), - xfail( - "atan2", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "baddbmm", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Matmul", "uint8, int8, int16" - ), - ), - xfail( - "baddbmm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("baddbmm", "complex64") - ), - xfail( - "bernoulli", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "bfloat16", - reason="fixme: ORT errors with RuntimeError: No corresponding Numpy type for Tensor Type.", - ), - xfail( - "bincount", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"), - ), - xfail( - "block_diag", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Block_diag", "complex"), - ), - xfail( - "bmm", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Matmul", "uint8, int8, int16" - ), - ), - xfail( - "broadcast_shapes", - reason=onnx_test_common.reason_dynamo_does_not_support("output is int"), - ), - xfail( - "cauchy", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - skip( - "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int") - ), - xfail( - "chalf", - reason="fixme: ONNX shape type inference error: Invalid tensor data type 0." - ), - xfail( - "chunk", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Chunk", "uint8, int8, int16" - ), - ), - xfail( - "clamp", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_max", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool") - ), - xfail( - "clamp_max", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_min", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_min", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool") - ), - xfail( - "constant_pad_nd", - dtypes=(torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Constant_pad_nd", "int16" - ), - ), - xfail( - "constant_pad_nd", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support( - "Constant_pad_nd", "complex64" - ), - ), - xfail( - "corrcoef", - reason=onnx_test_common.reason_dynamo_does_not_support( - "aten.equal.default" - ), - ), - xfail( - "cov", - reason=onnx_test_common.reason_dynamo_does_not_support( - "aten.equal.default" - ), - ), - xfail( - "cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16") - ), - xfail( - "combinations", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"), - ), - xfail( - "diag", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), - ), - xfail( - "diagonal_copy", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), - ), - xfail( - "dot", dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16") - ), - skip( - "dot", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Dot", "complex64(core dump)"), - ), - xfail( - "empty", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - xfail( - "empty_strided", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "eq", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"), - ), - xfail( - "equal", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "exponential", - reason=onnx_test_common.reason_dynamo_does_not_support("exponential"), - ), - xfail( - "fft.fft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.fft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.fftn", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifftn", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfftn", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfft", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfftn", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfft2", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "floor", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), - ), - xfail( - "floor_divide", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), - ), - xfail( - "full", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("full", "complex64") - ), - xfail( - "full_like", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64") - ), - xfail( - "gather", - reason="GatherElements op: Rank of input 'data' needs to be equal to rank of input 'indices'" - ), - xfail( - "geometric", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "heaviside", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"), - ), - xfail( - "index_add", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "int64, int32, bool"), - ), - xfail( - "index_fill", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64") - ), - xfail( - "index_fill", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES, - reason="fixme: Constant input list has None. ONNXScript does not support None in constant list." - ), - xfail( - "index_put", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"), - ), - xfail( - "index_put", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), - ), - xfail( - "index_put", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "float16"), - ), - xfail( - "isnan", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("IsNaN", "int, bool"), - ), - xfail( - "istft", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "item", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "lerp", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("lerp", "complex64") - ), - xfail( - "linalg.lstsq", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), - ), - xfail( - "linalg.lstsq", - variant_name="grad_oriented", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), - ), - xfail( - "linalg.matrix_power", - reason="fixme: The values for attribute 'shape' do not match: torch.Size([2, 2]) != torch.Size([2, 2, 2])." - ), - xfail( - "linalg.norm", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "linalg.norm", - variant_name="subgradients_at_zero", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "linalg.vecdot", - reason="fixme: Assertion error: result shape mismatch", - ), - xfail( - "linspace", - dtypes=(torch.int64, torch.int32,), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ), - xfail( - "linspace", - variant_name="tensor_overload", - dtypes=(torch.int64, torch.int32,), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ), - xfail( - "linspace", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") - ), - xfail( - "linspace", - variant_name="tensor_overload", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") - ), - xfail( - "log_normal", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "log_softmax", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "log_softmax", - variant_name="with_dtype", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "logical_and", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("And", "float, int"), - ), - xfail( - "logical_not", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Not", "float, int"), - ), - xfail( - "logical_or", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Or", "float, int"), - ), - xfail( - "logical_xor", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"), - ), - skip( - "masked.logsumexp", - reason="fixme: https://github.com/onnx/onnx/issues/4986", - ), - xfail( - "masked.amax", - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.amin", - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.argmin", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked.argmax", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked_fill", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked.sum", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked.log_softmax", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.mean", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMean", "bool"), - ), - xfail( - "masked.norm", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked.prod", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked_select", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked_select.default"), - ), - xfail( - "max", - variant_name="reduction_no_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), - ), - xfail( - "max", - variant_name="reduction_with_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), - ), - xfail( - "max", - variant_name="reduction_with_dim", - dtypes=(torch.int64,), - reason="https://github.com/onnx/onnx/issues/4986", - ), - xfail( - "min", - variant_name="reduction_no_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), - ), - xfail( - "min", - variant_name="reduction_with_dim", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), - ), - skip( - "mm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("MM", "complex64(core dump)"), - ), - xfail( - "multinomial", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nanquantile", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "nansum", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("IsNaN", "int, bool"), - ), - xfail( - "narrow", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - skip( - "native_batch_norm", - reason=onnx_test_common.reason_onnx_script_does_not_support("cpu is not supported: \ - https://github.com/microsoft/onnxscript/pull/1289") - ), - xfail( - "native_layer_norm", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "new_full", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("new_full", "complex64") - ), - xfail( - "nn.functional.adaptive_avg_pool2d", - reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \ - maximum recursion depth exceeded while calling a Python object"), - ), - xfail( - "nn.functional.adaptive_avg_pool3d", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"), - ), - xfail( - "nn.functional.alpha_dropout", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.avg_pool1d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.avg_pool2d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.avg_pool3d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.batch_norm", - dtypes=(torch.float16,), - reason="fixme: https://github.com/microsoft/onnxscript/issues/1270", - ), - xfail( - "nn.functional.conv_transpose1d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), - ), - xfail( - "nn.functional.conv_transpose2d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), - ), - xfail( - "nn.functional.conv_transpose3d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), - ), - skip( - "nn.functional.conv_transpose1d", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "nn.functional.conv_transpose2d", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "nn.functional.conv_transpose3d", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.conv1d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), - ), - xfail( - "nn.functional.conv2d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), - ), - xfail( - "nn.functional.conv2d", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.conv3d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), - ), - xfail( - "nn.functional.conv3d", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.cosine_embedding_loss", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("CosineEmbeddingLoss", "bool"), - ), - xfail( - "nn.functional.ctc_loss", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"), - ), - xfail( - "nn.functional.dropout", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.dropout2d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.dropout3d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.feature_alpha_dropout", - variant_name="with_train", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.feature_alpha_dropout", - variant_name="without_train", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.fractional_max_pool2d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.fractional_max_pool3d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.gaussian_nll_loss", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.gaussian_nll_loss"), - ), - xfail( - "nn.functional.grid_sample", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.group_norm", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"), - ), - xfail( - "nn.functional.local_response_norm", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("avgpool", "int64"), - ), - xfail( - "nn.functional.linear", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Gemm", "int"), - ), - xfail( - "nn.functional.max_pool2d", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"), - ), - xfail( - "nn.functional.max_pool3d", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"), - ), - xfail( - "nn.functional.multi_head_attention_forward", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.one_hot", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "nn.functional.pad", - variant_name="replicate", - reason="fixme: ORT error: padding size", - ), - xfail( - "nn.functional.pad", - variant_name="replicate_negative", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.pad", - variant_name="reflect", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.pixel_shuffle", - dtypes=(torch.int32, torch.int64) + onnx_test_common.BOOL_TYPES, - reason="fixme: ONNX Runtime does not support int32/64 inputs", - ), - xfail( - "nn.functional.pixel_unshuffle", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten.pixel_unshuffle.default"), - ), - xfail( - "nn.functional.poisson_nll_loss", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason="fixme: result mismatch with NaN.", - ), - xfail( - "nn.functional.rrelu", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.rrelu", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Relu", "int64"), - ), - skip( - "nn.functional.scaled_dot_product_attention", - matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, - reason="dropout is random so the results do not match", - ), - xfail( - "nn.functional.scaled_dot_product_attention", - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "nn.functional.selu", - reason="fixme: nn.functional.selu is not in torch._decomp.decomposition_table", - ), - xfail( - "nn.functional.soft_margin_loss", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nonzero", - dtypes=(torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"), - ), - xfail( - "normal", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "normal", - variant_name="in_place", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "normal", - variant_name="number_mean", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "ones", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - xfail( - "pca_lowrank", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "quantile", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "rand_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randint", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randint_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randn", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randn_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "resize_", - reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") - ), - xfail( - "resize_as_", - reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") - ), - xfail( - "round", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Round", "int"), - ), - xfail( - "rsub", - dtypes=(torch.uint8, torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Mul", "uint8, int8, int16" - ), - ), - xfail( - "scatter_add", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="sum", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="prod", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="amin", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="amax", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="mean", - reason="ONNX doesn't support reduce='mean' option", - ), - xfail( - "sgn", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), - ), - xfail( - "sign", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), - ), - xfail( - "signal.windows.kaiser", - reason=onnx_test_common.reason_dynamo_does_not_support("functionalization"), - ), - xfail( - "softmax", - dtypes=(torch.float16,), - reason="ORT error: https://github.com/microsoft/onnxruntime/issues/16438" - ), - xfail( - "sparse.mm", - variant_name="reduce", - reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), - ), - xfail( - "sparse.sampled_addmm", - reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), - ), - xfail( - "special.erfcx", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Erf", "int, bool"), - ), - xfail( - "special.erfcx", - dtypes=onnx_test_common.FLOAT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"), - ), - xfail( - "special.log_ndtr", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "special.ndtr", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "square", - dtypes=(torch.int8, torch.uint8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"), - ), - xfail( - "squeeze", - variant_name="multiple", - reason="fixme: https://github.com/microsoft/onnxscript/issues/1264", - ), - xfail( - "svd_lowrank", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "stft", - reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "sub", - dtypes=(torch.uint8, torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Mul", "uint8, int8, int16" - ), - ), - xfail( - "take", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "tensor_split", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "topk", - dtypes=(torch.int64, torch.int32, torch.float16), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "tril", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), - ), - xfail( - "triu", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), - ), - xfail( - "trunc", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"), - ), - xfail( - "unflatten", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Unflatten") - ), - xfail( - "uniform", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "unique", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), - ), - xfail( - "unique_consecutive", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), - ), - xfail( - "unravel_index", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"), - ), - xfail( - "unsqueeze_copy", - reason="OnnxExporterError: Failed to export model", - dtypes=(torch.int8, torch.uint8, torch.int16), - ), - xfail( - "where", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "zeros", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - # SLOW TESTS (All are xfails if we run them) - # TODO: https://github.com/pytorch/pytorch/issues/117118 - skip_slow( - "cdist", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "histogram", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "histogramdd", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.lu_solve", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.solve_triangular", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.svd", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "logspace", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "logspace", - variant_name="tensor_overload", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "max_pool2d_with_indices_backward", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.interpolate", - variant_name="bicubic", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool1d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool2d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool3d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool1d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool2d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool3d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.unfold", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "ormqr", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "searchsorted", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "svd", - reason="fixme: Test sets are too many.", - ), -) -# fmt: on - -# NOTE: The xfail and skip with a matcher function or model_type should be -# at under the `SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE` section. -SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE: tuple[ - onnx_test_common.DecorateMeta, ... -] = ( - skip( - "_native_batch_norm_legit", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - skip( - "_batch_norm_with_update", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - # TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]). - # Numerically the ONNX program is correct, but the output shapes for `save_mean` - # and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268]) - # for example. - skip( - "_batch_norm_with_update", - model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - reason="not supported yet", - ), - xfail( - "addmm", # xfail can't only use dtypes to catch all cases - matcher=lambda sample: sample.input.dtype - in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Gemm", "uint8, int8, int16, int32, int64" - ), - ), - xfail( - "addmm", - matcher=lambda sample: sample.args[0].numel() == 0, - reason="ONNX Runtime does not support empty tensors multiplication", - ), - xfail( - "addmm", - variant_name="decomposed", - matcher=lambda sample: sample.args[0].numel() == 0, - reason="ONNX Runtime does not support empty tensors multiplication", - ), - xfail( - "amax", - matcher=lambda sample: len(sample.input.shape) == 0 - and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), - reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - xfail( - "amin", - matcher=lambda sample: len(sample.input.shape) == 0 - and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), - reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - xfail( - "aminmax", - matcher=lambda sample: len(sample.input.shape) == 0 - and sample.kwargs.get("dim") is not None, - reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - skip( - "cat", - matcher=lambda sample: sample.input[0].equal(torch.tensor([])), - reason="core dump - cat does not support zero-dim tensors yet", - ), - xfail( - "index_add", - matcher=lambda sample: len(sample.input.shape) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ScatterND", "0-D tensor" - ), - ), - xfail( - "index_add", - matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, - reason="fixme: aten::index_put indices contains None when dim is -1", - ), - xfail( - "index_copy", - matcher=lambda sample: len(sample.input.shape) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ScatterND", "0-D tensor" - ), - ), - xfail( - "index_copy", - matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, - reason="fixme: aten::index_put indices contains None when dim is -1", - ), - xfail( - "index_put", - matcher=lambda sample: (sample.args[0][0].dtype == torch.bool) - and (sample.kwargs.get("accumulate") is False), - reason=onnx_test_common.reason_dynamo_does_not_support( - "https://github.com/pytorch/pytorch/issues/101150" - ), - ), - skip( - "linalg.multi_dot", - matcher=lambda sample: sum(torch.numel(input) for input in sample.input) == 0, - reason="fixme: Undefined", - ), - skip( - "log_softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "log_softmax", - variant_name="with_dtype", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "masked.log_softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "matmul", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason="values of matmul of [m, 0] and [0, n] matrices are undefined", - ), - skip( - "mm", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason="values of matmul of [m, 0] and [0, n] matrices are undefined", - ), - xfail( - "native_batch_norm", - matcher=lambda sample: sample.args[-3] is True - and any(arg is not None for arg in sample.args[2:4]), - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - xfail( - "nn.functional.avg_pool1d", - matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) - and ( - sample.kwargs.get("count_include_pad") is True - or sample.input.shape[2] - % ( - sample.args[0][0] - if isinstance(sample.args[0], tuple) - else sample.args[0] - ) - != 0 - ), - reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", - ), - xfail( - "nn.functional.avg_pool2d", - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), - reason="ONNX doesn't support divisor_override argument", - ), - xfail( - "nn.functional.avg_pool3d", - matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, - reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", - ), - xfail( - "nn.functional.avg_pool3d", - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), - reason="ONNX doesn't support divisor_override argument", - ), - xfail( - "nn.functional.batch_norm", - matcher=lambda sample: sample.kwargs.get("training") is True - and any(arg is not None for arg in sample.args[2:4]), - reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106", - ), - xfail( - "nn.functional.conv2d", - matcher=lambda sample: sample.kwargs.get("padding") == "valid", - reason="fixme: https://github.com/pytorch/pytorch/issues/117054", - ), - xfail( - "nn.functional.conv3d", - matcher=lambda sample: sample.kwargs.get("padding") == "valid", - reason="fixme: https://github.com/pytorch/pytorch/issues/117054", - ), - skip( - "nn.functional.cross_entropy", - matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int), - reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type", - ), - xfail( - "nn.functional.embedding", - matcher=lambda sample: sample.kwargs.get("max_norm") is not None, - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - skip_torchlib_forward_compatibility( - "nn.functional.embedding_bag", - matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True, - reason=onnx_test_common.reason_onnx_script_does_not_support( - "'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. " - "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided" - ), - github_issue="https://github.com/microsoft/onnxscript/issues/1056", - ), - xfail( - "nn.functional.group_norm", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Reshape", "empty tensor" - ), - ), - xfail( - "nn.functional.instance_norm", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - matcher=lambda sample: sample.kwargs.get("running_mean") is not None, - reason="fixme: KeyError: 'self___kwargs__running_mean'", - ), - xfail( - "nn.functional.max_pool3d", - matcher=lambda sample: sample.kwargs.get("ceil_mode") is True - and sample.kwargs.get("padding") == 1, - reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", - ), - xfail( - "nn.functional.pixel_shuffle", - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", - ), - xfail( - "nonzero", - matcher=lambda sample: len(sample.input.shape) == 0 - and sample.kwargs.get("as_tuple", False) is False, - reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).", - model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - ), - xfail( - "scatter_add", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch", - ), - skip( - "scatter_reduce", - variant_name="amax", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="amin", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="prod", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="sum", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - xfail( - "unflatten", - reason="Logic not implemented for size 0 inputs in op.Reshape", - matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), - ), - skip( - "signal.windows.hamming", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.general_hamming", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.blackman", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.general_cosine", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.hann", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.nuttall", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), -) - -OPS_DB = copy.deepcopy(common_methods_invocations.op_db) -OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset( - meta.op_name for meta in SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE -) -ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) - - -def _torch_size_flatten_spec(d: List[Any], spec: Any) -> List[Any]: - return [d[i] for i in range(spec.num_children)] - - -torch.fx._pytree.register_pytree_flatten_spec( - torch.Size, - _torch_size_flatten_spec, -) - - -class SingleOpModel(torch.nn.Module): - """Test model to wrap around a single op for export.""" - - def __init__(self, op, kwargs): - super().__init__() - self.operator = op - self.kwargs = kwargs - - def forward(self, *args): - return self.operator(*args, **self.kwargs) - - -def _should_skip_xfail_test_sample( - op_name: str, - variant_test_name: str, - sample, - model_type: pytorch_test_common.TorchModelType, -) -> Tuple[Optional[str], Optional[str]]: - """Check if the test sample should be skipped or xfailed. - - If the xfail/skip decorator meta is matched with its op_name and model_type, - return the test_behavior and reason. Otherwise, return None, None. Note that - if the matcher is None, the test is decorator_meta is meant to skip/xfail all model types. - - Args: - op_name: The name of the op. - sample: The test sample. - model_type: The model type of the test. - - Returns: - A tuple of (test_behavior, reason). test_behavior is either "skip" or "xfail". - reason is the reason for the test_behavior. - """ - - if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS: - return None, None - for decorator_meta in SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE: - # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE. That's fine because the list is small. - # NOTE: If model_type is None, the test is decorator_meta is meant to skip/xfail all model types. - if ( - decorator_meta.op_name == op_name - and decorator_meta.variant_name == variant_test_name - ) and ( - model_type == decorator_meta.model_type or decorator_meta.model_type is None - ): - if decorator_meta.matcher is None and decorator_meta.model_type is None: - raise TypeError( - "Either Matcher or model_type must be defined in sub xfail and skip." - ) - if decorator_meta.matcher is not None and decorator_meta.matcher(sample): - return decorator_meta.test_behavior, decorator_meta.reason - elif decorator_meta.matcher is None: - # xfail/skip the whole test of the model type without matcher - return decorator_meta.test_behavior, decorator_meta.reason - return None, None - - -def _compare_onnx_and_torch_exported_program( - torch_exported_program, - onnx_exported_program, - input_args, - input_kwargs=None, - test_name=None, - sample_num=None, - sample_kwargs=None, - rtol=1e-03, - atol=1e-07, - only_check_shape=False, -): - # avoid mutable default argument - if input_kwargs is None: - input_kwargs = {} - - # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. - # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. - # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() - onnx_outputs = onnx_exported_program(*input_args, **input_kwargs) - if isinstance(torch_exported_program, torch.export.ExportedProgram): - torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) - else: - torch_outputs = torch_exported_program(*input_args, **input_kwargs) - torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( - torch_outputs - ) - if len(torch_outputs_onnx_format) != len(onnx_outputs): - raise AssertionError( - f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" - ) - - for j, (torch_output, onnx_output) in enumerate( - zip(torch_outputs_onnx_format, onnx_outputs) - ): - if only_check_shape: - assert torch_output.shape == onnx_output.shape - else: - try: - torch.testing.assert_close( - torch.tensor(onnx_output), - torch_output, - rtol=rtol, - atol=atol, - equal_nan=True, - ) - except AssertionError as e: - if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": - error_reproduction.create_mismatch_report( - test_name, - sample_num, - onnx_exported_program.model_proto, - input_args, - sample_kwargs, - torch.tensor(onnx_output), - torch_output, - e, - ) - if len(torch_outputs_onnx_format) > 1: - raise AssertionError(f"Output {j} mismatch") from e - raise - - -def _run_test_output_match( - test_suite: onnx_test_common._TestONNXRuntime, - device: str, - dtype: torch.dtype, - op: opinfo_core.OpInfo, -): - # device is provided by instantiate_device_type_tests, but we only want to run in cpu. - assert device == "cpu" - samples = op.sample_inputs( - device, - dtype, - requires_grad=False, - ) - for i, cpu_sample in enumerate(samples): - inputs = (cpu_sample.input, *cpu_sample.args) - # Provide the repr to subtest because tensors are not serializable in parallel test runs - - with test_suite.subTest( - opset=test_suite.opset_version, - sample_num=i, - inputs=repr(inputs), - kwargs=repr(cpu_sample.kwargs), - ): - test_behavior, reason = _should_skip_xfail_test_sample( - op.name, op.variant_test_name, cpu_sample, test_suite.model_type - ) - with onnx_test_common.normal_xfail_skip_test_behaviors( - test_behavior, reason - ): - model = SingleOpModel(op.op, cpu_sample.kwargs) - model.eval() - - if ( - dtype == torch.float32 - and op.name in test_suite.fp32_low_precision_dict - ): - rtol = test_suite.fp32_low_precision_dict[op.name][0] - atol = test_suite.fp32_low_precision_dict[op.name][1] - elif dtype == torch.float32: - # Relax atol and rtol for float32 based on empirical results - rtol = 1e-5 - atol = 2e-5 - elif ( - dtype == torch.float16 - and (op.name, op.variant_test_name) - in test_suite.fp16_low_precision_variant_dict - ): - rtol = test_suite.fp16_low_precision_variant_dict[ - (op.name, op.variant_test_name) - ][0] - atol = test_suite.fp16_low_precision_variant_dict[ - (op.name, op.variant_test_name) - ][1] - elif ( - dtype == torch.float16 - and op.name in test_suite.fp16_low_precision_dict - ): - rtol = test_suite.fp16_low_precision_dict[op.name][0] - atol = test_suite.fp16_low_precision_dict[op.name][1] - else: - rtol = None - atol = None - - if ( - test_suite.model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - try: - # TODO (tugsbayasgalan) Migrate to pre-dispatch IR - # BUG1: python test/onnx/test_fx_op_consistency.py -k test_output_match_triu_cpu_int32 - # has unexpected success, but don't know how to remove from xfail list - # BUG2: User output to_sparse is not in the correct order or is not found in the - # exported program's user_output list (https://github.com/pytorch/pytorch/issues/124328) - # python test/onnx/test_fx_op_consistency.py -k test_output_match_to_sparse_cpu_float32 - # BUG3: [ShapeInferenceError] Inference error(s): (op_type:aten_view, node name: aten_view_4): - # [ShapeInferenceError] - # Inference error(s): (op_type:Reshape, node name: n1): [ShapeInferenceError] Invalid position of 0. - # python test/onnx/test_fx_op_consistency.py -k test_output_match_stack_cpu_int32 - from torch.export import _trace - - model = _trace._export(model, inputs, pre_dispatch=False) - - except AssertionError as e: - # NOTE: avoid fake_mode detection bug in torch.export.export - pytest.xfail( - onnx_test_common.reason_dynamo_does_not_support(str(e)) - ) - - try: - onnx_program = torch.onnx.dynamo_export( - model, - *inputs, - ) - except torch.onnx.OnnxExporterError as e: - # NOTE: If the model has unsupported nodes, we will skip the test - # with non-strict xfail. Otherwise, we will raise the error. - if hasattr( - e.__cause__, "diagnostic" - ) and e.__cause__.diagnostic.rule in ( - _rules._POERules.no_symbolic_function_for_call_function, - _rules._POERules.unsupported_fx_node_analysis, - ): - pytest.xfail( - onnx_test_common.reason_onnx_script_does_not_support(str(e)) - ) - else: - raise e - _compare_onnx_and_torch_exported_program( - model, - onnx_program, - inputs, - test_name=test_suite.id(), - sample_num=i, - sample_kwargs=cpu_sample.kwargs, - rtol=rtol, - atol=atol, - only_check_shape=(op.name in test_suite.only_shape_check_list), - ) - - -def _parameterized_class_attrs_and_values(): - input_values = [] - input_values.extend( - itertools.product( - (opset for opset in onnx_test_common.FX_TESTED_OPSETS), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), - ) - ) - return { - "attrs": ["opset_version", "model_type"], - "input_values": input_values, - } - - -def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): - """Combine class name with the parameterized arguments. - - This function is passed to `parameterized.parameterized_class` as the - `class_name_func` argument. - """ - suffixes = [] - for k, v in input_dicts.items(): - suffixes.append(f"{k}_{v}") - return f"{cls.__name__}_{'_'.join(suffixes)}" - - -@parameterized.parameterized_class( - **_parameterized_class_attrs_and_values(), - class_name_func=_parameterize_class_name, -) -class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): - """Test output consistency between exported ONNX models and PyTorch eager mode. - - This is a parameterized test suite. - """ - - opset_version = -1 - dynamic_shapes: bool = False - model_type: pytorch_test_common.TorchModelType = ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE - ) - - # NOTE: Follow torchlib settings in ops_test_data.py - only_shape_check_list = [ - "empty", - "empty_like", - "empty_strided", - "new_empty", - "new_empty_strided", - ] - - fp32_low_precision_dict = { - "native_layer_norm": [2e-4, 7e-4], - } - - fp16_low_precision_dict = { - "addbmm": [2e-1, 2e-2], - "addcdiv": [3e-2, 1.4e-3], - "addcmul": [3e-2, 1e-3], - "addmv": [5e-2, 3e-2], - "addr": [3e-3, 4e-3], - "baddbmm": [3e-2, 1e-3], - "cumulative_trapezoid": [3e-2, 1e-3], - "cross": [3e-2, 2e-2], - "diff": [1e-2, 5e-2], - "div": [5e-3, 1e-3], - "gradient": [3e-3, 4e-3], - "linalg.cross": [1e-3, 2e-2], - "linalg.multi_dot": [3e-2, 1e-3], - "linalg.vecdot": [1e-2, 2e-2], - "linspace": [2e-2, 2e-3], - "masked.std": [2e-2, 2e-3], - "masked.var": [2e-2, 2e-2], - "matmul": [2e-2, 6e-2], - "mv": [9e-3, 1e-5], - "nn.functional.batch_norm": [3e-2, 1e-3], - "nn.functional.binary_cross_entropy": [3e-2, 1e-3], - "nn.functional.binary_cross_entropy_with_logits": [4e-2, 4e-3], - "nn.functional.cosine_similarity": [3e-2, 1e-3], - "nn.functional.cosine_embedding_loss": [1e-2, 1e-3], - "nn.functional.hardsigmoid": [1e-3, 5e-3], - "nn.functional.hardswish": [1e-3, 5e-3], - "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], - "nn.functional.huber_loss": [1e-2, 1e-1], - "nn.functional.instance_norm": [1e-2, 1e-3], - "nn.functional.interpolate": [1e-2, 1e-3], - "nn.functional.kl_div": [2e-3, 2e-4], - "nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3], - "nn.functional.local_response_norm": [1e-2, 5e-3], - "nn.functional.poisson_nll_loss": [4e-2, 6e-3], - "nn.functional.nll_loss": [3e-2, 1e-3], - "nn.functional.triplet_margin_loss": [2e-2, 1e-2], - "nn.functional.triplet_margin_with_distance_loss": [3e-2, 1e-2], - "native_batch_norm": [3e-2, 1e-3], - "norm": [1e-2, 1e-2], - "dot": [3e-2, 1e-3], - "logit": [3e-2, 1e-3], - "rsub": [3e-2, 1e-3], - "sinc": [2e-1, 6e-4], - "sub": [3e-2, 1e-3], - "trapezoid": [1e-3, 7e-3], - "trapz": [1e-3, 7e-3], - "vdot": [1e-3, 1e-2], - } - - fp16_low_precision_variant_dict = { - ("nn.functional.interpolate", "trilinear"): [3e-2, 3e-3], - ("nn.functional.interpolate", "linear"): [3e-2, 3e-3], - } - - @common_device_type.ops( - [op for op in OPS_DB if op.name in ALL_OPS_IN_DB], - allowed_dtypes=onnx_test_common.TESTED_DTYPES, - ) - def test_output_match(self, device: str, dtype: torch.dtype, op): - """Test the ONNX exporter.""" - _run_test_output_match(self, device, dtype, op) - - -for opset in onnx_test_common.FX_TESTED_OPSETS: - for model_type in pytorch_test_common.TorchModelType: - # The name needs to match the parameterized_class name. - test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}" - onnx_test_common.add_decorate_info( - OPS_DB, - test_class_name, - "test_output_match", - opset=opset, - skip_or_xfails=EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES, - ) - - common_device_type.instantiate_device_type_tests( - globals()[test_class_name], globals(), only_for="cpu" - ) - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 04f132abc9c9aa..5de711f181f4d3 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -200,34 +200,6 @@ def forward(self, input): expected_node="aten.convolution.default", ) - def test_dispatch_overload_fall_back_default_raise_diagnostic_warning(self): - class TraceModel(torch.nn.Module): - def forward(self, input): - return torch.ops.aten.add.Tensor(input, input) - - onnx_registry = torch.onnx.OnnxRegistry() - self.assertTrue( - onnx_registry.is_registered_op( - namespace="aten", op_name="add", overload="Tensor" - ) - ) - - aten_add_Tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="add", overload="Tensor" - ) - onnx_registry._registry.pop(aten_add_Tensor) - - x = torch.tensor(3) - onnx_program = dynamo_export( - TraceModel(), x, export_options=ExportOptions(onnx_registry=onnx_registry) - ) - assert_has_diagnostics( - onnx_program.diagnostic_context, - diagnostics.rules.find_operator_overloads_in_onnx_registry, - diagnostics.levels.WARNING, - expected_node="aten.add.Tensor", - ) - def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self): class CustomModule(torch.nn.Module): def forward(self, input): diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 150edfef3412d1..dd53ded23cdbe3 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -478,6 +478,9 @@ def forward(self, x): MutationModel(), (torch.randn(12),), has_mutation=True ) + @unittest.skip( + "Fixme: arange in torchlib does not support dynamic start and end yet." + ) def test_arange(self): class ArangeModel(torch.nn.Module): def forward(self, input): diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 52f37460592cd4..c66c3492e98ba4 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -46,11 +46,8 @@ import onnx - import onnxruntime # type: ignore[import] - import onnxscript # type: ignore[import] - from onnxscript.function_libs.torch_lib import ( # type: ignore[import] - registration as torchlib_registry, - ) + import onnxruntime + import onnxscript from torch._subclasses import fake_tensor from torch.onnx._internal.fx import diagnostics @@ -110,10 +107,6 @@ def __init__(self) -> None: self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( defaultdict(list) ) - # FIXME: Avoid importing onnxscript into torch - from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401 - registration, - ) # opset_version is unused for now, since torchlib only supports opset18. # TODO: get opset version from torchlib @@ -123,8 +116,7 @@ def __init__(self) -> None: "different opset version, please register them with register_custom_op." ) - # Initialize registry from torchlib - self._initiate_registry_from_torchlib(registration.default_registry) + self._initiate_registry_from_torchlib() @property def opset_version(self) -> int: @@ -134,33 +126,25 @@ def opset_version(self) -> int: return self._opset_version - def _initiate_registry_from_torchlib( - self, torchlib_registry: torchlib_registry.Registry - ): + def _initiate_registry_from_torchlib(self) -> None: """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ - for aten_name, aten_overloads_func in torchlib_registry.items(): - internal_name_instance = registration.OpName.from_qualified_name(aten_name) - for overload_func in aten_overloads_func.overloads: - symbolic_function = registration.ONNXFunction( - onnx_function=overload_func, - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=False, - ) - self._register(internal_name_instance, symbolic_function) - - for complex_func in aten_overloads_func.complex: - symbolic_function = registration.ONNXFunction( - onnx_function=complex_func, - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=True, - ) - self._register(internal_name_instance, symbolic_function) + import onnxscript._framework_apis.torch_2_5 as onnxscript_apis + + for meta in onnxscript_apis.get_torchlib_ops(): + internal_name_instance = registration.OpName.from_qualified_name( + meta.qualified_name + ) + symbolic_function = registration.ONNXFunction( + onnx_function=meta.function, # type: ignore[arg-type] + op_full_name=internal_name_instance.qualified_name(), + is_custom=False, + is_complex=meta.is_complex, + ) + self._register(internal_name_instance, symbolic_function) def _register( self, From 2febb21d07b9f751c90021f7e005856db4c17bb5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 3 Sep 2024 15:40:45 +0000 Subject: [PATCH 0352/1018] Revert "[Inductor] Apply loop split optimization in codegen_node (#132389)" This reverts commit 3cb5d251224b3fb59b5a10c6fefbb4c84eb565a6. Reverted https://github.com/pytorch/pytorch/pull/132389 on behalf of https://github.com/ZainRizvi due to Hi, this seems to be breaking in trunk. See test_dataloader.py::TestDataLoader::test_segfault [GH job link](https://github.com/pytorch/pytorch/actions/runs/10660461216/job/29556282081) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/de3a641476c857414d1f07ada685ca83e0097d64) ([comment](https://github.com/pytorch/pytorch/pull/132389#issuecomment-2326843129)) --- test/inductor/test_cpu_repro.py | 12 ---- torch/_inductor/codegen/cpp.py | 116 +------------------------------- torch/_inductor/ir.py | 13 ---- torch/_inductor/scheduler.py | 13 +--- 4 files changed, 6 insertions(+), 148 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index d1f5a82cdf6c74..7dbe35f766f9a7 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3435,18 +3435,6 @@ def forward(self, x): # 2 generated kernels (one for var_mean, the other for result) check_metrics_vec_kernel_count(2) - # check loop split optimization - if fmt == torch.channels_last: - torch._dynamo.reset() - metrics.reset() - with torch.no_grad(): - opt_mod = torch.compile(mod) - _, code = run_and_get_cpp_code(opt_mod, x) - # check that there are no non_contiguous loads - FileCheck().check_count("__at_align__ std::array", 0, exactly=True).run( - code - ) - def test_int_div_vec(self): def fn(x, y, mode): return torch.div(x, y, rounding_mode=mode) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 999df248fc6110..6663de7639e8e1 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -313,12 +313,9 @@ def visit_modular_indexing(divisor, modulus): @functools.lru_cache -def stride_at_vec_range( - index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None -): - if vec_length: - index = simplify_index_in_vec_range(index, var, vec_length) - return stride_at(index, var) +def stride_at_vec_range(index: sympy.Expr, var: sympy.Symbol, vec_length: int): + index_vec_simplified = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index_vec_simplified, var) class OuterLoopFusedSchedulerNode(FusedSchedulerNode): @@ -4088,112 +4085,6 @@ def can_fuse_vertical(self, node1, node2): self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) - def try_loop_split(self, nodes: List[SchedulerNode]): - """ - Apply loop split optimization. - When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop - to avoid non-contiguous loads, subject to the following conditions: - 1. No reduction and no mudular index for all nodes. - 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, - we can get the dimension that needs to be split, and the split dimension is contiguous - in all other indexing_exprs. - - For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: - {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, - we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to - {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to - {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. - """ - - # No reduction and no mudular - if any( - len(node.group[1][1]) != 0 - or any( - expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() - ) - for node in nodes - ): - return nodes - - split_var = None - split_number = None - divide_index_name = None - num_div = 0 - match_div = False - matched_node = None - - for node in nodes: - assert isinstance(node.node, ir.ComputedBuffer) - _, original_body, _ = node.node.get_default_sizes_body() - for name, expr in original_body.indexing_exprs.items(): - num_div += expr.count(FloorDiv) - if num_div > 1: - return nodes - if expr.count(FloorDiv) == 1: - div_expr = expr.find(FloorDiv).pop() - split_var = div_expr.args[0] - split_number = div_expr.args[1] - divide_index_name = name - if ( - isinstance(split_number, sympy.core.numbers.Integer) - and isinstance(split_var, sympy.core.symbol.Symbol) - and divide_index_name is not None - and all( - stride_at_vec_range(expr, split_var) == 1 - for name, expr in original_body.indexing_exprs.items() - if name != divide_index_name - ) - ): - match_div = True - matched_node = node - - # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. - if not match_div: - return nodes - - extra_indexing_constraints = None - - def loop_split(sizes, body, vars): - index_size, reduce_size = sizes - index_vars, reduce_vars = vars - split_idx = index_vars.index(split_var) - new_index_size = index_size.copy() - new_index_size[split_idx] = index_size[split_idx] // split_number - new_index_size.insert(split_idx + 1, split_number) - (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( - new_index_size, reduce_size, prefix="y" - ) - iter_vars = new_index_vars.copy() - divisor_var = iter_vars.pop(split_idx + 1) - iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var - body = ir.LoopBody( - body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars - ) - nonlocal extra_indexing_constraints - if not extra_indexing_constraints: - extra_indexing_constraints = ( - body.var_ranges, - list(body.indexing_exprs.values()), - ) - return ( - (new_index_size, reduce_size), - body, - (new_index_vars, reduce_vars), - ) - - # Here decide the final loop order - for node in nodes: - if node == matched_node: - node.recompute_size_and_body(recompute_sizes_body_func=loop_split) - for node in nodes: - if node != matched_node: - node.recompute_size_and_body( - extra_indexing_constraints=extra_indexing_constraints, - recompute_sizes_body_func=loop_split, - ) - - return nodes - def codegen_outer_loop_node( self, node: OuterLoopFusedSchedulerNode, @@ -4396,7 +4287,6 @@ def codegen_node( self.codegen_outer_loop_node(node) else: nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] - nodes = self.try_loop_split(nodes) cpp_kernel_proxy = CppKernelProxy(kernel_group) cpp_kernel_proxy.codegen_nodes(nodes) kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b9f08fedfe847c..18cdac11ac8077 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3717,7 +3717,6 @@ def get_default_sizes_body(self): def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, - recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ): """ This is a main place where we do loop transformations in a @@ -3733,8 +3732,6 @@ def simplify_and_reorder( to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) on CPU by preventing indexing simplifications and obtaining index/reduce ranges for the scheduler node compatible with other nodes. - Optional argument recompute_sizes_body_func can be used to recompute sizes and body - on the default body. This can be useful to append additional loop transformations. """ ( (index_size, reduce_size), @@ -3742,15 +3739,6 @@ def simplify_and_reorder( (index_vars, reduce_vars), ) = self.get_default_sizes_body() - if recompute_sizes_body_func: - ( - (index_size, reduce_size), - body, - (index_vars, reduce_vars), - ) = recompute_sizes_body_func( - (index_size, reduce_size), body, (index_vars, reduce_vars) - ) - index_formulas = [*body.indexing_exprs.values()] if extra_indexing_constraints is not None: assert ( @@ -3927,7 +3915,6 @@ def should_allocate(self): def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, - recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ): return ( ( diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 07d8cf4f421daf..360e4809371443 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -828,12 +828,10 @@ def __init__( def _compute_attrs( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, - recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) self._sizes, self._body = self.node.simplify_and_reorder( - extra_indexing_constraints=extra_indexing_constraints, - recompute_sizes_body_func=recompute_sizes_body_func, + extra_indexing_constraints=extra_indexing_constraints ) group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn @@ -858,14 +856,9 @@ def _compute_attrs( ) def recompute_size_and_body( - self, - extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, - recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] ) -> None: - self._compute_attrs( - extra_indexing_constraints=extra_indexing_constraints, - recompute_sizes_body_func=recompute_sizes_body_func, - ) + self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) def refresh_dependencies(self, normalize: bool) -> None: # Fake dependencies are added manually. They can not be analyzed from From 7ba31de5ea5be1de104ea9d627c5a45fb4cc4b9f Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 1 Sep 2024 11:13:39 -0700 Subject: [PATCH 0353/1018] Properly handle empty CPUINFO variable (#134916) Fixes https://github.com/pytorch/pytorch/issues/134915 But I did not root cause why CPUINFO is totally empty to begin with... Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/134916 Approved by: https://github.com/Skylion007 --- cmake/Modules/FindARM.cmake | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cmake/Modules/FindARM.cmake b/cmake/Modules/FindARM.cmake index ffd588cf2737df..c7985d0111f024 100644 --- a/cmake/Modules/FindARM.cmake +++ b/cmake/Modules/FindARM.cmake @@ -5,7 +5,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") EXECUTE_PROCESS(COMMAND cat /proc/cpuinfo OUTPUT_VARIABLE CPUINFO) #neon instruction can be found on the majority part of modern ARM processor - STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE "${CPUINFO}") STRING(COMPARE EQUAL "neon" "${NEON_THERE}" NEON_TRUE) IF (NEON_TRUE) set(NEON_FOUND true CACHE BOOL "NEON available on host") @@ -14,7 +14,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ENDIF (NEON_TRUE) # on ARMv8, neon is inherit and instead listed as 'asimd' in /proc/cpuinfo - STRING(REGEX REPLACE "^.*(asimd).*$" "\\1" ASIMD_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(asimd).*$" "\\1" ASIMD_THERE "${CPUINFO}") STRING(COMPARE EQUAL "asimd" "${ASIMD_THERE}" ASIMD_TRUE) IF (ASIMD_TRUE) set(ASIMD_FOUND true CACHE BOOL "ASIMD/NEON available on host") @@ -23,7 +23,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ENDIF (ASIMD_TRUE) #Find the processor type (for now OMAP3 or OMAP4) - STRING(REGEX REPLACE "^.*(OMAP3).*$" "\\1" OMAP3_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(OMAP3).*$" "\\1" OMAP3_THERE "${CPUINFO}") STRING(COMPARE EQUAL "OMAP3" "${OMAP3_THERE}" OMAP3_TRUE) IF (OMAP3_TRUE) set(CORTEXA8_FOUND true CACHE BOOL "OMAP3 available on host") @@ -32,7 +32,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ENDIF (OMAP3_TRUE) #Find the processor type (for now OMAP3 or OMAP4) - STRING(REGEX REPLACE "^.*(OMAP4).*$" "\\1" OMAP4_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(OMAP4).*$" "\\1" OMAP4_THERE "${CPUINFO}") STRING(COMPARE EQUAL "OMAP4" "${OMAP4_THERE}" OMAP4_TRUE) IF (OMAP4_TRUE) set(CORTEXA9_FOUND true CACHE BOOL "OMAP4 available on host") @@ -49,7 +49,7 @@ ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin") IF(NOT CPUINFO STREQUAL "") #neon instruction can be found on the majority part of modern ARM processor - STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE "${CPUINFO}") STRING(COMPARE EQUAL "neon" "${NEON_THERE}" NEON_TRUE) IF (NEON_TRUE) set(NEON_FOUND true CACHE BOOL "NEON available on host") From 4b2c113eaf6ead85487348c11ba047c3e98327ef Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 3 Sep 2024 16:19:47 +0000 Subject: [PATCH 0354/1018] Revert "[ONNX] Bump onnxscript version in CI; temporarily remove op test (#133748)" This reverts commit 6eed63c8b9c4f54a573bb51960d252cd42bfab0c. Reverted https://github.com/pytorch/pytorch/pull/133748 on behalf of https://github.com/ZainRizvi due to The version bump appears to be pulling in an unavailable numpy version? [GH job link](https://github.com/pytorch/pytorch/actions/runs/10686076754/job/29620426371) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/6eed63c8b9c4f54a573bb51960d252cd42bfab0c) ([comment](https://github.com/pytorch/pytorch/pull/133748#issuecomment-2326932868)) --- .ci/docker/common/install_onnx.sh | 7 +- test/onnx/dynamo/test_registry_dispatcher.py | 223 +- test/onnx/test_fx_op_consistency.py | 2076 +++++++++++++++++ test/onnx/test_fx_to_onnx.py | 28 + test/onnx/test_fx_to_onnx_with_onnxruntime.py | 3 - torch/onnx/_internal/_exporter_legacy.py | 50 +- 6 files changed, 2363 insertions(+), 24 deletions(-) create mode 100644 test/onnx/test_fx_op_consistency.py diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index b99ecac283d646..65907a99f2572a 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -30,9 +30,10 @@ pip_install \ pip_install coloredlogs packaging -pip_install onnxruntime==1.18.1 -pip_install onnx==1.16.2 -pip_install onnxscript==0.1.0.dev20240831 --no-deps +pip_install onnxruntime==1.18 +pip_install onnx==1.16.0 +# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps +pip_install onnxscript==0.1.0.dev20240613 --no-deps # required by onnxscript pip_install ml_dtypes diff --git a/test/onnx/dynamo/test_registry_dispatcher.py b/test/onnx/dynamo/test_registry_dispatcher.py index 8605b098d30055..d782f94dd311d9 100644 --- a/test/onnx/dynamo/test_registry_dispatcher.py +++ b/test/onnx/dynamo/test_registry_dispatcher.py @@ -8,11 +8,18 @@ import onnxscript # type: ignore[import] from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import] +from onnxscript.function_libs.torch_lib import ops # type: ignore[import] from onnxscript.onnx_opset import opset15 as op # type: ignore[import] import torch import torch.fx -from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration +from torch.onnx._internal.diagnostics import infra +from torch.onnx._internal.fx import ( + analysis, + diagnostics, + onnxfunction_dispatcher, + registration, +) from torch.testing._internal import common_utils @@ -77,6 +84,60 @@ def test_custom(x, y): [test_original, test_custom], ) + def test_unsupported_nodes_analysis_with_missing_aten_op(self): + # NOTE: simulate unsupported nodes + aten_mul_tensor = registration.OpName.from_name_parts( + namespace="aten", op_name="mul", overload="Tensor" + ) + aten_mul_default = registration.OpName.from_name_parts( + namespace="aten", op_name="mul" + ) + aten_add_tensor = registration.OpName.from_name_parts( + namespace="aten", op_name="add", overload="Tensor" + ) + aten_add_default = registration.OpName.from_name_parts( + namespace="aten", op_name="add" + ) + + self.registry._registry.pop(aten_mul_tensor) + self.registry._registry.pop(aten_mul_default) + self.registry._registry.pop(aten_add_tensor) + self.registry._registry.pop(aten_add_default) + + diagnostic_context = diagnostics.DiagnosticContext( + "torch.onnx.dynamo_export", torch.__version__ + ) + dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( + self.registry, diagnostic_context + ) + + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + x.meta["val"] = torch.tensor(3.0) + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.ops.aten.mul.Tensor, args=(x, x) + ) + c: torch.fx.Node = graph.create_node( + "call_function", target=torch.ops.aten.add.Tensor, args=(b, b) + ) + output: torch.fx.Node = graph.output(c) + module = torch.fx.GraphModule(torch.nn.Module(), graph) + + with self.assertRaises(infra.RuntimeErrorWithDiagnostic): + analysis.UnsupportedFxNodesAnalysis( + diagnostic_context, module, dispatcher + ).analyze(infra.levels.ERROR) + + try: + analysis.UnsupportedFxNodesAnalysis( + diagnostic_context, module, dispatcher + ).analyze(infra.levels.ERROR) + except infra.RuntimeErrorWithDiagnostic as e: + self.assertIn( + "Unsupported FX nodes: {'call_function': ['aten.mul.Tensor', 'aten.add.Tensor']}.", + e.diagnostic.message, + ) + @common_utils.instantiate_parametrized_tests class TestDispatcher(common_utils.TestCase): @@ -427,5 +488,165 @@ def test_first_custom_op( self.assertEqual(symbolic_fn, test_third_custom_op) +@common_utils.instantiate_parametrized_tests +class TestOpSchemaWrapper(common_utils.TestCase): + def setUp(self): + # overload type: optional dtype + self.onnx_function_new_full = ops.core.aten_new_full + self.onnx_function_new_full_dtype = ops.core.aten_new_full_dtype + + @common_utils.parametrize( + "inputs, attributes, assertion", + [ + common_utils.subtest( + ([torch.randn(3, 4), torch.randn(3, 4)], {"alpha": 2.0}, True), + name="perfect_match_with_kwargs", + ), + common_utils.subtest( + (["A", "B"], {}, False), + name="non_perfect_match_due_to_non_tensor_inputs", + ), + common_utils.subtest( + ([torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4)], {}, False), + name="non_perfect_match_due_to_too_many_inputs", + ), + common_utils.subtest( + ([torch.randn(3, 4), torch.randn(3, 4)], {"wrong_kwargs": 2.0}, False), + name="non_perfect_match_due_to_wrong_kwargs", + ), + ], + ) + def test_perfect_match_inputs(self, inputs, attributes, assertion): + # OnnxFunction with default attributes + dummy_diagnostic = diagnostics.Diagnostic( + rule=diagnostics.rules.find_opschema_matched_symbolic_function, + level=diagnostics.levels.WARNING, + ) + op_schema_wrapper_add = onnxfunction_dispatcher._OnnxSchemaChecker( + ops.core.aten_add + ) + self.assertEqual( + op_schema_wrapper_add.perfect_match_inputs( + dummy_diagnostic, inputs, attributes + ), + assertion, + ) + + @common_utils.parametrize( + "inputs, kwargs, op, score", + [ + common_utils.subtest( + ([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul, 2), + name="match_2_inputs", + ), + common_utils.subtest( + ( + [ + torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), + torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), + ], + {}, + ops.core.aten_mul, + 0, + ), + name="match_0_inputs", + ), + common_utils.subtest( + ([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul_bool, 0), + name="match_0_inputs_bool", + ), + common_utils.subtest( + ( + [ + torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), + torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), + ], + {}, + ops.core.aten_mul_bool, + 2, + ), + name="match_2_inputs_bool", + ), + ], + ) + def test_matching_score_system_on_overload_dtypes(self, inputs, kwargs, op, score): + op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op) + op_schema_wrapper._record_matching_score(inputs, kwargs) + self.assertEqual(op_schema_wrapper.match_score, score) + + @common_utils.parametrize( + "inputs, kwargs, op, score", + [ + common_utils.subtest( + ([torch.randn(3, 4), torch.tensor(3)], {}, ops.core.aten_new_full, 2), + name="match_2_inputs", + ), + common_utils.subtest( + ( + [torch.randn(3, 4), torch.tensor(3)], + {"dtype": 2}, # at this point, dtype should be converted to int + ops.core.aten_new_full_dtype, + 2, + ), + name="match_2_input_and_match_1_kwargs_optional", + ), + ], + ) + def test_matching_score_system_on_optional_dtypes(self, inputs, kwargs, op, score): + op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op) + op_schema_wrapper._record_matching_score(inputs, kwargs) + self.assertEqual(op_schema_wrapper.match_score, score) + + @common_utils.parametrize( + "value, expected_onnx_str_dtype", + [ + common_utils.subtest( + (1, {"tensor(int64)", "tensor(int16)", "tensor(int32)"}), + name="all_ints", + ), + common_utils.subtest( + (1.0, {"tensor(float)", "tensor(double)", "tensor(float16)"}), + name="all_floats", + ), + common_utils.subtest( + (torch.tensor([True]), {"tensor(bool)"}), + name="bool", + ), + common_utils.subtest( + (torch.tensor([1], dtype=torch.int64), {"tensor(int64)"}), + name="int64", + ), + common_utils.subtest( + (torch.tensor([1], dtype=torch.int32), {"tensor(int32)"}), + name="int32", + ), + common_utils.subtest( + (torch.tensor([1], dtype=torch.int16), {"tensor(int16)"}), + name="int16", + ), + common_utils.subtest( + (torch.tensor([1], dtype=torch.float), {"tensor(float)"}), + name="float", + ), + common_utils.subtest( + (torch.tensor([1], dtype=torch.float16), {"tensor(float16)"}), + name="float16", + ), + common_utils.subtest( + (torch.tensor([1], dtype=torch.double), {"tensor(double)"}), + name="double", + ), + common_utils.subtest((None, set()), name="None"), # None allows no dtype + common_utils.subtest( + ([], set()), name="empaty_list" + ), # Empty list allows no dtype + ], + ) + def test_find_onnx_data_type(self, value, expected_onnx_str_dtype): + self.assertEqual( + onnxfunction_dispatcher._find_onnx_data_type(value), expected_onnx_str_dtype + ) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py new file mode 100644 index 00000000000000..9972ea901722c7 --- /dev/null +++ b/test/onnx/test_fx_op_consistency.py @@ -0,0 +1,2076 @@ +# Owner(s): ["module: onnx"] + +"""Test consistency between the output values of torch.onnx FX exported operators +and torch operators given the same inputs. + +Usage: + + 1. Test all operators: + + pytest test/onnx/test_fx_op_consistency.py + + 2. To run tests on a specific operator (e.g. torch.ceil): + + pytest test/onnx/test_fx_op_consistency.py -k ceil + pytest test/onnx/test_fx_op_consistency.py -k nn_functional_scaled_dot_product_attention + + 3. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g. + + CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/test_fx_op_consistency.py -k div_mode_int + + NOTE: Read more on Running and writing tests: + https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests + +Note: + + 1. Please make sure pytest-subtests is installed. Otherwise, the sub-tests will be ignored. + + 2. Install pytest-xdist to run tests in parallel if runng all tests is the goal. + + 3. When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES and + TESTED_OPS lists. See "Modify this section" + +""" + +from __future__ import annotations + +import copy +import itertools +import os +from typing import ( + Any, + Callable, + Collection, + List, + Mapping, + Optional, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +import error_reproduction +import onnx_test_common +import parameterized +import pytest +import pytorch_test_common +from onnx_test_common import skip, skip_slow, xfail + +import torch +from torch.onnx._internal.diagnostics import _rules +from torch.testing._internal import ( + common_device_type, + common_methods_invocations, + common_utils, +) + + +if TYPE_CHECKING: + from torch.testing._internal.opinfo import core as opinfo_core + + +# NOTE: For ATen signature modifications that will break ONNX export, +# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip +# to make the signal apparent for maintainers. +def xfail_torchlib_forward_compatibility( + op_name: str, + variant_name: str = "", + *, + reason: str, + github_issue: str, + opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, + dtypes: Optional[Collection[torch.dtype]] = None, + matcher: Optional[Callable[[Any], bool]] = None, + enabled_if: bool = True, +): + """Prefer using this (xfail) over skip when possible. + + Only skip when the test is not failing consistently. + """ + return xfail( + op_name, + variant_name=variant_name, + reason=f"{reason}. GitHub Issue: {github_issue}", + opsets=opsets, + dtypes=dtypes, + matcher=matcher, + enabled_if=enabled_if, + ) + + +def skip_torchlib_forward_compatibility( + op_name: str, + variant_name: str = "", + *, + reason: str, + github_issue: str, + opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, + dtypes: Optional[Collection[torch.dtype]] = None, + matcher: Optional[Callable[[Any], Any]] = None, + enabled_if: bool = True, +): + """Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible. + + Only skip when the test is not failing consistently. + """ + return skip( + op_name, + variant_name=variant_name, + reason=f"{reason}. GitHub Issue: {github_issue}", + opsets=opsets, + dtypes=dtypes, + matcher=matcher, + enabled_if=enabled_if, + ) + + +# fmt: off +# Turn off black formatting to keep the list compact + +# Expected failures for onnx export. +# The list should be sorted alphabetically by op name. +# Q: When should I use fixme vs vs skip vs xfail? +# A: Prefer xfail over skip when possible. +# 2a. If a test is now failing because of xpass, because some previous errors +# are now fixed, removed the corresponding xfail. +# 2b. If a test is not failing consistently, use skip. +# NOTE: EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES only supports dtypes. If a matcher or model_type +# is needed, use the SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE list further down below. +EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] = ( + xfail( + "__getitem__", + reason="io_adaper doesn't support __getitem__ input slice(0, 3, None)", + ), + xfail( + "__radd__", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "bool"), + ), + xfail( + "__rmatmul__", + dtypes=(torch.float16,), + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "__rpow__", + dtypes=onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Pow", "int"), + ), + skip( + "_native_batch_norm_legit", + reason=onnx_test_common.reason_onnx_script_does_not_support("cpu is not supported: \ + https://github.com/microsoft/onnxscript/pull/1289") + ), + skip( + "_batch_norm_with_update", + dtypes=(torch.float16,), + reason="fixme: Assertion error: result mismatch and type error", + ), + xfail( + "_softmax_backward_data", + dtypes=(torch.float16,), + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "_unsafe_masked_index", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), + ), + xfail( + "_unsafe_masked_index", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("_unsafe_masked_index", "complex64"), + ), + xfail( + "_unsafe_masked_index_put_accumulate", + reason="fixme: Status Message: updates tensor should have shape equal to " + "indices.shape[:-1] + data.shape[indices.shape[-1]:]", + ), + xfail( + "_unsafe_masked_index_put_accumulate", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), + ), + xfail( + "add", dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Add") + ), + xfail( + "add", + dtypes=(torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_script_does_not_support( + "Add", "int8, int16, uint8 have type issue." + ), + ), + xfail( + "addbmm", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("Addbmm", "complex64") + ), + xfail( + "addmm", dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Addmm") + ), + xfail( + "addmm", + variant_name="decomposed", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Addmm") + ), + skip( + "addmm", dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") + ), + skip( + "addmm", + variant_name="decomposed", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") + ), + xfail( + "addr", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support( + "Addr", "bool" + ), + ), + xfail( + "addr", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") + ), + xfail( + "alias_copy", + dtypes=(torch.int8, torch.uint8, torch.int16, torch.float64), + reason="OnnxExporterError: Failed to export model", + ), + xfail( + "allclose", + reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") + ), + xfail( + "amax", + dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), + reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), + ), + xfail( + "amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), + reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16") + ), + xfail( + "aminmax", + dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), + reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), + ), + xfail( + "arange", + dtypes=(torch.uint8,), + reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"), + ), + xfail( + "arange", + dtypes=(torch.int16, torch.int32), + reason="AssertionError: The values for attribute 'shape' do not match", + ), + xfail( + "argmax", + dtypes=( + torch.int16, + torch.int64, + ), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "ArgMax", "int16, int64" + ), + ), + xfail( + "argmin", + dtypes=( + torch.uint8, + torch.int8, + torch.int16, + torch.int64, + ), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "ArgMin", "uint8, int8, int16, int64" + ), + ), + xfail( + "argwhere", + reason="fixme: Assertion error: result mismatch", + ), + skip( + "as_strided", + variant_name="partial_views", + reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults", + ), + xfail( + "atan2", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "baddbmm", + dtypes=( + torch.uint8, + torch.int8, + torch.int16, + ), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Matmul", "uint8, int8, int16" + ), + ), + xfail( + "baddbmm", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("baddbmm", "complex64") + ), + xfail( + "bernoulli", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "bfloat16", + reason="fixme: ORT errors with RuntimeError: No corresponding Numpy type for Tensor Type.", + ), + xfail( + "bincount", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"), + ), + xfail( + "block_diag", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Block_diag", "complex"), + ), + xfail( + "bmm", + dtypes=( + torch.uint8, + torch.int8, + torch.int16, + ), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Matmul", "uint8, int8, int16" + ), + ), + xfail( + "broadcast_shapes", + reason=onnx_test_common.reason_dynamo_does_not_support("output is int"), + ), + xfail( + "cauchy", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + skip( + "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int") + ), + xfail( + "chalf", + reason="fixme: ONNX shape type inference error: Invalid tensor data type 0." + ), + xfail( + "chunk", + dtypes=(torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Chunk", "uint8, int8, int16" + ), + ), + xfail( + "clamp", + dtypes=(torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Max", "uint8, int8, int16" + ), + ), + xfail( + "clamp_max", dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool") + ), + xfail( + "clamp_max", + dtypes=(torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Max", "uint8, int8, int16" + ), + ), + xfail( + "clamp_min", + dtypes=(torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Max", "uint8, int8, int16" + ), + ), + xfail( + "clamp_min", dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool") + ), + xfail( + "constant_pad_nd", + dtypes=(torch.int16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Constant_pad_nd", "int16" + ), + ), + xfail( + "constant_pad_nd", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support( + "Constant_pad_nd", "complex64" + ), + ), + xfail( + "corrcoef", + reason=onnx_test_common.reason_dynamo_does_not_support( + "aten.equal.default" + ), + ), + xfail( + "cov", + reason=onnx_test_common.reason_dynamo_does_not_support( + "aten.equal.default" + ), + ), + xfail( + "cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16") + ), + xfail( + "combinations", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"), + ), + xfail( + "diag", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), + ), + xfail( + "diagonal_copy", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), + ), + xfail( + "dot", dtypes=(torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16") + ), + skip( + "dot", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("Dot", "complex64(core dump)"), + ), + xfail( + "empty", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." + ), + xfail( + "empty_strided", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "eq", + dtypes=(torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"), + ), + xfail( + "equal", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") + ), + xfail( + "exponential", + reason=onnx_test_common.reason_dynamo_does_not_support("exponential"), + ), + xfail( + "fft.fft", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "fft.fft2", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "fft.fftn", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "fft.ifft", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "fft.ifft2", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "fft.ifftn", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "fft.irfft", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "fft.irfft2", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "fft.irfftn", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), + ), + xfail( + "fft.rfft", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), + ), + xfail( + "fft.rfftn", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), + ), + xfail( + "fft.rfft2", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), + ), + xfail( + "floor", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), + ), + xfail( + "floor_divide", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), + ), + xfail( + "full", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("full", "complex64") + ), + xfail( + "full_like", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64") + ), + xfail( + "gather", + reason="GatherElements op: Rank of input 'data' needs to be equal to rank of input 'indices'" + ), + xfail( + "geometric", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "heaviside", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"), + ), + xfail( + "index_add", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "int64, int32, bool"), + ), + xfail( + "index_fill", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64") + ), + xfail( + "index_fill", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES, + reason="fixme: Constant input list has None. ONNXScript does not support None in constant list." + ), + xfail( + "index_put", + dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), + reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"), + ), + xfail( + "index_put", + dtypes=(torch.uint8, torch.int8, torch.int16,), + reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), + ), + xfail( + "index_put", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "float16"), + ), + xfail( + "isnan", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("IsNaN", "int, bool"), + ), + xfail( + "istft", + reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), + ), + xfail( + "item", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "lerp", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("lerp", "complex64") + ), + xfail( + "linalg.lstsq", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), + ), + xfail( + "linalg.lstsq", + variant_name="grad_oriented", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), + ), + xfail( + "linalg.matrix_power", + reason="fixme: The values for attribute 'shape' do not match: torch.Size([2, 2]) != torch.Size([2, 2, 2])." + ), + xfail( + "linalg.norm", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "linalg.norm", + variant_name="subgradients_at_zero", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "linalg.vecdot", + reason="fixme: Assertion error: result shape mismatch", + ), + xfail( + "linspace", + dtypes=(torch.int64, torch.int32,), + reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", + ), + xfail( + "linspace", + variant_name="tensor_overload", + dtypes=(torch.int64, torch.int32,), + reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", + ), + xfail( + "linspace", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") + ), + xfail( + "linspace", + variant_name="tensor_overload", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") + ), + xfail( + "log_normal", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "log_softmax", + dtypes=(torch.float16,), + reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", + ), + xfail( + "log_softmax", + variant_name="with_dtype", + dtypes=(torch.float16,), + reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", + ), + xfail( + "logical_and", + dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("And", "float, int"), + ), + xfail( + "logical_not", + dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Not", "float, int"), + ), + xfail( + "logical_or", + dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Or", "float, int"), + ), + xfail( + "logical_xor", + dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"), + ), + skip( + "masked.logsumexp", + reason="fixme: https://github.com/onnx/onnx/issues/4986", + ), + xfail( + "masked.amax", + reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", + ), + xfail( + "masked.amin", + reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", + ), + xfail( + "masked.argmin", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "masked.argmax", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "masked_fill", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), + ), + xfail( + "masked.sum", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), + ), + xfail( + "masked.log_softmax", + dtypes=(torch.float16,), + reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", + ), + xfail( + "masked.mean", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("ReduceMean", "bool"), + ), + xfail( + "masked.norm", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "masked.prod", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), + ), + xfail( + "masked_select", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked_select.default"), + ), + xfail( + "max", + variant_name="reduction_no_dim", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), + ), + xfail( + "max", + variant_name="reduction_with_dim", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), + ), + xfail( + "max", + variant_name="reduction_with_dim", + dtypes=(torch.int64,), + reason="https://github.com/onnx/onnx/issues/4986", + ), + xfail( + "min", + variant_name="reduction_no_dim", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), + ), + xfail( + "min", + variant_name="reduction_with_dim", + dtypes=onnx_test_common.BOOL_TYPES + (torch.int64,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), + ), + skip( + "mm", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("MM", "complex64(core dump)"), + ), + xfail( + "multinomial", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nanquantile", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") + ), + xfail( + "nansum", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("IsNaN", "int, bool"), + ), + xfail( + "narrow", + reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), + ), + skip( + "native_batch_norm", + reason=onnx_test_common.reason_onnx_script_does_not_support("cpu is not supported: \ + https://github.com/microsoft/onnxscript/pull/1289") + ), + xfail( + "native_layer_norm", + dtypes=(torch.float16,), + reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", + ), + xfail( + "new_full", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_dynamo_does_not_support("new_full", "complex64") + ), + xfail( + "nn.functional.adaptive_avg_pool2d", + reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \ + maximum recursion depth exceeded while calling a Python object"), + ), + xfail( + "nn.functional.adaptive_avg_pool3d", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"), + ), + xfail( + "nn.functional.alpha_dropout", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.avg_pool1d", + dtypes=onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), + ), + xfail( + "nn.functional.avg_pool2d", + dtypes=onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), + ), + xfail( + "nn.functional.avg_pool3d", + dtypes=onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), + ), + xfail( + "nn.functional.batch_norm", + dtypes=(torch.float16,), + reason="fixme: https://github.com/microsoft/onnxscript/issues/1270", + ), + xfail( + "nn.functional.conv_transpose1d", + dtypes=(torch.int64,), + reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), + ), + xfail( + "nn.functional.conv_transpose2d", + dtypes=(torch.int64,), + reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), + ), + xfail( + "nn.functional.conv_transpose3d", + dtypes=(torch.int64,), + reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), + ), + skip( + "nn.functional.conv_transpose1d", + reason="fixme: Assertion error: result mismatch", + ), + skip( + "nn.functional.conv_transpose2d", + reason="fixme: Assertion error: result mismatch", + ), + skip( + "nn.functional.conv_transpose3d", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "nn.functional.conv1d", + dtypes=(torch.int64,), + reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), + ), + xfail( + "nn.functional.conv2d", + dtypes=(torch.int64,), + reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), + ), + xfail( + "nn.functional.conv2d", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "nn.functional.conv3d", + dtypes=(torch.int64,), + reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), + ), + xfail( + "nn.functional.conv3d", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "nn.functional.cosine_embedding_loss", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("CosineEmbeddingLoss", "bool"), + ), + xfail( + "nn.functional.ctc_loss", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"), + ), + xfail( + "nn.functional.dropout", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.dropout2d", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.dropout3d", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.feature_alpha_dropout", + variant_name="with_train", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.feature_alpha_dropout", + variant_name="without_train", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.fractional_max_pool2d", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.fractional_max_pool3d", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.gaussian_nll_loss", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.gaussian_nll_loss"), + ), + xfail( + "nn.functional.grid_sample", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "nn.functional.group_norm", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"), + ), + xfail( + "nn.functional.local_response_norm", + dtypes=(torch.int64,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("avgpool", "int64"), + ), + xfail( + "nn.functional.linear", + dtypes=onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Gemm", "int"), + ), + xfail( + "nn.functional.max_pool2d", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"), + ), + xfail( + "nn.functional.max_pool3d", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"), + ), + xfail( + "nn.functional.multi_head_attention_forward", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.one_hot", + reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), + ), + xfail( + "nn.functional.pad", + variant_name="replicate", + reason="fixme: ORT error: padding size", + ), + xfail( + "nn.functional.pad", + variant_name="replicate_negative", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "nn.functional.pad", + variant_name="reflect", + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "nn.functional.pixel_shuffle", + dtypes=(torch.int32, torch.int64) + onnx_test_common.BOOL_TYPES, + reason="fixme: ONNX Runtime does not support int32/64 inputs", + ), + xfail( + "nn.functional.pixel_unshuffle", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten.pixel_unshuffle.default"), + ), + xfail( + "nn.functional.poisson_nll_loss", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason="fixme: result mismatch with NaN.", + ), + xfail( + "nn.functional.rrelu", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "nn.functional.rrelu", + dtypes=(torch.int64,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Relu", "int64"), + ), + skip( + "nn.functional.scaled_dot_product_attention", + matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, + reason="dropout is random so the results do not match", + ), + xfail( + "nn.functional.scaled_dot_product_attention", + dtypes=(torch.float16,), + reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", + ), + xfail( + "nn.functional.selu", + reason="fixme: nn.functional.selu is not in torch._decomp.decomposition_table", + ), + xfail( + "nn.functional.soft_margin_loss", + dtypes=(torch.float16,), + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "nonzero", + dtypes=(torch.int8, torch.int16), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"), + ), + xfail( + "normal", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "normal", + variant_name="in_place", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "normal", + variant_name="number_mean", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "ones", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." + ), + xfail( + "pca_lowrank", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "quantile", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") + ), + xfail( + "rand_like", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "randint", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "randint_like", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "randn", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "randn_like", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "resize_", + reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") + ), + xfail( + "resize_as_", + reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") + ), + xfail( + "round", + dtypes=onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Round", "int"), + ), + xfail( + "rsub", + dtypes=(torch.uint8, torch.int8, torch.int16), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Mul", "uint8, int8, int16" + ), + ), + xfail( + "scatter_add", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), + ), + xfail( + "scatter_reduce", + variant_name="sum", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), + ), + xfail( + "scatter_reduce", + variant_name="prod", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"), + ), + xfail( + "scatter_reduce", + variant_name="amin", + dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"), + ), + xfail( + "scatter_reduce", + variant_name="amax", + dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"), + ), + xfail( + "scatter_reduce", + variant_name="mean", + reason="ONNX doesn't support reduce='mean' option", + ), + xfail( + "sgn", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), + ), + xfail( + "sign", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), + ), + xfail( + "signal.windows.kaiser", + reason=onnx_test_common.reason_dynamo_does_not_support("functionalization"), + ), + xfail( + "softmax", + dtypes=(torch.float16,), + reason="ORT error: https://github.com/microsoft/onnxruntime/issues/16438" + ), + xfail( + "sparse.mm", + variant_name="reduce", + reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), + ), + xfail( + "sparse.sampled_addmm", + reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), + ), + xfail( + "special.erfcx", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Erf", "int, bool"), + ), + xfail( + "special.erfcx", + dtypes=onnx_test_common.FLOAT_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"), + ), + xfail( + "special.log_ndtr", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES, + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "special.ndtr", + dtypes=(torch.float16,), + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "square", + dtypes=(torch.int8, torch.uint8, torch.int16), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"), + ), + xfail( + "squeeze", + variant_name="multiple", + reason="fixme: https://github.com/microsoft/onnxscript/issues/1264", + ), + xfail( + "svd_lowrank", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "stft", + reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"), + ), + xfail( + "sub", + dtypes=(torch.uint8, torch.int8, torch.int16), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Mul", "uint8, int8, int16" + ), + ), + xfail( + "take", + reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), + ), + xfail( + "tensor_split", + reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), + ), + xfail( + "topk", + dtypes=(torch.int64, torch.int32, torch.float16), + reason="fixme: Assertion error: result mismatch", + ), + xfail( + "tril", + dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), + ), + xfail( + "triu", + dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), + ), + xfail( + "trunc", + dtypes=onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"), + ), + xfail( + "unflatten", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_does_not_support("Unflatten") + ), + xfail( + "uniform", + reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), + ), + xfail( + "unique", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), + ), + xfail( + "unique_consecutive", + reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), + ), + xfail( + "unravel_index", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"), + ), + xfail( + "unsqueeze_copy", + reason="OnnxExporterError: Failed to export model", + dtypes=(torch.int8, torch.uint8, torch.int16), + ), + xfail( + "where", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), + ), + xfail( + "zeros", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." + ), + # SLOW TESTS (All are xfails if we run them) + # TODO: https://github.com/pytorch/pytorch/issues/117118 + skip_slow( + "cdist", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "histogram", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "histogramdd", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "linalg.lu_solve", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "linalg.solve_triangular", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "linalg.svd", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "logspace", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "logspace", + variant_name="tensor_overload", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "max_pool2d_with_indices_backward", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "nn.functional.interpolate", + variant_name="bicubic", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "nn.functional.max_unpool1d", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "nn.functional.max_unpool2d", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "nn.functional.max_unpool3d", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "nn.functional.max_pool1d", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "nn.functional.max_pool2d", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "nn.functional.max_pool3d", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "nn.functional.unfold", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "ormqr", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "searchsorted", + reason="fixme: Test sets are too many.", + ), + skip_slow( + "svd", + reason="fixme: Test sets are too many.", + ), +) +# fmt: on + +# NOTE: The xfail and skip with a matcher function or model_type should be +# at under the `SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE` section. +SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE: tuple[ + onnx_test_common.DecorateMeta, ... +] = ( + skip( + "_native_batch_norm_legit", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="https://github.com/pytorch/pytorch/issues/115106", + ), + skip( + "_batch_norm_with_update", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="https://github.com/pytorch/pytorch/issues/115106", + ), + # TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]). + # Numerically the ONNX program is correct, but the output shapes for `save_mean` + # and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268]) + # for example. + skip( + "_batch_norm_with_update", + model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, + reason="not supported yet", + ), + xfail( + "addmm", # xfail can't only use dtypes to catch all cases + matcher=lambda sample: sample.input.dtype + in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Gemm", "uint8, int8, int16, int32, int64" + ), + ), + xfail( + "addmm", + matcher=lambda sample: sample.args[0].numel() == 0, + reason="ONNX Runtime does not support empty tensors multiplication", + ), + xfail( + "addmm", + variant_name="decomposed", + matcher=lambda sample: sample.args[0].numel() == 0, + reason="ONNX Runtime does not support empty tensors multiplication", + ), + xfail( + "amax", + matcher=lambda sample: len(sample.input.shape) == 0 + and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), + reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", + ), + xfail( + "amin", + matcher=lambda sample: len(sample.input.shape) == 0 + and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), + reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", + ), + xfail( + "aminmax", + matcher=lambda sample: len(sample.input.shape) == 0 + and sample.kwargs.get("dim") is not None, + reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", + ), + skip( + "cat", + matcher=lambda sample: sample.input[0].equal(torch.tensor([])), + reason="core dump - cat does not support zero-dim tensors yet", + ), + xfail( + "index_add", + matcher=lambda sample: len(sample.input.shape) == 0, + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "ScatterND", "0-D tensor" + ), + ), + xfail( + "index_add", + matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, + reason="fixme: aten::index_put indices contains None when dim is -1", + ), + xfail( + "index_copy", + matcher=lambda sample: len(sample.input.shape) == 0, + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "ScatterND", "0-D tensor" + ), + ), + xfail( + "index_copy", + matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, + reason="fixme: aten::index_put indices contains None when dim is -1", + ), + xfail( + "index_put", + matcher=lambda sample: (sample.args[0][0].dtype == torch.bool) + and (sample.kwargs.get("accumulate") is False), + reason=onnx_test_common.reason_dynamo_does_not_support( + "https://github.com/pytorch/pytorch/issues/101150" + ), + ), + skip( + "linalg.multi_dot", + matcher=lambda sample: sum(torch.numel(input) for input in sample.input) == 0, + reason="fixme: Undefined", + ), + skip( + "log_softmax", + matcher=lambda sample: len(sample.input.shape) == 0, + reason="fixme: LogSoftMax does not support empty tensor as input", + ), + skip( + "log_softmax", + variant_name="with_dtype", + matcher=lambda sample: len(sample.input.shape) == 0, + reason="fixme: LogSoftMax does not support empty tensor as input", + ), + skip( + "masked.log_softmax", + matcher=lambda sample: len(sample.input.shape) == 0, + reason="fixme: LogSoftMax does not support empty tensor as input", + ), + skip( + "matmul", + matcher=lambda sample: torch.numel(sample.input) == 0, + reason="values of matmul of [m, 0] and [0, n] matrices are undefined", + ), + skip( + "mm", + matcher=lambda sample: torch.numel(sample.input) == 0, + reason="values of matmul of [m, 0] and [0, n] matrices are undefined", + ), + xfail( + "native_batch_norm", + matcher=lambda sample: sample.args[-3] is True + and any(arg is not None for arg in sample.args[2:4]), + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="https://github.com/pytorch/pytorch/issues/115106", + ), + xfail( + "nn.functional.avg_pool1d", + matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) + and ( + sample.kwargs.get("count_include_pad") is True + or sample.input.shape[2] + % ( + sample.args[0][0] + if isinstance(sample.args[0], tuple) + else sample.args[0] + ) + != 0 + ), + reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", + ), + xfail( + "nn.functional.avg_pool2d", + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None), + reason="ONNX doesn't support divisor_override argument", + ), + xfail( + "nn.functional.avg_pool3d", + matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, + reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", + ), + xfail( + "nn.functional.avg_pool3d", + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None), + reason="ONNX doesn't support divisor_override argument", + ), + xfail( + "nn.functional.batch_norm", + matcher=lambda sample: sample.kwargs.get("training") is True + and any(arg is not None for arg in sample.args[2:4]), + reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106", + ), + xfail( + "nn.functional.conv2d", + matcher=lambda sample: sample.kwargs.get("padding") == "valid", + reason="fixme: https://github.com/pytorch/pytorch/issues/117054", + ), + xfail( + "nn.functional.conv3d", + matcher=lambda sample: sample.kwargs.get("padding") == "valid", + reason="fixme: https://github.com/pytorch/pytorch/issues/117054", + ), + skip( + "nn.functional.cross_entropy", + matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int), + reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type", + ), + xfail( + "nn.functional.embedding", + matcher=lambda sample: sample.kwargs.get("max_norm") is not None, + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="https://github.com/pytorch/pytorch/issues/115106", + ), + skip_torchlib_forward_compatibility( + "nn.functional.embedding_bag", + matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True, + reason=onnx_test_common.reason_onnx_script_does_not_support( + "'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. " + "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided" + ), + github_issue="https://github.com/microsoft/onnxscript/issues/1056", + ), + xfail( + "nn.functional.group_norm", + matcher=lambda sample: torch.numel(sample.input) == 0, + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "Reshape", "empty tensor" + ), + ), + xfail( + "nn.functional.instance_norm", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + matcher=lambda sample: sample.kwargs.get("running_mean") is not None, + reason="fixme: KeyError: 'self___kwargs__running_mean'", + ), + xfail( + "nn.functional.max_pool3d", + matcher=lambda sample: sample.kwargs.get("ceil_mode") is True + and sample.kwargs.get("padding") == 1, + reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", + ), + xfail( + "nn.functional.pixel_shuffle", + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: ORT does not support empty tensor as input", + ), + xfail( + "nonzero", + matcher=lambda sample: len(sample.input.shape) == 0 + and sample.kwargs.get("as_tuple", False) is False, + reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).", + model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, + ), + xfail( + "scatter_add", + matcher=lambda sample: len(sample.input.shape) == 0, + reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch", + ), + skip( + "scatter_reduce", + variant_name="amax", + # ONNX has not include_self parameter and default is include_self=True mode + matcher=lambda sample: sample.kwargs.get("include_self") is False, + reason="ONNX does't support include_self=False option", + ), + skip( + "scatter_reduce", + variant_name="amin", + # ONNX has not include_self parameter and default is include_self=True mode + matcher=lambda sample: sample.kwargs.get("include_self") is False, + reason="ONNX does't support include_self=False option", + ), + skip( + "scatter_reduce", + variant_name="prod", + # ONNX has not include_self parameter and default is include_self=True mode + matcher=lambda sample: sample.kwargs.get("include_self") is False, + reason="ONNX does't support include_self=False option", + ), + skip( + "scatter_reduce", + variant_name="sum", + # ONNX has not include_self parameter and default is include_self=True mode + matcher=lambda sample: sample.kwargs.get("include_self") is False, + reason="ONNX does't support include_self=False option", + ), + skip( + "softmax", + matcher=lambda sample: len(sample.input.shape) == 0, + reason="fixme: LogSoftMax does not support empty tensor as input", + ), + xfail( + "unflatten", + reason="Logic not implemented for size 0 inputs in op.Reshape", + matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), + ), + skip( + "signal.windows.hamming", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="does not match node name", + ), + skip( + "signal.windows.general_hamming", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="does not match node name", + ), + skip( + "signal.windows.blackman", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="does not match node name", + ), + skip( + "signal.windows.general_cosine", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="does not match node name", + ), + skip( + "signal.windows.hann", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="does not match node name", + ), + skip( + "signal.windows.nuttall", + model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + reason="does not match node name", + ), +) + +OPS_DB = copy.deepcopy(common_methods_invocations.op_db) +OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset( + meta.op_name for meta in SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE +) +ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) + + +def _torch_size_flatten_spec(d: List[Any], spec: Any) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +torch.fx._pytree.register_pytree_flatten_spec( + torch.Size, + _torch_size_flatten_spec, +) + + +class SingleOpModel(torch.nn.Module): + """Test model to wrap around a single op for export.""" + + def __init__(self, op, kwargs): + super().__init__() + self.operator = op + self.kwargs = kwargs + + def forward(self, *args): + return self.operator(*args, **self.kwargs) + + +def _should_skip_xfail_test_sample( + op_name: str, + variant_test_name: str, + sample, + model_type: pytorch_test_common.TorchModelType, +) -> Tuple[Optional[str], Optional[str]]: + """Check if the test sample should be skipped or xfailed. + + If the xfail/skip decorator meta is matched with its op_name and model_type, + return the test_behavior and reason. Otherwise, return None, None. Note that + if the matcher is None, the test is decorator_meta is meant to skip/xfail all model types. + + Args: + op_name: The name of the op. + sample: The test sample. + model_type: The model type of the test. + + Returns: + A tuple of (test_behavior, reason). test_behavior is either "skip" or "xfail". + reason is the reason for the test_behavior. + """ + + if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS: + return None, None + for decorator_meta in SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE: + # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE. That's fine because the list is small. + # NOTE: If model_type is None, the test is decorator_meta is meant to skip/xfail all model types. + if ( + decorator_meta.op_name == op_name + and decorator_meta.variant_name == variant_test_name + ) and ( + model_type == decorator_meta.model_type or decorator_meta.model_type is None + ): + if decorator_meta.matcher is None and decorator_meta.model_type is None: + raise TypeError( + "Either Matcher or model_type must be defined in sub xfail and skip." + ) + if decorator_meta.matcher is not None and decorator_meta.matcher(sample): + return decorator_meta.test_behavior, decorator_meta.reason + elif decorator_meta.matcher is None: + # xfail/skip the whole test of the model type without matcher + return decorator_meta.test_behavior, decorator_meta.reason + return None, None + + +def _compare_onnx_and_torch_exported_program( + torch_exported_program, + onnx_exported_program, + input_args, + input_kwargs=None, + test_name=None, + sample_num=None, + sample_kwargs=None, + rtol=1e-03, + atol=1e-07, + only_check_shape=False, +): + # avoid mutable default argument + if input_kwargs is None: + input_kwargs = {} + + # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. + # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. + # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() + onnx_outputs = onnx_exported_program(*input_args, **input_kwargs) + if isinstance(torch_exported_program, torch.export.ExportedProgram): + torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) + else: + torch_outputs = torch_exported_program(*input_args, **input_kwargs) + torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( + torch_outputs + ) + if len(torch_outputs_onnx_format) != len(onnx_outputs): + raise AssertionError( + f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" + ) + + for j, (torch_output, onnx_output) in enumerate( + zip(torch_outputs_onnx_format, onnx_outputs) + ): + if only_check_shape: + assert torch_output.shape == onnx_output.shape + else: + try: + torch.testing.assert_close( + torch.tensor(onnx_output), + torch_output, + rtol=rtol, + atol=atol, + equal_nan=True, + ) + except AssertionError as e: + if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": + error_reproduction.create_mismatch_report( + test_name, + sample_num, + onnx_exported_program.model_proto, + input_args, + sample_kwargs, + torch.tensor(onnx_output), + torch_output, + e, + ) + if len(torch_outputs_onnx_format) > 1: + raise AssertionError(f"Output {j} mismatch") from e + raise + + +def _run_test_output_match( + test_suite: onnx_test_common._TestONNXRuntime, + device: str, + dtype: torch.dtype, + op: opinfo_core.OpInfo, +): + # device is provided by instantiate_device_type_tests, but we only want to run in cpu. + assert device == "cpu" + samples = op.sample_inputs( + device, + dtype, + requires_grad=False, + ) + for i, cpu_sample in enumerate(samples): + inputs = (cpu_sample.input, *cpu_sample.args) + # Provide the repr to subtest because tensors are not serializable in parallel test runs + + with test_suite.subTest( + opset=test_suite.opset_version, + sample_num=i, + inputs=repr(inputs), + kwargs=repr(cpu_sample.kwargs), + ): + test_behavior, reason = _should_skip_xfail_test_sample( + op.name, op.variant_test_name, cpu_sample, test_suite.model_type + ) + with onnx_test_common.normal_xfail_skip_test_behaviors( + test_behavior, reason + ): + model = SingleOpModel(op.op, cpu_sample.kwargs) + model.eval() + + if ( + dtype == torch.float32 + and op.name in test_suite.fp32_low_precision_dict + ): + rtol = test_suite.fp32_low_precision_dict[op.name][0] + atol = test_suite.fp32_low_precision_dict[op.name][1] + elif dtype == torch.float32: + # Relax atol and rtol for float32 based on empirical results + rtol = 1e-5 + atol = 2e-5 + elif ( + dtype == torch.float16 + and (op.name, op.variant_test_name) + in test_suite.fp16_low_precision_variant_dict + ): + rtol = test_suite.fp16_low_precision_variant_dict[ + (op.name, op.variant_test_name) + ][0] + atol = test_suite.fp16_low_precision_variant_dict[ + (op.name, op.variant_test_name) + ][1] + elif ( + dtype == torch.float16 + and op.name in test_suite.fp16_low_precision_dict + ): + rtol = test_suite.fp16_low_precision_dict[op.name][0] + atol = test_suite.fp16_low_precision_dict[op.name][1] + else: + rtol = None + atol = None + + if ( + test_suite.model_type + == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM + ): + try: + # TODO (tugsbayasgalan) Migrate to pre-dispatch IR + # BUG1: python test/onnx/test_fx_op_consistency.py -k test_output_match_triu_cpu_int32 + # has unexpected success, but don't know how to remove from xfail list + # BUG2: User output to_sparse is not in the correct order or is not found in the + # exported program's user_output list (https://github.com/pytorch/pytorch/issues/124328) + # python test/onnx/test_fx_op_consistency.py -k test_output_match_to_sparse_cpu_float32 + # BUG3: [ShapeInferenceError] Inference error(s): (op_type:aten_view, node name: aten_view_4): + # [ShapeInferenceError] + # Inference error(s): (op_type:Reshape, node name: n1): [ShapeInferenceError] Invalid position of 0. + # python test/onnx/test_fx_op_consistency.py -k test_output_match_stack_cpu_int32 + from torch.export import _trace + + model = _trace._export(model, inputs, pre_dispatch=False) + + except AssertionError as e: + # NOTE: avoid fake_mode detection bug in torch.export.export + pytest.xfail( + onnx_test_common.reason_dynamo_does_not_support(str(e)) + ) + + try: + onnx_program = torch.onnx.dynamo_export( + model, + *inputs, + ) + except torch.onnx.OnnxExporterError as e: + # NOTE: If the model has unsupported nodes, we will skip the test + # with non-strict xfail. Otherwise, we will raise the error. + if hasattr( + e.__cause__, "diagnostic" + ) and e.__cause__.diagnostic.rule in ( + _rules._POERules.no_symbolic_function_for_call_function, + _rules._POERules.unsupported_fx_node_analysis, + ): + pytest.xfail( + onnx_test_common.reason_onnx_script_does_not_support(str(e)) + ) + else: + raise e + _compare_onnx_and_torch_exported_program( + model, + onnx_program, + inputs, + test_name=test_suite.id(), + sample_num=i, + sample_kwargs=cpu_sample.kwargs, + rtol=rtol, + atol=atol, + only_check_shape=(op.name in test_suite.only_shape_check_list), + ) + + +def _parameterized_class_attrs_and_values(): + input_values = [] + input_values.extend( + itertools.product( + (opset for opset in onnx_test_common.FX_TESTED_OPSETS), + ( + pytorch_test_common.TorchModelType.TORCH_NN_MODULE, + pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, + ), + ) + ) + return { + "attrs": ["opset_version", "model_type"], + "input_values": input_values, + } + + +def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): + """Combine class name with the parameterized arguments. + + This function is passed to `parameterized.parameterized_class` as the + `class_name_func` argument. + """ + suffixes = [] + for k, v in input_dicts.items(): + suffixes.append(f"{k}_{v}") + return f"{cls.__name__}_{'_'.join(suffixes)}" + + +@parameterized.parameterized_class( + **_parameterized_class_attrs_and_values(), + class_name_func=_parameterize_class_name, +) +class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): + """Test output consistency between exported ONNX models and PyTorch eager mode. + + This is a parameterized test suite. + """ + + opset_version = -1 + dynamic_shapes: bool = False + model_type: pytorch_test_common.TorchModelType = ( + pytorch_test_common.TorchModelType.TORCH_NN_MODULE + ) + + # NOTE: Follow torchlib settings in ops_test_data.py + only_shape_check_list = [ + "empty", + "empty_like", + "empty_strided", + "new_empty", + "new_empty_strided", + ] + + fp32_low_precision_dict = { + "native_layer_norm": [2e-4, 7e-4], + } + + fp16_low_precision_dict = { + "addbmm": [2e-1, 2e-2], + "addcdiv": [3e-2, 1.4e-3], + "addcmul": [3e-2, 1e-3], + "addmv": [5e-2, 3e-2], + "addr": [3e-3, 4e-3], + "baddbmm": [3e-2, 1e-3], + "cumulative_trapezoid": [3e-2, 1e-3], + "cross": [3e-2, 2e-2], + "diff": [1e-2, 5e-2], + "div": [5e-3, 1e-3], + "gradient": [3e-3, 4e-3], + "linalg.cross": [1e-3, 2e-2], + "linalg.multi_dot": [3e-2, 1e-3], + "linalg.vecdot": [1e-2, 2e-2], + "linspace": [2e-2, 2e-3], + "masked.std": [2e-2, 2e-3], + "masked.var": [2e-2, 2e-2], + "matmul": [2e-2, 6e-2], + "mv": [9e-3, 1e-5], + "nn.functional.batch_norm": [3e-2, 1e-3], + "nn.functional.binary_cross_entropy": [3e-2, 1e-3], + "nn.functional.binary_cross_entropy_with_logits": [4e-2, 4e-3], + "nn.functional.cosine_similarity": [3e-2, 1e-3], + "nn.functional.cosine_embedding_loss": [1e-2, 1e-3], + "nn.functional.hardsigmoid": [1e-3, 5e-3], + "nn.functional.hardswish": [1e-3, 5e-3], + "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], + "nn.functional.huber_loss": [1e-2, 1e-1], + "nn.functional.instance_norm": [1e-2, 1e-3], + "nn.functional.interpolate": [1e-2, 1e-3], + "nn.functional.kl_div": [2e-3, 2e-4], + "nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3], + "nn.functional.local_response_norm": [1e-2, 5e-3], + "nn.functional.poisson_nll_loss": [4e-2, 6e-3], + "nn.functional.nll_loss": [3e-2, 1e-3], + "nn.functional.triplet_margin_loss": [2e-2, 1e-2], + "nn.functional.triplet_margin_with_distance_loss": [3e-2, 1e-2], + "native_batch_norm": [3e-2, 1e-3], + "norm": [1e-2, 1e-2], + "dot": [3e-2, 1e-3], + "logit": [3e-2, 1e-3], + "rsub": [3e-2, 1e-3], + "sinc": [2e-1, 6e-4], + "sub": [3e-2, 1e-3], + "trapezoid": [1e-3, 7e-3], + "trapz": [1e-3, 7e-3], + "vdot": [1e-3, 1e-2], + } + + fp16_low_precision_variant_dict = { + ("nn.functional.interpolate", "trilinear"): [3e-2, 3e-3], + ("nn.functional.interpolate", "linear"): [3e-2, 3e-3], + } + + @common_device_type.ops( + [op for op in OPS_DB if op.name in ALL_OPS_IN_DB], + allowed_dtypes=onnx_test_common.TESTED_DTYPES, + ) + def test_output_match(self, device: str, dtype: torch.dtype, op): + """Test the ONNX exporter.""" + _run_test_output_match(self, device, dtype, op) + + +for opset in onnx_test_common.FX_TESTED_OPSETS: + for model_type in pytorch_test_common.TorchModelType: + # The name needs to match the parameterized_class name. + test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}" + onnx_test_common.add_decorate_info( + OPS_DB, + test_class_name, + "test_output_match", + opset=opset, + skip_or_xfails=EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES, + ) + + common_device_type.instantiate_device_type_tests( + globals()[test_class_name], globals(), only_for="cpu" + ) + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 5de711f181f4d3..04f132abc9c9aa 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -200,6 +200,34 @@ def forward(self, input): expected_node="aten.convolution.default", ) + def test_dispatch_overload_fall_back_default_raise_diagnostic_warning(self): + class TraceModel(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.add.Tensor(input, input) + + onnx_registry = torch.onnx.OnnxRegistry() + self.assertTrue( + onnx_registry.is_registered_op( + namespace="aten", op_name="add", overload="Tensor" + ) + ) + + aten_add_Tensor = registration.OpName.from_name_parts( + namespace="aten", op_name="add", overload="Tensor" + ) + onnx_registry._registry.pop(aten_add_Tensor) + + x = torch.tensor(3) + onnx_program = dynamo_export( + TraceModel(), x, export_options=ExportOptions(onnx_registry=onnx_registry) + ) + assert_has_diagnostics( + onnx_program.diagnostic_context, + diagnostics.rules.find_operator_overloads_in_onnx_registry, + diagnostics.levels.WARNING, + expected_node="aten.add.Tensor", + ) + def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self): class CustomModule(torch.nn.Module): def forward(self, input): diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index dd53ded23cdbe3..150edfef3412d1 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -478,9 +478,6 @@ def forward(self, x): MutationModel(), (torch.randn(12),), has_mutation=True ) - @unittest.skip( - "Fixme: arange in torchlib does not support dynamic start and end yet." - ) def test_arange(self): class ArangeModel(torch.nn.Module): def forward(self, input): diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index c66c3492e98ba4..52f37460592cd4 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -46,8 +46,11 @@ import onnx - import onnxruntime - import onnxscript + import onnxruntime # type: ignore[import] + import onnxscript # type: ignore[import] + from onnxscript.function_libs.torch_lib import ( # type: ignore[import] + registration as torchlib_registry, + ) from torch._subclasses import fake_tensor from torch.onnx._internal.fx import diagnostics @@ -107,6 +110,10 @@ def __init__(self) -> None: self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( defaultdict(list) ) + # FIXME: Avoid importing onnxscript into torch + from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401 + registration, + ) # opset_version is unused for now, since torchlib only supports opset18. # TODO: get opset version from torchlib @@ -116,7 +123,8 @@ def __init__(self) -> None: "different opset version, please register them with register_custom_op." ) - self._initiate_registry_from_torchlib() + # Initialize registry from torchlib + self._initiate_registry_from_torchlib(registration.default_registry) @property def opset_version(self) -> int: @@ -126,25 +134,33 @@ def opset_version(self) -> int: return self._opset_version - def _initiate_registry_from_torchlib(self) -> None: + def _initiate_registry_from_torchlib( + self, torchlib_registry: torchlib_registry.Registry + ): """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ - import onnxscript._framework_apis.torch_2_5 as onnxscript_apis - - for meta in onnxscript_apis.get_torchlib_ops(): - internal_name_instance = registration.OpName.from_qualified_name( - meta.qualified_name - ) - symbolic_function = registration.ONNXFunction( - onnx_function=meta.function, # type: ignore[arg-type] - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=meta.is_complex, - ) - self._register(internal_name_instance, symbolic_function) + for aten_name, aten_overloads_func in torchlib_registry.items(): + internal_name_instance = registration.OpName.from_qualified_name(aten_name) + for overload_func in aten_overloads_func.overloads: + symbolic_function = registration.ONNXFunction( + onnx_function=overload_func, + op_full_name=internal_name_instance.qualified_name(), + is_custom=False, + is_complex=False, + ) + self._register(internal_name_instance, symbolic_function) + + for complex_func in aten_overloads_func.complex: + symbolic_function = registration.ONNXFunction( + onnx_function=complex_func, + op_full_name=internal_name_instance.qualified_name(), + is_custom=False, + is_complex=True, + ) + self._register(internal_name_instance, symbolic_function) def _register( self, From 69b954613089d0a38777fbbff7c0d1bd467408af Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 16:30:07 +0000 Subject: [PATCH 0355/1018] [ONNX] Bump onnxscript version in CI; temporarily remove op test (#133748) Bump onnxscript version in CI to 0.1.0.dev20240831, and temporarily remove the fx consistency test. We will add a better version back later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133748 Approved by: https://github.com/titaiwangms --- .ci/docker/common/install_onnx.sh | 7 +- test/onnx/dynamo/test_registry_dispatcher.py | 223 +- test/onnx/test_fx_op_consistency.py | 2076 ----------------- test/onnx/test_fx_to_onnx.py | 28 - test/onnx/test_fx_to_onnx_with_onnxruntime.py | 3 + torch/onnx/_internal/_exporter_legacy.py | 50 +- 6 files changed, 24 insertions(+), 2363 deletions(-) delete mode 100644 test/onnx/test_fx_op_consistency.py diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 65907a99f2572a..b99ecac283d646 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -30,10 +30,9 @@ pip_install \ pip_install coloredlogs packaging -pip_install onnxruntime==1.18 -pip_install onnx==1.16.0 -# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps -pip_install onnxscript==0.1.0.dev20240613 --no-deps +pip_install onnxruntime==1.18.1 +pip_install onnx==1.16.2 +pip_install onnxscript==0.1.0.dev20240831 --no-deps # required by onnxscript pip_install ml_dtypes diff --git a/test/onnx/dynamo/test_registry_dispatcher.py b/test/onnx/dynamo/test_registry_dispatcher.py index d782f94dd311d9..8605b098d30055 100644 --- a/test/onnx/dynamo/test_registry_dispatcher.py +++ b/test/onnx/dynamo/test_registry_dispatcher.py @@ -8,18 +8,11 @@ import onnxscript # type: ignore[import] from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import] -from onnxscript.function_libs.torch_lib import ops # type: ignore[import] from onnxscript.onnx_opset import opset15 as op # type: ignore[import] import torch import torch.fx -from torch.onnx._internal.diagnostics import infra -from torch.onnx._internal.fx import ( - analysis, - diagnostics, - onnxfunction_dispatcher, - registration, -) +from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration from torch.testing._internal import common_utils @@ -84,60 +77,6 @@ def test_custom(x, y): [test_original, test_custom], ) - def test_unsupported_nodes_analysis_with_missing_aten_op(self): - # NOTE: simulate unsupported nodes - aten_mul_tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="mul", overload="Tensor" - ) - aten_mul_default = registration.OpName.from_name_parts( - namespace="aten", op_name="mul" - ) - aten_add_tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="add", overload="Tensor" - ) - aten_add_default = registration.OpName.from_name_parts( - namespace="aten", op_name="add" - ) - - self.registry._registry.pop(aten_mul_tensor) - self.registry._registry.pop(aten_mul_default) - self.registry._registry.pop(aten_add_tensor) - self.registry._registry.pop(aten_add_default) - - diagnostic_context = diagnostics.DiagnosticContext( - "torch.onnx.dynamo_export", torch.__version__ - ) - dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( - self.registry, diagnostic_context - ) - - graph: torch.fx.Graph = torch.fx.Graph() - x: torch.fx.Node = graph.create_node("placeholder", "x") - x.meta["val"] = torch.tensor(3.0) - b: torch.fx.Node = graph.create_node( - "call_function", target=torch.ops.aten.mul.Tensor, args=(x, x) - ) - c: torch.fx.Node = graph.create_node( - "call_function", target=torch.ops.aten.add.Tensor, args=(b, b) - ) - output: torch.fx.Node = graph.output(c) - module = torch.fx.GraphModule(torch.nn.Module(), graph) - - with self.assertRaises(infra.RuntimeErrorWithDiagnostic): - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, module, dispatcher - ).analyze(infra.levels.ERROR) - - try: - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, module, dispatcher - ).analyze(infra.levels.ERROR) - except infra.RuntimeErrorWithDiagnostic as e: - self.assertIn( - "Unsupported FX nodes: {'call_function': ['aten.mul.Tensor', 'aten.add.Tensor']}.", - e.diagnostic.message, - ) - @common_utils.instantiate_parametrized_tests class TestDispatcher(common_utils.TestCase): @@ -488,165 +427,5 @@ def test_first_custom_op( self.assertEqual(symbolic_fn, test_third_custom_op) -@common_utils.instantiate_parametrized_tests -class TestOpSchemaWrapper(common_utils.TestCase): - def setUp(self): - # overload type: optional dtype - self.onnx_function_new_full = ops.core.aten_new_full - self.onnx_function_new_full_dtype = ops.core.aten_new_full_dtype - - @common_utils.parametrize( - "inputs, attributes, assertion", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {"alpha": 2.0}, True), - name="perfect_match_with_kwargs", - ), - common_utils.subtest( - (["A", "B"], {}, False), - name="non_perfect_match_due_to_non_tensor_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4)], {}, False), - name="non_perfect_match_due_to_too_many_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {"wrong_kwargs": 2.0}, False), - name="non_perfect_match_due_to_wrong_kwargs", - ), - ], - ) - def test_perfect_match_inputs(self, inputs, attributes, assertion): - # OnnxFunction with default attributes - dummy_diagnostic = diagnostics.Diagnostic( - rule=diagnostics.rules.find_opschema_matched_symbolic_function, - level=diagnostics.levels.WARNING, - ) - op_schema_wrapper_add = onnxfunction_dispatcher._OnnxSchemaChecker( - ops.core.aten_add - ) - self.assertEqual( - op_schema_wrapper_add.perfect_match_inputs( - dummy_diagnostic, inputs, attributes - ), - assertion, - ) - - @common_utils.parametrize( - "inputs, kwargs, op, score", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul, 2), - name="match_2_inputs", - ), - common_utils.subtest( - ( - [ - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - ], - {}, - ops.core.aten_mul, - 0, - ), - name="match_0_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul_bool, 0), - name="match_0_inputs_bool", - ), - common_utils.subtest( - ( - [ - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - ], - {}, - ops.core.aten_mul_bool, - 2, - ), - name="match_2_inputs_bool", - ), - ], - ) - def test_matching_score_system_on_overload_dtypes(self, inputs, kwargs, op, score): - op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op) - op_schema_wrapper._record_matching_score(inputs, kwargs) - self.assertEqual(op_schema_wrapper.match_score, score) - - @common_utils.parametrize( - "inputs, kwargs, op, score", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.tensor(3)], {}, ops.core.aten_new_full, 2), - name="match_2_inputs", - ), - common_utils.subtest( - ( - [torch.randn(3, 4), torch.tensor(3)], - {"dtype": 2}, # at this point, dtype should be converted to int - ops.core.aten_new_full_dtype, - 2, - ), - name="match_2_input_and_match_1_kwargs_optional", - ), - ], - ) - def test_matching_score_system_on_optional_dtypes(self, inputs, kwargs, op, score): - op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op) - op_schema_wrapper._record_matching_score(inputs, kwargs) - self.assertEqual(op_schema_wrapper.match_score, score) - - @common_utils.parametrize( - "value, expected_onnx_str_dtype", - [ - common_utils.subtest( - (1, {"tensor(int64)", "tensor(int16)", "tensor(int32)"}), - name="all_ints", - ), - common_utils.subtest( - (1.0, {"tensor(float)", "tensor(double)", "tensor(float16)"}), - name="all_floats", - ), - common_utils.subtest( - (torch.tensor([True]), {"tensor(bool)"}), - name="bool", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int64), {"tensor(int64)"}), - name="int64", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int32), {"tensor(int32)"}), - name="int32", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int16), {"tensor(int16)"}), - name="int16", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.float), {"tensor(float)"}), - name="float", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.float16), {"tensor(float16)"}), - name="float16", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.double), {"tensor(double)"}), - name="double", - ), - common_utils.subtest((None, set()), name="None"), # None allows no dtype - common_utils.subtest( - ([], set()), name="empaty_list" - ), # Empty list allows no dtype - ], - ) - def test_find_onnx_data_type(self, value, expected_onnx_str_dtype): - self.assertEqual( - onnxfunction_dispatcher._find_onnx_data_type(value), expected_onnx_str_dtype - ) - - if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py deleted file mode 100644 index 9972ea901722c7..00000000000000 --- a/test/onnx/test_fx_op_consistency.py +++ /dev/null @@ -1,2076 +0,0 @@ -# Owner(s): ["module: onnx"] - -"""Test consistency between the output values of torch.onnx FX exported operators -and torch operators given the same inputs. - -Usage: - - 1. Test all operators: - - pytest test/onnx/test_fx_op_consistency.py - - 2. To run tests on a specific operator (e.g. torch.ceil): - - pytest test/onnx/test_fx_op_consistency.py -k ceil - pytest test/onnx/test_fx_op_consistency.py -k nn_functional_scaled_dot_product_attention - - 3. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g. - - CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/test_fx_op_consistency.py -k div_mode_int - - NOTE: Read more on Running and writing tests: - https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests - -Note: - - 1. Please make sure pytest-subtests is installed. Otherwise, the sub-tests will be ignored. - - 2. Install pytest-xdist to run tests in parallel if runng all tests is the goal. - - 3. When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES and - TESTED_OPS lists. See "Modify this section" - -""" - -from __future__ import annotations - -import copy -import itertools -import os -from typing import ( - Any, - Callable, - Collection, - List, - Mapping, - Optional, - Tuple, - Type, - TYPE_CHECKING, - Union, -) - -import error_reproduction -import onnx_test_common -import parameterized -import pytest -import pytorch_test_common -from onnx_test_common import skip, skip_slow, xfail - -import torch -from torch.onnx._internal.diagnostics import _rules -from torch.testing._internal import ( - common_device_type, - common_methods_invocations, - common_utils, -) - - -if TYPE_CHECKING: - from torch.testing._internal.opinfo import core as opinfo_core - - -# NOTE: For ATen signature modifications that will break ONNX export, -# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip -# to make the signal apparent for maintainers. -def xfail_torchlib_forward_compatibility( - op_name: str, - variant_name: str = "", - *, - reason: str, - github_issue: str, - opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, - dtypes: Optional[Collection[torch.dtype]] = None, - matcher: Optional[Callable[[Any], bool]] = None, - enabled_if: bool = True, -): - """Prefer using this (xfail) over skip when possible. - - Only skip when the test is not failing consistently. - """ - return xfail( - op_name, - variant_name=variant_name, - reason=f"{reason}. GitHub Issue: {github_issue}", - opsets=opsets, - dtypes=dtypes, - matcher=matcher, - enabled_if=enabled_if, - ) - - -def skip_torchlib_forward_compatibility( - op_name: str, - variant_name: str = "", - *, - reason: str, - github_issue: str, - opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, - dtypes: Optional[Collection[torch.dtype]] = None, - matcher: Optional[Callable[[Any], Any]] = None, - enabled_if: bool = True, -): - """Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible. - - Only skip when the test is not failing consistently. - """ - return skip( - op_name, - variant_name=variant_name, - reason=f"{reason}. GitHub Issue: {github_issue}", - opsets=opsets, - dtypes=dtypes, - matcher=matcher, - enabled_if=enabled_if, - ) - - -# fmt: off -# Turn off black formatting to keep the list compact - -# Expected failures for onnx export. -# The list should be sorted alphabetically by op name. -# Q: When should I use fixme vs vs skip vs xfail? -# A: Prefer xfail over skip when possible. -# 2a. If a test is now failing because of xpass, because some previous errors -# are now fixed, removed the corresponding xfail. -# 2b. If a test is not failing consistently, use skip. -# NOTE: EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES only supports dtypes. If a matcher or model_type -# is needed, use the SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE list further down below. -EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] = ( - xfail( - "__getitem__", - reason="io_adaper doesn't support __getitem__ input slice(0, 3, None)", - ), - xfail( - "__radd__", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "bool"), - ), - xfail( - "__rmatmul__", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "__rpow__", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Pow", "int"), - ), - skip( - "_native_batch_norm_legit", - reason=onnx_test_common.reason_onnx_script_does_not_support("cpu is not supported: \ - https://github.com/microsoft/onnxscript/pull/1289") - ), - skip( - "_batch_norm_with_update", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch and type error", - ), - xfail( - "_softmax_backward_data", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "_unsafe_masked_index", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "_unsafe_masked_index", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("_unsafe_masked_index", "complex64"), - ), - xfail( - "_unsafe_masked_index_put_accumulate", - reason="fixme: Status Message: updates tensor should have shape equal to " - "indices.shape[:-1] + data.shape[indices.shape[-1]:]", - ), - xfail( - "_unsafe_masked_index_put_accumulate", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "add", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Add") - ), - xfail( - "add", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_script_does_not_support( - "Add", "int8, int16, uint8 have type issue." - ), - ), - xfail( - "addbmm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addbmm", "complex64") - ), - xfail( - "addmm", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Addmm") - ), - xfail( - "addmm", - variant_name="decomposed", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Addmm") - ), - skip( - "addmm", dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") - ), - skip( - "addmm", - variant_name="decomposed", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") - ), - xfail( - "addr", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support( - "Addr", "bool" - ), - ), - xfail( - "addr", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") - ), - xfail( - "alias_copy", - dtypes=(torch.int8, torch.uint8, torch.int16, torch.float64), - reason="OnnxExporterError: Failed to export model", - ), - xfail( - "allclose", - reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") - ), - xfail( - "amax", - dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), - ), - xfail( - "amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16") - ), - xfail( - "aminmax", - dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), - ), - xfail( - "arange", - dtypes=(torch.uint8,), - reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"), - ), - xfail( - "arange", - dtypes=(torch.int16, torch.int32), - reason="AssertionError: The values for attribute 'shape' do not match", - ), - xfail( - "argmax", - dtypes=( - torch.int16, - torch.int64, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ArgMax", "int16, int64" - ), - ), - xfail( - "argmin", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - torch.int64, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ArgMin", "uint8, int8, int16, int64" - ), - ), - xfail( - "argwhere", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "as_strided", - variant_name="partial_views", - reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults", - ), - xfail( - "atan2", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "baddbmm", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Matmul", "uint8, int8, int16" - ), - ), - xfail( - "baddbmm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("baddbmm", "complex64") - ), - xfail( - "bernoulli", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "bfloat16", - reason="fixme: ORT errors with RuntimeError: No corresponding Numpy type for Tensor Type.", - ), - xfail( - "bincount", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"), - ), - xfail( - "block_diag", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Block_diag", "complex"), - ), - xfail( - "bmm", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Matmul", "uint8, int8, int16" - ), - ), - xfail( - "broadcast_shapes", - reason=onnx_test_common.reason_dynamo_does_not_support("output is int"), - ), - xfail( - "cauchy", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - skip( - "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int") - ), - xfail( - "chalf", - reason="fixme: ONNX shape type inference error: Invalid tensor data type 0." - ), - xfail( - "chunk", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Chunk", "uint8, int8, int16" - ), - ), - xfail( - "clamp", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_max", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool") - ), - xfail( - "clamp_max", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_min", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_min", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool") - ), - xfail( - "constant_pad_nd", - dtypes=(torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Constant_pad_nd", "int16" - ), - ), - xfail( - "constant_pad_nd", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support( - "Constant_pad_nd", "complex64" - ), - ), - xfail( - "corrcoef", - reason=onnx_test_common.reason_dynamo_does_not_support( - "aten.equal.default" - ), - ), - xfail( - "cov", - reason=onnx_test_common.reason_dynamo_does_not_support( - "aten.equal.default" - ), - ), - xfail( - "cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16") - ), - xfail( - "combinations", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"), - ), - xfail( - "diag", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), - ), - xfail( - "diagonal_copy", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), - ), - xfail( - "dot", dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16") - ), - skip( - "dot", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Dot", "complex64(core dump)"), - ), - xfail( - "empty", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - xfail( - "empty_strided", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "eq", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"), - ), - xfail( - "equal", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "exponential", - reason=onnx_test_common.reason_dynamo_does_not_support("exponential"), - ), - xfail( - "fft.fft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.fft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.fftn", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifftn", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfftn", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfft", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfftn", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfft2", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "floor", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), - ), - xfail( - "floor_divide", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), - ), - xfail( - "full", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("full", "complex64") - ), - xfail( - "full_like", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64") - ), - xfail( - "gather", - reason="GatherElements op: Rank of input 'data' needs to be equal to rank of input 'indices'" - ), - xfail( - "geometric", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "heaviside", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"), - ), - xfail( - "index_add", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "int64, int32, bool"), - ), - xfail( - "index_fill", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64") - ), - xfail( - "index_fill", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES, - reason="fixme: Constant input list has None. ONNXScript does not support None in constant list." - ), - xfail( - "index_put", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"), - ), - xfail( - "index_put", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), - ), - xfail( - "index_put", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "float16"), - ), - xfail( - "isnan", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("IsNaN", "int, bool"), - ), - xfail( - "istft", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "item", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "lerp", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("lerp", "complex64") - ), - xfail( - "linalg.lstsq", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), - ), - xfail( - "linalg.lstsq", - variant_name="grad_oriented", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), - ), - xfail( - "linalg.matrix_power", - reason="fixme: The values for attribute 'shape' do not match: torch.Size([2, 2]) != torch.Size([2, 2, 2])." - ), - xfail( - "linalg.norm", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "linalg.norm", - variant_name="subgradients_at_zero", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "linalg.vecdot", - reason="fixme: Assertion error: result shape mismatch", - ), - xfail( - "linspace", - dtypes=(torch.int64, torch.int32,), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ), - xfail( - "linspace", - variant_name="tensor_overload", - dtypes=(torch.int64, torch.int32,), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ), - xfail( - "linspace", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") - ), - xfail( - "linspace", - variant_name="tensor_overload", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") - ), - xfail( - "log_normal", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "log_softmax", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "log_softmax", - variant_name="with_dtype", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "logical_and", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("And", "float, int"), - ), - xfail( - "logical_not", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Not", "float, int"), - ), - xfail( - "logical_or", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Or", "float, int"), - ), - xfail( - "logical_xor", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"), - ), - skip( - "masked.logsumexp", - reason="fixme: https://github.com/onnx/onnx/issues/4986", - ), - xfail( - "masked.amax", - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.amin", - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.argmin", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked.argmax", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked_fill", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked.sum", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked.log_softmax", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.mean", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMean", "bool"), - ), - xfail( - "masked.norm", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked.prod", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked_select", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked_select.default"), - ), - xfail( - "max", - variant_name="reduction_no_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), - ), - xfail( - "max", - variant_name="reduction_with_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), - ), - xfail( - "max", - variant_name="reduction_with_dim", - dtypes=(torch.int64,), - reason="https://github.com/onnx/onnx/issues/4986", - ), - xfail( - "min", - variant_name="reduction_no_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), - ), - xfail( - "min", - variant_name="reduction_with_dim", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), - ), - skip( - "mm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("MM", "complex64(core dump)"), - ), - xfail( - "multinomial", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nanquantile", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "nansum", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("IsNaN", "int, bool"), - ), - xfail( - "narrow", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - skip( - "native_batch_norm", - reason=onnx_test_common.reason_onnx_script_does_not_support("cpu is not supported: \ - https://github.com/microsoft/onnxscript/pull/1289") - ), - xfail( - "native_layer_norm", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "new_full", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("new_full", "complex64") - ), - xfail( - "nn.functional.adaptive_avg_pool2d", - reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \ - maximum recursion depth exceeded while calling a Python object"), - ), - xfail( - "nn.functional.adaptive_avg_pool3d", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"), - ), - xfail( - "nn.functional.alpha_dropout", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.avg_pool1d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.avg_pool2d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.avg_pool3d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.batch_norm", - dtypes=(torch.float16,), - reason="fixme: https://github.com/microsoft/onnxscript/issues/1270", - ), - xfail( - "nn.functional.conv_transpose1d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), - ), - xfail( - "nn.functional.conv_transpose2d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), - ), - xfail( - "nn.functional.conv_transpose3d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), - ), - skip( - "nn.functional.conv_transpose1d", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "nn.functional.conv_transpose2d", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "nn.functional.conv_transpose3d", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.conv1d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), - ), - xfail( - "nn.functional.conv2d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), - ), - xfail( - "nn.functional.conv2d", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.conv3d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), - ), - xfail( - "nn.functional.conv3d", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.cosine_embedding_loss", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("CosineEmbeddingLoss", "bool"), - ), - xfail( - "nn.functional.ctc_loss", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"), - ), - xfail( - "nn.functional.dropout", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.dropout2d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.dropout3d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.feature_alpha_dropout", - variant_name="with_train", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.feature_alpha_dropout", - variant_name="without_train", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.fractional_max_pool2d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.fractional_max_pool3d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.gaussian_nll_loss", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.gaussian_nll_loss"), - ), - xfail( - "nn.functional.grid_sample", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.group_norm", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"), - ), - xfail( - "nn.functional.local_response_norm", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("avgpool", "int64"), - ), - xfail( - "nn.functional.linear", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Gemm", "int"), - ), - xfail( - "nn.functional.max_pool2d", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"), - ), - xfail( - "nn.functional.max_pool3d", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"), - ), - xfail( - "nn.functional.multi_head_attention_forward", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.one_hot", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "nn.functional.pad", - variant_name="replicate", - reason="fixme: ORT error: padding size", - ), - xfail( - "nn.functional.pad", - variant_name="replicate_negative", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.pad", - variant_name="reflect", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.pixel_shuffle", - dtypes=(torch.int32, torch.int64) + onnx_test_common.BOOL_TYPES, - reason="fixme: ONNX Runtime does not support int32/64 inputs", - ), - xfail( - "nn.functional.pixel_unshuffle", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten.pixel_unshuffle.default"), - ), - xfail( - "nn.functional.poisson_nll_loss", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason="fixme: result mismatch with NaN.", - ), - xfail( - "nn.functional.rrelu", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.rrelu", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Relu", "int64"), - ), - skip( - "nn.functional.scaled_dot_product_attention", - matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, - reason="dropout is random so the results do not match", - ), - xfail( - "nn.functional.scaled_dot_product_attention", - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "nn.functional.selu", - reason="fixme: nn.functional.selu is not in torch._decomp.decomposition_table", - ), - xfail( - "nn.functional.soft_margin_loss", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nonzero", - dtypes=(torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"), - ), - xfail( - "normal", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "normal", - variant_name="in_place", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "normal", - variant_name="number_mean", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "ones", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - xfail( - "pca_lowrank", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "quantile", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "rand_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randint", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randint_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randn", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randn_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "resize_", - reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") - ), - xfail( - "resize_as_", - reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") - ), - xfail( - "round", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Round", "int"), - ), - xfail( - "rsub", - dtypes=(torch.uint8, torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Mul", "uint8, int8, int16" - ), - ), - xfail( - "scatter_add", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="sum", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="prod", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="amin", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="amax", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="mean", - reason="ONNX doesn't support reduce='mean' option", - ), - xfail( - "sgn", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), - ), - xfail( - "sign", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), - ), - xfail( - "signal.windows.kaiser", - reason=onnx_test_common.reason_dynamo_does_not_support("functionalization"), - ), - xfail( - "softmax", - dtypes=(torch.float16,), - reason="ORT error: https://github.com/microsoft/onnxruntime/issues/16438" - ), - xfail( - "sparse.mm", - variant_name="reduce", - reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), - ), - xfail( - "sparse.sampled_addmm", - reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), - ), - xfail( - "special.erfcx", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Erf", "int, bool"), - ), - xfail( - "special.erfcx", - dtypes=onnx_test_common.FLOAT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"), - ), - xfail( - "special.log_ndtr", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "special.ndtr", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "square", - dtypes=(torch.int8, torch.uint8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"), - ), - xfail( - "squeeze", - variant_name="multiple", - reason="fixme: https://github.com/microsoft/onnxscript/issues/1264", - ), - xfail( - "svd_lowrank", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "stft", - reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "sub", - dtypes=(torch.uint8, torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Mul", "uint8, int8, int16" - ), - ), - xfail( - "take", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "tensor_split", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "topk", - dtypes=(torch.int64, torch.int32, torch.float16), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "tril", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), - ), - xfail( - "triu", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), - ), - xfail( - "trunc", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"), - ), - xfail( - "unflatten", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Unflatten") - ), - xfail( - "uniform", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "unique", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), - ), - xfail( - "unique_consecutive", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), - ), - xfail( - "unravel_index", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"), - ), - xfail( - "unsqueeze_copy", - reason="OnnxExporterError: Failed to export model", - dtypes=(torch.int8, torch.uint8, torch.int16), - ), - xfail( - "where", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "zeros", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - # SLOW TESTS (All are xfails if we run them) - # TODO: https://github.com/pytorch/pytorch/issues/117118 - skip_slow( - "cdist", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "histogram", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "histogramdd", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.lu_solve", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.solve_triangular", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.svd", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "logspace", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "logspace", - variant_name="tensor_overload", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "max_pool2d_with_indices_backward", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.interpolate", - variant_name="bicubic", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool1d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool2d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool3d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool1d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool2d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool3d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.unfold", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "ormqr", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "searchsorted", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "svd", - reason="fixme: Test sets are too many.", - ), -) -# fmt: on - -# NOTE: The xfail and skip with a matcher function or model_type should be -# at under the `SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE` section. -SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE: tuple[ - onnx_test_common.DecorateMeta, ... -] = ( - skip( - "_native_batch_norm_legit", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - skip( - "_batch_norm_with_update", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - # TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]). - # Numerically the ONNX program is correct, but the output shapes for `save_mean` - # and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268]) - # for example. - skip( - "_batch_norm_with_update", - model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - reason="not supported yet", - ), - xfail( - "addmm", # xfail can't only use dtypes to catch all cases - matcher=lambda sample: sample.input.dtype - in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Gemm", "uint8, int8, int16, int32, int64" - ), - ), - xfail( - "addmm", - matcher=lambda sample: sample.args[0].numel() == 0, - reason="ONNX Runtime does not support empty tensors multiplication", - ), - xfail( - "addmm", - variant_name="decomposed", - matcher=lambda sample: sample.args[0].numel() == 0, - reason="ONNX Runtime does not support empty tensors multiplication", - ), - xfail( - "amax", - matcher=lambda sample: len(sample.input.shape) == 0 - and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), - reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - xfail( - "amin", - matcher=lambda sample: len(sample.input.shape) == 0 - and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), - reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - xfail( - "aminmax", - matcher=lambda sample: len(sample.input.shape) == 0 - and sample.kwargs.get("dim") is not None, - reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - skip( - "cat", - matcher=lambda sample: sample.input[0].equal(torch.tensor([])), - reason="core dump - cat does not support zero-dim tensors yet", - ), - xfail( - "index_add", - matcher=lambda sample: len(sample.input.shape) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ScatterND", "0-D tensor" - ), - ), - xfail( - "index_add", - matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, - reason="fixme: aten::index_put indices contains None when dim is -1", - ), - xfail( - "index_copy", - matcher=lambda sample: len(sample.input.shape) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ScatterND", "0-D tensor" - ), - ), - xfail( - "index_copy", - matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, - reason="fixme: aten::index_put indices contains None when dim is -1", - ), - xfail( - "index_put", - matcher=lambda sample: (sample.args[0][0].dtype == torch.bool) - and (sample.kwargs.get("accumulate") is False), - reason=onnx_test_common.reason_dynamo_does_not_support( - "https://github.com/pytorch/pytorch/issues/101150" - ), - ), - skip( - "linalg.multi_dot", - matcher=lambda sample: sum(torch.numel(input) for input in sample.input) == 0, - reason="fixme: Undefined", - ), - skip( - "log_softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "log_softmax", - variant_name="with_dtype", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "masked.log_softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "matmul", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason="values of matmul of [m, 0] and [0, n] matrices are undefined", - ), - skip( - "mm", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason="values of matmul of [m, 0] and [0, n] matrices are undefined", - ), - xfail( - "native_batch_norm", - matcher=lambda sample: sample.args[-3] is True - and any(arg is not None for arg in sample.args[2:4]), - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - xfail( - "nn.functional.avg_pool1d", - matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) - and ( - sample.kwargs.get("count_include_pad") is True - or sample.input.shape[2] - % ( - sample.args[0][0] - if isinstance(sample.args[0], tuple) - else sample.args[0] - ) - != 0 - ), - reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", - ), - xfail( - "nn.functional.avg_pool2d", - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), - reason="ONNX doesn't support divisor_override argument", - ), - xfail( - "nn.functional.avg_pool3d", - matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, - reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", - ), - xfail( - "nn.functional.avg_pool3d", - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), - reason="ONNX doesn't support divisor_override argument", - ), - xfail( - "nn.functional.batch_norm", - matcher=lambda sample: sample.kwargs.get("training") is True - and any(arg is not None for arg in sample.args[2:4]), - reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106", - ), - xfail( - "nn.functional.conv2d", - matcher=lambda sample: sample.kwargs.get("padding") == "valid", - reason="fixme: https://github.com/pytorch/pytorch/issues/117054", - ), - xfail( - "nn.functional.conv3d", - matcher=lambda sample: sample.kwargs.get("padding") == "valid", - reason="fixme: https://github.com/pytorch/pytorch/issues/117054", - ), - skip( - "nn.functional.cross_entropy", - matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int), - reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type", - ), - xfail( - "nn.functional.embedding", - matcher=lambda sample: sample.kwargs.get("max_norm") is not None, - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - skip_torchlib_forward_compatibility( - "nn.functional.embedding_bag", - matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True, - reason=onnx_test_common.reason_onnx_script_does_not_support( - "'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. " - "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided" - ), - github_issue="https://github.com/microsoft/onnxscript/issues/1056", - ), - xfail( - "nn.functional.group_norm", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Reshape", "empty tensor" - ), - ), - xfail( - "nn.functional.instance_norm", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - matcher=lambda sample: sample.kwargs.get("running_mean") is not None, - reason="fixme: KeyError: 'self___kwargs__running_mean'", - ), - xfail( - "nn.functional.max_pool3d", - matcher=lambda sample: sample.kwargs.get("ceil_mode") is True - and sample.kwargs.get("padding") == 1, - reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", - ), - xfail( - "nn.functional.pixel_shuffle", - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", - ), - xfail( - "nonzero", - matcher=lambda sample: len(sample.input.shape) == 0 - and sample.kwargs.get("as_tuple", False) is False, - reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).", - model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - ), - xfail( - "scatter_add", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch", - ), - skip( - "scatter_reduce", - variant_name="amax", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="amin", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="prod", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="sum", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - xfail( - "unflatten", - reason="Logic not implemented for size 0 inputs in op.Reshape", - matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), - ), - skip( - "signal.windows.hamming", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.general_hamming", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.blackman", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.general_cosine", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.hann", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.nuttall", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), -) - -OPS_DB = copy.deepcopy(common_methods_invocations.op_db) -OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset( - meta.op_name for meta in SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE -) -ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) - - -def _torch_size_flatten_spec(d: List[Any], spec: Any) -> List[Any]: - return [d[i] for i in range(spec.num_children)] - - -torch.fx._pytree.register_pytree_flatten_spec( - torch.Size, - _torch_size_flatten_spec, -) - - -class SingleOpModel(torch.nn.Module): - """Test model to wrap around a single op for export.""" - - def __init__(self, op, kwargs): - super().__init__() - self.operator = op - self.kwargs = kwargs - - def forward(self, *args): - return self.operator(*args, **self.kwargs) - - -def _should_skip_xfail_test_sample( - op_name: str, - variant_test_name: str, - sample, - model_type: pytorch_test_common.TorchModelType, -) -> Tuple[Optional[str], Optional[str]]: - """Check if the test sample should be skipped or xfailed. - - If the xfail/skip decorator meta is matched with its op_name and model_type, - return the test_behavior and reason. Otherwise, return None, None. Note that - if the matcher is None, the test is decorator_meta is meant to skip/xfail all model types. - - Args: - op_name: The name of the op. - sample: The test sample. - model_type: The model type of the test. - - Returns: - A tuple of (test_behavior, reason). test_behavior is either "skip" or "xfail". - reason is the reason for the test_behavior. - """ - - if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS: - return None, None - for decorator_meta in SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE: - # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE. That's fine because the list is small. - # NOTE: If model_type is None, the test is decorator_meta is meant to skip/xfail all model types. - if ( - decorator_meta.op_name == op_name - and decorator_meta.variant_name == variant_test_name - ) and ( - model_type == decorator_meta.model_type or decorator_meta.model_type is None - ): - if decorator_meta.matcher is None and decorator_meta.model_type is None: - raise TypeError( - "Either Matcher or model_type must be defined in sub xfail and skip." - ) - if decorator_meta.matcher is not None and decorator_meta.matcher(sample): - return decorator_meta.test_behavior, decorator_meta.reason - elif decorator_meta.matcher is None: - # xfail/skip the whole test of the model type without matcher - return decorator_meta.test_behavior, decorator_meta.reason - return None, None - - -def _compare_onnx_and_torch_exported_program( - torch_exported_program, - onnx_exported_program, - input_args, - input_kwargs=None, - test_name=None, - sample_num=None, - sample_kwargs=None, - rtol=1e-03, - atol=1e-07, - only_check_shape=False, -): - # avoid mutable default argument - if input_kwargs is None: - input_kwargs = {} - - # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. - # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. - # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() - onnx_outputs = onnx_exported_program(*input_args, **input_kwargs) - if isinstance(torch_exported_program, torch.export.ExportedProgram): - torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) - else: - torch_outputs = torch_exported_program(*input_args, **input_kwargs) - torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( - torch_outputs - ) - if len(torch_outputs_onnx_format) != len(onnx_outputs): - raise AssertionError( - f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" - ) - - for j, (torch_output, onnx_output) in enumerate( - zip(torch_outputs_onnx_format, onnx_outputs) - ): - if only_check_shape: - assert torch_output.shape == onnx_output.shape - else: - try: - torch.testing.assert_close( - torch.tensor(onnx_output), - torch_output, - rtol=rtol, - atol=atol, - equal_nan=True, - ) - except AssertionError as e: - if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": - error_reproduction.create_mismatch_report( - test_name, - sample_num, - onnx_exported_program.model_proto, - input_args, - sample_kwargs, - torch.tensor(onnx_output), - torch_output, - e, - ) - if len(torch_outputs_onnx_format) > 1: - raise AssertionError(f"Output {j} mismatch") from e - raise - - -def _run_test_output_match( - test_suite: onnx_test_common._TestONNXRuntime, - device: str, - dtype: torch.dtype, - op: opinfo_core.OpInfo, -): - # device is provided by instantiate_device_type_tests, but we only want to run in cpu. - assert device == "cpu" - samples = op.sample_inputs( - device, - dtype, - requires_grad=False, - ) - for i, cpu_sample in enumerate(samples): - inputs = (cpu_sample.input, *cpu_sample.args) - # Provide the repr to subtest because tensors are not serializable in parallel test runs - - with test_suite.subTest( - opset=test_suite.opset_version, - sample_num=i, - inputs=repr(inputs), - kwargs=repr(cpu_sample.kwargs), - ): - test_behavior, reason = _should_skip_xfail_test_sample( - op.name, op.variant_test_name, cpu_sample, test_suite.model_type - ) - with onnx_test_common.normal_xfail_skip_test_behaviors( - test_behavior, reason - ): - model = SingleOpModel(op.op, cpu_sample.kwargs) - model.eval() - - if ( - dtype == torch.float32 - and op.name in test_suite.fp32_low_precision_dict - ): - rtol = test_suite.fp32_low_precision_dict[op.name][0] - atol = test_suite.fp32_low_precision_dict[op.name][1] - elif dtype == torch.float32: - # Relax atol and rtol for float32 based on empirical results - rtol = 1e-5 - atol = 2e-5 - elif ( - dtype == torch.float16 - and (op.name, op.variant_test_name) - in test_suite.fp16_low_precision_variant_dict - ): - rtol = test_suite.fp16_low_precision_variant_dict[ - (op.name, op.variant_test_name) - ][0] - atol = test_suite.fp16_low_precision_variant_dict[ - (op.name, op.variant_test_name) - ][1] - elif ( - dtype == torch.float16 - and op.name in test_suite.fp16_low_precision_dict - ): - rtol = test_suite.fp16_low_precision_dict[op.name][0] - atol = test_suite.fp16_low_precision_dict[op.name][1] - else: - rtol = None - atol = None - - if ( - test_suite.model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - try: - # TODO (tugsbayasgalan) Migrate to pre-dispatch IR - # BUG1: python test/onnx/test_fx_op_consistency.py -k test_output_match_triu_cpu_int32 - # has unexpected success, but don't know how to remove from xfail list - # BUG2: User output to_sparse is not in the correct order or is not found in the - # exported program's user_output list (https://github.com/pytorch/pytorch/issues/124328) - # python test/onnx/test_fx_op_consistency.py -k test_output_match_to_sparse_cpu_float32 - # BUG3: [ShapeInferenceError] Inference error(s): (op_type:aten_view, node name: aten_view_4): - # [ShapeInferenceError] - # Inference error(s): (op_type:Reshape, node name: n1): [ShapeInferenceError] Invalid position of 0. - # python test/onnx/test_fx_op_consistency.py -k test_output_match_stack_cpu_int32 - from torch.export import _trace - - model = _trace._export(model, inputs, pre_dispatch=False) - - except AssertionError as e: - # NOTE: avoid fake_mode detection bug in torch.export.export - pytest.xfail( - onnx_test_common.reason_dynamo_does_not_support(str(e)) - ) - - try: - onnx_program = torch.onnx.dynamo_export( - model, - *inputs, - ) - except torch.onnx.OnnxExporterError as e: - # NOTE: If the model has unsupported nodes, we will skip the test - # with non-strict xfail. Otherwise, we will raise the error. - if hasattr( - e.__cause__, "diagnostic" - ) and e.__cause__.diagnostic.rule in ( - _rules._POERules.no_symbolic_function_for_call_function, - _rules._POERules.unsupported_fx_node_analysis, - ): - pytest.xfail( - onnx_test_common.reason_onnx_script_does_not_support(str(e)) - ) - else: - raise e - _compare_onnx_and_torch_exported_program( - model, - onnx_program, - inputs, - test_name=test_suite.id(), - sample_num=i, - sample_kwargs=cpu_sample.kwargs, - rtol=rtol, - atol=atol, - only_check_shape=(op.name in test_suite.only_shape_check_list), - ) - - -def _parameterized_class_attrs_and_values(): - input_values = [] - input_values.extend( - itertools.product( - (opset for opset in onnx_test_common.FX_TESTED_OPSETS), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), - ) - ) - return { - "attrs": ["opset_version", "model_type"], - "input_values": input_values, - } - - -def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): - """Combine class name with the parameterized arguments. - - This function is passed to `parameterized.parameterized_class` as the - `class_name_func` argument. - """ - suffixes = [] - for k, v in input_dicts.items(): - suffixes.append(f"{k}_{v}") - return f"{cls.__name__}_{'_'.join(suffixes)}" - - -@parameterized.parameterized_class( - **_parameterized_class_attrs_and_values(), - class_name_func=_parameterize_class_name, -) -class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): - """Test output consistency between exported ONNX models and PyTorch eager mode. - - This is a parameterized test suite. - """ - - opset_version = -1 - dynamic_shapes: bool = False - model_type: pytorch_test_common.TorchModelType = ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE - ) - - # NOTE: Follow torchlib settings in ops_test_data.py - only_shape_check_list = [ - "empty", - "empty_like", - "empty_strided", - "new_empty", - "new_empty_strided", - ] - - fp32_low_precision_dict = { - "native_layer_norm": [2e-4, 7e-4], - } - - fp16_low_precision_dict = { - "addbmm": [2e-1, 2e-2], - "addcdiv": [3e-2, 1.4e-3], - "addcmul": [3e-2, 1e-3], - "addmv": [5e-2, 3e-2], - "addr": [3e-3, 4e-3], - "baddbmm": [3e-2, 1e-3], - "cumulative_trapezoid": [3e-2, 1e-3], - "cross": [3e-2, 2e-2], - "diff": [1e-2, 5e-2], - "div": [5e-3, 1e-3], - "gradient": [3e-3, 4e-3], - "linalg.cross": [1e-3, 2e-2], - "linalg.multi_dot": [3e-2, 1e-3], - "linalg.vecdot": [1e-2, 2e-2], - "linspace": [2e-2, 2e-3], - "masked.std": [2e-2, 2e-3], - "masked.var": [2e-2, 2e-2], - "matmul": [2e-2, 6e-2], - "mv": [9e-3, 1e-5], - "nn.functional.batch_norm": [3e-2, 1e-3], - "nn.functional.binary_cross_entropy": [3e-2, 1e-3], - "nn.functional.binary_cross_entropy_with_logits": [4e-2, 4e-3], - "nn.functional.cosine_similarity": [3e-2, 1e-3], - "nn.functional.cosine_embedding_loss": [1e-2, 1e-3], - "nn.functional.hardsigmoid": [1e-3, 5e-3], - "nn.functional.hardswish": [1e-3, 5e-3], - "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], - "nn.functional.huber_loss": [1e-2, 1e-1], - "nn.functional.instance_norm": [1e-2, 1e-3], - "nn.functional.interpolate": [1e-2, 1e-3], - "nn.functional.kl_div": [2e-3, 2e-4], - "nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3], - "nn.functional.local_response_norm": [1e-2, 5e-3], - "nn.functional.poisson_nll_loss": [4e-2, 6e-3], - "nn.functional.nll_loss": [3e-2, 1e-3], - "nn.functional.triplet_margin_loss": [2e-2, 1e-2], - "nn.functional.triplet_margin_with_distance_loss": [3e-2, 1e-2], - "native_batch_norm": [3e-2, 1e-3], - "norm": [1e-2, 1e-2], - "dot": [3e-2, 1e-3], - "logit": [3e-2, 1e-3], - "rsub": [3e-2, 1e-3], - "sinc": [2e-1, 6e-4], - "sub": [3e-2, 1e-3], - "trapezoid": [1e-3, 7e-3], - "trapz": [1e-3, 7e-3], - "vdot": [1e-3, 1e-2], - } - - fp16_low_precision_variant_dict = { - ("nn.functional.interpolate", "trilinear"): [3e-2, 3e-3], - ("nn.functional.interpolate", "linear"): [3e-2, 3e-3], - } - - @common_device_type.ops( - [op for op in OPS_DB if op.name in ALL_OPS_IN_DB], - allowed_dtypes=onnx_test_common.TESTED_DTYPES, - ) - def test_output_match(self, device: str, dtype: torch.dtype, op): - """Test the ONNX exporter.""" - _run_test_output_match(self, device, dtype, op) - - -for opset in onnx_test_common.FX_TESTED_OPSETS: - for model_type in pytorch_test_common.TorchModelType: - # The name needs to match the parameterized_class name. - test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}" - onnx_test_common.add_decorate_info( - OPS_DB, - test_class_name, - "test_output_match", - opset=opset, - skip_or_xfails=EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES, - ) - - common_device_type.instantiate_device_type_tests( - globals()[test_class_name], globals(), only_for="cpu" - ) - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 04f132abc9c9aa..5de711f181f4d3 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -200,34 +200,6 @@ def forward(self, input): expected_node="aten.convolution.default", ) - def test_dispatch_overload_fall_back_default_raise_diagnostic_warning(self): - class TraceModel(torch.nn.Module): - def forward(self, input): - return torch.ops.aten.add.Tensor(input, input) - - onnx_registry = torch.onnx.OnnxRegistry() - self.assertTrue( - onnx_registry.is_registered_op( - namespace="aten", op_name="add", overload="Tensor" - ) - ) - - aten_add_Tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="add", overload="Tensor" - ) - onnx_registry._registry.pop(aten_add_Tensor) - - x = torch.tensor(3) - onnx_program = dynamo_export( - TraceModel(), x, export_options=ExportOptions(onnx_registry=onnx_registry) - ) - assert_has_diagnostics( - onnx_program.diagnostic_context, - diagnostics.rules.find_operator_overloads_in_onnx_registry, - diagnostics.levels.WARNING, - expected_node="aten.add.Tensor", - ) - def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self): class CustomModule(torch.nn.Module): def forward(self, input): diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 150edfef3412d1..dd53ded23cdbe3 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -478,6 +478,9 @@ def forward(self, x): MutationModel(), (torch.randn(12),), has_mutation=True ) + @unittest.skip( + "Fixme: arange in torchlib does not support dynamic start and end yet." + ) def test_arange(self): class ArangeModel(torch.nn.Module): def forward(self, input): diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 52f37460592cd4..c66c3492e98ba4 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -46,11 +46,8 @@ import onnx - import onnxruntime # type: ignore[import] - import onnxscript # type: ignore[import] - from onnxscript.function_libs.torch_lib import ( # type: ignore[import] - registration as torchlib_registry, - ) + import onnxruntime + import onnxscript from torch._subclasses import fake_tensor from torch.onnx._internal.fx import diagnostics @@ -110,10 +107,6 @@ def __init__(self) -> None: self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( defaultdict(list) ) - # FIXME: Avoid importing onnxscript into torch - from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401 - registration, - ) # opset_version is unused for now, since torchlib only supports opset18. # TODO: get opset version from torchlib @@ -123,8 +116,7 @@ def __init__(self) -> None: "different opset version, please register them with register_custom_op." ) - # Initialize registry from torchlib - self._initiate_registry_from_torchlib(registration.default_registry) + self._initiate_registry_from_torchlib() @property def opset_version(self) -> int: @@ -134,33 +126,25 @@ def opset_version(self) -> int: return self._opset_version - def _initiate_registry_from_torchlib( - self, torchlib_registry: torchlib_registry.Registry - ): + def _initiate_registry_from_torchlib(self) -> None: """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ - for aten_name, aten_overloads_func in torchlib_registry.items(): - internal_name_instance = registration.OpName.from_qualified_name(aten_name) - for overload_func in aten_overloads_func.overloads: - symbolic_function = registration.ONNXFunction( - onnx_function=overload_func, - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=False, - ) - self._register(internal_name_instance, symbolic_function) - - for complex_func in aten_overloads_func.complex: - symbolic_function = registration.ONNXFunction( - onnx_function=complex_func, - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=True, - ) - self._register(internal_name_instance, symbolic_function) + import onnxscript._framework_apis.torch_2_5 as onnxscript_apis + + for meta in onnxscript_apis.get_torchlib_ops(): + internal_name_instance = registration.OpName.from_qualified_name( + meta.qualified_name + ) + symbolic_function = registration.ONNXFunction( + onnx_function=meta.function, # type: ignore[arg-type] + op_full_name=internal_name_instance.qualified_name(), + is_custom=False, + is_complex=meta.is_complex, + ) + self._register(internal_name_instance, symbolic_function) def _register( self, From cc2ca8b719253e5b853c73ed0fc5aeefbd36e16d Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:02:35 +0000 Subject: [PATCH 0356/1018] [BE] Update numpy version to 2.0.2 (#134875) It's long time to abandon pre-release version Partially addresses https://github.com/pytorch/pytorch/issues/134868 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134875 Approved by: https://github.com/justinchuby, https://github.com/clee2000, https://github.com/kit1980, https://github.com/atalman, https://github.com/Skylion007 --- .ci/pytorch/build.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 9c2b4096d169de..a9662bcac2cefd 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -285,9 +285,8 @@ else if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* ]]; then if [[ "$BUILD_ENVIRONMENT" != *py3.8* ]]; then - # Install numpy-2.0 release candidate for builds - # Which should be backward compatible with Numpy-1.X - python -mpip install --pre numpy==2.0.0rc1 + # Install numpy-2.0.2 for builds which are backward compatible with 1.X + python -mpip install --pre numpy==2.0.2 fi WERROR=1 python setup.py clean From 9b38c38a590fbddd6e7a74c35c7ec45ee7a5f60f Mon Sep 17 00:00:00 2001 From: Haibo Chen Date: Tue, 3 Sep 2024 17:08:45 +0000 Subject: [PATCH 0357/1018] add a new Guage API with an empty backend to PyTorch core (#134883) Summary: The current use case is to continuously measure the total allocated and reserved CUDA memory size from CUDACachingAllocator, and export their distribution (min, max, p90 etc) over time as timeseries. The current callback-based API does not work because the backend decides when the measurement is taken, so data points between two measurements may not be recorded. The distribution (e.g. max) as such will not be accurate. This new API closely follow the design of the existing WaitCounter API otherwise. This is not quite a synchronous version of DynamicCounter, as summing multiple data points does not make sense to my use case Test Plan: CI Differential Revision: D61837528 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134883 Approved by: https://github.com/c-p-i-o --- c10/util/Gauge.cpp | 79 ++++++++++++++++++++++++++++++++++++++++++++++ c10/util/Gauge.h | 48 ++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 c10/util/Gauge.cpp create mode 100644 c10/util/Gauge.h diff --git a/c10/util/Gauge.cpp b/c10/util/Gauge.cpp new file mode 100644 index 00000000000000..4f3a2214e725be --- /dev/null +++ b/c10/util/Gauge.cpp @@ -0,0 +1,79 @@ +#include + +#include + +#include +#include +#include +#include +#include + +namespace c10::monitor { + +namespace detail { +namespace { +using GaugeBackendFactories = + std::vector>; + +Synchronized& gaugeBackendFactories() { + static auto instance = new Synchronized(); + return *instance; +} +} // namespace + +class GaugeImpl { + public: + static GaugeImpl& getInstance(std::string_view key) { + static auto& implMapSynchronized = *new Synchronized< + std::unordered_map>>(); + + return *implMapSynchronized.withLock([&](auto& implMap) { + if (auto implIt = implMap.find(std::string(key)); + implIt != implMap.end()) { + return implIt->second.get(); + } + + auto [implIt, emplaceSuccess] = implMap.emplace( + std::string{key}, std::unique_ptr(new GaugeImpl(key))); + + assert(emplaceSuccess); + + return implIt->second.get(); + }); + } + + void record(int64_t value) { + for (auto& backend : backends_) { + backend->record(value); + } + } + + private: + explicit GaugeImpl(std::string_view key) { + auto factoriesCopy = gaugeBackendFactories().withLock( + [](auto& factories) { return factories; }); + for (const auto& factory : factoriesCopy) { + if (auto backend = factory->create(key)) { + backends_.push_back(std::move(backend)); + } + } + } + + SmallVector> backends_; +}; + +void registerGaugeBackend(std::unique_ptr backend) { + gaugeBackendFactories().withLock( + [&](auto& backends) { backends.push_back(std::move(backend)); }); +} + +} // namespace detail + +GaugeHandle::GaugeHandle(std::string_view key) + : impl_(detail::GaugeImpl::getInstance(key)) {} + +void GaugeHandle::record(int64_t value) { + impl_.record(value); +} + +} // namespace c10::monitor diff --git a/c10/util/Gauge.h b/c10/util/Gauge.h new file mode 100644 index 00000000000000..f92ecd986bee1b --- /dev/null +++ b/c10/util/Gauge.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10::monitor { +namespace detail { + +class GaugeImpl; + +class GaugeBackendIf { + public: + virtual ~GaugeBackendIf() = default; + virtual void record(int64_t value) noexcept = 0; +}; + +class GaugeBackendFactoryIf { + public: + virtual ~GaugeBackendFactoryIf() = default; + + // May return nullptr if the gauge will be ignored by the given backend. + virtual std::unique_ptr create( + std::string_view key) noexcept = 0; +}; + +void C10_API registerGaugeBackend(std::unique_ptr); +} // namespace detail + +// A handle to a Gauge. +class C10_API GaugeHandle { + public: + explicit GaugeHandle(std::string_view key); + void record(int64_t value); + + private: + detail::GaugeImpl& impl_; +}; + +} // namespace c10::monitor + +#define STATIC_GAUGE(_key) \ + []() -> ::c10::monitor::GaugeHandle& { \ + static ::c10::monitor::GaugeHandle handle(#_key); \ + return handle; \ + }() From 20c078375d867e6fd8a761d627eac77493cdfce4 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 29 Aug 2024 23:09:24 -0700 Subject: [PATCH 0358/1018] Move auto functionalize tests in their own test file (#134834) title + use `with torch.library._scoped_library as lib` when needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134834 Approved by: https://github.com/zou3519 ghstack dependencies: #134831 --- test/dynamo/test_misc.py | 396 +---------------------- test/inductor/test_auto_functionalize.py | 363 +++++++++++++++++++++ 2 files changed, 367 insertions(+), 392 deletions(-) create mode 100644 test/inductor/test_auto_functionalize.py diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 1bfb15d1c5e81f..0057383430ff14 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -112,16 +112,6 @@ def wrapper(*args, **kwargs): return wrapper -def cleanup_op(opname): - ns, name = opname.split("::") - if not hasattr(torch.ops, ns): - return - actual_ns = getattr(torch.ops, ns) - if not hasattr(actual_ns, name): - return - delattr(actual_ns, name) - - class MyPickledModule(torch.nn.Module): def __init__(self, z): super().__init__() @@ -509,8 +499,7 @@ def fn(x): @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) def test_pt2_compliant_ops_are_allowed(self): - lib = torch.library.Library("mylib", "FRAGMENT") - try: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::bar", "(Tensor x) -> Tensor", @@ -538,14 +527,10 @@ def g(x): optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) _ = optimized_g(x) - finally: - cleanup_op("mylib::bar") - del lib @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) def test_non_pt2_compliant_ops_graph_break(self): - lib = torch.library.Library("mylib", "FRAGMENT") - try: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define("mylib::bar2", "(Tensor x) -> Tensor", lib=lib) torch.library.impl( "mylib::bar2", "CompositeImplicitAutograd", torch.sin, lib=lib @@ -574,14 +559,10 @@ def g(x): ): optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) y = optimized_g(x) - finally: - cleanup_op("mylib::bar2") - del lib @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) def test_pt2_compliant_overload(self): - lib = torch.library.Library("mylib", "FRAGMENT") - try: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::bar3.tensor", "(Tensor x) -> Tensor", @@ -630,70 +611,6 @@ def h(x): with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "failed to"): y = optimized_h(x) - finally: - cleanup_op("mylib::bar3") - del lib - - def test_auto_functionalize_can_with_default(self): - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - def foo_impl(a, b, c=None, d=None, e=-1): - a + b - return - - def f(a, mode): - return torch.ops.mylib.foo( - a, - 0, - ) - - a = torch.tensor([10, 10, 10], dtype=torch.int64) - - torch.compile(f)(a, 0) - - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_can_with_none_return(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - lib.define("foo(Tensor x, Tensor(a!) out) -> None") - - def foo_impl(x, out): - out.copy_(x) - - lib.impl("foo", foo_impl, "CompositeExplicitAutograd") - x = torch.randn(3) - out = torch.zeros(3) - - @torch.compile - def f(x, out): - torch.ops.mylib.foo(x, out) - - f(x, out) - - def test_auto_functionalize_self_as_mutate_arg(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - lib.define("foo(Tensor(a!) self) -> None") - - def foo_impl(self: torch.Tensor) -> None: - self.sin_() - - x = torch.randn(3) - lib.impl("foo", foo_impl, "CompositeExplicitAutograd") - - @torch.compile(backend="inductor", fullgraph=True) - def f(x): - torch.ops.mylib.foo(x) - - f(x) - def test_user_defined_setattr1(self): @torch.compile(backend="eager", fullgraph=True) def fn(obj): @@ -742,8 +659,7 @@ def fn(x, other_fn): self.assertEqual(cnt.frame_count, 2) def test_generate_trivial_abstract_impl(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo", "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()", @@ -768,291 +684,6 @@ def f(x, y, z, w): output = torch.compile(f, backend="eager", fullgraph=True)(*args) self.assertEqual(output, None) - finally: - cleanup_op("mylib::foo") - del lib - - def test_can_auto_functionalize(self): - from torch._higher_order_ops.auto_functionalize import can_auto_functionalize - - expected_true = [ - "(Tensor(a!) x) -> ()", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", - "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", - "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)", - ] - expected_false = [ - "(Tensor x) -> ()", - "(Tensor(a) x) -> Tensor(a)", - "(Tensor(a!) x) -> Tensor(a!)", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", - "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])", - ] - for schema in expected_true: - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define("mylib::a", schema, lib=lib) - self.assertTrue( - can_auto_functionalize(torch.ops.mylib.a.default), msg=schema - ) - self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) - finally: - cleanup_op("mylib::a") - del lib - for schema in expected_false: - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define("mylib::a", schema, lib=lib) - self.assertFalse( - can_auto_functionalize(torch.ops.mylib.a.default), msg=schema - ) - self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) - finally: - cleanup_op("mylib::a") - del lib - - def test_auto_functionalize(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x, y, z, w, n): - x.add_(y[0] + w) - z.add_(y[1] + n) - - def f(x, y, z, n): - torch.ops.mylib.foo(x, y, z, 2, n) - - x = torch.randn(3) - y = (torch.randn(3), torch.randn(3)) - z = torch.randn(3) - n = torch.randn(3) - orig_args = (x, y, z, n) - - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) - - post_grad_graphs = "\n".join( - log_stream.getvalue().strip().split("\n")[3:] - ).strip() - - # Check the graph under static shapes - if torch._dynamo.config.assume_static_by_default: - self.assertExpectedInline( - post_grad_graphs, - """\ -def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None - return ()""", - ) - - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - f(*eager_args) - self.assertEqual(compiled_args, eager_args) - finally: - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_with_returns(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x, y, z, w, n): - x.add_(y[0] + w) - z.add_(y[1] + n) - return y[0] + w, y[1] + n - - @torch.library.impl_abstract("mylib::foo", lib=lib) - def foo_abstract(x, y, z, w, n): - return y[0] + w, y[1] + n - - def f(x, y, z, n): - return torch.ops.mylib.foo(x, y, z, 2, n) - - x = torch.randn(3) - y = (torch.randn(3), torch.randn(3)) - z = torch.randn(3) - n = torch.randn(3) - orig_args = (x, y, z, n) - - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( - *compiled_args - ) - - if torch._dynamo.config.assume_static_by_default: - post_grad_graphs = "\n".join( - log_stream.getvalue().strip().split("\n")[3:] - ).strip() - self.assertExpectedInline( - post_grad_graphs, - """\ -def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None - getitem_4: "f32[3][1]cpu" = foo_default[0] - getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None - return (getitem_4, getitem_5)""", - ) - - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - eager_out = f(*eager_args) - self.assertEqual(compiled_args, eager_args) - self.assertEqual(compiled_out, eager_out) - finally: - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_on_view(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor(a!) x) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x): - x_np = x.detach().numpy() # view - np.sin(x_np, out=x_np) - return - - x = torch.randn(3) - expected = x.sin() - torch.ops.mylib.foo(x) - assert torch.allclose(x, expected) - - @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) - def f(x): - x = x.clone() - y = x[:] - torch.ops.mylib.foo(y) - return x - - y = f(x) - self.assertEqual(y, x.sin()) - finally: - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_optional(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x, y, z, w, n): - if x is not None: - x.add_(y[0] + w) - if z is not None: - z.add_(y[1] + n) - - def f(x, y, z, n): - torch.ops.mylib.foo(x, y, z, 2, n) - - x = None - y = (torch.randn(3), torch.randn(3)) - z = torch.randn(3) - n = torch.randn(3) - orig_args = (x, y, z, n) - - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) - - if torch._dynamo.config.assume_static_by_default: - post_grad_graphs = "\n".join( - log_stream.getvalue().strip().split("\n")[3:] - ).strip() - self.assertExpectedInline( - post_grad_graphs, - """\ -def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None - return ()""", - ) - - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - f(*eager_args) - self.assertEqual(compiled_args, eager_args) - finally: - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_tensorlist(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - torch.library.define( - "mylib::foo", - "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out): - for o in out: - o.copy_(all_gather_output) - - def f(all_gather_output, all_gather_input_split_sizes, dim, out): - torch.ops.mylib.foo( - all_gather_output, all_gather_input_split_sizes, dim, out - ) - - a = torch.ones(4) - b = [2, 3] - c = 0 - d = [torch.empty(4) for _ in range(2)] - orig_args = (a, b, c, d) - - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) - - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - f(*eager_args) - self.assertEqual(compiled_args, eager_args) def test_shape_int_inplace_binops(self): def fn(x): @@ -8961,25 +8592,6 @@ def f(lengths, values): @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) - def test_unbacked_auto_functionalize_op(self): - @torch.library.custom_op( - "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"] - ) - def mk_image(decoder: Tensor) -> Tensor: - return torch.randn(2, 3, 4, 5) - - @torch.library.register_fake("mylib::mk_image") - def _(decoder: Tensor) -> Tensor: - image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)] - return torch.empty(image_size) - - @torch.compile(fullgraph=True) - def f(x): - return torch.ops.mylib.mk_image.default(x) - - x = torch.zeros(100, dtype=torch.int64) - f(x) - def test_out_variant_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: lib.define( diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py new file mode 100644 index 00000000000000..a007c3451896e8 --- /dev/null +++ b/test/inductor/test_auto_functionalize.py @@ -0,0 +1,363 @@ +# Owner(s): ["module: functionalization"] + +import numpy as np + +import torch +import torch._dynamo.testing +import torch._inductor.test_case +import torch.onnx.operators +import torch.utils._pytree as pytree +import torch.utils.cpp_extension +from torch import Tensor +from torch.testing._internal.logging_utils import logs_to_string + + +class AutoFunctionalizeTests(torch._inductor.test_case.TestCase): + def test_auto_functionalize_can_with_default(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + def foo_impl(a, b, c=None, d=None, e=-1): + a + b + return + + def f(a, mode): + return torch.ops.mylib.foo( + a, + 0, + ) + + a = torch.tensor([10, 10, 10], dtype=torch.int64) + + torch.compile(f)(a, 0) + + def test_auto_functionalize_can_with_none_return(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo(Tensor x, Tensor(a!) out) -> None") + + def foo_impl(x, out): + out.copy_(x) + + lib.impl("foo", foo_impl, "CompositeExplicitAutograd") + x = torch.randn(3) + out = torch.zeros(3) + + @torch.compile + def f(x, out): + torch.ops.mylib.foo(x, out) + + f(x, out) + + def test_auto_functionalize_self_as_mutate_arg(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo(Tensor(a!) self) -> None") + + def foo_impl(self: torch.Tensor) -> None: + self.sin_() + + x = torch.randn(3) + lib.impl("foo", foo_impl, "CompositeExplicitAutograd") + + @torch.compile(backend="inductor", fullgraph=True) + def f(x): + torch.ops.mylib.foo(x) + + f(x) + + def test_auto_functionalize_tensorlist(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out): + for o in out: + o.copy_(all_gather_output) + + def f(all_gather_output, all_gather_input_split_sizes, dim, out): + torch.ops.mylib.foo( + all_gather_output, all_gather_input_split_sizes, dim, out + ) + + a = torch.ones(4) + b = [2, 3] + c = 0 + d = [torch.empty(4) for _ in range(2)] + orig_args = (a, b, c, d) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + def test_can_auto_functionalize(self): + from torch._higher_order_ops.auto_functionalize import can_auto_functionalize + + expected_true = [ + "(Tensor(a!) x) -> ()", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", + "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)", + ] + expected_false = [ + "(Tensor x) -> ()", + "(Tensor(a) x) -> Tensor(a)", + "(Tensor(a!) x) -> Tensor(a!)", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", + "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])", + ] + for schema in expected_true: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define("mylib::a", schema, lib=lib) + self.assertTrue( + can_auto_functionalize(torch.ops.mylib.a.default), msg=schema + ) + self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) + + for schema in expected_false: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define("mylib::a", schema, lib=lib) + self.assertFalse( + can_auto_functionalize(torch.ops.mylib.a.default), msg=schema + ) + self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) + + def test_auto_functionalize(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + + def f(x, y, z, n): + torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + # Check the graph under static shapes + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \ +"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \ +arg3_1 = arg1_1 = arg0_1 = foo_default = None + return ()""", + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + def test_auto_functionalize_with_returns(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + return y[0] + w, y[1] + n + + @torch.library.impl_abstract("mylib::foo", lib=lib) + def foo_abstract(x, y, z, w, n): + return y[0] + w, y[1] + n + + def f(x, y, z, n): + return torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( + *compiled_args + ) + + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", \ +arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ +arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None + getitem_4: "f32[3][1]cpu" = foo_default[0] + getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None + return (getitem_4, getitem_5)""", + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + eager_out = f(*eager_args) + self.assertEqual(compiled_args, eager_args) + self.assertEqual(compiled_out, eager_out) + + def test_auto_functionalize_on_view(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x): + x_np = x.detach().numpy() # view + np.sin(x_np, out=x_np) + return + + x = torch.randn(3) + expected = x.sin() + torch.ops.mylib.foo(x) + assert torch.allclose(x, expected) + + @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) + def f(x): + x = x.clone() + y = x[:] + torch.ops.mylib.foo(y) + return x + + y = f(x) + self.assertEqual(y, x.sin()) + + def test_auto_functionalize_optional(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + if x is not None: + x.add_(y[0] + w) + if z is not None: + z.add_(y[1] + n) + + def f(x, y, z, n): + torch.ops.mylib.foo(x, y, z, 2, n) + + x = None + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ +arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None + return ()""", + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + def test_unbacked_auto_functionalize_op(self): + @torch.library.custom_op( + "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"] + ) + def mk_image(decoder: Tensor) -> Tensor: + return torch.randn(2, 3, 4, 5) + + @torch.library.register_fake("mylib::mk_image") + def _(decoder: Tensor) -> Tensor: + image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)] + return torch.empty(image_size) + + @torch.compile(fullgraph=True) + def f(x): + return torch.ops.mylib.mk_image.default(x) + + x = torch.zeros(100, dtype=torch.int64) + f(x) + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + run_tests() From 35491fac375084aab97f3c5b0da6748097b760a6 Mon Sep 17 00:00:00 2001 From: Tobias Ringwald Date: Tue, 3 Sep 2024 17:28:36 +0000 Subject: [PATCH 0359/1018] Added complex support for `torch.logsumexp` (#133187) Added complex support for `torch.logsumexp`. Implemented complex backward pass for `torch.logsumexp`. Fixes #133047 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133187 Approved by: https://github.com/amjames, https://github.com/lezcano --- aten/src/ATen/native/ReduceOps.cpp | 7 +++++-- test/test_mps.py | 14 ++++++++++++++ test/test_reductions.py | 15 +++++++++++---- tools/autograd/gen_variable_type.py | 1 + torch/_refs/__init__.py | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 5 +++-- torch/masked/_ops.py | 13 ++++++++++--- .../_internal/common_methods_invocations.py | 4 ++-- .../_internal/opinfo/definitions/_masked.py | 2 +- 9 files changed, 48 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 7a3e34eb6df591..bbc50c3c2fca75 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -1456,7 +1456,9 @@ Tensor nanmean( static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) { // can't take max of empty tensor if (self.numel() != 0) { - auto maxes = at::amax(self, dims, true); + // For complex numbers, use the real part to calculate the max. Based on + // https://scicomp.stackexchange.com/questions/34273/log-sum-exp-trick-for-signed-complex-numbers + auto maxes = at::amax(at::real(self), dims, true); auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims)); maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0); at::sum_out(result, (self - maxes).exp_(), dims, keepdim); @@ -1469,7 +1471,8 @@ static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRe } Tensor& logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) { - TORCH_CHECK(at::isFloatingType(result.scalar_type()), + // Complex type implies floating point type + TORCH_CHECK(at::isFloatingType(result.scalar_type()) || at::isComplexType(result.scalar_type()), "logsumexp(): Expected floating point type for result tensor, but got: ", result.scalar_type()); { diff --git a/test/test_mps.py b/test/test_mps.py index f7f36e57c82adb..f3dd38dda7cfae 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -438,6 +438,7 @@ def mps_ops_modifier(ops): 'logical_not', 'logical_or', 'logical_xor', + 'logsumexp', 'long', 'masked_fill', 'masked.mean', @@ -445,6 +446,7 @@ def mps_ops_modifier(ops): 'masked.std', 'masked.sum', 'masked.var', + 'masked.logsumexp', 'matmul', 'mean', 'mm', @@ -6540,6 +6542,18 @@ def helper(shape): helper((2, 8, 4, 5)) + def test_logsumexp(self): + def helper(shape): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + x = cpu_x.detach().clone().to('mps') + + log_result = torch.logsumexp(x, -1) + log_result_cpu = torch.logsumexp(cpu_x, -1) + + self.assertEqual(log_result, log_result_cpu) + + helper((2, 8, 4, 5)) + # Test concat forward def test_cat2(self): diff --git a/test/test_reductions.py b/test/test_reductions.py index 2375324d94f853..c31408d8de6f19 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -487,10 +487,14 @@ def test_dim_reduction_lastdim(self, device, dtype): self.assertEqual(y, y2) @skipIfNoSciPy - def test_logsumexp(self, device): + @dtypes(torch.float32, torch.double, torch.complex64, torch.complex128) + def test_logsumexp(self, device, dtype): from scipy.special import logsumexp - a = torch.randn(5, 4, device=device) - a[0, 0] = inf + a = torch.randn(5, 4, device=device, dtype=dtype) + # torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp, + # numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740 + if torch.device(device) != torch.device('cpu'): + a[0, 0] = inf a[1, :] = -inf actual = a.logsumexp(1) expected = logsumexp(a.cpu().numpy(), 1) @@ -498,11 +502,14 @@ def test_logsumexp(self, device): self.assertEqual(expected, actual) # check that out is actually inplace - b = torch.zeros(5, 2, device=device) + b = torch.zeros(5, 2, device=device, dtype=dtype) c = b[:, 0] torch.logsumexp(a, 1, out=c) self.assertEqual(expected, b[:, 0]) + @skipIfNoSciPy + def test_logsumexp_integral_promotion(self, device): + from scipy.special import logsumexp # check integral inputs is promoted to floating point e = torch.randint(-100, 100, [5, 4], device=device) actual = e.logsumexp(1).to(torch.float64) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 7b908425cf8660..4bec1871ae483e 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -259,6 +259,7 @@ "log1p", "log2", "logaddexp", + "logsumexp", "logcumsumexp", "reciprocal", "tan", diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 2da7ff0669771c..673f5f1dea8295 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -804,7 +804,7 @@ def logsumexp( dim = (dim,) if self.numel() == 0: return torch.sum(torch.exp(self), dim, keepdim).log() - maxes = torch.amax(self, dim, keepdim=True) + maxes = torch.amax(torch.real(self), dim, keepdim=True) maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0) maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim) result = torch.sum(torch.exp(self - maxes), dim, keepdim) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 37b1a7bd5983f6..3f24c6ecb40951 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -897,7 +897,7 @@ Tensor logsumexp_backward( grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size()); result = unsqueeze_multiple(result, dim, self.sym_sizes().size()); } - return grad * (self - result).exp(); + return grad * (self - result).exp().conj(); } Tensor logcumsumexp_backward( @@ -6689,7 +6689,8 @@ Tensor logsumexp_jvp( // forward auto self_p_exp = [&self_p, &dim]() { if (self_p.sym_numel() > 0) { - return (self_p - at::amax(self_p, dim, true)) + // Use only the real part for complex tensors + return (self_p - at::amax(at::real(self_p), dim, true)) .exp(); // Use the exp-normalize trick } else { // amax fails if numel() == 0, in which case it doesn't matter anyway diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 72f8249785d288..3d8151197a207e 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -418,11 +418,18 @@ def _reduction_identity(op_name: str, input: Tensor, *args): return torch.tensor(0, dtype=dtype, device=device) elif op_name in {"prod", "cumprod"}: return torch.tensor(1, dtype=dtype, device=device) - elif op_name in {"amax", "argmax", "logsumexp"}: + elif op_name in {"amax", "argmax", "logaddexp"}: if torch.is_floating_point(input): return torch.tensor(-torch.inf, dtype=dtype, device=device) elif torch.is_signed(input) or dtype == torch.uint8: return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) + elif op_name in {"logsumexp"}: + if torch.is_floating_point(input): + return torch.tensor(-torch.inf, dtype=dtype, device=device) + elif torch.is_complex(input): + return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) elif op_name in {"amin", "argmin"}: if torch.is_floating_point(input): return torch.tensor(torch.inf, dtype=dtype, device=device) @@ -1523,8 +1530,8 @@ def logaddexp( if dtype is None: dtype = input.dtype if input.layout == torch.strided and other.layout == torch.strided: - mask_input = _combine_input_and_mask(logsumexp, input, input_mask) - mask_other = _combine_input_and_mask(logsumexp, other, other_mask) + mask_input = _combine_input_and_mask(logaddexp, input, input_mask) + mask_other = _combine_input_and_mask(logaddexp, other, other_mask) return torch.logaddexp(mask_input, mask_other).to(dtype=dtype) else: raise ValueError( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2120f4f019e863..3b81da0a55872c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1573,7 +1573,7 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs): ((S, S), (0, 1), False), ) # Test large inputs to check numerical stability - lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64) else (None,) + lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,) for low in lows: high = low * 2 if low is not None else None for shape, dim, keepdim in inputs: @@ -19593,7 +19593,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_zero_), OpInfo('logsumexp', aliases=('special.logsumexp',), - dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index f5290c7643ce14..eda339ebfe68a6 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -1182,7 +1182,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar ), ReductionOpInfo( "masked.logsumexp", - dtypes=all_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), method_variant=None, nan_policy="propagate", supports_out=False, From 92b7bd5afd8675bf75f8769124d2bf02fd5a5ee4 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 3 Sep 2024 19:20:11 +0000 Subject: [PATCH 0360/1018] [MPS] Fix bachnorm_2d for channels last (#134618) By skipping gather of input tensor if memory_layout is channels_last, which is a first step towards fixing https://github.com/pytorch/pytorch/issues/134580 Though underlying problem is much more interesting, i.e. MPS does not have a generic support for channels last, but `c10::is_contiguoius()` is true for channels last layout. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134618 Approved by: https://github.com/albanD --- aten/src/ATen/native/mps/operations/Normalization.mm | 7 +++++-- torch/testing/_internal/common_modules.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index 72337b8d340d7d..1d060f78ee2337 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -296,7 +296,9 @@ Check if running mean exists (maybe do this check before making graph) newCachedGraph->runningVarInplaceUpdate_ = runningVarInplaceUpdate; }); - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape); + const auto needs_gather = memory_format != MemoryFormat::ChannelsLast; + auto inputPlaceholder = + Placeholder(cachedGraph->inputTensor_, self, input_shape, needs_gather, MPSDataTypeInvalid, needs_gather); auto weightPlaceholder = Placeholder(); if (has_weight) weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape); @@ -319,7 +321,8 @@ Check if running mean exists (maybe do this check before making graph) runningVarInplaceUpdatePlaceholder = Placeholder(cachedGraph->runningVarInplaceUpdate_, running_var_opt.value()); } - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, input_shape, false); + auto outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, output, input_shape, false, MPSDataTypeInvalid, needs_gather); auto saveMeanPlaceholder = Placeholder(cachedGraph->saveMeanTensor_, save_mean); auto saveVarPlaceholder = Placeholder(cachedGraph->saveVarTensor_, save_var); diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 02b621640a9a3f..00251ca264f840 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -3452,7 +3452,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad module_inputs_func=module_inputs_torch_nn_BatchNorm2d, skips=( # See https://github.com/pytorch/pytorch/issues/134580 - DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')), # tracking here rather than in the list in test_aotdispatch.py as eval mode passes # RuntimeError: tried to get Double out of SymInt DecorateInfo( From b9e7b4966816b67f9fab1c14cefe256c649ab92c Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Fri, 30 Aug 2024 14:02:28 -0700 Subject: [PATCH 0361/1018] [TorchElastic] make torch elastic not have to realize TCPStore backend type and rely on c10d to decide which backend to use (#134882) D53335860 and D56435815 added an option to torch elastic allowing users to choose a TCPStore backend type to use via 1) explicit argument passing in user code when instantiating `MastRendezvousHandler` 2) pass `--use_libuv` command line argument to `torchrun`. The motivation was to offer a quick way to roll back to non-libuv TCPStore backend since we were making libuv the default in `c10d` code. Now we think that it's better to have torch elastic to not realize the TCPStore backend type but rely on `c10d`'s mechanism to decide which backend to use for torch elastic as well. In this sense, the TCPStore backend type used by torch elastic will be identical to that in pytorch. PyTorch TCPStore uses the environment variable `USE_LIBUV` to determine the backend type: when `USE_LIBUV="0"`, the non-libuv backend will be used. when `USE_LIBUV="1"`, the libuv backend will be used. And this is the default option. Differential Revision: [D58259590](https://our.internmc.facebook.com/intern/diff/D58259590/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134882 Approved by: https://github.com/shuqiangzhang --- .../rendezvous/c10d_rendezvous_backend_test.py | 10 ---------- .../elastic/utils/distributed_test.py | 6 ++++-- .../rendezvous/c10d_rendezvous_backend.py | 3 --- .../elastic/rendezvous/dynamic_rendezvous.py | 1 - .../elastic/rendezvous/static_tcp_rendezvous.py | 1 + torch/distributed/elastic/utils/distributed.py | 17 +++++++++-------- 6 files changed, 14 insertions(+), 24 deletions(-) diff --git a/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py b/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py index 2838d918964021..2175e8b17e9522 100644 --- a/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py +++ b/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py @@ -206,16 +206,6 @@ def test_create_backend_returns_backend_if_read_timeout_is_not_specified( self._assert_create_backend_returns_backend() - def test_create_backend_returns_backend_with_libuv(self) -> None: - self._params.config["use_libuv"] = "true" - - self._assert_create_backend_returns_backend() - - def test_create_backend_returns_backend_without_libuv(self) -> None: - self._params.config["use_libuv"] = "false" - - self._assert_create_backend_returns_backend() - def test_create_backend_raises_error_if_store_is_unreachable(self) -> None: self._params.config["is_host"] = "false" self._params.config["read_timeout"] = "2" diff --git a/test/distributed/elastic/utils/distributed_test.py b/test/distributed/elastic/utils/distributed_test.py index ded2c7e9d458f0..c0562a2a3e775a 100644 --- a/test/distributed/elastic/utils/distributed_test.py +++ b/test/distributed/elastic/utils/distributed_test.py @@ -131,6 +131,7 @@ def test_create_store_with_libuv_support(self): wait_for_workers = False localhost = socket.gethostname() + os.environ["USE_LIBUV"] = "0" store = create_c10d_store( is_server=True, server_addr=localhost, @@ -138,10 +139,12 @@ def test_create_store_with_libuv_support(self): timeout=2, world_size=world_size, wait_for_workers=wait_for_workers, - use_libuv=False, ) self.assertFalse(store.libuvBackend) + del os.environ["USE_LIBUV"] + assert "USE_LIBUV" not in os.environ + # libuv backend is enabled by default store = create_c10d_store( is_server=True, server_addr=localhost, @@ -149,7 +152,6 @@ def test_create_store_with_libuv_support(self): timeout=2, world_size=world_size, wait_for_workers=wait_for_workers, - use_libuv=True, ) self.assertTrue(store.libuvBackend) diff --git a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py index 26c3153d9785b6..d49f4724a0ea76 100644 --- a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -144,8 +144,6 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore: else: is_host = _matches_machine_hostname(host) - use_libuv = params.get_as_bool("use_libuv", False) - # The timeout read_timeout = cast(int, params.get_as_int("read_timeout", 60)) if read_timeout <= 0: @@ -161,7 +159,6 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore: is_master=is_server, multi_tenant=True, timeout=timedelta(seconds=read_timeout), - use_libuv=use_libuv, ) if is_server: diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index 31627cf0a0b27c..a2d2011f8173ea 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -1115,7 +1115,6 @@ def _create_tcp_store_server(self, bootstrap_store_info) -> dist.TCPStore: bootstrap_store_info.master_port, is_master=True, multi_tenant=True, - use_libuv=True, ) @property diff --git a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py index 5d2679d9fb4a0c..e6395b70be2b43 100644 --- a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -122,6 +122,7 @@ def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: timeout = int(params.config["timeout"]) else: timeout = _default_timeout_seconds + return StaticTCPRendezvous( master_addr, master_port, rank, world_size, run_id, timeout ) diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 1a7ea81451f7c1..4d64fccb698942 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -7,7 +7,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import datetime -import functools +import os import socket from contextlib import closing from typing import Optional @@ -37,6 +37,12 @@ def create_c10d_store( retries=3, use_libuv: Optional[bool] = None, ): + if use_libuv is not None: + logger.warning("argument use_libuv is deprecated and ignored.") + + # check os.environ for use_libuv + use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option + if server_port == -1 and world_size > 1: raise ValueError( f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" @@ -68,20 +74,15 @@ def create_c10d_store( ) try: - store_builder = functools.partial( - dist.TCPStore, + store = dist.TCPStore( host_name=server_addr, port=port, world_size=world_size, is_master=is_server, timeout=datetime.timedelta(seconds=timeout), wait_for_workers=wait_for_workers, + use_libuv=use_libuv, ) - if use_libuv is None: - # TCPStore default backend may change, don't specify it unless we explicity told to do so. - store = store_builder() - else: - store = store_builder(use_libuv=use_libuv) # skips full rank check when we don't have to wait for all workers if wait_for_workers: _check_full_rank(store, world_size, timeout=timeout) From 7517fd8d4603af9dc7f40c4105f5611c598b36ca Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Tue, 3 Sep 2024 20:09:19 +0000 Subject: [PATCH 0362/1018] [FlexAttention] Update tolerance for failing test (#135035) Summary: Address: T198937061 Test Plan: buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:flex_attention -- --exact 'caffe2/test/inductor:flex_attention - test_no_q_info_compile_False (caffe2.test.inductor.test_flex_attention.TestBlockMask)' --run-disabled Differential Revision: D62137797 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135035 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 4652166ceb62b5..2776b8b89bfa72 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2172,7 +2172,7 @@ def causal_mask(b, h, q_idx, kv_idx): *inputs, is_causal=True ) - torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=1e-3, rtol=0.0) + torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=5e-3, rtol=0.0) common_utils.instantiate_parametrized_tests(TestFlexAttention) From 7a82433b51908499dd6b9eaa6d8f94b093cb6e4e Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 3 Sep 2024 20:10:19 +0000 Subject: [PATCH 0363/1018] [inductor] add openmp config for intel conpiler on Linux. (#134973) Config `openmp` for Intel Compiler on Linux. Base on this PR, we can confirm the Intel optimized libraries are work built well. image Pull Request resolved: https://github.com/pytorch/pytorch/pull/134973 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/cpp_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 45aaa2f399daef..7cfefbc1d84af9 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -959,6 +959,8 @@ def _get_openmp_args( # TODO: fix issue, can't find omp.h cflags.append("fopenmp") libs.append("gomp") + elif _is_intel_compiler(cpp_compiler): + cflags.append("fiopenmp") else: cflags.append("fopenmp") libs.append("gomp") From fb065399f04b87bee1905f93ca4754f94bfba10e Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 3 Sep 2024 11:27:53 -0700 Subject: [PATCH 0364/1018] autolabel aotinductor->export (#135040) "module: aotinductor" will automatically add "oncall: export". Test Plan: - none Pull Request resolved: https://github.com/pytorch/pytorch/pull/135040 Approved by: https://github.com/ydwu4 --- .github/label_to_label.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/label_to_label.yml b/.github/label_to_label.yml index 9f3ed07a544c70..5d6544a2f50f0a 100644 --- a/.github/label_to_label.yml +++ b/.github/label_to_label.yml @@ -31,6 +31,10 @@ - "module: flex attention" then: - "module: higher order operators" +- any: + - "module: aotinductor" + then: + - "oncall: export" - any: - "module: dynamo" - "module: pt2-dispatcher" From 604e5c23bf98469979ed0592e196e019c05f51c4 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 3 Sep 2024 10:54:51 -0700 Subject: [PATCH 0365/1018] [Docs] Fix call to deprecated function (#135037) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135037 Approved by: https://github.com/janeyx99, https://github.com/jbschlosser --- torch/nn/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 4bae0cfafb1799..3072ac0fef0f29 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -5745,7 +5745,7 @@ def forward(self, ...): >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") - >>> with torch.backends.cuda.sdp_kernel(enable_math=False): + >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): >>> F.scaled_dot_product_attention(query,key,value) From 9461c954de44c29fc6240309984b19ffad7edbe2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 3 Sep 2024 22:21:27 +0000 Subject: [PATCH 0366/1018] Revert "Add torch.serialization.skip_data context manager (#134504)" This reverts commit 94db935749b8de99d8c3ab23fb880c67c8f3e67a. Reverted https://github.com/pytorch/pytorch/pull/134504 on behalf of https://github.com/kit1980 due to See D62082697 ([comment](https://github.com/pytorch/pytorch/pull/134504#issuecomment-2327542276)) --- docs/source/notes/serialization.rst | 1 - ...cpp_extensions_open_device_registration.py | 12 +- test/test_serialization.py | 88 -------------- torch/_tensor.py | 66 ++--------- torch/_utils.py | 6 +- torch/serialization.py | 110 ++---------------- torch/storage.py | 4 - 7 files changed, 26 insertions(+), 261 deletions(-) diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index c05dc028a471c9..5541d28bdcafa0 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -398,4 +398,3 @@ The following utility functions are related to serialization: .. autofunction:: clear_safe_globals .. autofunction:: get_safe_globals .. autoclass:: safe_globals -.. autoclass:: skip_data diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 4e86ed458b0788..616a6e0f4b5518 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - def test_open_device_numpy_serialization(self): + def test_open_device_numpy_serialization_map_location(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL @@ -553,7 +553,6 @@ def test_open_device_numpy_serialization(self): self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) - # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( @@ -570,15 +569,6 @@ def test_open_device_numpy_serialization(self): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) - # Test metadata_only - with TemporaryFileName() as f: - with self.assertRaisesRegex( - RuntimeError, - "Cannot serialize tensors on backends with no storage under skip_data context manager", - ): - with torch.serialization.skip_data(): - torch.save(sd, f) - if __name__ == "__main__": common.run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index 3ba96b80541d88..a041473d195b19 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,6 +1,5 @@ # Owner(s): ["module: serialization"] -import contextlib import copy import gc import gzip @@ -20,7 +19,6 @@ from pathlib import Path import torch -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter from torch._utils import _rebuild_tensor from torch._utils_internal import get_file_path_2 from torch.serialization import ( @@ -29,7 +27,6 @@ LoadEndianness, safe_globals, set_default_load_endianness, - skip_data, SourceChangeWarning, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -4215,91 +4212,6 @@ def test_filewriter_metadata_writing(self, filename): sd_loaded_ref = torch.load(f) self.assertEqual(sd_loaded, sd_loaded_ref) - @parametrize("materialize_fake", (True, False)) - def test_skip_data_serialization(self, materialize_fake): - # Create one tensor that uses each of the paths in __reduce_ex__ that should work - t_device = "cuda" if torch.cuda.is_available() else "cpu" - t_v2 = torch.randn(2, 3, device=t_device) - t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device) - i = torch.tensor([[0, 1, 1], - [2, 0, 2]]) - v = torch.tensor([3, 4, 5], dtype=torch.float32) - if not materialize_fake: - # FakeTensorConverter messes up sizes of i and v for the sparse tensor - st = torch.sparse_coo_tensor(i, v, (2, 4)) - tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device)) - - mode, converter = FakeTensorMode(), FakeTensorConverter() - - def fn(t): - return converter.from_real_tensor(mode, t) if materialize_fake else t - - sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)} - sd_expected = { - 't_v2': torch.zeros(2, 3, device=t_device), - 't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device), - 'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)), - } - - if not materialize_fake: - sd['st'] = st - sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4)) - - with BytesIOContext() as f: - with skip_data(materialize_fake_tensors=materialize_fake): - torch.save(sd, f) - f.seek(0) - with safe_globals([TwoTensor]): - sd_loaded = torch.load(f, weights_only=True) - self.assertEqual(sd_loaded, sd_expected, exact_device=True) - self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False)) - self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False)) - - # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx - if not materialize_fake: - ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) - with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__..remove'"): - with skip_data(), BytesIOContext() as f: - torch.save(ft, f) - - @parametrize("materialize_fake", (True, False)) - def test_skip_data_serialization_preserves_views(self, materialize_fake): - ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext - with ctx(): - t = torch.randn(2, 3) - t_view = t.view(-1) - t_slice = t[1] - sd = {'t': t, 't_view': t_view, 't_slice': t_slice} - with BytesIOContext() as f: - with skip_data(materialize_fake_tensors=materialize_fake): - torch.save(sd, f) - f.seek(0) - sd_loaded = torch.load(f, weights_only=True) - self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) - self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) - - def test_skip_data_serialization_error_cases(self): - def _save_load(t): - with BytesIOContext() as f: - with skip_data(): - torch.save(t, f) - f.seek(0) - torch.load(f, weights_only=True) - - nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)]) - t = torch.randn(2, 3, device="meta") - with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"): - _save_load(nt) - - with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"): - _save_load(t) - - with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"): - with skip_data(), BytesIOContext() as f: - torch.save(torch.randn(2, 3), f) - f.seek(0) - torch.load(f, weights_only=True) - def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 61d5e3891b9e56..98563aebae9aa7 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -209,16 +209,8 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): - materialize_fake_tensors = ( - torch.serialization._serialization_tls.materialize_fake_tensors - ) state = torch._utils._get_obj_state(self) - # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has - # some state that cannot be pickled - if ( - type(self) is torch._subclasses.fake_tensor.FakeTensor - and materialize_fake_tensors - ) or (type(self) is Tensor and not state): + if type(self) is Tensor and not state: # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): @@ -259,12 +251,6 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() - - skip_data = torch.serialization._serialization_tls.skip_data - materialize_fake_tensors = ( - torch.serialization._serialization_tls.materialize_fake_tensors - ) - # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -282,10 +268,6 @@ def _reduce_ex_internal(self, proto): # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. - if skip_data: - raise RuntimeError( - "Cannot serialize tensors on backends with no storage under skip_data context manager" - ) numpy_tensor = ( self.cpu().numpy() if self.dtype != torch.bfloat16 @@ -298,10 +280,6 @@ def _reduce_ex_internal(self, proto): if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. - if skip_data: - warnings.warn( - "Serializing tensors on the meta device under skip_data context manager is a no-op" - ) arg_meta = ( self.dtype, tuple(self.size()), @@ -310,10 +288,6 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: - if skip_data: - raise RuntimeError( - "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" - ) # quantizer_params can be different type based on torch attribute quantizer_params: Union[ Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] @@ -395,10 +369,6 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif self.is_nested: - if skip_data: - raise RuntimeError( - "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" - ) args_nested = ( # NB: values() currently returns the storage as a buffer in an unsafe way. # Ideally, we'd use a private API for this instead. TODO: Switch to this if @@ -413,30 +383,14 @@ def _reduce_ex_internal(self, proto): type(self) is not torch.Tensor and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ and ( - isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) - or ( - not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) - and self.data_ptr() == 0 + isinstance( + self, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), ) - ) - ): - arg_wrapper_subclass = ( - type(self), - self.dtype, - tuple(self.size()), - self.stride(), - self.storage_offset(), - self.layout, - self.device, - self.requires_grad, - ) - return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) - elif ( - type(self) is not torch.Tensor - and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ - and ( - isinstance(self, torch._subclasses.fake_tensor.FakeTensor) - and not (skip_data and materialize_fake_tensors) + or self.data_ptr() == 0 ) ): arg_wrapper_subclass = ( @@ -464,10 +418,6 @@ def _reduce_ex_internal(self, proto): dtype=self.dtype, _internal=True, ) # type: ignore[assignment] - - if isinstance(self, torch._subclasses.fake_tensor.FakeTensor) and skip_data: - storage._fake_device = self.device - args = ( storage, self.storage_offset(), diff --git a/torch/_utils.py b/torch/_utils.py index f0d38daa811490..938392fa971590 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -3,6 +3,7 @@ import functools import logging import sys +import threading import traceback import warnings from collections import defaultdict @@ -108,13 +109,16 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] +_thread_local_state = threading.local() + + def _get_restore_location(device): """Return the map_location location. Used for rebuild functions where the tensor device is distinct from the storage """ - map_location = torch.serialization._serialization_tls.map_location + map_location = getattr(_thread_local_state, "map_location", None) if map_location is None: return device else: diff --git a/torch/serialization.py b/torch/serialization.py index d936d31d6f5209..2ac36ec371fe57 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,7 +11,6 @@ import sys import tarfile import tempfile -import threading import warnings from contextlib import closing, contextmanager from enum import Enum @@ -61,7 +60,6 @@ "get_safe_globals", "add_safe_globals", "safe_globals", - "skip_data", ] @@ -89,22 +87,6 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] -# _serialization_tls is used to store thread local state specific to serialization -# that needs to be propagated to other files, in particular we use this for -# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) -# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) -# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) -class _SerializationLocal(threading.local): - def __init__(self): - super().__init__() - self.map_location: Optional[MAP_LOCATION] = None - self.skip_data: bool = False - self.materialize_fake_tensors: bool = False - - -_serialization_tls = _SerializationLocal() - - class SourceChangeWarning(Warning): pass @@ -286,47 +268,6 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ -class skip_data: - """ - Context-manager that skips writing storage bytes for ``torch.save`` calls. - - Storages will still be saved, but the space that their bytes would usually be written to - will be empty space. The storage bytes can then be populated in a separate pass. - - .. warning:: - The ``skip_data`` context manager is an early prototype and is subject to change. - - Args: - materialize_fake_tensors: Whether to materialize FakeTensors. - - Example: - >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows") - >>> import tempfile - >>> t = torch.randn(2, 3) - >>> with tempfile.NamedTemporaryFile() as f: - ... with torch.serialization.skip_data(): - ... torch.save(t, f.name) - ... torch.load(f.name, weights_only=True) - tensor([[0., 0., 0.], - [0., 0., 0.]]) - """ - - def __init__(self, materialize_fake_tensors: bool = False): - self.materialize_fake_tensors = materialize_fake_tensors - - def __enter__(self): - global _serialization_tls - self._old_skip_data = _serialization_tls.skip_data - self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors - _serialization_tls.skip_data = True - _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors - - def __exit__(self, type, value, tb): - global _serialization_tls - _serialization_tls.skip_data = self._old_skip_data - _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors - - def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -856,11 +797,6 @@ def save( ) return else: - global _serialization_tls - if _serialization_tls.skip_data: - raise RuntimeError( - "Cannot use skip_data=True with _use_new_zipfile_serialization=False" - ) with _open_file_like(f, "wb") as opened_file: _legacy_save(obj, opened_file, pickle_module, pickle_protocol) @@ -1019,13 +955,7 @@ def persistent_id(obj: Any) -> Optional[Tuple]: ) -def _save( - obj, - zip_file, - pickle_module, - pickle_protocol, - _disable_byteorder_record, -): +def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): serialized_storages = {} id_map: Dict[int, str] = {} @@ -1060,7 +990,7 @@ def persistent_id(obj): # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check - if str(storage.device) != "meta" and storage.data_ptr() != 0: + if storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( @@ -1071,10 +1001,7 @@ def persistent_id(obj): storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) - if hasattr(obj, "_fake_device") and obj._fake_device is not None: - location = str(obj._fake_device) - else: - location = location_tag(storage) + location = location_tag(storage) serialized_storages[storage_key] = storage return ("storage", storage_type, storage_key, location, storage_numel) @@ -1100,18 +1027,14 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f"data/{key}" storage = serialized_storages[key] + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file num_bytes = storage.nbytes() - global _serialization_tls - if _serialization_tls.skip_data: - zip_file.write_record_metadata(name, num_bytes) - else: - # given that we copy things around anyway, we might use storage.cpu() - # this means to that to get tensors serialized, you need to implement - # .cpu() on the underlying Storage - if storage.device.type != "cpu": - storage = storage.cpu() - # Now that it is on the CPU we can directly copy it into the zip file - zip_file.write_record(name, storage, num_bytes) + zip_file.write_record(name, storage, num_bytes) def load( @@ -1261,14 +1184,6 @@ def _get_wo_message(message: str) -> str: updated_message += message return updated_message + DOCS_MESSAGE - global _serialization_tls - skip_data = _serialization_tls.skip_data - if skip_data: - raise RuntimeError( - "`torch.load` called within a torch.serialization.skip_data context manager " - "is not supported yet. Please call torch.load outside the skip_data context manager." - ) - if weights_only is None: weights_only, warn_weights_only = False, True else: @@ -1843,10 +1758,9 @@ def find_class(self, mod_name, name): unpickler.persistent_load = persistent_load # Needed for tensors where storage device and rebuild tensor device are # not connected (wrapper subclasses and tensors rebuilt using numpy) - global _serialization_tls - _serialization_tls.map_location = map_location + torch._utils._thread_local_state.map_location = map_location result = unpickler.load() - _serialization_tls.map_location = None + del torch._utils._thread_local_state.map_location torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata( diff --git a/torch/storage.py b/torch/storage.py index 8848649905f936..b6ba608c16e5ca 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,8 +39,6 @@ class _StorageBase: is_sparse: _bool = False is_sparse_csr: _bool = False device: torch.device - # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: _Optional[torch.device] = None def __init__(self, *args, **kwargs): pass @@ -651,8 +649,6 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False - # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: _Optional[torch.device] = None dtype: torch.dtype From 7cfe2ae9ccf5b8a68dc671aa545ab59ebe6389c3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 3 Sep 2024 22:35:14 +0000 Subject: [PATCH 0367/1018] Revert "c10d/logging: add C10D_LOCK_GUARD (#134131)" This reverts commit f33bcbe5fd67e6b18be259ad2f0dc11c74157075. Reverted https://github.com/pytorch/pytorch/pull/134131 on behalf of https://github.com/kit1980 due to See D61985186 ([comment](https://github.com/pytorch/pytorch/pull/134131#issuecomment-2327556381)) --- build_variables.bzl | 1 - caffe2/CMakeLists.txt | 2 - test/cpp/c10d/CMakeLists.txt | 2 - test/cpp/c10d/LoggingTest.cpp | 54 ------------------- test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 4 +- torch/csrc/distributed/c10d/LockGuard.cpp | 29 ---------- torch/csrc/distributed/c10d/LockGuard.hpp | 32 ----------- torch/csrc/distributed/c10d/NCCLUtils.cpp | 10 ++-- torch/csrc/distributed/c10d/NCCLUtils.hpp | 21 ++++---- .../distributed/c10d/ProcessGroupGloo.cpp | 2 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 50 +++++++++-------- .../distributed/c10d/ProcessGroupNCCL.hpp | 18 +++---- torch/csrc/distributed/c10d/Work.cpp | 13 +++-- torch/csrc/distributed/c10d/Work.hpp | 4 +- 14 files changed, 59 insertions(+), 183 deletions(-) delete mode 100644 test/cpp/c10d/LoggingTest.cpp delete mode 100644 torch/csrc/distributed/c10d/LockGuard.cpp delete mode 100644 torch/csrc/distributed/c10d/LockGuard.hpp diff --git a/build_variables.bzl b/build_variables.bzl index 57c9caefbdbe1c..ff1e6083218c78 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -496,7 +496,6 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/Functional.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", "torch/csrc/distributed/c10d/GroupRegistry.cpp", - "torch/csrc/distributed/c10d/LockGuard.cpp", "torch/csrc/distributed/c10d/Ops.cpp", "torch/csrc/distributed/c10d/ParamCommsUtils.cpp", "torch/csrc/distributed/c10d/PrefixStore.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b3e9e3ca721027..8552dab106e4cf 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -927,7 +927,6 @@ elseif(USE_CUDA) set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_cuda) # see cmake/public/utils.cmake target_compile_definitions(torch_cuda PRIVATE USE_CUDA) - target_link_libraries(torch_cuda PRIVATE fmt::fmt-header-only) if(USE_CUFILE) target_link_libraries(torch_cuda PRIVATE torch::cufile) @@ -1334,7 +1333,6 @@ if(USE_ROCM) ${ROCM_SOURCE_DIR}/rocblas/include ${ROCM_SOURCE_DIR}/hipsparse/include ) - target_link_libraries(torch_hip PRIVATE fmt::fmt-header-only) if(USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) endif() diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index c64adb4ad0521e..0874852517e33b 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -10,14 +10,12 @@ function(c10d_add_test test_src) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) target_link_libraries(${test_name} ${ARGN}) - target_link_libraries(${test_name} fmt::fmt-header-only) if(NOT WIN32) target_link_libraries(${test_name} pthread) endif() add_test(NAME ${test_name} COMMAND $) endfunction() -c10d_add_test(LoggingTest.cpp torch_cpu gtest_main) c10d_add_test(BackoffTest.cpp torch_cpu gtest_main) c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main) c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main) diff --git a/test/cpp/c10d/LoggingTest.cpp b/test/cpp/c10d/LoggingTest.cpp deleted file mode 100644 index 1b8664fb4c780c..00000000000000 --- a/test/cpp/c10d/LoggingTest.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include - -#include -#include - -#include -#include - -TEST(LockGuard, basic) { - std::timed_mutex mutex; - - { - C10D_LOCK_GUARD(lock, mutex); - - // already locked - ASSERT_FALSE(mutex.try_lock()); - } - - ASSERT_TRUE(mutex.try_lock()); - mutex.unlock(); -} - -TEST(LockGuard, logging) { - // set log level to INFO - FLAGS_caffe2_log_level = 0; - - std::timed_mutex mutex; - - mutex.lock(); - - auto loggingThread = std::async(std::launch::async, [&]() { - std::unique_lock name{mutex, std::defer_lock}; - ::c10d::detail::lockWithLogging( - name, std::chrono::milliseconds(10), "my lock", __FILE__, __LINE__); - }); - - auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10); - while (true) { - ASSERT_LT(std::chrono::system_clock::now(), deadline); - - testing::internal::CaptureStderr(); - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - std::string output = testing::internal::GetCapturedStderr(); - - if (output.find("my lock: waiting for lock for 10ms") != - std::string::npos) { - break; - } - } - - mutex.unlock(); - - loggingThread.get(); -} diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 54a929ab982514..d416847f7911a5 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -180,7 +180,7 @@ class ProcessGroupNCCLNoHeartbeatCaught : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts), hasMonitorThreadCaughtError_(false) {} - std::timed_mutex& getWatchdogMutex() { + std::mutex& getWatchdogMutex() { return workMetaListMutex_; } @@ -413,7 +413,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { work = pg.allreduce(tensors_); { // Now run all reduce with errors. - std::lock_guard lock(pg.getWatchdogMutex()); + std::lock_guard lock(pg.getWatchdogMutex()); LOG(INFO) << "Lock watchdog thread."; // Wait long enough before monitor thread throws exceptions. std::this_thread::sleep_for( diff --git a/torch/csrc/distributed/c10d/LockGuard.cpp b/torch/csrc/distributed/c10d/LockGuard.cpp deleted file mode 100644 index 6ce3ea736f3b3f..00000000000000 --- a/torch/csrc/distributed/c10d/LockGuard.cpp +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and its affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -namespace c10d::detail { - -void lockWithLogging( - std::unique_lock& lock, - std::chrono::milliseconds log_interval, - const char* desc, - const char* file, - int line) { - while (!lock.try_lock_for(log_interval)) { - C10D_WARNING( - "{}:{} {}: waiting for lock for {}ms", - file, - line, - desc, - log_interval.count()); - } -} - -} // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/LockGuard.hpp b/torch/csrc/distributed/c10d/LockGuard.hpp deleted file mode 100644 index d2a4dd9b3743d8..00000000000000 --- a/torch/csrc/distributed/c10d/LockGuard.hpp +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and its affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include -#include - -#include - -namespace c10d::detail { - -// logWithLogging is a wrapper around std::unique_lock -// that automatically logs if the lock cannot be acquired within a given -// timeout. -TORCH_API void lockWithLogging( - std::unique_lock& lock, - std::chrono::milliseconds log_interval, - const char* desc, - const char* file, - int line); - -} // namespace c10d::detail - -// TODO: use std::source_location() when we can use C++20 -#define C10D_LOCK_GUARD(name, mutex) \ - std::unique_lock name{mutex, std::defer_lock}; \ - ::c10d::detail::lockWithLogging( \ - name, std::chrono::seconds(30), #mutex, __FILE__, __LINE__) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 648c87881c6110..b11728e0ba8be7 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -21,7 +21,7 @@ constexpr int64_t kCommInitBusyWaitMillis = 10; namespace c10d { ncclComm_t NCCLComm::getNcclComm() { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (aborted_) { auto commFailureMsg = commFailureReason_ != std::nullopt ? c10::str(" Original reason for failure was: ", *commFailureReason_) @@ -391,7 +391,7 @@ std::optional NCCLTraceBuffer::record( } auto traceback = torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); auto te = Entry{ id_, @@ -448,7 +448,7 @@ void NCCLTraceBuffer::record_pg_ranks( if (!enabled_) { return; } - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); pg_name_to_ranks_[pg_name] = ranks; } @@ -468,7 +468,7 @@ void NCCLTraceBuffer::update_state(Entry& r) { } std::vector NCCLTraceBuffer::dump_entries() { - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); std::vector result; result.reserve(entries_.size()); result.insert(result.end(), entries_.begin() + next_, entries_.end()); @@ -493,7 +493,7 @@ void NCCLTraceBuffer::retire_id( Event* endEvent = nullptr; std::optional duration = std::nullopt; - C10D_LOCK_GUARD(guard, mutex_); + std::unique_lock guard(mutex_); Entry* entry = &entries_.at(*id % max_entries_); if (entry->id_ == *id) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index b34ff33333daf0..4b0a197dbd536a 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -13,7 +13,6 @@ #include #include #include -#include #include #include @@ -271,7 +270,7 @@ class NCCLComm { ~NCCLComm() noexcept { // Add lock in this destructor, as aborted_ needs to be read after memory // barrier here. - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (ncclComm_ && initialized_ && !aborted_) { #ifdef ENABLE_NCCL_ERROR_CHECKING // Use ncclCommAbort instead of ncclCommDestroy here since @@ -363,7 +362,7 @@ class NCCLComm { NCCLComm(NCCLComm&& other) { // Using other's lock, as it reads other's states // Can not use this.mutex_, as this object is being constructed. - C10D_LOCK_GUARD(lock, other.mutex_); + std::unique_lock lock(other.mutex_); std::swap(ncclComm_, other.ncclComm_); std::swap(aborted_, other.aborted_); std::swap(ncclAsyncErr_, other.ncclAsyncErr_); @@ -373,13 +372,13 @@ class NCCLComm { ncclComm_t getNcclComm(); std::optional getNcclCommFailureReason() const { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); return commFailureReason_; } void ncclCommAbort( std::optional commFailureReason = std::nullopt) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_ && !initialized_) { // Should not abort twice. @@ -427,7 +426,7 @@ class NCCLComm { } bool isAborted() const { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); return aborted_; } @@ -436,7 +435,7 @@ class NCCLComm { } ncclResult_t checkForNcclError() { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (ncclAsyncErr_ != ncclSuccess) { return ncclAsyncErr_; @@ -451,7 +450,7 @@ class NCCLComm { } ncclResult_t registerSegment(void* ptr, size_t size) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always @@ -482,7 +481,7 @@ class NCCLComm { } ncclResult_t deregisterSegment(void* ptr) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER TORCH_CHECK( registeredSegmentHandles_.count(ptr) == 1, @@ -519,7 +518,7 @@ class NCCLComm { bool aborted_; uint64_t ncclCommSplitCounter_{0}; ncclResult_t ncclAsyncErr_; - mutable std::timed_mutex mutex_; + mutable std::mutex mutex_; // Rank that this communicator corresponds to. int rank_; // Optional reason for communicator failure, provided by ProcessGroupNCCL for @@ -638,7 +637,7 @@ struct NCCLTraceBuffer { bool enabled_ = false; bool capture_cpp_stack_ = false; - std::timed_mutex mutex_; + std::mutex mutex_; std::vector entries_; size_t max_entries_ = 0; size_t next_ = 0; diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 96ea93f7413b6c..51fa248ec403b9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -602,7 +602,7 @@ uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const { } int ProcessGroupGloo::RecvWork::sourceRank() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); return srcRank_; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3fe8aab1764aa6..ebf92808e801e4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -303,7 +302,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { // hooks are called outside the scope of any PG, thus we need traverse // communicators in all PGs. static std::unordered_map, int> ncclCommDevIdxMap; -static std::timed_mutex ncclCommDevIdxMapMutex; +static std::mutex ncclCommDevIdxMapMutex; static bool allocatorHooksAttached = false; std::atomic ProcessGroupNCCL::shouldDump_(false); @@ -316,7 +315,7 @@ void cacheAllocatorRegisterHook( return; } - C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex); + std::lock_guard lock(ncclCommDevIdxMapMutex); for (auto& it : ncclCommDevIdxMap) { auto& ncclComm = it.first; auto& devIdx = it.second; @@ -334,7 +333,7 @@ void cacheAllocatorDeregisterHook( return; } - C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex); + std::lock_guard lock(ncclCommDevIdxMapMutex); for (auto& it : ncclCommDevIdxMap) { auto& ncclComm = it.first; auto& devIdx = it.second; @@ -537,7 +536,7 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { } auto exception_ptr = checkForNCCLErrors(); - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); exception_ = exception_ptr; if (exception_) { LOG(ERROR) << logPrefix() << "Collective " << *this @@ -553,7 +552,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { void ProcessGroupNCCL::WorkNCCL::setException( std::exception_ptr exception_ptr) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); exception_ = exception_ptr; } @@ -762,12 +761,12 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( bool timing) { auto deleter = [this, timing](at::cuda::CUDAEvent* event) { - C10D_LOCK_GUARD(lock, this->cacheMutex_); + std::lock_guard lock(this->cacheMutex_); this->eventsArray_[timing ? 1 : 0].push_back(event); }; at::cuda::CUDAEvent* event = nullptr; { - C10D_LOCK_GUARD(lock, cacheMutex_); + std::lock_guard lock(cacheMutex_); auto events = eventsArray_[timing ? 1 : 0]; if (!events.empty()) { event = events.back(); @@ -1072,9 +1071,8 @@ void ProcessGroupNCCL::waitForPendingWorks() { while (true) { { std::lock(workMetaListMutex_, completedWorkListMutex_); - std::lock_guard lockWork( - workMetaListMutex_, std::adopt_lock); - std::lock_guard lockHook( + std::lock_guard lockWork(workMetaListMutex_, std::adopt_lock); + std::lock_guard lockHook( completedWorkListMutex_, std::adopt_lock); if (workMetaList_.empty() && completedWorkList_.empty()) { @@ -1187,7 +1185,7 @@ bool ProcessGroupNCCL::abort(std::optional abortReason) { } ncclCommDevIdxMapMutex.unlock(); - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); abortCommsFromMap(devNCCLCommMap_, abortReason); abortCommsFromMap(inInitializationCommMap_, abortReason); return true; @@ -1256,8 +1254,8 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { // Serialize all calls to this function to avoid corrupting data, but allow // multiple calls in one runtime. User is responsible for preserving the // output file from an earlier call before a later call overwrites it. - static std::timed_mutex writeDebugInfoMutex; - C10D_LOCK_GUARD(lock, writeDebugInfoMutex); + static std::mutex writeDebugInfoMutex; + std::lock_guard lock(writeDebugInfoMutex); LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; if (ncclTraceBufferSize_ > 0) { // We dump nccl trace into local disk by default and users can register @@ -1336,7 +1334,7 @@ void ProcessGroupNCCL::heartbeatMonitor() { // This won't have any lock since this lock is only used here. // Please be aware that mutex `monitorMutex_` should not be used // somewhere else to avoid the deadlock. - C10D_LOCK_GUARD(lock, monitorMutex_); + std::unique_lock lock(monitorMutex_); if (monitorWakeUpCV_.wait_for( lock, std::chrono::milliseconds(monitorPollInterval), [&] { return terminateHeartbeatMonitorThread_.load(); @@ -1661,7 +1659,7 @@ const std::vector& ProcessGroupNCCL::groupRanks() const { void ProcessGroupNCCL::addEphemeralTimeout( const std::chrono::milliseconds& timeout) { - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); ephemeralTimeoutActive_ += timeout; } @@ -1684,7 +1682,7 @@ void ProcessGroupNCCL::watchdogHandler() { std::list completedWorkList; while (!done || !terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, workMetaListMutex_); + std::unique_lock lock(workMetaListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. workMetaListCV_.wait_for( @@ -1856,7 +1854,7 @@ void ProcessGroupNCCL::watchdogHandler() { if (work.isCompleted()) { { // Reset the timeout and first work if the work is completed. - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); if (work.ownedEphermeralTimeout_.count() > 0) { ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; @@ -1871,7 +1869,7 @@ void ProcessGroupNCCL::watchdogHandler() { // Move Work object to completedWorkList_ to be consumed by the hook // thread { - C10D_LOCK_GUARD(lock, completedWorkListMutex_); + const std::lock_guard lock(completedWorkListMutex_); completedWorkList_.splice( completedWorkList_.end(), workMetaList_, it++); } @@ -1899,7 +1897,7 @@ void ProcessGroupNCCL::runHookLoop() { bool done = false; while (!done || !terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, completedWorkListMutex_); + std::unique_lock lock(completedWorkListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. completedWorkListCV_.wait_for( @@ -2072,7 +2070,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( } void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { TORCH_INTERNAL_ASSERT( false, @@ -2120,7 +2118,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( usedDeviceIdxs_.insert(device.index()); { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { // Reuse the cached communicator if there is one. return devNCCLCommMap_[deviceKey]; @@ -2194,7 +2192,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( options_->split_color != 0, "Must specify a non-zero color when splitting"); // Find a valid, healthy communicator to split from if possible. - C10D_LOCK_GUARD(lock, options_->split_from->mutex_); + std::lock_guard lock(options_->split_from->mutex_); auto& other_comms = options_->split_from->devNCCLCommMap_; auto dit = other_comms.find(getKeyFromDevice(device)); if (dit != other_comms.end()) { @@ -2248,7 +2246,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( options_->is_high_priority_stream || force_high); { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); inInitializationCommMap_.emplace(deviceKey, ncclComm); } @@ -2498,7 +2496,7 @@ void ProcessGroupNCCL::assignTimeoutToWork( const c10::intrusive_ptr& work, const c10::intrusive_ptr& option) { std::chrono::milliseconds timeout = option->timeout; - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); if (ephemeralTimeoutActive_.count() > 0) { timeout += ephemeralTimeoutActive_; } @@ -2511,7 +2509,7 @@ void ProcessGroupNCCL::assignTimeoutToWork( void ProcessGroupNCCL::workEnqueue( c10::intrusive_ptr work) { if (!terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, workMetaListMutex_); + std::lock_guard lock(workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. // View tensors' destruction invokes autograd_meta, which // needs to be destructed in user thread. Otherwise will diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index b6635157d75e1b..e099a75de58879 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -449,7 +449,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { static CUDAEventCache& get(); private: - std::timed_mutex cacheMutex_; + std::mutex cacheMutex_; // NOTE: We intentionaly store raw pointers so that // we do not attempt to destroy the event objects on process exit, // because cuda may be gone. @@ -920,7 +920,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // ephemeralTimeoutActive_/ephemeralTimeoutInflight_. // TODO(fduwjj): We need to have an audit on all mutexes we are adding here. // And consolidate them if possible. - std::timed_mutex mtxTimeoutExtension_; + std::mutex mtxTimeoutExtension_; // The ephemeral timeout added on top of existing timeout for works issued // before first work finishes. @@ -980,7 +980,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { inInitializationCommMap_; // Mutex to guard maps like devNCCLCommMap_. - std::timed_mutex mutex_; + std::mutex mutex_; // Heartbeat of watchdog thread. std::atomic_uint64_t heartbeat_; @@ -1041,18 +1041,18 @@ class TORCH_API ProcessGroupNCCL : public Backend { static std::atomic shouldDump_; // Mutex to Guard workMetaList_ - std::timed_mutex workMetaListMutex_; + std::mutex workMetaListMutex_; // Mutex to Guard monitorWakeUpCV_ - std::timed_mutex monitorMutex_; + std::mutex monitorMutex_; bool writeDebugInfo_ = false; // Condition Variable for watchdog thread sleep - std::condition_variable_any workMetaListCV_; + std::condition_variable workMetaListCV_; // Condition Variable for monitor thread to wake up early - std::condition_variable_any monitorWakeUpCV_; + std::condition_variable monitorWakeUpCV_; // Vector to Store WorkNCCL pointers std::list workMetaList_; @@ -1060,10 +1060,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::chrono::time_point lastWorkListUpdateTime_; // Mutex to Guard workMetaList_ - std::timed_mutex completedWorkListMutex_; + std::mutex completedWorkListMutex_; // Condition Variable for watchdog thread sleep - std::condition_variable_any completedWorkListCV_; + std::condition_variable completedWorkListCV_; std::list completedWorkList_; diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index 11453f0454357a..8beb8f29362080 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,6 +1,5 @@ #include -#include #include #include @@ -46,17 +45,17 @@ OpType Work::retrieveOpType() const { Work::~Work() = default; bool Work::isCompleted() { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return completed_; } bool Work::isSuccess() const { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return !exception_; } std::exception_ptr Work::exception() const { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return exception_; } @@ -74,7 +73,7 @@ std::vector Work::result() { void Work::synchronize() {} bool Work::wait(std::chrono::milliseconds timeout) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (timeout == kNoTimeout) { // This waits without a timeout. cv_.wait(lock, [&] { return completed_; }); @@ -104,7 +103,7 @@ c10::intrusive_ptr Work::getFuture() { } void Work::finish(std::exception_ptr exception) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { @@ -116,7 +115,7 @@ void Work::finish(std::exception_ptr exception) { } void Work::finishAndThrow(std::exception_ptr exception) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index ea5853aefb4049..c10e5007b9f544 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -126,8 +126,8 @@ class TORCH_API Work : public torch::CustomClassHolder { // provided by the user. void finishAndThrow(std::exception_ptr exception); - mutable std::timed_mutex mutex_; - std::condition_variable_any cv_; + mutable std::mutex mutex_; + std::condition_variable cv_; bool completed_ = false; std::exception_ptr exception_; From e107152920021688b10368e6bbbb82f2b7c24e31 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 30 Aug 2024 16:45:54 -0500 Subject: [PATCH 0368/1018] Refactor runner determinator (#134796) Some minor refactorings to make the code easier to parse and easier to add unit tests for. Keeping this as a separate PR for ease of review, since it should have zero functional behavior changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/134796 Approved by: https://github.com/zxiiro, https://github.com/PaliC --- .github/scripts/runner_determinator.py | 65 +++++++++++++++------ .github/workflows/_runner-determinator.yml | 66 ++++++++++++++++------ 2 files changed, 97 insertions(+), 34 deletions(-) diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index 7aa71e7c688595..e52b19c64b4e3c 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -137,11 +137,14 @@ def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: def get_potential_pr_author( - gh: Github, repo: str, username: str, ref_type: str, ref_name: str + github_token: str, repo: str, username: str, ref_type: str, ref_name: str ) -> str: # If the trigger was a new tag added by a bot, this is a ciflow case # Fetch the actual username from the original PR. The PR number is # embedded in the tag name: ciflow// + + gh = get_gh_client(github_token) + if username == "pytorch-bot[bot]" and ref_type == "tag": split_tag = ref_name.split("/") if ( @@ -163,23 +166,32 @@ def get_potential_pr_author( def is_exception_branch(branch: str) -> bool: + """ + Branches that get opted out of all experiments and should always use Meta runners + """ return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} -def get_workflow_type(issue: Issue, workflow_requestors: Iterable[str]) -> str: - try: - first_comment = issue.get_comments()[0].body.strip("\n\t ") +def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str: + """ + Determines if the job should run on the LF fleet or the Meta fleet + + Returns: + The appropriate label prefix for the runner, corresponding to the fleet to use. + This gets prefixed to the very start of the runner label. + """ - if first_comment[0] == "!": + try: + if rollout_state[0] == "!": log.info("LF Workflows are disabled for everyone. Using meta runners.") return WORKFLOW_LABEL_META - elif first_comment[0] == "*": + elif rollout_state[0] == "*": log.info("LF Workflows are enabled for everyone. Using LF runners.") return WORKFLOW_LABEL_LF else: all_opted_in_users = { usr_raw.strip("\n\t@ ").split(",")[0] - for usr_raw in first_comment.split() + for usr_raw in rollout_state.split() } opted_in_requestors = { usr for usr in workflow_requestors if usr in all_opted_in_users @@ -203,11 +215,17 @@ def get_workflow_type(issue: Issue, workflow_requestors: Iterable[str]) -> str: def get_optin_feature( - issue: Issue, workflow_requestors: Iterable[str], feature: str, fallback: str + rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str ) -> str: + """ + Used to dynamically opt in jobs to specific runner-type variants. + + Returns: + The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string. + This variant name is prefixed to the runner-type in the label. + """ try: - first_comment = issue.get_comments()[0].body.strip("\n\t ") - userlist = {u.lstrip("#").strip("\n\t@ ") for u in first_comment.split()} + userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()} all_opted_in_users = set() for user in userlist: for i in user.split(","): @@ -235,6 +253,17 @@ def get_optin_feature( return fallback +def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: + """ + Gets the first comment of the issue, which contains the desired rollout state. + + The default issue we use - https://github.com/pytorch/test-infra/issues/5132 + """ + gh = get_gh_client(github_token) + issue = get_issue(gh, repo, issue_num) + return str(issue.get_comments()[0].body.strip("\n\t ")) + + def main() -> None: args = parse_args() @@ -244,25 +273,27 @@ def main() -> None: runner_ami = RUNNER_AMI_LEGACY else: try: - gh = get_gh_client(args.github_token) - # The default issue we use - https://github.com/pytorch/test-infra/issues/5132 - issue = get_issue(gh, args.github_issue_repo, args.github_issue) + rollout_state = get_rollout_state_from_issue( + args.github_token, args.github_issue_repo, args.github_issue + ) + username = get_potential_pr_author( - gh, + args.github_token, args.github_repo, args.github_actor, args.github_ref_type, args.github_branch, ) - label_type = get_workflow_type( - issue, + + label_type = get_fleet( + rollout_state, ( args.github_issue_owner, username, ), ) runner_ami = get_optin_feature( - issue=issue, + rollout_state=rollout_state, workflow_requestors=( args.github_issue_owner, username, diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 9d6563e2ea5af0..d1694ef85e16ca 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -196,11 +196,14 @@ jobs: def get_potential_pr_author( - gh: Github, repo: str, username: str, ref_type: str, ref_name: str + github_token: str, repo: str, username: str, ref_type: str, ref_name: str ) -> str: # If the trigger was a new tag added by a bot, this is a ciflow case # Fetch the actual username from the original PR. The PR number is # embedded in the tag name: ciflow// + + gh = get_gh_client(github_token) + if username == "pytorch-bot[bot]" and ref_type == "tag": split_tag = ref_name.split("/") if ( @@ -222,23 +225,32 @@ jobs: def is_exception_branch(branch: str) -> bool: + """ + Branches that get opted out of all experiments and should always use Meta runners + """ return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} - def get_workflow_type(issue: Issue, workflow_requestors: Iterable[str]) -> str: - try: - first_comment = issue.get_comments()[0].body.strip("\n\t ") + def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str: + """ + Determines if the job should run on the LF fleet or the Meta fleet - if first_comment[0] == "!": + Returns: + The appropriate label prefix for the runner, corresponding to the fleet to use. + This gets prefixed to the very start of the runner label. + """ + + try: + if rollout_state[0] == "!": log.info("LF Workflows are disabled for everyone. Using meta runners.") return WORKFLOW_LABEL_META - elif first_comment[0] == "*": + elif rollout_state[0] == "*": log.info("LF Workflows are enabled for everyone. Using LF runners.") return WORKFLOW_LABEL_LF else: all_opted_in_users = { usr_raw.strip("\n\t@ ").split(",")[0] - for usr_raw in first_comment.split() + for usr_raw in rollout_state.split() } opted_in_requestors = { usr for usr in workflow_requestors if usr in all_opted_in_users @@ -262,11 +274,17 @@ jobs: def get_optin_feature( - issue: Issue, workflow_requestors: Iterable[str], feature: str, fallback: str + rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str ) -> str: + """ + Used to dynamically opt in jobs to specific runner-type variants. + + Returns: + The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string. + This variant name is prefixed to the runner-type in the label. + """ try: - first_comment = issue.get_comments()[0].body.strip("\n\t ") - userlist = {u.lstrip("#").strip("\n\t@ ") for u in first_comment.split()} + userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()} all_opted_in_users = set() for user in userlist: for i in user.split(","): @@ -294,6 +312,17 @@ jobs: return fallback + def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: + """ + Gets the first comment of the issue, which contains the desired rollout state. + + The default issue we use - https://github.com/pytorch/test-infra/issues/5132 + """ + gh = get_gh_client(github_token) + issue = get_issue(gh, repo, issue_num) + return issue.get_comments()[0].body.strip("\n\t ") + + def main() -> None: args = parse_args() @@ -303,25 +332,27 @@ jobs: runner_ami = RUNNER_AMI_LEGACY else: try: - gh = get_gh_client(args.github_token) - # The default issue we use - https://github.com/pytorch/test-infra/issues/5132 - issue = get_issue(gh, args.github_issue_repo, args.github_issue) + rollout_state = get_rollout_state_from_issue( + args.github_token, args.github_issue_repo, args.github_issue + ) + username = get_potential_pr_author( - gh, + args.github_token, args.github_repo, args.github_actor, args.github_ref_type, args.github_branch, ) - label_type = get_workflow_type( - issue, + + label_type = get_fleet( + rollout_state, ( args.github_issue_owner, username, ), ) runner_ami = get_optin_feature( - issue=issue, + rollout_state=rollout_state, workflow_requestors=( args.github_issue_owner, username, @@ -346,6 +377,7 @@ jobs: if __name__ == "__main__": main() + EOF cat runner_determinator.py From ed1ef69e125f3d04b984d7aa14153908d895514d Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 30 Aug 2024 16:57:28 -0500 Subject: [PATCH 0369/1018] Add validator to ensure runner determinator script is kept in sync (#134800) We keep two copies of the runner-determinator script: 1. In runner_determinator.py, for ease of testing. This however is not actually executed during CI 2. Embedded in _runner-determinator.yml. This is what CI uses. Why the duplication? Short version: Because of how github CI works, during a given CI run the workflow yml files could actually come from the main branch, while the remaining files get read from the local commit. This can lead to a newer version of _runner-determinator.yml trying to invoke an older version of runner_determintor.py than it was actually designed for. Chaos ensues. We mitigate this by embedding the script into the yml file. But we still keep the script around because it's much easier to run tests against. This workflow's job is to ensure that if one edits the script in one of those two locations then they remember to update it in the other location as well Pull Request resolved: https://github.com/pytorch/pytorch/pull/134800 Approved by: https://github.com/zxiiro, https://github.com/PaliC ghstack dependencies: #134796 --- .github/workflows/_runner-determinator.yml | 2 +- .../runner-determinator-validator.yml | 40 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/runner-determinator-validator.yml diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index d1694ef85e16ca..8ba709358c5ea1 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -320,7 +320,7 @@ jobs: """ gh = get_gh_client(github_token) issue = get_issue(gh, repo, issue_num) - return issue.get_comments()[0].body.strip("\n\t ") + return str(issue.get_comments()[0].body.strip("\n\t ")) def main() -> None: diff --git a/.github/workflows/runner-determinator-validator.yml b/.github/workflows/runner-determinator-validator.yml new file mode 100644 index 00000000000000..976fd2cccad5da --- /dev/null +++ b/.github/workflows/runner-determinator-validator.yml @@ -0,0 +1,40 @@ +name: Validate Runner Determinator Script is in Sync + +on: + # Run on PRs when the runner-determinator script is updated to ensure it's copies are kept in sync + pull_request: + paths: + - .github/workflows/_runner-determinator.yml + - .github/workflows/runner-determinator-validator.yml + - .github/scripts/runner_determinator.py + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + check-runner-determinator: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Run Hardcode runner-determinator script + id: hardcode-script + run: | + # Extract the script content from _runner-determinator.yml and skip the first 10 spaces of each line + script_content=$(awk '/cat < runner_determinator.py/{flag=1;next}/EOF$/{flag=0}flag{print substr($0, 11)}' .github/workflows/_runner-determinator.yml) + + # Write the extracted script content to runner_determinator.py + echo "$script_content" > runner_determinator_workflow.py + + - name: Compare runner-determinator script embedded in workflow with checked in script + run: | + # Compare the extracted runner_determinator script with the existing one + # If this check fails, then make sure the contents of .github/scripts/runner_determinator.py is in sync with the + # version embedded into .github/workflows/_runner-determinator.yml + diff runner_determinator_workflow.py .github/scripts/runner_determinator.py + # Fail the job if the scripts are not identical + continue-on-error: false \ No newline at end of file From 11aee1d859d4bbf01d98338af8addc306da51678 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Tue, 3 Sep 2024 23:46:38 +0000 Subject: [PATCH 0370/1018] [Profiler] Handle Tensor Sizes/Strides Parsing Error (#134862) Summary: Currently some jobs are encountering the following trace, P1539415198. This suggests that when we are parsing through tensors the path is prone to encountering an invalid address. This is is possibly occurring because for some reason the sizes() and strides() of a Tensor seem to not be of the same dimensions. We assume such when iterating through the shapes to get the Ivalue generator. When browsing some of the tensor implementations, I found that some of the size and stride paths are different which could be the cause of this issue. Regardless, the profiler should be flexible enough to handle such issues without bringing down the whole main thread. If the crashes still persist, it will still give us a data point as to where they are occurring and we can rule out the strides/sizes as the culprit Test Plan: This change doesn't break anything in the happy path, just makes sure the bad path is not exited abruptly. We should use this in order to debug what the events are having mismatching dimensions between sizes and strides. Differential Revision: D62008788 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134862 Approved by: https://github.com/aaronenyeshi --- torch/csrc/profiler/collection.cpp | 30 +++++++++++++++++++++++++---- torch/csrc/profiler/collection.h | 3 ++- torch/csrc/profiler/python/init.cpp | 2 +- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 3bfd0a4b8f3d9b..a4f53a5e10f99f 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -33,11 +33,18 @@ RawTensorMetadataBase::RawTensorMetadataBase(const at::Tensor& t) : data_{t.has_storage() ? t.storage().data() : nullptr}, dtype_{t.scalar_type()}, layout_{t.layout()}, - dim_{static_cast(t.sizes().size())} { + size_dim_{static_cast(t.sizes().size())}, + stride_dim_{static_cast(t.strides().size())} { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( t.sizes().size() <= std::numeric_limits::max(), "Cannot profile Tensors of size > uint32 max. Got dim: ", t.sizes().size()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + t.sizes().size() != t.strides().size(), + "Tensor has mismatching sizes and strides. Sizes: ", + t.sizes().size(), + " Strides: ", + t.strides().size()); } RawTensorMetadata::RawTensorMetadata(const at::Tensor& t) @@ -181,14 +188,29 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { ivals_it = ivalues_.begin(), io_type]() mutable { auto decode_tensor = [&]() -> TensorMetadata { - const auto& raw_metadata = *tensor_metadata_it++; std::vector sizes; std::vector strides; - for (C10_UNUSED const auto _ : c10::irange(raw_metadata.dim_)) { + if (tensor_metadata_it.exhausted()) { + LOG(WARNING) + << "Tensor metadata exhausted prematurely. Reported shapes may be inaccurate!"; + return {RawTensorMetadata(), sizes, strides}; + } + const auto& raw_metadata = *tensor_metadata_it++; + for (C10_UNUSED const auto _ : c10::irange(raw_metadata.size_dim_)) { + if (tensor_size_strides_it.exhausted()) { + LOG(WARNING) + << "Expected Tensor Size mismatch with raw Tensor metadata. Reported shapes may be inaccurate!"; + return {raw_metadata, sizes, strides}; + } sizes.push_back(*tensor_size_strides_it++); } if (raw_metadata.layout_ == at::kStrided) { - for (C10_UNUSED const auto _ : c10::irange(raw_metadata.dim_)) { + for (C10_UNUSED const auto _ : c10::irange(raw_metadata.stride_dim_)) { + if (tensor_size_strides_it.exhausted()) { + LOG(WARNING) + << "Expected Tensor Strides mismatch with raw Tensor metadata. Reported shapes may be inaccurate!"; + return {raw_metadata, sizes, strides}; + } strides.push_back(*tensor_size_strides_it++); } } diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 716fdb910c01ea..e2bb603c387cc5 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -47,7 +47,8 @@ struct TORCH_API RawTensorMetadataBase { StorageImplData data_; c10::ScalarType dtype_{c10::ScalarType::Undefined}; c10::Layout layout_{c10::Layout::Strided}; - uint32_t dim_{0}; + uint32_t size_dim_{0}; + uint32_t stride_dim_{0}; }; // Collected during profiling. diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 25f93a2663dfb5..661646920632e2 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -441,7 +441,7 @@ void initPythonBindings(PyObject* module) { return py::reinterpret_borrow( torch::autograd::utils::wrap(metadata.dtype_)); }) - .def_readonly("dim", &TensorMetadata::dim_) + .def_readonly("dim", &TensorMetadata::size_dim_) .def_readonly("sizes", &TensorMetadata::sizes_) .def_readonly("strides", &TensorMetadata::strides_); From 5a0d67c551a2c732477e9d06cc42e10f09c8c31e Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Wed, 4 Sep 2024 00:05:51 +0000 Subject: [PATCH 0371/1018] [PT2][Optimus] Skip meta update on symblic shape (#134975) Summary: We noticed that there will be runtime error to do the dim broadcast when the meta example value has symbolic shape, thus we skip it. Test Plan: ``` buck2 run mode/opt //caffe2/benchmarks/dynamo/fb:torchbench_run_ads_dhen_5x_training -- -m ads_dhen_5x -t training ``` P1559019921 Differential Revision: D62115015 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134975 Approved by: https://github.com/xuzhao9 --- .../_inductor/fx_passes/group_batch_fusion.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 55b2690edf9f02..5311c9789fa59e 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -707,15 +707,24 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): torch.baddbmm, args=(unsqueeze_biases, stack_inputs, transpose_weight), ) - bmm.meta["example_value"] = torch.baddbmm( - unsqueeze_biases.meta["example_value"], - stack_inputs.meta["example_value"], - transpose_weight.meta["example_value"], - ) - bmm_meta = bmm.meta["example_value"] + try: + # it will have runtime error to broadcast when it has dynamic shape included + # in the meta data, so we need to skip the update meta data + bmm.meta["example_value"] = torch.baddbmm( + unsqueeze_biases.meta["example_value"], + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + except Exception as e: + log.debug( + f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004 + ) + bmm_meta = None bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0}) - bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) + if bmm_meta is not None: + bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) for i, linear in enumerate(batch_nodes): with graph.inserting_after(bmm): getitem = graph.call_function(operator.getitem, args=(bmm, i)) From fb9064a52b2197d785cf353fe35a7305787c92e3 Mon Sep 17 00:00:00 2001 From: Moritz Marseu Date: Wed, 4 Sep 2024 00:21:22 +0000 Subject: [PATCH 0372/1018] Fix license metadata in setup.py (#129219) Package metadata in setup.py lists license as BSD-3 which is not a valid SPDX id. The correct id would be BSD-3-Clause. Specifying an SPDX id is beneficial to license compliance scanning. *Taking up #129123 from my personal account.* Pull Request resolved: https://github.com/pytorch/pytorch/pull/129219 Approved by: https://github.com/malfet, https://github.com/kit1980 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5eaddb894df61e..89097e9fdba3e4 100644 --- a/setup.py +++ b/setup.py @@ -1511,7 +1511,7 @@ def main(): f"Programming Language :: Python :: 3.{i}" for i in range(python_min_version[1], version_range_max) ], - license="BSD-3", + license="BSD-3-Clause", keywords="pytorch, machine learning", ) if EMIT_BUILD_WARNING: From 8bc3f64fd77012fd05c714a7b1967b930d972770 Mon Sep 17 00:00:00 2001 From: CaoE Date: Wed, 4 Sep 2024 00:56:56 +0000 Subject: [PATCH 0373/1018] Turn on expanded index path for Half on CPU (#133553) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133553 Approved by: https://github.com/yanbing-j, https://github.com/jgong5, https://github.com/peterbell10 --- aten/src/ATen/native/TensorAdvancedIndexing.cpp | 2 +- aten/src/ATen/native/cpu/ScatterGatherKernel.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index a3e11077a6fd5d..699f9393be2af5 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -1588,7 +1588,7 @@ static bool can_use_expanded_index_path( } const auto st = self.scalar_type(); - if (!(c10::isFloatingType(st)) || st == ScalarType::Half) { + if (!(c10::isFloatingType(st))) { return false; } diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index 95119b5ac08580..6af22033c805e0 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -867,8 +867,8 @@ void scatter_reduce_expanded_index_kernel( } void gather_expanded_index_kernel(const Tensor& result, const Tensor& self, const Tensor& index) { - AT_DISPATCH_FLOATING_TYPES_AND( - ScalarType::BFloat16, self.scalar_type(), "gather_expanded_index", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "gather_expanded_index", [&] { cpu_gather_expanded_index_kernel(result, index, self); }); } From c2749e7c7bc2637c990c0b6c28e6ed828588102a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 3 Sep 2024 20:58:58 +0000 Subject: [PATCH 0374/1018] Fix dim mismatch logic automatic dynamic not working with compiler collectives (#135025) Fixes https://fb.workplace.com/groups/3095840833991792/permalink/3810738595835342/ Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135025 Approved by: https://github.com/albanD --- test/distributed/test_dynamo_distributed.py | 32 +++++++++++++++++++++ torch/_dynamo/variables/builder.py | 11 ++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index afe299106b65f8..134370e41703c3 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -913,6 +913,38 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @config.patch(enable_compiler_collectives=True) + def test_compiler_collectives_dim_mismatch(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + # TODO: This should be possible to do inside the function, but + device = f"cuda:{self.rank}" + + @torch.compile() + def f(x, y): + zx = x.shape + zy = y.shape + return x.sum() + y.sum() + + if self.rank == 0: + dataloader = [[4, 2]] + else: + dataloader = [[3]] + + for data in dataloader: + f( + torch.randn(data, device=self.rank), + torch.randn(data, device=self.rank), + ) + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", False) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index fbe6e90d3eea48..f2e981276fcc86 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2429,6 +2429,9 @@ def _automatic_dynamic( # Prep for automatic dynamic def update_frame_state(size, stride): + # Intentionally shadow e from parent scope so it is not accidentally + # called + e = None frame_state_entry = None if name not in tx.output.frame_state: # If there is no entry for this source, add the tensor to frame state with its current static size. @@ -2439,13 +2442,13 @@ def update_frame_state(size, stride): else: frame_state_entry = tx.output.frame_state[name] if frame_state_entry.size is not None: - if e.ndim != len(frame_state_entry.size): + if len(size) != len(frame_state_entry.size): # If there is already an entry, and the dim mismatches, replace the frame state entry with None. # E.g. {"x": [2, 3, 4]} -> {"x": None} log.debug( "automatic dynamic %s dim %s != %s", name, - e.ndim, + len(size), frame_state_entry.size, ) frame_state_entry.size = None @@ -2462,7 +2465,7 @@ def update_frame_state(size, stride): "automatic dynamic %s size(%s) %s != %s", name, i, - e.size(i), + size[i], dim, ) frame_state_entry.size[i] = None @@ -2492,7 +2495,7 @@ def update_frame_state(size, stride): "automatic dynamic %s stride(%s) %s != %s", name, i, - e.stride(i), + stride[i], dim, ) frame_state_entry.stride[i] = None From b642ae02b6bdae1c7b0056052bca67c3b47c5d5a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 3 Sep 2024 08:20:34 -0700 Subject: [PATCH 0375/1018] [EASY] Typofix (#135022) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135022 Approved by: https://github.com/albanD --- torch/_inductor/cudagraph_trees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 422be37e78a0d2..5a33de0e366899 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -1578,7 +1578,7 @@ def _allocate_and_copy_recording_inputs( self, inputs: List[InputType] ) -> List[Union[torch.Tensor, int]]: """ - Allocate inputs for non static, non cudagraph managraphed managed tensors in the memory pool + Allocate inputs for non static, non cudagraph managed tensors in the memory pool and copy over the tensor values. """ From 0733718b1dfa582df72a721e8d1049958d806497 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Tue, 3 Sep 2024 16:07:37 -0700 Subject: [PATCH 0376/1018] Remove unused comment (#135034) As part of my rampup I've been reading through some of @ezyang's diffs. I noticed in https://github.com/pytorch/pytorch/pull/133439 there was a comment that he forgot to remove. This diff removes that comment. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135034 Approved by: https://github.com/albanD --- torch/_decomp/decompositions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index bc1ccc6a20bb2d..c35d7a72774f16 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1413,8 +1413,6 @@ def split_with_sizes( splits = [] start_idx = 0 - # Avoid importing sympy at a module level - for i in range(num_splits): length = split_sizes[i] splits.append(self.narrow(dim, start_idx, length)) From 6f5489fd46aa32c312b0f596741298481f31794e Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Wed, 4 Sep 2024 02:41:30 +0000 Subject: [PATCH 0377/1018] [AOTI][Tooling][5/n] Refactor the debug printer call to a level lower (#134789) Summary: 1. Move the debug printer call a level lower -> at here :https://www.internalfb.com/code/fbsource/[931d7bbb9e7cf2dcb926f42718f56fc940903eec]/fbcode/caffe2/torch/_inductor/codegen/cpp_wrapper_cuda.py?lines=335 2. Add UT for validating debug printer for user defined triton kernel codegen The benefit of having the debug printer call happens at a more centralized place is 1) reduce the duplicate debug printer related logic code scattered everywhere in the codebase 2) it can handle more triton kernel codegen path as long as it invokes this `generate_kernel_call()` for example, it can automatically handle/support user_defined_kernel 's debug printing which is a pretty common use case we encounter in debugging Test Plan: ```AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_aoti_debug_printer_user_defined_triton_kernel_abi_compatible_cuda``` Also verified that templateKernel codegen path still works Differential Revision: D61949020 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134789 Approved by: https://github.com/ColinPeppler --- test/inductor/test_aot_inductor.py | 53 ++++++++++++++++++++- torch/_inductor/codegen/cpp_wrapper_cpu.py | 11 +++-- torch/_inductor/codegen/cpp_wrapper_cuda.py | 38 ++++++++------- torch/_inductor/codegen/debug_utils.py | 9 ++-- torch/_inductor/codegen/simd.py | 23 +-------- 5 files changed, 88 insertions(+), 46 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 20d03fc40738e7..784c52a50dd990 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3378,6 +3378,49 @@ def forward(self, a): FileCheck().check_not(f"before_launch - {kernel_name}").run(code) FileCheck().check_not(f"after_launch - {kernel_name}").run(code) + def test_aoti_debug_printer_user_defined_triton_kernel(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + out = torch.zeros_like(x) + add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16) + return out + + example_inputs = ( + torch.randn(4, 4, device=self.device), + torch.randn(4, 4, device=self.device), + ) + + kernel_calls = [ + ("add_kernel_0", 3), + ] + + # test the default debug printing codegen + with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): + result, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, Model(), example_inputs + ) + + # check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected + self.assertEqual("aoti_torch_print_tensor_handle" in code, True) + + # check the codegen for debug printing around the actual kernel call is expected + + for kernel_call, count in kernel_calls: + FileCheck().check_count( + f"before_launch - {kernel_call}", + count, + ).run(code) + FileCheck().check_count( + f"after_launch - {kernel_call}", + count, + ).run(code) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) @@ -3584,8 +3627,6 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True), # fp8 to be re-enabled for AOTI "test_fp8": fail_cuda(is_skip=True), - # non-abi compatible mode debug printer is not supported yet - "test_aoti_debug_printer_codegen": fail_non_abi_compatible_cuda(is_skip=True), "test_custom_op_all_inputs": fail_non_abi_compatible_cuda(is_skip=True), "test_custom_op_missing_arg_with_default_value": fail_non_abi_compatible_cuda( is_skip=True @@ -3595,6 +3636,11 @@ def fail_non_abi_compatible_cuda(is_skip=False): is_skip=True ), "test_custom_op_with_multiple_outputs": fail_non_abi_compatible_cuda(is_skip=True), + # non-abi compatible mode aoti debug printer is not supported yet + "test_aoti_debug_printer_codegen": fail_non_abi_compatible_cuda(is_skip=True), + "test_aoti_debug_printer_user_defined_triton_kernel": fail_non_abi_compatible_cuda( + is_skip=True + ), } @@ -3651,6 +3697,9 @@ def fail_non_abi_compatible_cuda(is_skip=False): CUDA_TEST_FAILURES.update( { "test_aoti_debug_printer_codegen": fail_cuda(is_skip=True), + "test_aoti_debug_printer_user_defined_triton_kernel": fail_cuda( + is_skip=True + ), } ) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index f25c2021d36b9e..af2bd731f5c324 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1213,8 +1213,10 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): args_to_print_or_save = None debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_level = debug_printer_manager.debug_printer_level - if debug_printer_level != IntermediateValueDebuggingLevel.OFF: + if ( + debug_printer_manager.debug_printer_level + != IntermediateValueDebuggingLevel.OFF + ): args_to_print_or_save = [] for x in args: @@ -1231,7 +1233,10 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): ): # TODO: The current way to find a 'tensor' type arg is hacky also as mentioned above # Find a more reliable way to detect tensor kernel args for extern kernel calls - if debug_printer_level != IntermediateValueDebuggingLevel.OFF: + if ( + debug_printer_manager.debug_printer_level + != IntermediateValueDebuggingLevel.OFF + ): if piece.startswith(("buf", "arg")): args_to_print_or_save.append(piece) piece = f"convert_arrayref_tensor_to_tensor({piece})" diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index 05016c387b1fa7..cc406e4dac4ec9 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -395,9 +395,9 @@ def generate_kernel_call( call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1] arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] - call_args = self.generate_args_decl(call_args, arg_types) + call_args_str = self.generate_args_decl(call_args, arg_types) kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" - self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};") + self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};") stream = ( "stream" if V.graph.aot_mode @@ -410,19 +410,23 @@ def generate_kernel_call( ) kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name - self.writeline(f"if ({grid_var}.is_non_zero()) {{") - self.writeline( - DeferredCudaKernelLine( - kernel_name, - r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( - kernel_var_name, - f"{grid_var}.grid_x", - f"{grid_var}.grid_y", - f"{grid_var}.grid_z", - kernel_args_var, - stream, + # add debug printer code for all triton kernel related calls + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) + with debug_printer_manager: + self.writeline(f"if ({grid_var}.is_non_zero()) {{") + self.writeline( + DeferredCudaKernelLine( + kernel_name, + r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( + kernel_var_name, + f"{grid_var}.grid_x", + f"{grid_var}.grid_y", + f"{grid_var}.grid_z", + kernel_args_var, + stream, + ), + ("num_warps", "shared_mem"), ), - ("num_warps", "shared_mem"), - ), - ) - self.writeline("}") + ) + self.writeline("}") diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index ebe46fc8cde01e..7b3dcd8eea2cf5 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -6,9 +6,10 @@ from enum import Enum from typing import List, Optional +from torch import dtype as torch_dtype + from .. import config from ..virtualized import V -from .common import TensorArg from .multi_kernel import MultiKernel @@ -121,8 +122,9 @@ def codegen_intermediate_tensor_value_save( ) -> None: for i, arg in enumerate(args_to_save): if arg_signatures is not None and not isinstance( - arg_signatures[i], TensorArg + arg_signatures[i], torch_dtype ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type continue launch_prefix = "before_launch" if before_launch else "after_launch" if V.graph.cpp_wrapper: @@ -146,8 +148,9 @@ def codegen_intermediate_tensor_value_print( ) -> None: for i, arg in enumerate(args_to_print): if arg_signatures is not None and not isinstance( - arg_signatures[i], TensorArg + arg_signatures[i], torch_dtype ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type continue if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY, diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 4958104b7105a2..59c126b993e3d4 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1376,18 +1376,7 @@ def _node_has_sort(node): node.mark_run() self.codegen_comment(node_schedule) - - _, call_args, arg_signatures, _ = ( - final_kernel.args.python_argdefs() - if not isinstance(final_kernel, MultiKernel) - else [None, [], None, None] - ) - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.set_printer_args( - call_args, kernel_name, arg_signatures, final_kernel - ) - with debug_printer_manager: - final_kernel.call_kernel(final_kernel.kernel_name) + final_kernel.call_kernel(final_kernel.kernel_name) if config.nan_asserts: final_kernel.codegen_nan_check() @@ -1510,15 +1499,7 @@ def codegen_template( kernel_name = self.define_kernel(src_code, node_schedule, kernel) self.codegen_comment(node_schedule) - - # debug printing values of intermediate tensors - _, call_args, arg_signatures, _ = kernel.args.python_argdefs() - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.set_printer_args( - call_args, kernel_name, arg_signatures, kernel - ) - with debug_printer_manager: - kernel.call_kernel(kernel_name, template_node.node) + kernel.call_kernel(kernel_name, template_node.node) V.graph.removed_buffers |= kernel.removed_buffers V.graph.inplaced_to_remove |= kernel.inplaced_to_remove From 5ab4df902fca8f58be810ae2cf96b536106ff959 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 2 Sep 2024 20:16:16 +0000 Subject: [PATCH 0378/1018] [halide-backend] Update CI pin (#130258) Pull Request resolved: https://github.com/pytorch/pytorch/pull/130258 Approved by: https://github.com/eellison --- .ci/docker/ci_commit_pins/halide.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/halide.txt b/.ci/docker/ci_commit_pins/halide.txt index 17ff17c31524e6..aaf88275edad42 100644 --- a/.ci/docker/ci_commit_pins/halide.txt +++ b/.ci/docker/ci_commit_pins/halide.txt @@ -1 +1 @@ -340136fec6d3ebc73e7a19eba1663e9b0ba8ab2d \ No newline at end of file +461c12871f336fe6f57b55d6a297f13ef209161b \ No newline at end of file From e3ea4733630605f5d915f9a19af1e78ea153dc53 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 18:31:18 -0700 Subject: [PATCH 0379/1018] [ONNX] Use the stable APIs in onnxscript and sync the latest logic (#134782) Use the stable apis from onnxscript: https://github.com/microsoft/onnxscript/issues/1827 Sync with torch-onnx at https://github.com/justinchuby/torch-onnx/compare/v0.1.12...v0.1.15. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134782 Approved by: https://github.com/titaiwangms --- torch/onnx/_internal/_lazy_import.py | 2 + torch/onnx/_internal/exporter/_analysis.py | 10 +- torch/onnx/_internal/exporter/_compat.py | 62 ++++------- torch/onnx/_internal/exporter/_core.py | 61 ++++------- torch/onnx/_internal/exporter/_dispatching.py | 33 ++++-- .../onnx/_internal/exporter/_onnx_program.py | 101 ++++++++---------- .../onnx/_internal/exporter/_registration.py | 65 +++++------ torch/onnx/_internal/exporter/_schemas.py | 21 +++- .../onnx/_internal/exporter/_verification.py | 5 +- torch/onnx/utils.py | 2 +- 10 files changed, 158 insertions(+), 204 deletions(-) diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py index 382bf7d78cd4cf..b12e53ef292624 100644 --- a/torch/onnx/_internal/_lazy_import.py +++ b/torch/onnx/_internal/_lazy_import.py @@ -30,6 +30,7 @@ def __getattr__(self, attr): if TYPE_CHECKING: import onnx import onnxscript + import onnxscript._framework_apis.torch_2_5 as onnxscript_apis onnxscript_ir = onnxscript.ir @@ -37,3 +38,4 @@ def __getattr__(self, attr): onnx = _LazyModule("onnx") onnxscript = _LazyModule("onnxscript") onnxscript_ir = _LazyModule("onnxscript.ir") + onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_5") diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py index 43a6519482663f..2eb5adab3a7d9d 100644 --- a/torch/onnx/_internal/exporter/_analysis.py +++ b/torch/onnx/_internal/exporter/_analysis.py @@ -10,8 +10,6 @@ from collections import defaultdict from typing import TYPE_CHECKING -import onnxscript - import torch import torch._export.serde.schema from torch.export import graph_signature @@ -203,13 +201,7 @@ def analyze( model_info.outputs = outputs if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops # noqa: F401 - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) + registry = _registration.ONNXRegistry.from_torchlib() # Try to find ops for every node in the graph for node in exported_program.graph.nodes: diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 642f768d7285cb..7d38c2ce1415e9 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -8,10 +8,9 @@ import logging from typing import Any, Mapping, Sequence, TYPE_CHECKING -import onnx - import torch import torch.export +from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir from torch.onnx._internal.exporter import _core, _onnx_program @@ -108,13 +107,6 @@ def _get_torch_export_args( return args, kwargs -def _convert_version(path: str | os.PathLike, opset_version: int) -> None: - """Convert the ONNX file to a specific version.""" - model = onnx.load(path, load_external_data=False) - model = onnx.version_converter.convert_version(model, opset_version) - onnx.save(model, path) - - def export_compat( model: torch.nn.Module | torch.export.ExportedProgram @@ -142,9 +134,13 @@ def export_compat( artifacts_dir: str | os.PathLike = ".", fallback: bool = False, **_, -) -> _onnx_program.ONNXProgram | None: +) -> _onnx_program.ONNXProgram: + if opset_version is None: + # TODO(justinchuby): Change the hardcoded opset version for it to be flexible + opset_version = 18 + if isinstance(model, torch.export.ExportedProgram): - # We the model is already exported program, so the args, kwargs, and dynamic_shapes + # We know the model is already exported program, so the args, kwargs, and dynamic_shapes # are not used dynamic_shapes = dynamic_shapes or {} else: @@ -154,8 +150,6 @@ def export_compat( model, dynamic_axes, input_names ) - should_convert_version = False - try: onnx_program = _core.export( model, @@ -173,20 +167,6 @@ def export_compat( verbose=verbose, ) - if f is not None: - # Always save the initializers as external data to reduce the size of the ONNX file - onnx_program.save( - f, - include_initializers=export_params, - keep_initializers_as_inputs=keep_initializers_as_inputs, - external_data=external_data, - ) - if ( - opset_version is not None - and opset_version != onnx_program.model.opset_imports.get("") - ): - should_convert_version = True - except Exception as e: if fallback: if verbose is not False: @@ -194,6 +174,8 @@ def export_compat( "[torch.onnx] Falling back to legacy torch.onnx.export due " f"to the following error: {e}", ) + if f is None: + raise TypeError("f must be provided when fallback is enabled") from e torch.onnx.utils.export( model, # type: ignore[arg-type] args, @@ -206,20 +188,22 @@ def export_compat( dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs, ) - onnx_program = None - if opset_version is None: - opset_version = 18 - if opset_version != 17: - should_convert_version = True + onnx_program = _onnx_program.ONNXProgram(ir.load(f), None) else: raise - if f is not None and should_convert_version: - assert opset_version is not None - if verbose is not False: - print( - f"[torch.onnx] Converting the ONNX file to opset version {opset_version}..." - ) - _convert_version(f, opset_version) + # Converter opset version and optimize + onnx_program.model = onnxscript_apis.convert_version( + onnx_program.model, opset_version + ) + onnx_program.model = onnxscript_apis.optimize(onnx_program.model) + + if f is not None: + onnx_program.save( + f, + include_initializers=export_params, + keep_initializers_as_inputs=keep_initializers_as_inputs, + external_data=external_data, + ) return onnx_program diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index e775571d583327..5994db2c4e0e55 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -14,19 +14,17 @@ import typing from typing import Any, Callable, Literal, Sequence -import onnx - import onnxscript import onnxscript.evaluator import onnxscript.function_libs import onnxscript.function_libs.torch_lib -import onnxscript.function_libs.torch_lib.registration from onnxscript import ir from onnxscript.ir import convenience as ir_convenience import torch import torch.fx from torch.export import graph_signature +from torch.onnx._internal._lazy_import import onnxscript_apis from torch.onnx._internal.exporter import ( _analysis, _building, @@ -34,7 +32,6 @@ _dispatching, _fx_passes, _ir_passes, - _isolated, _onnx_program, _registration, _reporting, @@ -165,7 +162,11 @@ def _set_shape_types( def _set_shape_type( value: ir.Value, - meta_val: torch.Tensor | tuple[torch.Tensor], + meta_val: torch.Tensor + | torch.SymBool + | torch.SymInt + | torch.SymFloat + | tuple[torch.Tensor], complex_to_float: bool, ) -> None: # TODO: Consider using meta["tensor_meta"] for this? Would it be faster? @@ -691,13 +692,7 @@ def exported_program_to_ir( registry: The registry of all ONNX Script decomposition. """ if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops # noqa: F401 - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) + registry = _registration.ONNXRegistry.from_torchlib() if lower != "none": exported_program = _prepare_exported_program_for_export( exported_program, registry=registry @@ -768,7 +763,7 @@ def _exported_program_to_onnx_program( }, ), ir_version=9, - producer_name="torch", + producer_name="pytorch", producer_version=torch.__version__, ) @@ -981,6 +976,8 @@ def export( program: torch.export.ExportedProgram | None = None # Step 1: Export the model with torch.export.export if the model is not already an ExportedProgram if isinstance(model, torch.export.ExportedProgram): + # We know the model is already exported program, so the args, kwargs, and dynamic_shapes + # are not used. program = model export_status.torch_export = True else: @@ -1071,13 +1068,7 @@ def export( try: # Build the ONNX function registry if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) + registry = _registration.ONNXRegistry.from_torchlib() # Process the exported program to run decompositions and type promotions etc. decomposed_program = _prepare_exported_program_for_export( @@ -1197,11 +1188,12 @@ def export( if not failed_results else _format_exceptions_for_all_strategies(failed_results), onnx_program.exported_program, - profile_result=profile_result, - export_status=export_status, decomp_comparison=_reporting.format_decomp_comparison( pre_decomp_unique_ops, post_decomp_unique_ops ), + export_status=export_status, + profile_result=profile_result, + model=onnx_program.model, registry=registry, ) verbose_print(f"Export report has been saved to '{report_path}'.") @@ -1214,28 +1206,13 @@ def export( # Step 3: (verify=True) Check the ONNX model with ONNX checker try: - verbose_print("Run `onnx.checker` on the ONNX model...") - - # TODO: Handle when model is >2GB - - model_proto = onnx_program.model_proto - byte_size = model_proto.ByteSize() - if byte_size < 2 * 1024 * 1024 * 1024: - # The checker may segfault so we need to run it in a separate process - _isolated.safe_call( - onnx.checker.check_model, # type:ignore[attr-defined] - onnx_program.model_proto, - full_check=True, - ) - export_status.onnx_checker = True - verbose_print("Run `onnx.checker` on the ONNX model... ✅") - else: - verbose_print( - f"Run `onnx.checker` on the ONNX model... ⚠️ Skipped because model is too large ({byte_size})." - ) + verbose_print("Check the ONNX model...") + onnxscript_apis.check_model(onnx_program.model) + export_status.onnx_checker = True + verbose_print("Check the ONNX model... ✅") except Exception as e: export_status.onnx_checker = False - verbose_print("Run `onnx.checker` on the ONNX model... ❌") + verbose_print("Check the ONNX model... ❌") if report: try: assert pre_decomp_unique_ops is not None diff --git a/torch/onnx/_internal/exporter/_dispatching.py b/torch/onnx/_internal/exporter/_dispatching.py index b8aecfaa93793a..11ed1af17aaade 100644 --- a/torch/onnx/_internal/exporter/_dispatching.py +++ b/torch/onnx/_internal/exporter/_dispatching.py @@ -2,9 +2,8 @@ from __future__ import annotations import logging -from typing import Sequence +from typing import Callable, Sequence -import onnxscript from onnxscript import ir import torch @@ -163,10 +162,22 @@ def _param_type_compatible_with_arg( def _get_type_from_tensor( - tensor: torch.Tensor | Sequence[torch.Tensor], + tensor: torch.Tensor + | torch.SymBool + | torch.SymInt + | torch.SymFloat + | Sequence[torch.Tensor], ) -> ir.TypeProtocol: if isinstance(tensor, torch.Tensor): return ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(tensor.dtype)) + if isinstance(tensor, torch.SymBool): + return ir.TensorType(ir.DataType.BOOL) + if isinstance(tensor, torch.SymInt): + return ir.TensorType(ir.DataType.INT64) + if isinstance(tensor, torch.SymFloat): + return ir.TensorType(ir.DataType.FLOAT) + + # Handle sequences first_tensor = next((item for item in tensor if item is not None), None) if first_tensor is None: return ir.SequenceType(ir.TensorType(ir.DataType.UNDEFINED)) @@ -189,7 +200,7 @@ def _get_first_tensor_in_node_list( def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argument]: - # FIXME: node.target may not have a schema + assert hasattr(node.target, "_schema") torch_schema: torch.FunctionSchema = node.target._schema # type: ignore[union-attr] node_args = {} for arg, schema_arg in zip(node.args, torch_schema.arguments): @@ -201,8 +212,8 @@ def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argu def get_matching_overload( node: torch.fx.Node, - overloads: Sequence[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction], -) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]: + overloads: Sequence[Callable], +) -> tuple[Callable | None, str]: """Get the overload that matches the node's arguments. Args: @@ -212,8 +223,14 @@ def get_matching_overload( Returns: A tuple containing the matched overload and a string describing the reason for failure or success. """ + if not hasattr(node.target, "_schema"): + # FIXME(justinchuby): When the target is a builtin, we should instead + # Match only the inputs positionally. Figure out how to do that as right + # now we assume all inputs are named. + return overloads[ + 0 + ], "The node target does not have a schema. Return the first one." named_args = _get_named_fx_node_args(node) - # FIXME: node.target may and builtin and not have a schema # FIXME: Handle when we don't know the names of the arguments schema_args: dict[str, torch.Argument] = { arg.name: arg @@ -308,7 +325,7 @@ def _arg_has_complex_dtype(arg) -> bool: def dispatch( node: torch.fx.Node, registry: _registration.ONNXRegistry -) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]: +) -> tuple[Callable | None, str]: """Dispatch a node to an ONNX function based on the node's target and the ONNX registry. Args: diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 646b2e7d494f1d..49172e772c8a1c 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -5,26 +5,24 @@ __all__ = ["ONNXProgram"] +import copy import gc import logging import os -import pathlib import tempfile import textwrap -from typing import Callable, IO, Sequence, TYPE_CHECKING +from typing import Callable, Sequence, TYPE_CHECKING import torch -from torch.onnx._internal import _lazy_import -from torch.utils import _pytree as pytree - - -onnx = _lazy_import.onnx -ir = _lazy_import.onnxscript_ir +from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir +from torch.utils import _pytree if TYPE_CHECKING: import onnxruntime as ort +_LARGE_MODEL_THRESHOLD = 1536 * 1024 * 1024 # 1536MB + logger = logging.getLogger(__name__) @@ -47,6 +45,15 @@ def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: ) +def _count_initializer_size(graph: ir.Graph) -> int: + """Count the total size of the initializers in bytes.""" + return sum( + v.const_value.nbytes + for v in graph.initializers.values() + if v.const_value is not None + ) + + class ONNXProgram: """A class to represent an ONNX program that is callable with torch tensors.""" @@ -105,7 +112,7 @@ def model_proto(self) -> onnx.ModelProto: def save( self, - destination: str | os.PathLike | IO[bytes], + destination: str | os.PathLike, *, include_initializers: bool = True, keep_initializers_as_inputs: bool = False, @@ -128,44 +135,30 @@ def save( Raises: TypeError: If `external_data` is `True` and `destination` is not a file path. """ + original_initializers = copy.copy(self.model.graph.initializers) + original_inputs = copy.copy(self.model.graph.inputs) + + # Adjust the model based on options if not include_initializers: self.model.graph.initializers.clear() - logger.warning( - "The initializers have been removed from the model. This is destructive. " - "Developers: Please implement ir.Model copy() and remove initializers on the copied model." - ) if keep_initializers_as_inputs: self.model.graph.inputs.extend(self.model.graph.initializers.values()) # type: ignore[arg-type] - logger.warning( - "The initializers have been added as inputs to the model. This is destructive. " - "Developers: Please implement ir.Model copy() and remove initializers on the copied model." - ) - proto = ir.serde.serialize_model(self.model) - byte_size = proto.ByteSize() - model_too_large = (byte_size) >= 1 << 31 - if external_data or model_too_large: - # TODO: Create an IR pass to handle external tensors conversion - if model_too_large: - logger.warning( - "The serialized ONNX model is larger than 2GB (%s). " - "Saving the weights as external data in a separate file.", - byte_size, - ) - if not isinstance(destination, (str, os.PathLike)): - raise TypeError( - "Saving the weights as external data is only supported when destination is a file path" - ) - destination_path = pathlib.Path(destination) - # Create the directory if it does not exist - data_path = f"{destination_path.name}.data" - onnx.save_model( - proto, - destination, - save_as_external_data=True, - location=data_path, - ) + + # Save the model to disk + if ( + external_data + or _count_initializer_size(self.model.graph) > _LARGE_MODEL_THRESHOLD + ): + onnxscript_apis.save_model_with_external_data(self.model, destination) else: - onnx.save_model(proto, destination) + ir.save(self.model, destination) + + # Revert the changes to the model + if not include_initializers: + self.model.graph.initializers.update(original_initializers) + if keep_initializers_as_inputs: + self.model.graph.inputs.clear() + self.model.graph.inputs.extend(original_inputs) def initialize_inference_session( self, @@ -182,27 +175,17 @@ def initialize_inference_session( """ # TODO(justinchuby): Allow different inference options logger.debug("Initializing the inference session.") - proto = ir.serde.serialize_model(self.model) - byte_size = proto.ByteSize() - model_too_large = (byte_size) >= 1 << 31 - - if model_too_large: - logger.debug( - "The serialized ONNX model is larger than 2GB (%s).", byte_size - ) + if ( + byte_size := _count_initializer_size(self.model.graph) + ) > _LARGE_MODEL_THRESHOLD: + logger.debug("The model initializers is larger than 1.5GB (%s).", byte_size) # Save the model to a temporary file if too large self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) model_path = os.path.join(self._tempdir.name, "model.onnx") - data_path = "model.onnx.data" - onnx.save_model( - proto, - model_path, - save_as_external_data=True, - location=data_path, - ) + self.save(model_path, external_data=True) model = model_path else: - model = proto.SerializeToString() # type: ignore[assignment] + model = self.model_proto.SerializeToString() # type: ignore[assignment] self._inference_session = initializer(model) logger.debug("Inference session initialized.") @@ -231,7 +214,7 @@ def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]: def _flatten_inputs(model_args, model_kwargs): - flattened_args, _ = pytree.tree_flatten((model_args, model_kwargs)) + flattened_args, _ = _pytree.tree_flatten((model_args, model_kwargs)) return flattened_args diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index b649188c264eac..0afa084b06c80d 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -18,17 +18,17 @@ import operator import types import typing -from typing import Callable, Literal, Mapping, Union +from typing import Callable, Literal, Union from typing_extensions import TypeAlias import torch import torch._ops +from torch.onnx._internal._lazy_import import onnxscript_apis from torch.onnx._internal.exporter import _schemas if typing.TYPE_CHECKING: import onnxscript - from onnxscript.function_libs.torch_lib import registration as torchlib_registration _DEFAULT_OPSET_VERSION = 18 @@ -49,7 +49,7 @@ class OnnxDecompMeta: device: The device the function is registered to. If None, it is registered to all devices. """ - onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction + onnx_function: Callable fx_target: TorchOp is_custom: bool = False is_complex: bool = False @@ -135,24 +135,21 @@ def opset_version(self) -> int: return self._opset_version @classmethod - def from_torchlib( - cls, - torchlib_registry: Mapping[str, torchlib_registration.OverloadedFunction] - | None = None, - ) -> ONNXRegistry: + def from_torchlib(cls) -> ONNXRegistry: """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ registry = cls() - if torchlib_registry is None: - from onnxscript.function_libs.torch_lib import ( - registration as torchlib_registration, - ) - torchlib_registry = torchlib_registration.default_registry # type: ignore[assignment] - for qualified_name, aten_overloads_func in torchlib_registry.items(): # type: ignore[union-attr] + torchlib_ops = onnxscript_apis.get_torchlib_ops() + + for meta in torchlib_ops: + qualified_name = meta.qualified_name + overload_func = meta.function + domain = meta.domain + name = meta.name try: # NOTE: This is heavily guarded with try-except because we don't want # to fail the entire registry population if one function fails. @@ -162,33 +159,19 @@ def from_torchlib( target = _get_overload(qualified_name) if target is None: continue - for overload_func in aten_overloads_func.overloads: - overload_func.signature = _schemas.OpSignature.from_function( - overload_func, - overload_func.function_ir.domain, - overload_func.name, - ) - onnx_decomposition = OnnxDecompMeta( - onnx_function=overload_func, - fx_target=target, - is_custom=False, - is_complex=False, - ) - registry._register(target, onnx_decomposition) - - for complex_func in aten_overloads_func.complex: - overload_func.signature = _schemas.OpSignature.from_function( - overload_func, - overload_func.function_ir.domain, - overload_func.name, - ) - onnx_decomposition = OnnxDecompMeta( - onnx_function=complex_func, - fx_target=target, - is_custom=False, - is_complex=True, - ) - registry._register(target, onnx_decomposition) + + overload_func.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] + overload_func, + domain, + name, + ) + onnx_decomposition = OnnxDecompMeta( + onnx_function=overload_func, + fx_target=target, + is_custom=False, + is_complex=meta.is_complex, + ) + registry._register(target, onnx_decomposition) except Exception: logger.exception("Failed to register '%s'. Skipped", qualified_name) continue diff --git a/torch/onnx/_internal/exporter/_schemas.py b/torch/onnx/_internal/exporter/_schemas.py index 12e2f3d2f44df2..fbc406053fbee8 100644 --- a/torch/onnx/_internal/exporter/_schemas.py +++ b/torch/onnx/_internal/exporter/_schemas.py @@ -121,6 +121,8 @@ def has_default(self) -> bool: @dataclasses.dataclass(frozen=True) class AttributeParameter: + """A parameter in the function signature that represents an ONNX attribute.""" + name: str type: ir.AttributeType required: bool @@ -435,7 +437,7 @@ def from_function( # https://github.com/python/cpython/issues/102405 type_hints = typing.get_type_hints(func) - params = [] + params: list[Parameter | AttributeParameter] = [] # Create a mapping from type to a unique name type_constraints: dict[str, TypeConstraintParam] = {} @@ -446,8 +448,19 @@ def from_function( param.name, py_signature, ) - type_constraints[param.name] = TypeConstraintParam.any_value( - f"T_{param.name}" + type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") + type_constraints[param.name] = type_constraint + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) ) else: type_ = type_hints[param.name] @@ -490,7 +503,7 @@ def from_function( type_constraints[type_constraint_name] = type_constraint # 4. Create Parameter params.append( - Parameter( # type: ignore[arg-type] + Parameter( name=param.name, type_constraint=type_constraint, required=param.default is inspect.Parameter.empty, diff --git a/torch/onnx/_internal/exporter/_verification.py b/torch/onnx/_internal/exporter/_verification.py index 10560414616fbf..c4eec16da4990f 100644 --- a/torch/onnx/_internal/exporter/_verification.py +++ b/torch/onnx/_internal/exporter/_verification.py @@ -91,7 +91,10 @@ def verify_onnx_program( ) abs_diff = abs_diff.flatten() rel_diff = rel_diff.flatten() - bins = torch.tensor([0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10]) + bins = torch.tensor( + [0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10, 1000000], + dtype=abs_diff.dtype, + ) abs_diff_hist = torch.histogram(abs_diff, bins=bins) rel_diff_hist = torch.histogram(rel_diff, bins=bins) results.append( diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 7370411ad3d221..9398b92cd9a43b 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -175,7 +175,7 @@ def _get_torch_export_args( def export( model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, args: tuple[Any, ...] | torch.Tensor, - f: str | None = None, + f: str, *, kwargs: dict[str, Any] | None = None, export_params: bool = True, From ce13ad62cf0c7f91a49847f3d9041888602a5c79 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Thu, 15 Aug 2024 12:29:57 -0700 Subject: [PATCH 0380/1018] [inductor] Skip the sub-process pool until it's ready (#133508) Summary: Torch-compiling a quick script can be a bit slower than it needs to be: even though we initialize the subprocess pool early, it still might not be ready by the time we try to compile the first Triton kernel. Instead, let's use the single-threaded path until the pool has successfully completed a no-op job. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133508 Approved by: https://github.com/Chillee --- torch/_inductor/async_compile.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 1ba70135bd7f8f..ce265460958439 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -127,6 +127,11 @@ def pool() -> ThreadPoolExecutor: assert config.compile_threads > 1 return ThreadPoolExecutor(config.compile_threads) + @staticmethod + def _get_ready(): + """No-op function to help mark when the subprocess pool is ready.""" + return "ready" + @staticmethod @functools.lru_cache(1) def process_pool() -> AnyPool: @@ -149,6 +154,8 @@ def process_pool() -> AnyPool: # kill the worker thread that sends the shutdown message to the workers... multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + # Set an attribute we can check to see if the pool is ready. + pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr] _pool_set.add(pool) return pool @@ -166,13 +173,19 @@ def submit(cls, task: Callable[..., Any]) -> Any: return task() return cls.pool().submit(task) + def _use_process_pool(self): + return ( + config.compile_threads > 1 + and self.process_pool().ready_future.done() # type: ignore[union-attr] + ) + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): kernel_code_log.info("Triton Kernel:\n%s", source_code) _compile_start() _set_triton_ptxas_path() kernel = TritonCodeCache.load(kernel_name, source_code) - if config.compile_threads > 1: + if self._use_process_pool(): # We want to support changing these env vars after (and while) the # process pool is running, so pass them to the subprocess to reset. env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] From 17d1a9947c97b293361f9aec7dc19b604fb09e71 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 20:33:57 -0700 Subject: [PATCH 0381/1018] Bump actions/download-artifact from 2 to 4.1.7 in /.github/workflows (#135068) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 2 to 4.1.7. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v2...v4.1.7) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/_binary-test-linux.yml | 2 +- .github/workflows/_binary-upload.yml | 2 +- .github/workflows/_ios-build-test.yml | 2 +- .github/workflows/build-triton-wheel.yml | 4 +-- .github/workflows/create_release.yml | 2 +- ...inux-binary-libtorch-cxx11-abi-nightly.yml | 4 +-- ...inux-binary-libtorch-pre-cxx11-nightly.yml | 4 +-- ...nerated-linux-binary-manywheel-nightly.yml | 26 +++++++-------- ...generated-windows-binary-conda-nightly.yml | 32 +++++++++---------- ...ted-windows-binary-libtorch-debug-main.yml | 2 +- ...-windows-binary-libtorch-debug-nightly.yml | 8 ++--- ...d-windows-binary-libtorch-release-main.yml | 2 +- ...indows-binary-libtorch-release-nightly.yml | 8 ++--- ...generated-windows-binary-wheel-nightly.yml | 32 +++++++++---------- 14 files changed, 65 insertions(+), 65 deletions(-) diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index fa590499d6d52b..5123889fb01bf8 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -210,7 +210,7 @@ jobs: - name: Download Build Artifacts if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' }} - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: ${{ inputs.build_name }} path: "${{ runner.temp }}/artifacts/" diff --git a/.github/workflows/_binary-upload.yml b/.github/workflows/_binary-upload.yml index 1c5a9294598b78..927f72c8d83897 100644 --- a/.github/workflows/_binary-upload.yml +++ b/.github/workflows/_binary-upload.yml @@ -126,7 +126,7 @@ jobs: # NB: When the previous build job is skipped, there won't be any artifacts and # this step will fail. Binary build jobs can only be skipped on CI, not nightly continue-on-error: true - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: ${{ inputs.build_name }} path: "${{ runner.temp }}/artifacts/" diff --git a/.github/workflows/_ios-build-test.yml b/.github/workflows/_ios-build-test.yml index bb91b43e319d25..95fe6bd1a3b50a 100644 --- a/.github/workflows/_ios-build-test.yml +++ b/.github/workflows/_ios-build-test.yml @@ -292,7 +292,7 @@ jobs: bundler-cache: true - name: Download arm64 artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: pytorch-ios-build-artifacts-arm64 diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 01f06cdd286cf4..d6a83eaf3654cc 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -157,7 +157,7 @@ jobs: aws-region: us-east-1 - name: Download Build Artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: # Download all available artifacts path: ${{ runner.temp }}/artifacts-all @@ -273,7 +273,7 @@ jobs: - uses: actions/checkout@v3 - name: Download Build Artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: # Download all available artifacts path: ${{ runner.temp }}/artifacts-all diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index a9836734bfe5f9..ab8275ad263f69 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -80,7 +80,7 @@ jobs: id-token: write needs: release steps: - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4.1.7 with: name: ${{ needs.release.outputs.pt_release_name }} - name: Configure AWS credentials(PyTorch account) diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml index f4c344cbdbc99b..408106d0096aba 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml @@ -366,7 +366,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-rocm6_1-shared-with-deps-cxx11-abi @@ -476,7 +476,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-rocm6_2-shared-with-deps-cxx11-abi diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml index 23580f30881442..ee9f94c8ac6c25 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml @@ -366,7 +366,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-rocm6_1-shared-with-deps-pre-cxx11 @@ -476,7 +476,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-rocm6_2-shared-with-deps-pre-cxx11 diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index e88daa947a3d19..4fff1e0bacd2bf 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -635,7 +635,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_9-rocm6_1 @@ -742,7 +742,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_9-rocm6_2 @@ -859,7 +859,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_9-xpu @@ -1580,7 +1580,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_10-rocm6_1 @@ -1687,7 +1687,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_10-rocm6_2 @@ -1804,7 +1804,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_10-xpu @@ -2455,7 +2455,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_11-rocm6_1 @@ -2562,7 +2562,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_11-rocm6_2 @@ -2679,7 +2679,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_11-xpu @@ -3330,7 +3330,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_12-rocm6_1 @@ -3437,7 +3437,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_12-rocm6_2 @@ -3554,7 +3554,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_12-xpu @@ -4215,7 +4215,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_13-xpu diff --git a/.github/workflows/generated-windows-binary-conda-nightly.yml b/.github/workflows/generated-windows-binary-conda-nightly.yml index 85b66c8555034c..5d3925fb98eaf9 100644 --- a/.github/workflows/generated-windows-binary-conda-nightly.yml +++ b/.github/workflows/generated-windows-binary-conda-nightly.yml @@ -213,7 +213,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_9-cpu @@ -460,7 +460,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_9-cuda11_8 @@ -708,7 +708,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_9-cuda12_1 @@ -956,7 +956,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_9-cuda12_4 @@ -1202,7 +1202,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_10-cpu @@ -1449,7 +1449,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_10-cuda11_8 @@ -1697,7 +1697,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_10-cuda12_1 @@ -1945,7 +1945,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_10-cuda12_4 @@ -2191,7 +2191,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_11-cpu @@ -2438,7 +2438,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_11-cuda11_8 @@ -2686,7 +2686,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_11-cuda12_1 @@ -2934,7 +2934,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_11-cuda12_4 @@ -3180,7 +3180,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_12-cpu @@ -3427,7 +3427,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_12-cuda11_8 @@ -3675,7 +3675,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_12-cuda12_1 @@ -3923,7 +3923,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_12-cuda12_4 diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml index 6debd8b9a2dfcb..58a10b5c0ad6d6 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml @@ -214,7 +214,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cpu-shared-with-deps-debug diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index fef73a900b15ae..b9f75fc865ad17 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -221,7 +221,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cpu-shared-with-deps-debug @@ -480,7 +480,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda11_8-shared-with-deps-debug @@ -740,7 +740,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda12_1-shared-with-deps-debug @@ -1000,7 +1000,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda12_4-shared-with-deps-debug diff --git a/.github/workflows/generated-windows-binary-libtorch-release-main.yml b/.github/workflows/generated-windows-binary-libtorch-release-main.yml index b4ae0bc0130ae0..2897998822d65d 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-main.yml @@ -214,7 +214,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cpu-shared-with-deps-release diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index c327839e5ec1ae..885992be575a25 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -221,7 +221,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cpu-shared-with-deps-release @@ -480,7 +480,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda11_8-shared-with-deps-release @@ -740,7 +740,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda12_1-shared-with-deps-release @@ -1000,7 +1000,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda12_4-shared-with-deps-release diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index a48c6b938da647..c042b16be7a893 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -214,7 +214,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-cpu @@ -462,7 +462,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-cuda11_8 @@ -711,7 +711,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-cuda12_1 @@ -960,7 +960,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-cuda12_4 @@ -1207,7 +1207,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_10-cpu @@ -1455,7 +1455,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_10-cuda11_8 @@ -1704,7 +1704,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_10-cuda12_1 @@ -1953,7 +1953,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_10-cuda12_4 @@ -2200,7 +2200,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_11-cpu @@ -2448,7 +2448,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_11-cuda11_8 @@ -2697,7 +2697,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_11-cuda12_1 @@ -2946,7 +2946,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_11-cuda12_4 @@ -3193,7 +3193,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_12-cpu @@ -3441,7 +3441,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_12-cuda11_8 @@ -3690,7 +3690,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_12-cuda12_1 @@ -3939,7 +3939,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_12-cuda12_4 From 1d60780bad8cb146e6965034d3a18e08db598557 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 3 Sep 2024 16:01:47 -0700 Subject: [PATCH 0382/1018] [dynamo] Use __eq__ for backend match (#135039) Fixes https://github.com/pytorch/pytorch/issues/131150 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135039 Approved by: https://github.com/jansel --- test/dynamo/test_backends.py | 16 ++++++++++++++++ torch/__init__.py | 4 ---- torch/csrc/dynamo/cache_entry.cpp | 8 +++----- torch/csrc/dynamo/extra_state.cpp | 11 +++++++++-- torch/csrc/dynamo/extra_state.h | 2 +- 5 files changed, 29 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index e9d94d4f67be06..bf386bbf164926 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -330,6 +330,22 @@ def mock_eps(group=None): backends = list_backends() assert name in backends, (name, backends) + def test_backend_recompilation(self): + def fn(x): + return x + x + + input = torch.tensor(2.0) + + opt_fn = torch.compile( + fn, backend="inductor", options={"_raise_error_for_testing": False} + ) + opt_fn(input) + with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed): + opt_fn = torch.compile( + fn, backend="inductor", options={"_raise_error_for_testing": True} + ) + opt_fn(input) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/__init__.py b/torch/__init__.py index 0a0776ce459f4f..d5a4644a85490f 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2180,10 +2180,6 @@ def __init__(self, mode, options, dynamic): self.apply_mode(mode) self.apply_options(options) - # Stash the compiler_fn to be used for backend match guard. - from torch._inductor.compile_fx import compile_fx - - self.compiler_fn = compile_fx if self.config.get("triton.cudagraphs", False): os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" # FIXME: CUDA Graph does not work well with CUPTI teardown. diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp index 8a5ab17e57a490..1142187ffc21d1 100644 --- a/torch/csrc/dynamo/cache_entry.cpp +++ b/torch/csrc/dynamo/cache_entry.cpp @@ -4,11 +4,11 @@ #include #include -CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) { +CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) + : backend(backend) { this->check_fn = guarded_code.attr("check_fn"); this->code = guarded_code.attr("code"); this->compile_id = guarded_code.attr("compile_id"); - this->backend = backend; // TODO - clean this up when enable_cpp_guard_manager is True by default if (py::hasattr(this->check_fn, "root")) { this->root_mgr = torch::dynamo::convert_to_root_guard_manager( @@ -16,6 +16,7 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) { } } +// NOLINTNEXTLINE(bugprone-exception-escape) CacheEntry::~CacheEntry() { // prevent check_fn from use-after-free when invalidating this->check_fn.attr("cache_entry") = py::none(); @@ -48,8 +49,5 @@ PyObject* get_backend(PyObject* callback) { while (py::hasattr(handle, "_torchdynamo_orig_callable")) { handle = handle.attr("_torchdynamo_orig_callable"); } - if (py::hasattr(handle, "compiler_fn")) { - handle = handle.attr("compiler_fn"); - } return handle.ptr(); } diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 01d29eab519716..9142470315cd9b 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -86,13 +86,20 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code) { PyObject* lookup( ExtraState* extra_state, PyObject* f_locals, - const PyObject* backend) { + PyObject* backend) { size_t index = 0; CacheEntry* found = nullptr; py::handle locals(f_locals); for (CacheEntry& cache_entry : extra_state->cache_entry_list) { // Check backend. Py_False means run only mode. - bool valid = backend == Py_False || cache_entry.backend == backend; + + // The Py_TYPE check should not be required but there is a pre-existing + // issue where backend is possibly deallocated (or nullptr) and causes + // segfaults. Check test - test_inplace_custom_op_intermediate + bool valid = backend == Py_False || + (Py_TYPE(cache_entry.backend) == Py_TYPE(backend) && + PyObject_RichCompareBool(cache_entry.backend, backend, Py_EQ)); + if (valid) { try { // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h index 48784f9e4a6e00..22997d0e1a252c 100644 --- a/torch/csrc/dynamo/extra_state.h +++ b/torch/csrc/dynamo/extra_state.h @@ -127,7 +127,7 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code); PyObject* lookup( ExtraState* extra_state, PyObject* f_locals, - const PyObject* backend); + PyObject* backend); // Create a new cache entry at extra_state holding on to guarded_code. // Ownership contract From 27c8ec956a3d9d1be0172d0e53d6050c990ce87f Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 4 Sep 2024 03:47:56 +0000 Subject: [PATCH 0383/1018] Remove Caffe2 code from tool scripts (#134941) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134941 Approved by: https://github.com/ezyang --- tools/amd_build/build_amd.py | 13 ------------- tools/build_pytorch_libs.py | 6 ------ 2 files changed, 19 deletions(-) diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 967f63ae2c8f15..b06799702e6258 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -64,21 +64,10 @@ out_dir = args.output_directory includes = [ - "caffe2/operators/*", - "caffe2/sgd/*", - "caffe2/image/*", - "caffe2/transforms/*", - "caffe2/video/*", - "caffe2/distributed/*", - "caffe2/queue/*", - "caffe2/contrib/aten/*", "binaries/*", "caffe2/**/*_test*", "caffe2/core/*", - "caffe2/db/*", "caffe2/utils/*", - "caffe2/contrib/gloo/*", - "caffe2/contrib/nccl/*", "c10/cuda/*", "c10/cuda/test/CMakeLists.txt", "modules/*", @@ -118,8 +107,6 @@ includes.append(abs_new_dir) ignores = [ - "caffe2/operators/depthwise_3x3_conv_op_cudnn.cu", - "caffe2/operators/pool_op_cudnn.cu", "*/hip/*", # These files are compatible with both cuda and hip "aten/src/ATen/core/*", diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 64e7132a691017..3c5893be97e7e9 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -2,7 +2,6 @@ import os import platform -import shutil from glob import glob from setuptools import distutils # type: ignore[import] @@ -87,8 +86,3 @@ def build_caffe2( if cmake_only: return cmake.build(my_env) - if build_python: - caffe2_proto_dir = os.path.join(cmake.build_dir, "caffe2", "proto") - for proto_file in glob(os.path.join(caffe2_proto_dir, "*.py")): - if proto_file != os.path.join(caffe2_proto_dir, "__init__.py"): - shutil.copy(proto_file, os.path.join("caffe2", "proto")) From 591aab9a7103e0c42a43df1d7f18a5a7f59c0714 Mon Sep 17 00:00:00 2001 From: CK Luk Date: Wed, 4 Sep 2024 04:50:32 +0000 Subject: [PATCH 0384/1018] Back out "[FSDP2] Set `ctx.set_materialize_grads(False)` for post-backward (#133498)" (#135059) Summary: Original commit changeset: 96513cbc425f Original Phabricator Diff: D61291210 There is some evidence that FB-FM-v4 has better NE with Set ctx.set_materialize_grads(False), especially when pairing up with prefetching. See https://www.internalfb.com/intern/anp/view/?id=5732259 Test Plan: export NUM_WORKERS=128 export BATCH_SIZE=1024 export CONFIG_FILE="mast_joint_arch_exploration_cmf_updated_fbfm_v3_fsdp2.yaml" export ENTITLEMENT=ads_global_tc_2k_training_large_short buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -c fbcode.platform010_cuda_version=12 -c hpc_comms.use_nccl=2.17.1 -- mode=${CONFIG_FILE} launcher.tags='[ads_ranking_taxonomy_monetization_genai]' launcher.data_project=pytorch_at_scale launcher.max_retries=10 launcher.fbl_entitl ement=${ENTITLEMENT} launcher.oncall=pytorch_training_enablement launcher.hardware=GRANDTETON launcher.num_workers=${NUM_WORKERS} data_loader.dataset.batch_size=${BATCH_SIZE} training.planner.proposer=dynamic_col_dim training.planner.proposer.optim_target=h bm 2>&1| tee ~/tmp/log.mast Differential Revision: D62009163 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135059 Approved by: https://github.com/awgu --- torch/distributed/_composable/fsdp/_fsdp_param_group.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 32f65998d6ce4c..3c2f639607e14f 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -606,7 +606,6 @@ class RegisterPostBackwardFunction(torch.autograd.Function): def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor): # All tensors in `inputs` should require gradient ctx.param_group = param_group - ctx.set_materialize_grads(False) return inputs @staticmethod From 9e0cdc27cd49fe2b0725207285671a3e893ed0e2 Mon Sep 17 00:00:00 2001 From: CaoE Date: Sun, 1 Sep 2024 03:57:22 -0700 Subject: [PATCH 0385/1018] Add specializations for VecMaskLoad and VecMaskCast (#126501) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126501 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #126500 --- aten/src/ATen/cpu/vec/vec256/vec256_mask.h | 249 ++++++++++++++-- aten/src/ATen/cpu/vec/vec512/vec512_mask.h | 326 ++++++++++++++++++--- aten/src/ATen/cpu/vec/vec_mask.h | 40 ++- aten/src/ATen/cpu/vec/vec_n.h | 3 + aten/src/ATen/test/vec_test_all_types.cpp | 159 ++++++---- aten/src/ATen/test/vec_test_all_types.h | 26 +- 6 files changed, 659 insertions(+), 144 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_mask.h b/aten/src/ATen/cpu/vec/vec256/vec256_mask.h index dd6a8c52d82655..3460abe17e159d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_mask.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_mask.h @@ -9,49 +9,218 @@ inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + template struct VecMaskLoad< T, - 1, + 2, mask_t, 1, typename std::enable_if_t< - std::is_same_v || std::is_same_v || - std::is_same_v, - void>> { - static inline VectorizedN apply( + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( const T* ptr, const VecMask& vec_mask) { - auto int_mask = vec_mask.template cast()[0]; - if constexpr (std::is_same_v) { - return Vectorized(_mm256_maskload_ps(ptr, int_mask)); + auto int64_mask = vec_mask.template cast(); + auto result = at::vec::VectorizedN(); + if constexpr (std::is_same_v) { + result[0] = _mm256_maskload_pd(ptr, int64_mask[0]); + result[1] = _mm256_maskload_pd( + ptr + at::vec::Vectorized::size(), int64_mask[1]); } else { - return Vectorized(_mm256_maskload_epi32(ptr, int_mask)); + result[0] = _mm256_maskload_epi64( + reinterpret_cast(ptr), int64_mask[0]); + result[1] = _mm256_maskload_epi64( + reinterpret_cast( + ptr + at::vec::Vectorized::size()), + int64_mask[1]); } + return result; } }; // TODO: add specialization of VecMaskLoad for bfloat16/half and int8/uint8 -template <> -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - return Vectorized(_mm256_castsi256_ps(vec_mask[0])); +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_ps(vec_mask[i]); + } + return result; } }; -template <> -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - return Vectorized(_mm256_castps_si256(vec_mask[0])); +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castps_si256(vec_mask[i]); + } + return result; } }; -template -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - auto int_vec = convert(VectorizedN(vec_mask)); - return VecMask(int_vec).cast(); +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castpd_si256(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); } }; @@ -71,6 +240,42 @@ inline bool VecMask::all_masked() const { return mask == 0xff; } +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = all_zero && (_mm256_testz_si256(vec_mask[i], vec_mask[i]) > 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 4) { + return _mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[j])) & + (1 << (i - j * 4)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && + (_mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[i])) == 0x0f); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + #define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ T, N, return_type, method, args_def, args) \ template <> \ diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_mask.h b/aten/src/ATen/cpu/vec/vec512/vec512_mask.h index 9ba1b18372eb54..cdb433af252541 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_mask.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_mask.h @@ -9,50 +9,139 @@ inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) -template +template struct VecMaskLoad< T, - 1, + dst_n, mask_t, - 1, + mask_n, typename std::enable_if_t< - std::is_same_v || std::is_same_v || - std::is_same_v, + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), void>> { - static inline VectorizedN apply( + static inline VectorizedN apply( const T* ptr, - const VecMask& vec_mask) { + const VecMask& vec_mask) { at::vec::Vectorized zero_vec(0); auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); - auto int_mask = vec_mask.template cast()[0]; - auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); - if constexpr (std::is_same_v) { - return Vectorized(_mm512_mask_loadu_ps(zero_vec, mmask, ptr)); - } else { - return Vectorized(_mm512_mask_loadu_epi32(zero_vec, mmask, ptr)); + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } } + return result; } }; -template +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } + } + return result; + } +}; + +template struct VecMaskLoad< data_t, - 1, + dst_n, mask_t, - 1, + dst_n, typename std::enable_if< std::is_same_v || std::is_same_v>::type> { - static inline VectorizedN apply( + static inline VectorizedN apply( const data_t* ptr, - const VecMask& vec_mask) { + const VecMask& vec_mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); - auto int_mask = vec_mask.template cast()[0]; - auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); - auto zero = _mm256_set1_epi16(0); - auto temp = _mm256_mask_loadu_epi16(zero, mmask, ptr); - return Vectorized( - _mm512_inserti32x8(_mm512_castsi256_si512(temp), zero, 1)); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n && dst_n >= 1) && + (std::is_same_v || std::is_same_v)>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; + VectorizedN tmp_vec; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int_mask = VecMask(tmp_vec).template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; } }; @@ -78,41 +167,155 @@ struct VecMaskLoad< } }; -template -struct VecMaskLoad { - static inline VectorizedN apply( - const int64_t* ptr, +template +struct VecMaskLoad< + data_t, + 2, + mask_t, + 1, + typename std::enable_if< + std::is_same_v || + std::is_same_v>::type> { + static inline VectorizedN apply( + const data_t* ptr, const VecMask& vec_mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); - auto zero = _mm512_set1_epi64(0); + at::vec::Vectorized zero_vec(0); auto int_mask = vec_mask.template cast()[0]; auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); - at::vec::VectorizedN result; - result[0] = _mm512_mask_loadu_epi64(zero, (__mmask8)mmask, ptr); - result[1] = _mm512_mask_loadu_epi64(zero, (__mmask8)(mmask >> 8), ptr + 8); + at::vec::VectorizedN result; + if constexpr (std::is_same_v) { + result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } else { + result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } return result; } }; -template <> -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - return Vectorized(_mm512_castsi512_ps(vec_mask[0])); +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_ps(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castps_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castpd_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); } }; template <> -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - return Vectorized(_mm512_castps_si512(vec_mask[0])); +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); } }; -template -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - auto int_vec = convert(VectorizedN(vec_mask)); - return VecMask(int_vec).cast(); +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); } }; @@ -133,6 +336,41 @@ inline bool VecMask::all_masked() const { return mask == 0xffff; } +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = + all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 8) { + return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + #define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ T, N, return_type, method, args_def, args) \ template <> \ diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index ebec8d4a3e3c5d..2783417845492c 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -2,7 +2,6 @@ #include #include - namespace at::vec { inline namespace CPU_CAPABILITY { @@ -69,7 +68,7 @@ struct VecMaskTo { } }; -template +template struct VecMaskCast { static inline VecMask apply( const VecMask& vec_mask) { @@ -84,6 +83,29 @@ struct VecMaskCast { } }; +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of( + mask, mask + VectorizedN::size(), [](T m) { return m == static_cast(0); }); + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of( + mask, mask + VectorizedN::size(), [](T m) { return m != static_cast(0); }); + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return mask[i] != static_cast(0); + } +}; + template class VecMask { public: @@ -170,23 +192,15 @@ class VecMask { } inline bool all_zero() const { - __at_align__ T mask[size()]; - mask_.store(mask); - return std::all_of( - mask, mask + size(), [](T m) { return m == static_cast(0); }); + return VecMaskCheck::all_zero(mask_); } inline bool all_masked() const { - __at_align__ T mask[size()]; - mask_.store(mask); - return std::all_of( - mask, mask + size(), [](T m) { return m != static_cast(0); }); + return VecMaskCheck::all_masked(mask_); } inline bool is_masked(int i) const { - __at_align__ T mask[size()]; - mask_.store(mask); - return mask[i] != static_cast(0); + return VecMaskCheck::is_masked(mask_, i); } inline operator VectorizedN() const { diff --git a/aten/src/ATen/cpu/vec/vec_n.h b/aten/src/ATen/cpu/vec/vec_n.h index 2c0f2eef4b7ec6..8c4e622682a285 100644 --- a/aten/src/ATen/cpu/vec/vec_n.h +++ b/aten/src/ATen/cpu/vec/vec_n.h @@ -88,6 +88,9 @@ class VectorizedN { template = 0> VectorizedN(const Vectorized& val) : values({val}) {} + template = 0> + VectorizedN(const Vectorized& val_0, const Vectorized& val_1) : values({val_0, val_1}) {} + template = 0> inline operator Vectorized() const { return values[0]; diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index f9a0557f8bdfff..de1f70e97ca5a8 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -1660,40 +1660,84 @@ namespace { } TYPED_TEST(VecMaskTests, MaskedLoad) { using vec = TypeParam; - using VT = ValueType; - constexpr auto N = vec::size(); - CACHE_ALIGN VT x[N]; - CACHE_ALIGN VT y[N]; - CACHE_ALIGN VT ref[N]; - auto seed = TestSeed(); - ValueGen generator(VT(-100), VT(100), seed); - for (const auto i : c10::irange(N)) { - x[i] = generator.get(); - } - auto vec_mask = generate_vec_mask(seed); - auto x_vec = vec_mask.template loadu(x); - x_vec.store(y); - for (const auto i : c10::irange(N)) { - if (vec_mask.is_masked(i)) { - ref[i] = x[i]; - } else { - ref[i] = 0; - } - } - for (const auto i : c10::irange(N)) { - ASSERT_EQ(y[i], ref[i]) - << "Failure Details:\nTest Seed to reproduce: " << seed; - } + using src_t = ValueType; + constexpr auto size = vec::size(); + + #define TEST_MASK_LOAD(dst_t, mask_t, mask_n) \ + do { \ + CACHE_ALIGN dst_t x[mask_n * size]; \ + CACHE_ALIGN dst_t y[mask_n * size]; \ + CACHE_ALIGN dst_t ref[mask_n * size]; \ + auto seed = TestSeed(); \ + ValueGen generator(dst_t(-100), dst_t(100), seed); \ + for (const auto i : c10::irange(mask_n * size)) { \ + x[i] = generator.get(); \ + } \ + auto vec_mask = generate_vec_mask(seed); \ + constexpr int dst_size = at::vec::Vectorized::size(); \ + constexpr int dst_n = mask_n * size / dst_size; \ + constexpr int rnd_n = (mask_n * size + dst_size - 1) / dst_size; \ + if constexpr(dst_n * dst_size >= mask_n * size) { \ + auto x_vec = vec_mask.template loadu(x); \ + x_vec.store(y); \ + for (const auto i : c10::irange(mask_n * size)) { \ + if (vec_mask.is_masked(i)) { \ + ref[i] = x[i]; \ + } else { \ + ref[i] = 0; \ + } \ + } \ + for (const auto i : c10::irange(mask_n * size)) { \ + ASSERT_EQ(y[i], ref[i]) \ + << "Failure Details:\nTest Seed to reproduce: " << seed; \ + } \ + } \ + } while (0) + + + #define TEST_MASK_LOAD_N(N) \ + TEST_MASK_LOAD(int8_t, src_t, N); \ + TEST_MASK_LOAD(uint8_t, src_t, N); \ + TEST_MASK_LOAD(int16_t, src_t, N); \ + TEST_MASK_LOAD(uint16_t, src_t, N); \ + TEST_MASK_LOAD(int32_t, src_t, N); \ + TEST_MASK_LOAD(uint32_t, src_t, N); \ + TEST_MASK_LOAD(int64_t, src_t, N); \ + TEST_MASK_LOAD(uint64_t, src_t, N); \ + TEST_MASK_LOAD(c10::BFloat16, src_t, N); \ + TEST_MASK_LOAD(c10::Half, src_t, N); \ + TEST_MASK_LOAD(float, src_t, N); \ + TEST_MASK_LOAD(double, src_t, N); + + TEST_MASK_LOAD_N(1) + TEST_MASK_LOAD_N(2) + TEST_MASK_LOAD_N(4) + + #undef TEST_MASK_LOAD + #undef TEST_MASK_LOAD_N } TYPED_TEST(VecMaskTests, MaskedCheck) { using VT = ValueType; - auto vec_mask = create_vec_mask(0); - ASSERT_TRUE(vec_mask.all_zero()) << "all_zero check failed"; - vec_mask = create_vec_mask(-1); - ASSERT_TRUE(vec_mask.all_masked()) << "all_masked check failed"; - vec_mask = create_vec_mask(2); - ASSERT_TRUE(vec_mask.is_masked(1)) << "is_masked(1) check failed"; - ASSERT_TRUE(!vec_mask.is_masked(0)) << "!is_masked(0) check failed"; + using vec = TypeParam; + constexpr auto size = vec::size(); + #define TEST_MASK_CHECK_N(N) \ + do { \ + auto vec_mask = create_vec_mask(0); \ + ASSERT_TRUE(vec_mask.all_zero()) << "all_zero check failed"; \ + vec_mask = create_vec_mask(-1); \ + ASSERT_TRUE(vec_mask.all_masked()) << "all_masked check failed"; \ + vec_mask = create_vec_mask(2); \ + for (int i = 0; i < N; i ++) { \ + ASSERT_TRUE(vec_mask.is_masked(1 + i * size)) << "is_masked(1) check failed"; \ + ASSERT_TRUE(!vec_mask.is_masked(0 + i * size)) << "!is_masked(0) check failed"; \ + } \ + } while (0) + + TEST_MASK_CHECK_N(1); + TEST_MASK_CHECK_N(2); + TEST_MASK_CHECK_N(4); + + #undef TEST_MASK_CHECK_N } TYPED_TEST(VecMaskTests, ToFrom) { using vec = TypeParam; @@ -1723,37 +1767,46 @@ namespace { TYPED_TEST(VecMaskTests, Cast) { using vec = TypeParam; using src_t = ValueType; - constexpr auto N = vec::size(); - #define TEST_MASK_CAST(dst_t) \ + constexpr auto size = vec::size(); + + #define TEST_MASK_CAST(dst_t, mask_t, mask_n) \ do { \ - CACHE_ALIGN src_t x[N]; \ - CACHE_ALIGN dst_t y[N]; \ + CACHE_ALIGN mask_t x[mask_n * size]; \ + CACHE_ALIGN dst_t y[mask_n * size]; \ auto seed = TestSeed(); \ - auto vec_mask = generate_vec_mask(seed); \ + auto vec_mask = generate_vec_mask(seed); \ constexpr int num_dst_elements = \ - std::min(N, at::vec::Vectorized::size()); \ - constexpr int dst_n = N / num_dst_elements; \ + std::min(size, at::vec::Vectorized::size()); \ + constexpr int dst_n = mask_n * size / num_dst_elements; \ auto vec_mask_new = vec_mask.template cast(); \ - vec_mask.template to().store(x); \ - vec_mask_new.template to().store(y, N); \ - for (const auto i : c10::irange(N)) { \ + vec_mask.template to().store(x); \ + vec_mask_new.template to().store(y); \ + for (const auto i : c10::irange(mask_n * size)) { \ ASSERT_EQ(y[i], x[i]) \ << "Failure Details:\nTest Seed to reproduce: " << seed; \ } \ } while (0) - TEST_MASK_CAST(int8_t); - TEST_MASK_CAST(uint8_t); - TEST_MASK_CAST(int16_t); - TEST_MASK_CAST(uint16_t); - TEST_MASK_CAST(int32_t); - TEST_MASK_CAST(uint32_t); - TEST_MASK_CAST(int64_t); - TEST_MASK_CAST(uint64_t); - TEST_MASK_CAST(c10::BFloat16); - TEST_MASK_CAST(c10::Half); - TEST_MASK_CAST(float); - TEST_MASK_CAST(double); + + #define TEST_MASK_CAST_N(N) \ + TEST_MASK_CAST(int8_t, src_t, N); \ + TEST_MASK_CAST(uint8_t, src_t, N); \ + TEST_MASK_CAST(int16_t, src_t, N); \ + TEST_MASK_CAST(uint16_t, src_t, N); \ + TEST_MASK_CAST(int32_t, src_t, N); \ + TEST_MASK_CAST(uint32_t, src_t, N); \ + TEST_MASK_CAST(int64_t, src_t, N); \ + TEST_MASK_CAST(uint64_t, src_t, N); \ + TEST_MASK_CAST(c10::BFloat16, src_t, N); \ + TEST_MASK_CAST(c10::Half, src_t, N); \ + TEST_MASK_CAST(float, src_t, N); \ + TEST_MASK_CAST(double, src_t, N); + + TEST_MASK_CAST_N(1) + TEST_MASK_CAST_N(2) + TEST_MASK_CAST_N(4) + #undef TEST_MASK_CAST + #undef TEST_MASK_CAST_N } #else #error GTEST does not have TYPED_TEST diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 91788fcd56039d..98fda78885aecf 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -1434,22 +1434,24 @@ double getDefaultTolerance() { return 1.e-9; } -template -at::vec::VecMask create_vec_mask(uint64_t bitmask) { - constexpr auto N = at::vec::Vectorized::size(); - std::array mask; - for (int i = 0; i < N; i++) { - mask[i] = (bitmask >> i) & 1; +template +at::vec::VecMask create_vec_mask(uint64_t bitmask) { + constexpr auto size = at::vec::Vectorized::size(); + std::array mask; + for (int n = 0; n < N; n++) { + for (int i = 0; i < size; i++) { + mask[n * size + i] = (bitmask >> i) & 1; + } } - return at::vec::VecMask::from(mask.data()); + return at::vec::VecMask::from(mask.data()); } -template -at::vec::VecMask generate_vec_mask(int seed) { - constexpr auto N = at::vec::Vectorized::size(); - ValueGen generator(0, (1ULL << N) - 1, seed); +template +at::vec::VecMask generate_vec_mask(int seed) { + constexpr auto size = at::vec::Vectorized::size(); + ValueGen generator(0, (1ULL << size) - 1, seed); auto bitmask = generator.get(); - return create_vec_mask(bitmask); + return create_vec_mask(bitmask); } template From 58a915ec6a6b5b444d7bbe5bc21e87708bc9c2ee Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Wed, 4 Sep 2024 05:17:27 +0000 Subject: [PATCH 0386/1018] [executorch hash update] update the pinned executorch hash (#135077) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135077 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index 6d1222c950e8de..e7bb12db3dc0a3 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -61ddee54058bf3fc7e17eb539d61d39d12cc86f8 +d519b4d3a1ffdc81b45e2b1d4733423ce0577813 From cccd75e88b8a34cbb0795058cb36ad17c9bbe5fa Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 4 Sep 2024 05:29:08 +0000 Subject: [PATCH 0387/1018] [inductor] clean up cpp_builder code. (#134909) Clean up cpp_builder duplication code. Hi @henrylhtsang , could you please help on land internally? Pull Request resolved: https://github.com/pytorch/pytorch/pull/134909 Approved by: https://github.com/henrylhtsang --- torch/_inductor/cpp_builder.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 7cfefbc1d84af9..dade21d4303e58 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1188,10 +1188,7 @@ def get_cpp_torch_cuda_options( if config.is_fbcode(): libraries += ["cuda"] else: - if config.is_fbcode(): - libraries += ["cuda"] - else: - libraries += ["c10_cuda", "cuda", "torch_cuda"] + libraries += ["c10_cuda", "cuda", "torch_cuda"] if aot_mode: if config.is_fbcode(): From 828b95723e0d87a09a45e47da1394a1c7d34b131 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 4 Sep 2024 09:18:43 +0000 Subject: [PATCH 0388/1018] [Reland] Refactor caching device allocator utils (#130923) # Motivation Following [[RFC] Intel GPU Runtime Upstreaming for Allocator ](https://github.com/pytorch/pytorch/issues/116322), this PR aims to refactor caching device allocator utils to improve code reuse usage. This is the first PR, we could prepare some follow-up PRs continuing to refactor the device caching allocator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130923 Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD, https://github.com/eqy --- c10/core/CachingDeviceAllocator.h | 131 +++++++++++++++++ c10/cuda/CUDACachingAllocator.cpp | 163 +++++++-------------- c10/cuda/CUDACachingAllocator.h | 82 ++--------- c10/cuda/CUDAMallocAsyncAllocator.cpp | 2 + torch/csrc/cuda/CUDAPluggableAllocator.cpp | 4 +- torch/csrc/cuda/CUDAPluggableAllocator.h | 2 +- torch/csrc/cuda/Module.cpp | 8 +- 7 files changed, 202 insertions(+), 190 deletions(-) create mode 100644 c10/core/CachingDeviceAllocator.h diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h new file mode 100644 index 00000000000000..8724ecf88ae0f9 --- /dev/null +++ b/c10/core/CachingDeviceAllocator.h @@ -0,0 +1,131 @@ +#pragma once + +#include +#include + +#include + +namespace c10::CachingDeviceAllocator { + +struct Stat { + void increase(size_t amount) { + current += static_cast(amount); + peak = std::max(current, peak); + allocated += static_cast(amount); + } + + void decrease(size_t amount) { + current -= static_cast(amount); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + current >= 0, + "Negative tracked stat in device allocator (likely logic error)."); + freed += static_cast(amount); + } + + void reset_accumulated() { + allocated = 0; + freed = 0; + } + + void reset_peak() { + peak = current; + } + + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +using StatArray = std::array(StatType::NUM_TYPES)>; +using StatTypes = std::array(StatType::NUM_TYPES)>; + +template +void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { + for (const auto stat_type : c10::irange(stat_types.size())) { + if (stat_types[stat_type]) { + f(stat_type); + } + } +} + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from device memory allocation. + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via device memory deallocation) + StatArray inactive_split; + + // SUM: bytes allocated by this memory alocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to device malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to device memory allocation + // after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of device memory allocation calls. This includes both + // mapped and malloced memory. + int64_t num_device_alloc = 0; + + // COUNT: total number of device memory deallocation calls. This includes both + // un-mapped and free memory. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + +// Size pretty-printer +inline std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (static_cast(size) / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << static_cast(size) / 1048576.0; + os << " MiB"; + } else { + os << static_cast(size) / 1073741824.0; + os << " GiB"; + } + return os.str(); +} + +} // namespace c10::CachingDeviceAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index fa26a9f3e0cca7..470264321fffd6 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include @@ -27,7 +26,6 @@ #include #include #include -#include #include #include #include @@ -44,6 +42,8 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace cuda::CUDACachingAllocator { +using namespace c10::CachingDeviceAllocator; + // Included here as this is externally used in CUDAAllocatorConfig const size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks @@ -134,47 +134,13 @@ namespace { using stream_set = ska::flat_hash_set; -using StatTypes = std::array(StatType::NUM_TYPES)>; - -void increase_stat(Stat& stat, size_t amount) { - stat.current += static_cast(amount); - stat.peak = std::max(stat.current, stat.peak); - stat.allocated += static_cast(amount); -} - -void decrease_stat(Stat& stat, size_t amount) { - stat.current -= static_cast(amount); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - stat.current >= 0, - "Negative tracked stat in CUDA allocator (likely logic error)."); - stat.freed += static_cast(amount); -} - -void reset_accumulated_stat(Stat& stat) { - stat.allocated = 0; - stat.freed = 0; -} - -void reset_peak_stat(Stat& stat) { - stat.peak = stat.current; -} - -template -void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { - for (const auto stat_type : c10::irange(stat_types.size())) { - if (stat_types[stat_type]) { - f(stat_type); - } - } -} - void decrease_stat_array( StatArray& stat_array, size_t amount, const StatTypes& stat_types) { for_each_selected_stat_type( stat_types, [&stat_array, amount](size_t stat_type) { - decrease_stat(stat_array[stat_type], amount); + stat_array[stat_type].decrease(amount); }); } @@ -1424,16 +1390,16 @@ class DeviceCachingAllocator { // A new split inactive block is being created from a previously unsplit // block, size remaining->size bytes. for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - increase_stat(stats.inactive_split_bytes[stat_type], remaining->size); - increase_stat(stats.inactive_split[stat_type], 1); + stats.inactive_split_bytes[stat_type].increase(remaining->size); + stats.inactive_split[stat_type].increase(1); }); } } else if (already_split && !block->expandable_segment_) { // An already-split block is becoming active for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - decrease_stat(stats.inactive_split_bytes[stat_type], block->size); - decrease_stat(stats.inactive_split[stat_type], 1); + stats.inactive_split_bytes[stat_type].decrease(block->size); + stats.inactive_split[stat_type].decrease(1); }); } @@ -1454,14 +1420,14 @@ class DeviceCachingAllocator { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - increase_stat(stats.allocation[stat_type], 1); - increase_stat(stats.allocated_bytes[stat_type], block->size); - increase_stat(stats.active[stat_type], 1); - increase_stat(stats.active_bytes[stat_type], block->size); - increase_stat(stats.requested_bytes[stat_type], block->requested_size); + stats.allocation[stat_type].increase(1); + stats.allocated_bytes[stat_type].increase(block->size); + stats.active[stat_type].increase(1); + stats.active_bytes[stat_type].increase(block->size); + stats.requested_bytes[stat_type].increase(block->requested_size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - increase_stat(stats.oversize_allocations, 1); + stats.oversize_allocations.increase(1); c10::reportMemoryUsageToProfiler( block->ptr, @@ -1487,8 +1453,8 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.allocation[stat_type], 1); - decrease_stat(stats.allocated_bytes[stat_type], block->size); + stats.allocation[stat_type].decrease(1); + stats.allocated_bytes[stat_type].decrease(block->size); }); record_trace( @@ -1500,7 +1466,7 @@ class DeviceCachingAllocator { context ? context : block->context_when_allocated); if (block->size >= CUDAAllocatorConfig::max_split_size()) - decrease_stat(stats.oversize_allocations, 1); + stats.oversize_allocations.decrease(1); if (!block->stream_uses.empty()) { if (C10_UNLIKELY(!captures_underway.empty())) { @@ -1628,15 +1594,15 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - reset_accumulated_stat(stats.allocation[statType]); - reset_accumulated_stat(stats.segment[statType]); - reset_accumulated_stat(stats.active[statType]); - reset_accumulated_stat(stats.inactive_split[statType]); - reset_accumulated_stat(stats.allocated_bytes[statType]); - reset_accumulated_stat(stats.reserved_bytes[statType]); - reset_accumulated_stat(stats.active_bytes[statType]); - reset_accumulated_stat(stats.inactive_split_bytes[statType]); - reset_accumulated_stat(stats.requested_bytes[statType]); + stats.allocation[statType].reset_accumulated(); + stats.segment[statType].reset_accumulated(); + stats.active[statType].reset_accumulated(); + stats.inactive_split[statType].reset_accumulated(); + stats.allocated_bytes[statType].reset_accumulated(); + stats.reserved_bytes[statType].reset_accumulated(); + stats.active_bytes[statType].reset_accumulated(); + stats.inactive_split_bytes[statType].reset_accumulated(); + stats.requested_bytes[statType].reset_accumulated(); } stats.num_alloc_retries = 0; @@ -1644,8 +1610,8 @@ class DeviceCachingAllocator { stats.num_sync_all_streams = 0; stats.num_device_alloc = 0; stats.num_device_free = 0; - reset_accumulated_stat(stats.oversize_allocations); - reset_accumulated_stat(stats.oversize_segments); + stats.oversize_allocations.reset_accumulated(); + stats.oversize_segments.reset_accumulated(); } /** Resets the historical peak stats for the device **/ @@ -1654,18 +1620,18 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - reset_peak_stat(stats.allocation[statType]); - reset_peak_stat(stats.segment[statType]); - reset_peak_stat(stats.active[statType]); - reset_peak_stat(stats.inactive_split[statType]); - reset_peak_stat(stats.allocated_bytes[statType]); - reset_peak_stat(stats.reserved_bytes[statType]); - reset_peak_stat(stats.active_bytes[statType]); - reset_peak_stat(stats.inactive_split_bytes[statType]); - reset_peak_stat(stats.requested_bytes[statType]); + stats.allocation[statType].reset_peak(); + stats.segment[statType].reset_peak(); + stats.active[statType].reset_peak(); + stats.inactive_split[statType].reset_peak(); + stats.allocated_bytes[statType].reset_peak(); + stats.reserved_bytes[statType].reset_peak(); + stats.active_bytes[statType].reset_peak(); + stats.inactive_split_bytes[statType].reset_peak(); + stats.requested_bytes[statType].reset_peak(); } - reset_peak_stat(stats.oversize_allocations); - reset_peak_stat(stats.oversize_segments); + stats.oversize_allocations.reset_peak(); + stats.oversize_segments.reset_peak(); } /* Checkpoint the state of a private pool necessary to return it to its @@ -2277,7 +2243,7 @@ class DeviceCachingAllocator { total_allocated_memory += mapped_range.size; StatTypes stat_types = get_stat_types_for_pool(*to_map->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - increase_stat(stats.reserved_bytes[stat_type], mapped_range.size); + stats.reserved_bytes[stat_type].increase(mapped_range.size); }); stats.num_device_alloc++; @@ -2384,27 +2350,23 @@ class DeviceCachingAllocator { // inactive_split if (!block->expandable_segment_) { if (net_change_inactive_split_blocks > 0) { - increase_stat( - stats.inactive_split[stat_type], + stats.inactive_split[stat_type].increase( static_cast(net_change_inactive_split_blocks)); } else if (net_change_inactive_split_blocks < 0) { - decrease_stat( - stats.inactive_split[stat_type], + stats.inactive_split[stat_type].decrease( static_cast(-net_change_inactive_split_blocks)); } if (net_change_inactive_split_size > 0) { - increase_stat( - stats.inactive_split_bytes[stat_type], + stats.inactive_split_bytes[stat_type].increase( static_cast(net_change_inactive_split_size)); } else if (net_change_inactive_split_size < 0) { - decrease_stat( - stats.inactive_split_bytes[stat_type], + stats.inactive_split_bytes[stat_type].decrease( static_cast(-net_change_inactive_split_size)); } } - decrease_stat(stats.active[stat_type], 1); - decrease_stat(stats.active_bytes[stat_type], original_block_size); - decrease_stat(stats.requested_bytes[stat_type], requested_size); + stats.active[stat_type].decrease(1); + stats.active_bytes[stat_type].decrease(original_block_size); + stats.requested_bytes[stat_type].decrease(requested_size); }); } @@ -2716,11 +2678,11 @@ class DeviceCachingAllocator { total_allocated_memory += size; p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { - increase_stat(stats.segment[stat_type], 1); - increase_stat(stats.reserved_bytes[stat_type], size); + stats.segment[stat_type].increase(1); + stats.reserved_bytes[stat_type].increase(size); }); if (size >= CUDAAllocatorConfig::max_split_size()) - increase_stat(stats.oversize_segments, 1); + stats.oversize_segments.increase(1); // p.block came from new, not cudaMalloc. It should not be nullptr here. TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); @@ -2855,12 +2817,12 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.segment[stat_type], 1); - decrease_stat(stats.reserved_bytes[stat_type], block->size); + stats.segment[stat_type].decrease(1); + stats.reserved_bytes[stat_type].decrease(block->size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - decrease_stat(stats.oversize_segments, 1); + stats.oversize_segments.decrease(1); pool->blocks.erase(block); delete block; } @@ -2912,7 +2874,7 @@ class DeviceCachingAllocator { total_allocated_memory -= unmapped.size; StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.reserved_bytes[stat_type], unmapped.size); + stats.reserved_bytes[stat_type].decrease(unmapped.size); }); if (block->pool->owner_PrivatePool) { @@ -3741,25 +3703,6 @@ void local_raw_delete(void* ptr) { } } // namespace Native -// Size pretty-printer -std::string format_size(uint64_t size) { - std::ostringstream os; - os.precision(2); - os << std::fixed; - if (size <= 1024) { - os << size << " bytes"; - } else if (size <= 1048576) { - os << (static_cast(size) / 1024.0); - os << " KiB"; - } else if (size <= 1073741824ULL) { - os << static_cast(size) / 1048576.0; - os << " MiB"; - } else { - os << static_cast(size) / 1073741824.0; - os << " GiB"; - } - return os.str(); -} namespace CudaMallocAsync { // If this is put in its own header file, it gets incorrectly renamed in HIPify. diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 72617bcaf3a944..70385654201f50 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -48,74 +48,11 @@ C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace c10::cuda::CUDACachingAllocator { -extern const size_t kLargeBuffer; - -struct Stat { - int64_t current = 0; - int64_t peak = 0; - int64_t allocated = 0; - int64_t freed = 0; -}; - -enum struct StatType : uint64_t { - AGGREGATE = 0, - SMALL_POOL = 1, - LARGE_POOL = 2, - NUM_TYPES = 3 // remember to update this whenever a new stat type is added -}; +// Preserved only for BC reasons +// NOLINTNEXTLINE(misc-unused-using-decls) +using c10::CachingDeviceAllocator::DeviceStats; -typedef std::array(StatType::NUM_TYPES)> StatArray; - -// Struct containing memory allocator summary statistics for a device. -struct DeviceStats { - // COUNT: allocations requested by client code - StatArray allocation; - // COUNT: number of allocated segments from cudaMalloc(). - StatArray segment; - // COUNT: number of active memory blocks (allocated or used by stream) - StatArray active; - // COUNT: number of inactive, split memory blocks (unallocated but can't be - // released via cudaFree) - StatArray inactive_split; - - // SUM: bytes allocated by this memory alocator - StatArray allocated_bytes; - // SUM: bytes reserved by this memory allocator (both free and used) - StatArray reserved_bytes; - // SUM: bytes within active memory blocks - StatArray active_bytes; - // SUM: bytes within inactive, split memory blocks - StatArray inactive_split_bytes; - // SUM: bytes requested by client code - StatArray requested_bytes; - - // COUNT: total number of failed calls to CUDA malloc necessitating cache - // flushes. - int64_t num_alloc_retries = 0; - - // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) - int64_t num_ooms = 0; - - // COUNT: total number of oversize blocks allocated from pool - Stat oversize_allocations; - - // COUNT: total number of oversize blocks requiring malloc - Stat oversize_segments; - - // COUNT: total number of synchronize_and_free_events() calls - int64_t num_sync_all_streams = 0; - - // COUNT: total number of CUDA allocation calls. This includes both cuMemMap - // and cudaMalloc. - int64_t num_device_alloc = 0; - - // COUNT: total number of CUDA free calls. This includes both cuMemUnmap - // and cudaFree. - int64_t num_device_free = 0; - - // SIZE: maximum block size that is allowed to be split. - int64_t max_split_size = 0; -}; +extern const size_t kLargeBuffer; typedef std::shared_ptr (*CreateContextFn)(); @@ -247,9 +184,6 @@ enum struct RecordContext { ALL = 3, // additionally record stacks for when something is freed }; -// Size pretty-printer -std::string format_size(uint64_t size); - using OutOfMemoryObserver = std::functionrecordStream(dataPtr, stream); } -inline DeviceStats getDeviceStats(c10::DeviceIndex device) { +inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) { return get()->getDeviceStats(device); } diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 3fe414b55e30a2..3a7b485ebce223 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -11,6 +11,8 @@ namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync { +using namespace c10::CachingDeviceAllocator; + #if CUDA_VERSION >= 11040 // CUDA device allocator that uses cudaMallocAsync to implement // the same interface as CUDACachingAllocator.cpp. diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index c6af163481481e..5220e86233bd67 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -210,8 +210,8 @@ void CUDAPluggableAllocator::recordStream( } } -c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: - getDeviceStats(c10::DeviceIndex device) { +c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats( + c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getDeviceStats. " diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 70014b2fa0b37b..8652ef0f2bfde8 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -115,7 +115,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator void recordStream(const c10::DataPtr&, streamType stream) override; - c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 134758b2c06814..22cbadc42fc09f 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -565,10 +565,10 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const auto device_index = THPUtils_unpackDeviceIndex(arg); - using c10::cuda::CUDACachingAllocator::DeviceStats; - using c10::cuda::CUDACachingAllocator::Stat; - using c10::cuda::CUDACachingAllocator::StatArray; - using c10::cuda::CUDACachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + using c10::CachingDeviceAllocator::Stat; + using c10::CachingDeviceAllocator::StatArray; + using c10::CachingDeviceAllocator::StatType; const auto statToDict = [](const Stat& stat) { py::dict dict; From ac1e0a56891d811c480af75daf962c9c9b4e62fe Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Wed, 4 Sep 2024 05:34:26 +0000 Subject: [PATCH 0389/1018] rationalize STATIC vs. None (#134877) Summary: A bit of refactoring to prepare to remove `None` as a way to specify static dimensions in dynamic shapes, given we already have `Dim.STATIC` for the same purpose. We will now warn whenever this happens. However no tests were modified because problematic uses of `None` still need to behave as they do today, until we are ready to remove support. It should be easy to port tests by replacing the warning function to raise instead. Note that other uses of `None`, such as for entire values (tensor or non-tensor) remain as is. Moving forward this should be the only purpose of `None` (at least externally). Finally, there's a bit of confusion in our representation now because `AUTO` also internally transforms to `None`. Renamed dynamic_shapes to transformed_dynamic_shapes where this happens. Overall the two forms (pre and post transformation) have different properties so should probably not be represented in the same format in the future. Test Plan: existing Differential Revision: D62040729 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134877 Approved by: https://github.com/pianpwk --- test/export/test_export.py | 2 +- torch/_export/non_strict_utils.py | 13 +++++++--- torch/_logging/_registrations.py | 1 + torch/export/_trace.py | 7 +++--- torch/export/dynamic_shapes.py | 41 +++++++++++++++++++++++++------ 5 files changed, 50 insertions(+), 14 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 426a136aaf3ec2..887be93dbf75f4 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2014,7 +2014,7 @@ def forward(self, x): + re.escape( "specified at `dynamic_shapes[0]['k']['k'][0]` " "(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - " where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)" + " where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)" ), ): export(M(), inputs, dynamic_shapes=dynamic_shapes) diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index ddb51064f9ad78..5c7331659d26a0 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -136,10 +136,10 @@ def make_fake_inputs( combined_args = _combine_args(nn_module, args, kwargs) _check_dynamic_shapes(combined_args, dynamic_shapes) - _dynamic_shapes = _transform_shapes_for_default_dynamic( + transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( combined_args, dynamic_shapes ) - constraints = _process_dynamic_shapes(combined_args, _dynamic_shapes) + constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes) t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) for constraint in constraints: t_constraints[constraint.t_id][constraint.dim] = constraint @@ -216,7 +216,14 @@ def make_fake_inputs( phantom_symbols=list(phantom_symbols.values()), warn_only=False, ) - return fake_mode, fake_args, fake_kwargs, equalities_inputs, original_signature + return ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + transformed_dynamic_shapes, + ) def _flatten_dynamic_shapes( diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index b85ef6d880ae44..e2bdedf349a0ac 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -42,6 +42,7 @@ [ "torch._dynamo", "torch.export", + "torch.export.dynamic_shapes", *DYNAMIC, "torch._export.converter", "torch._export.non_strict_utils", diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 1c963691058f47..8b9b732b0119d8 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -537,7 +537,7 @@ def _export_to_torch_ir( kwargs = kwargs or {} combined_args = _combine_args(f, args, kwargs) _check_dynamic_shapes(combined_args, dynamic_shapes) - _dynamic_shapes = _transform_shapes_for_default_dynamic( + transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( combined_args, dynamic_shapes ) @@ -549,7 +549,7 @@ def _export_to_torch_ir( ), _ignore_backend_decomps(): gm_torch_level, _ = torch._dynamo.export( f, - dynamic_shapes=_dynamic_shapes, # type: ignore[arg-type] + dynamic_shapes=transformed_dynamic_shapes, # type: ignore[arg-type] tracing_mode="symbolic", disable_constraint_solver=disable_constraint_solver, # currently the following 2 flags are tied together for export purposes, @@ -1676,6 +1676,7 @@ def forward(self, *args, **kwargs): fake_kwargs, equalities_inputs, original_signature, + transformed_dynamic_shapes, ) = make_fake_inputs( mod, args, @@ -1691,7 +1692,7 @@ def _produce_guards_callback(gm): return produce_guards_and_solve_constraints( fake_mode=fake_mode, gm=gm, - dynamic_shapes=dynamic_shapes, + dynamic_shapes=transformed_dynamic_shapes, equalities_inputs=equalities_inputs, original_signature=original_signature, _is_torch_jit_trace=_is_torch_jit_trace, diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 7232114619b4f6..18fc1e73f0cf05 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import dataclasses import inspect +import logging import sys from collections import defaultdict from enum import auto, Enum @@ -36,6 +37,9 @@ ] +log = logging.getLogger(__name__) + + class _DimHint(Enum): """ Enum for dynamic shape hints. @@ -637,6 +641,15 @@ def find_shape(path, t): return dynamic_shapes +def _warn_on_None_dynamic_shape_dimension(): + msg = ( + "Using None as a dynamic shape dimension is deprecated. " + "Please use Dim.STATIC instead" + ) + # TODO(avik): raise an error in the future + log.warning(msg) + + def _check_dynamic_shapes( combined_args: Dict[str, Any], dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], @@ -674,7 +687,9 @@ def check_symbols(path, tensor, shape): for i, dim in shape.items(): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, _DimHint)) or dim is None): + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " @@ -686,7 +701,9 @@ def check_symbols(path, tensor, shape): for i, dim in enumerate(shape): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, _DimHint)) or dim is None): + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension #{i} in input tensor shape {shape} " @@ -699,7 +716,7 @@ def check_symbols(path, tensor, shape): UserErrorType.INVALID_INPUT, f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - f" where each dimension is None, an int, a Dim, Dim.AUTO, or Dim.STATIC)", + f" where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)", case_name="dynamic_shapes_validation", ) @@ -837,7 +854,7 @@ def _marked_dynamic(tensor, i): if isinstance(shape, dict): out = {} for i, val in enumerate(tensor.shape): - dim = shape.get(i, None) + dim = shape.get(i, _DimHint.STATIC) if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: # don't have to specify anything if dynamic # None also works, since assume_static_by_default=False @@ -850,9 +867,12 @@ def _marked_dynamic(tensor, i): # important that this is dim and not val, # so we can raise error if user-specified dim != val out[i] = dim + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + out[i] = val else: # make explicitly static - assert dim is None or dim == _DimHint.STATIC + assert dim == _DimHint.STATIC out[i] = val elif isinstance(shape, (tuple, list)): out = [] @@ -866,8 +886,11 @@ def _marked_dynamic(tensor, i): out.append(dim) elif isinstance(dim, int): out.append(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + out.append(val) else: - assert dim is None or dim == _DimHint.STATIC + assert dim == _DimHint.STATIC out.append(val) out = type(shape)(out) # type: ignore[assignment] else: @@ -1046,8 +1069,12 @@ def _get_dim_name_mapping( dynamic_shapes, is_leaf=lambda x: isinstance(x, _Dim), )[0]: - if isinstance(dim, (int, _DimHint)) or dim is None: + if dim is None: + # NOTE: this must denote a non-Tensor or automatic at this point. + continue + if isinstance(dim, int): continue + assert isinstance(dim, _Dim) # dim hints should have boiled away name_to_dim[dim.__name__] = dim if isinstance(dim, _DerivedDim): name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] From a8156cfe785f7968cf28003f41de83e9d2ab5e34 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 4 Sep 2024 05:56:28 +0000 Subject: [PATCH 0390/1018] restore CSE'd node metadata in runtime asserts pass (#134516) Adds val, and optionally stack_trace & nn_module_stack metadata back to SymInt compute nodes that we CSE, with a hook on `graph.create_node()`. Not sure if there's other metadata we want to populate here? Pull Request resolved: https://github.com/pytorch/pytorch/pull/134516 Approved by: https://github.com/ezyang --- test/distributed/test_dynamo_distributed.py | 1 - torch/fx/passes/runtime_assert.py | 102 +++++++++++++++----- 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 134370e41703c3..3ed2bf258e1cb4 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -423,7 +423,6 @@ def forward(self, x, y): opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512), torch.tensor([12, 13])) - @unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534" @config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True) def test_unbacked_symbol_splitting_no_binding(self): class Model(nn.Module): diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index e2a7c1bee3db1d..ba77d6ad809e22 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import functools import logging import operator import sys @@ -43,6 +44,18 @@ def _get_example_value(node: fx.Node) -> Optional[str]: return None +def _get_example_value_key(node: fx.Node) -> Optional[str]: + """ + actually just run this once at start of pass, based on first node, and constantly use that. + """ + if "example_value" in node.meta: + return "example_value" + elif "val" in node.meta: + return "val" + else: + return None + + def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]: val = _get_example_value(node) if isinstance(val, py_sym_types): @@ -92,6 +105,7 @@ def insert_deferred_runtime_asserts( # Import sympy locally import sympy + from torch._export.passes._node_metadata_hook import _set_node_metadata_hook from torch.fx.experimental.symbolic_shapes import ( _has_uninterpretable_sympy_function, CallMethodKey, @@ -145,6 +159,30 @@ def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: ) ) + # Figure out what key to use, val or example_value + val_key = "val" + for node in graph.nodes: + if "example_value" in node.meta: + val_key = "example_value" + break + elif "val" in node.meta: + break + + def _node_metadata_hook( + node: torch.fx.Node, + stack_trace: Optional[str] = None, + nn_module_stack: Optional[Dict[str, Any]] = None, + ) -> None: + fake_args = [ + _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ] + node.meta[val_key] = node.target(*fake_args) # type: ignore[operator] + if stack_trace is not None: + node.meta["stack_trace"] = stack_trace + if nn_module_stack is not None: + node.meta["nn_module_stack"] = nn_module_stack + # Track asserts/checks we've added added_asserts: Set[sympy.Expr] = set() constrained_unbacked_symbols: Set[sympy.Symbol] = set() @@ -247,7 +285,8 @@ def match_symbol(symint, cb): and isinstance(s := symint.node.expr, sympy.Symbol) and s not in expr_to_proxy ): - expr_to_proxy[s] = fx.Proxy(cb()) + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(cb()) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) match_symbol(example_value, lambda: node) @@ -331,7 +370,15 @@ def match_symbol(symint, cb): if _is_intermediate_tensor_sym_call( node ): # reify from input shapes - expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] # won't try DCE-ing tensor compute here hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] node.replace_all_uses_with(hash_node) @@ -436,7 +483,8 @@ def go(node, keypath): raise AssertionError(f"unrecognized keypath {keypath}") if s not in expr_to_proxy: - expr_to_proxy[s] = fx.Proxy(go(node, keypath)) + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(go(node, keypath)) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) for i0 in defs: @@ -519,26 +567,34 @@ def convert(s): # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts # raises AOTAutograd errors on cast_symbool_to_symint_guardless - if (min_val := convert(vr.lower)) is not None: - ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - ge, - f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", - ), - ) - added_asserts.add(i0 >= min_val) - if (max_val := convert(vr.upper)) is not None: - le = _sympy_interp(expr_to_proxy, i0 <= max_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - le, - f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", - ), - ) - added_asserts.add(i0 <= max_val) + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + if (min_val := convert(vr.lower)) is not None: + ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + ge, + f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", + ), + ) + added_asserts.add(i0 >= min_val) + if (max_val := convert(vr.upper)) is not None: + le = _sympy_interp(expr_to_proxy, i0 <= max_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + le, + f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", + ), + ) + added_asserts.add(i0 <= max_val) constrained_unbacked_symbols.add(i0) add_runtime_asserts(ras) From 542cb402169b614502b15cb64257204b133d81fc Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 4 Sep 2024 08:45:46 +0000 Subject: [PATCH 0391/1018] Update generate-xnnpack-wrappers.py parsing to handle build identifier (#134724) Fixes an issue after updating XNNPACK where parsing the XNNPACK CMakeLists breaks. I'm just ignored the generated build identifier for now, since it's not used and we would need to update the buck build to generate it at build time. Remove unused ukernels_xop XNNPACK target as it has no sources (after the recent update) and causes buck1 to complain. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134724 Approved by: https://github.com/mcr229 --- third_party/generate-xnnpack-wrappers.py | 11 ++- third_party/xnnpack.buck.bzl | 99 ------------------------ 2 files changed, 9 insertions(+), 101 deletions(-) diff --git a/third_party/generate-xnnpack-wrappers.py b/third_party/generate-xnnpack-wrappers.py index e9b23e4a784d5e..ee47aebc85bea9 100755 --- a/third_party/generate-xnnpack-wrappers.py +++ b/third_party/generate-xnnpack-wrappers.py @@ -92,6 +92,11 @@ # add non-prod microkernel sources here: } +# Source files not needed in buck build. +IGNORED_SOURCES = set(( + "\"${PROJECT_BINARY_DIR}/build_identifier.c\"", # Not currently used and requires build-time codegen. +)) + def handle_singleline_parse(line): start_index = line.find("(") end_index = line.find(")") @@ -131,12 +136,14 @@ def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"): while i < len(lines) and len(lines[i]) > 0 and ')' not in lines[i]: # remove "src/" at the beginning, remove whitespaces and newline value = lines[i].strip(' \t\n\r') - sources[name].append(value[4:]) + if value not in IGNORED_SOURCES: + sources[name].append(value[4:]) i += 1 if i < len(lines) and len(lines[i]) > 4: # remove "src/" at the beginning, possibly ')' at the end value = lines[i].strip(' \t\n\r)') - sources[name].append(value[4:]) + if value not in IGNORED_SOURCES: + sources[name].append(value[4:]) else: i += 1 return sources diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index 144dc8513ec622..6ce938b28ad685 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -993,103 +993,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ) - fb_xplat_cxx_library( - name = "ukernels_xop", - srcs = PROD_XOP_MICROKERNEL_SRCS if is_arvr_mode() else [], - headers = subdir_glob([ - ("XNNPACK/src", "**/*.h"), - ("XNNPACK/src", "**/*.c"), - ]), - header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), - compiler_flags = [ - "-O2", - ] + select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": [ - "-mxop", - ], - "ovr_config//cpu:x86_64": [ - "-mxop", - ], - }), - platform_compiler_flags = [ - ( - "x86|x86_64|platform009|platform010", - [ - "-mxop", - ], - ), - ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], - labels = labels, - platform_preprocessor_flags = [ - ( - "windows-x86_64", - [ - "-Drestrict=", - ], - ), - ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - PROD_XOP_MICROKERNEL_SRCS, - ), - ] if not is_arvr_mode() else []), - preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], - visibility = ["PUBLIC"], - windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mxop"], - windows_compiler_flags_override = WINDOWS_FLAGS + ["-mxop"], - deps = [ - ":interface", - ], - ) - - fb_xplat_cxx_library( - name = "ukernels_xop_ovr_win32", - headers = subdir_glob([ - ("XNNPACK/src", "**/*.h"), - ("XNNPACK/src", "**/*.c"), - ]), - header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), - compiler_flags = [ - "-O2", - "-mxop", - ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], - labels = labels, - platform_preprocessor_flags = [ - ( - "windows-x86_64", - [ - "-Drestrict=", - ], - ), - ], - preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], - visibility = ["PUBLIC"], - windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mxop"], - windows_compiler_flags_override = WINDOWS_FLAGS + ["-mxop"], - windows_srcs = PROD_XOP_MICROKERNEL_SRCS, - deps = [ - ":interface", - ], - ) - fb_xplat_cxx_library( name = "ukernels_fma3", srcs = PROD_FMA3_MICROKERNEL_SRCS if is_arvr_mode() else [], @@ -2527,7 +2430,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ":ukernels_sse2", ":ukernels_sse41", ":ukernels_ssse3", - ":ukernels_xop", ":ukernels_avx512vbmi", ":ukernels_avx512vnni", ":ukernels_avx512vnnigfni", @@ -2552,7 +2454,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ":ukernels_sse41_ovr_win32", ":ukernels_sse_ovr_win32", ":ukernels_ssse3_ovr_win32", - ":ukernels_xop_ovr_win32", ":ukernels_avx512vbmi", # ":ukernels_avx512vnni_ovr_win32", # Build crashes on Windows Clang 17.0.3, re-enable when fixed (T199959765) # ":ukernels_avx512vnnigfni_ovr_win32", From 71e0c6044bb92e5650e71c9d09ca59bc6e92cdbe Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 30 Aug 2024 12:52:48 +0000 Subject: [PATCH 0392/1018] [fp8 rowwise] Support transposing operands in order to change output layout (#134773) On some occasion, a column-major output layout is more efficient (it's unclear if it's because of better store coalescing for some tile shapes, or whether it's just that it's CUTLASS's default and thus it's better optimized). At this stage I only add a flag that allows to transpose, but the hardest will be deciding on a new heuristic to turn it on selectively. This will be in a follow-up PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134773 Approved by: https://github.com/drisspg --- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 62 ++++++++++++++++++-- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 6c3e7075bc6a8f..d8b18bce85541e 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -120,6 +120,7 @@ template < typename TileShape, typename ClusterShape, typename PingPong, + typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, @@ -150,7 +151,10 @@ void f8f8bf16_rowwise_impl( using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); - using LayoutOutput = cutlass::layout::RowMajor; + using LayoutOutput = std::conditional_t< + Transposed::value, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature @@ -167,8 +171,12 @@ void f8f8bf16_rowwise_impl( using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast; - using Bias = cutlass::epilogue::fusion:: - Sm90RowBroadcast; + using Bias = std::conditional_t< + Transposed::value, + cutlass::epilogue::fusion:: + Sm90ColBroadcast, + cutlass::epilogue::fusion:: + Sm90RowBroadcast>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; @@ -284,6 +292,45 @@ void f8f8bf16_rowwise_impl( C10_CUDA_KERNEL_LAUNCH_CHECK(); } +template < + typename TileShape, + typename ClusterShape, + typename PingPong, + typename Transposed, + typename FastAccum, + typename DtypeA, + typename DtypeB, + typename DtypeBias> +void handle_transposition( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + at::Tensor out) { + if constexpr (!Transposed::value) { + f8f8bf16_rowwise_impl< + TileShape, + ClusterShape, + PingPong, + Transposed, + FastAccum, + DtypeA, + DtypeB, + DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + f8f8bf16_rowwise_impl< + TileShape, + ClusterShape, + PingPong, + Transposed, + FastAccum, + DtypeB, + DtypeA, + DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t()); + } +} + // FP8 Rowwise Cutlass kernel dispatch. enum class KernelMode { Small, Large, Default }; @@ -314,22 +361,25 @@ void dispatch_fp8_rowwise_kernel_on_tile_size( at::Tensor out) { KernelMode kernel = get_kernel_mode(XQ, WQ); if (kernel == KernelMode::Small) { - return f8f8bf16_rowwise_impl< + return handle_transposition< /*TileShape=*/cute::Shape, /*ClusterShape=*/cute::Shape, /*PingPong=*/std::false_type, + /*Transposed=*/std::false_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); } else if (kernel == KernelMode::Large) { - return f8f8bf16_rowwise_impl< + return handle_transposition< /*TileShape=*/cute::Shape, /*ClusterShape=*/cute::Shape, /*PingPong=*/std::true_type, + /*Transposed=*/std::false_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); } else { - return f8f8bf16_rowwise_impl< + return handle_transposition< /*TileShape=*/cute::Shape, /*ClusterShape=*/cute::Shape, /*PingPong=*/std::false_type, + /*Transposed=*/std::false_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); } } From d90448c74865e9c92a24d821c9a4e6f4fcce01f4 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 30 Aug 2024 12:52:48 +0000 Subject: [PATCH 0393/1018] [fp8 rowwise] Retune the tile heuristics to increase perf (#134781) I propose a new heuristic function to select tile tile size, cluster size, and transposition given M, N and K. It improves the performance across the board (on average) while remaining simple and relying only on a handful of kernels (to limit build time and binary size). Across the shapes I benchmarked, the new heuristic gives a (geometric) mean speedup of +16.5%. Some shapes worsen, but 98.6% of the shapes retain their old performance (up to 5% to allow for noise) or improve it. ![image](https://github.com/user-attachments/assets/bca30583-ac32-4af6-a4f9-37164bdb2430) I benchmarked on over 5.4k different shapes: - For M and N I swept across all values which are the sums of two powers of 2 (limited to multiples of 64, capped at 16,384) - For K I only used powers of 2 between 1,024 and 8,192 (based on the intuition that the optimal config doesn't depend on K, which turned out to be the case) Here's the detailed speedup for each shape ![image](https://github.com/user-attachments/assets/acac4318-9ee0-455d-861b-c764b8c13d22)
This is the code I used to benchmark ``` import torch import torch.utils.benchmark s = set() for i in range(6, 15): s.add(2**i) for j in range(6, i): s.add(2**i + 2**j) ms = [i for i in sorted(s) if i <= 2**14] ns = [i for i in sorted(s) if i <= 2**14] ks = [2**i for i in range(10, 14)] def make_graph(n_iters, f): g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): for _ in range(n_iters): f() return g def rowwise_scale(t, dtype_t): min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max scale_t = torch.clamp(t.abs().amax(dim=-1, keepdim=True).float(), min=1e-12) / max_v t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t) return t_fp8, scale_t for m in ms: for n in ns: for k in ks: a = torch.randn((m, k), device="cuda", dtype=torch.float) b_t = torch.randn((n, k), device="cuda", dtype=torch.float) a_fp8, scale_a = rowwise_scale(a, torch.float8_e4m3fn) b_t_fp8, scale_b_t = rowwise_scale(b_t, torch.float8_e4m3fn) func = lambda: torch._scaled_mm( a_fp8, b_t_fp8.t(), scale_a=scale_a, scale_b=scale_b_t.t(), bias=None, use_fast_accum=True, out_dtype=torch.bfloat16 ) print(f"{m=},{n=},{k=}") print(torch.utils.benchmark.Timer("g.replay()", globals={"g": make_graph(1000, func)}).blocked_autorange(min_run_time=1).mean / 1000) ```
This is the code I used for the plots ``` from itertools import islice import pandas as pd import matplotlib.pyplot as plt from matplotlib.cm import ScalarMappable from matplotlib.colors import FuncNorm from mpl_toolkits.axes_grid1 import ImageGrid def batched(iterable, n): iterator = iter(iterable) while batch := tuple(islice(iterator, n)): yield batch def try_to_convert(v): if v == "False": return False if v == "True": return True return int(v) def get_from_paste(filename): text = open(filename, "rt").read() headers = [] data = [] for config, value in batched(text.splitlines(), 2): config_elems = config.split(",") if not headers: headers = [e.partition("=")[0] for e in config_elems] data.append((*(try_to_convert(e.partition("=")[-1]) for e in config_elems), float(value))) return pd.DataFrame(data, columns=headers + ["latency"]) old_latencies = get_from_paste(...) new_latencies = get_from_paste(...) ratios = pd.merge(new_latencies, old_latencies, how="left", left_on=["m", "n", "k"], right_on=["m", "n", "k"], suffixes=("_new", "_old")) ratios = ratios.assign(ratio=ratios.latency_old / ratios.latency_new) fig = plt.figure(figsize=(40.0, 10.0)) grid = ImageGrid( fig, 111, nrows_ncols=(1, 4), axes_pad=0.5, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="7%", cbar_pad=0.15, ) log_amax = np.max(np.abs(np.log(ratios.ratio.to_numpy()))) for K, ax in zip([1024, 2048, 4096, 8192], grid): pivoted = ratios[(ratios.k == K)].pivot_table(index="m", columns="n", values="ratio") im = ax.imshow(np.log(pivoted.to_numpy()), origin="lower", vmin=-log_amax, vmax=log_amax, cmap="PiYG") m_vals, n_vals = pivoted.axes ax.set_xticks(np.arange(len(n_vals)), labels=[f"N={i}" for i in n_vals.values], fontsize=12) ax.set_yticks(np.arange(len(m_vals)), labels=[f"M={i}" for i in m_vals.values], fontsize=12) plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor") ax.grid(False) ax.set_title(f"K={K}", fontsize=20) norm = FuncNorm((lambda x: np.log(x), lambda x: np.exp(x)), np.exp(-log_amax), np.exp(log_amax)) ax.cax.colorbar(ScalarMappable(norm=norm, cmap="PiYG")) plt.show() counts, bins = np.histogram(np.log(ratios.ratio.to_numpy()), bins=500) plt.stairs(counts, np.exp(bins), fill=True) plt.xscale("function", functions=(lambda x: np.log(x), lambda x: np.exp(x))) ```
I only benchmarked fast_accum=True and out_dtype=torch.bfloat16 supposing that these are the most commonly-used flags (e.g., with fast_accum=False row-wise scaling is much slower than tensor-wise scaling hence unpractical). Pull Request resolved: https://github.com/pytorch/pytorch/pull/134781 Approved by: https://github.com/drisspg, https://github.com/eqy ghstack dependencies: #134773 --- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 161 +++++++++++++------ 1 file changed, 115 insertions(+), 46 deletions(-) diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index d8b18bce85541e..f76d6bfb66a727 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -69,6 +69,8 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled( namespace { +constexpr int kNumSMsForH100 = 132; + using DtypeScale = float; using DtypeAccum = float; using DtypeEpilogue = float; @@ -115,6 +117,14 @@ struct Schedule { using type = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; }; +int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +int round_up_to_nearest_multiple(int a, int b) { + return ceildiv(a, b) * b; +} + // Cutlass rowwise kernel template < typename TileShape, @@ -292,10 +302,39 @@ void f8f8bf16_rowwise_impl( C10_CUDA_KERNEL_LAUNCH_CHECK(); } +template +void dispatch_fp8_rowwise_kernel_on_tile_size( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + at::Tensor out) { + int M = XQ.size(0); + int N = WQ.size(1); + + // We prefer to use smaller tiles (less wasted compute in case of padding), + // but if this causes us to have more CUDA blocks than there are SMs on the + // GPU then we'll hit wave quantization, hence we'll switch to larger tiles. + if (ceildiv(M, 64 * cute::get<0>(ClusterShape{})) * + ceildiv(N, 128 * cute::get<1>(ClusterShape{})) <= + kNumSMsForH100 / cute::size(ClusterShape{})) { + return f8f8bf16_rowwise_impl< + /*TileShape=*/cute::Shape, + ClusterShape, + /*PingPong=*/std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return f8f8bf16_rowwise_impl< + /*TileShape=*/cute::Shape, + ClusterShape, + /*PingPong=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } +} + template < - typename TileShape, typename ClusterShape, - typename PingPong, typename Transposed, typename FastAccum, typename DtypeA, @@ -309,20 +348,16 @@ void handle_transposition( std::optional bias, at::Tensor out) { if constexpr (!Transposed::value) { - f8f8bf16_rowwise_impl< - TileShape, + dispatch_fp8_rowwise_kernel_on_tile_size< ClusterShape, - PingPong, Transposed, FastAccum, DtypeA, DtypeB, DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out); } else { - f8f8bf16_rowwise_impl< - TileShape, + dispatch_fp8_rowwise_kernel_on_tile_size< ClusterShape, - PingPong, Transposed, FastAccum, DtypeB, @@ -331,57 +366,89 @@ void handle_transposition( } } -// FP8 Rowwise Cutlass kernel dispatch. -enum class KernelMode { Small, Large, Default }; - -KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { - auto M = XQ.size(0); - auto K = XQ.size(1); - auto N = WQ.size(0); - // Use a large kernel if at least two shapes are large.... - bool use_large_kernel = - ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || - (K >= 2048 && N >= 2048)); - if (M <= 128 || N <= 128) { - return KernelMode::Small; - } else if (use_large_kernel) { - return KernelMode::Large; - } else { - return KernelMode::Default; - } -} - template -void dispatch_fp8_rowwise_kernel_on_tile_size( +void dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose( at::Tensor XQ, at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, std::optional bias, at::Tensor out) { - KernelMode kernel = get_kernel_mode(XQ, WQ); - if (kernel == KernelMode::Small) { + int M = XQ.size(0); + int N = WQ.size(1); + + // All the tiles we use have sizes which are multiples of 64, hence any + // non-multiple of 64 will get padded anyways. Let's round up to simplify. + M = round_up_to_nearest_multiple(M, 64); + N = round_up_to_nearest_multiple(N, 64); + + // Small/skinny shapes with odd multiples of 64. + if (M == 64 && N >= 3072) { return handle_transposition< - /*TileShape=*/cute::Shape, - /*ClusterShape=*/cute::Shape, - /*PingPong=*/std::false_type, + /*ClusterShape=*/cute::Shape, /*Transposed=*/std::false_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); - } else if (kernel == KernelMode::Large) { + } + if (N == 64 && M >= 3072) { return handle_transposition< - /*TileShape=*/cute::Shape, - /*ClusterShape=*/cute::Shape, - /*PingPong=*/std::true_type, - /*Transposed=*/std::false_type, + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); - } else { + } + if (M == 192 && N >= 4096) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + if (N == 192 && M >= 4096) { return handle_transposition< - /*TileShape=*/cute::Shape, /*ClusterShape=*/cute::Shape, - /*PingPong=*/std::false_type, /*Transposed=*/std::false_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); } + + // Now to odd multiples of 128 (but only if not too large). + if (M * N <= 4096 * 4096) { + if (M % 256 > 0 && N % 256 == 0) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + if (N % 256 > 0 && M % 256 == 0) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + if (M % 256 > 0 && N % 256 > 0) { + if ((M <= N) ^ (M * N <= 1024 * 1024)) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + + // General case for large tensors. + if ((M <= N) ^ (M >= 2048 && N >= 2048)) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } } template @@ -394,11 +461,13 @@ void dispatch_fp8_rowwise_kernel_on_fast_accum( bool use_fast_accum, at::Tensor out) { if (use_fast_accum) { - dispatch_fp8_rowwise_kernel_on_tile_size( - XQ, WQ, x_scale, w_scale, bias, out); + dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< + std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); } else { - dispatch_fp8_rowwise_kernel_on_tile_size( - XQ, WQ, x_scale, w_scale, bias, out); + dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< + std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); } } From 28ee879ef09a55355d540c77a1de39baf34fb8b5 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Wed, 4 Sep 2024 07:53:09 +0000 Subject: [PATCH 0394/1018] Moving _run_autocast_outofplace to basic class named TestAutocast to reduce redundance (#134460) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134460 Approved by: https://github.com/EikanWang, https://github.com/ezyang --- test/test_autocast.py | 142 +- test/test_cuda.py | 4454 ++++++++--------- test/test_xpu.py | 107 +- .../testing/_internal/autocast_test_lists.py | 104 + 4 files changed, 2377 insertions(+), 2430 deletions(-) diff --git a/test/test_autocast.py b/test/test_autocast.py index 4c3f031a88c153..8e702f92296c13 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -1,10 +1,12 @@ # Owner(s): ["module: unknown"] -import collections import unittest import torch -from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists +from torch.testing._internal.autocast_test_lists import ( + AutocastCPUTestLists, + TestAutocast, +) from torch.testing._internal.common_utils import ( IS_WINDOWS, run_tests, @@ -14,7 +16,7 @@ from torch.utils._python_dispatch import TorchDispatchMode -class TestAutocastCPU(TestCase): +class TestAutocastCPU(TestAutocast): def setUp(self): super().setUp() self.autocast_lists = AutocastCPUTestLists(torch.device("cpu")) @@ -23,100 +25,6 @@ def tearDown(self): del self.autocast_lists super().tearDown() - def _run_autocast_outofplace( - self, - op, - args, - run_as_type, - out_type=None, - module=torch, - add_kwargs=None, - amp_dtype=torch.bfloat16, - ): - # helper to cast args - def cast(val, to_type): - if isinstance(val, torch.Tensor): - return val.to(to_type) if val.is_floating_point() else val - elif isinstance(val, collections.abc.Iterable): - return type(val)(cast(v, to_type) for v in val) - else: - return val - - if add_kwargs is None: - add_kwargs = {} - - self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) - with torch.amp.autocast(device_type="cpu", dtype=amp_dtype): - self.assertTrue(torch.is_autocast_enabled(device_type="cpu")) - out_type = out_type if out_type is not None else run_as_type - output = output_method = None - - # Try module.* variant, if requested: - if module is not None and hasattr(module, op): - output = getattr(module, op)(*args, **add_kwargs) - if isinstance(output, torch.Tensor): - self.assertTrue( - out_type == output.dtype, - f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", - ) - # Try Tensor.* variant: - if hasattr(torch.Tensor, op): - output_method = getattr(args[0], op)(*args[1:], **add_kwargs) - if isinstance(output_method, torch.Tensor): - self.assertTrue( - out_type == output_method.dtype, - f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", - ) - - self.assertTrue( - (output is not None) or (output_method is not None), - f"{op} not found as an attribute on either Tensor or the requested module {module}", - ) - - # Accounts for ops that return Tensors, iterables, and other non-Tensors. - # For example, lstm_cell returns a tuple and equal returns bool. - def compare(first, second): - if isinstance(first, torch.Tensor): - return torch.equal(first, second) - elif isinstance(first, collections.abc.Iterable): - return all(compare(f, s) for f, s in zip(first, second)) - else: - return first == second - - # If both torch.* and Tensor.* variants were found, check outputs are identical - if (output is not None) and (output_method is not None): - self.assertTrue(type(output) == type(output_method)) - comparison = compare(output, output_method) - self.assertTrue( - comparison, f"torch.{op} result did not match Tensor.{op} result" - ) - - # Compare numerics to Python-side "autocasting" that (we expect) does the same thing - # as the C++-side autocasting, and should be bitwise accurate. - output_to_compare = output if output is not None else output_method - with torch.amp.autocast(device_type="cpu", enabled=False): - self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) - - if module is not None and hasattr(module, op): - control = getattr(module, op)( - *cast(args, run_as_type), **add_kwargs - ) - else: - control = getattr(args[0].to(run_as_type), op)( - *cast(args[1:], run_as_type), **add_kwargs - ) - self.assertTrue(type(output_to_compare) == type(control)) - comparison = compare(output_to_compare, control) - self.assertTrue(comparison, f"torch.{op} result did not match control") - self.assertTrue(torch.is_autocast_enabled(device_type="cpu")) - self.assertFalse(torch.is_autocast_enabled(device_type="cpu")) - - def args_maybe_kwargs(self, op_with_args): - if len(op_with_args) == 2: - return op_with_args[0], op_with_args[1], {} - else: - return op_with_args[0], op_with_args[1], op_with_args[2] - @skipIfTorchDynamo() def test_autocast_torch_expect_builtin_promote(self): for ( @@ -125,9 +33,16 @@ def test_autocast_torch_expect_builtin_promote(self): args2, out_type, ) in self.autocast_lists.torch_expect_builtin_promote: - self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type) self._run_autocast_outofplace( - op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16 + op, args1, torch.float32, device="cpu", out_type=out_type + ) + self._run_autocast_outofplace( + op, + args2, + torch.float32, + device="cpu", + out_type=out_type, + amp_dtype=torch.float16, ) @skipIfTorchDynamo() @@ -139,12 +54,13 @@ def test_autocast_methods_expect_builtin_promote(self): out_type, ) in self.autocast_lists.methods_expect_builtin_promote: self._run_autocast_outofplace( - op, args1, torch.float32, module=None, out_type=out_type + op, args1, torch.float32, device="cpu", module=None, out_type=out_type ) self._run_autocast_outofplace( op, args2, torch.float32, + device="cpu", module=None, out_type=out_type, amp_dtype=torch.float16, @@ -155,12 +71,13 @@ def test_autocast_torch_16(self): for op_with_args in self.autocast_lists.torch_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( - op, args, torch.bfloat16, add_kwargs=maybe_kwargs + op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs ) self._run_autocast_outofplace( op, args, torch.float16, + device="cpu", add_kwargs=maybe_kwargs, amp_dtype=torch.float16, ) @@ -170,12 +87,18 @@ def test_autocast_nn_16(self): for op_with_args in self.autocast_lists.nn_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( - op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs + op, + args, + torch.bfloat16, + device="cpu", + module=torch._C._nn, + add_kwargs=maybe_kwargs, ) self._run_autocast_outofplace( op, args, torch.float16, + device="cpu", module=torch._C._nn, add_kwargs=maybe_kwargs, amp_dtype=torch.float16, @@ -186,12 +109,13 @@ def test_autocast_torch_fp32(self): for op_with_args in self.autocast_lists.torch_fp32: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( - op, args, torch.float32, add_kwargs=maybe_kwargs + op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs ) self._run_autocast_outofplace( op, args, torch.float32, + device="cpu", add_kwargs=maybe_kwargs, amp_dtype=torch.float16, ) @@ -201,12 +125,18 @@ def test_autocast_nn_fp32(self): for op_with_args in self.autocast_lists.nn_fp32: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( - op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs + op, + args, + torch.float32, + device="cpu", + module=torch._C._nn, + add_kwargs=maybe_kwargs, ) self._run_autocast_outofplace( op, args, torch.float32, + device="cpu", module=torch._C._nn, add_kwargs=maybe_kwargs, amp_dtype=torch.float16, @@ -215,9 +145,9 @@ def test_autocast_nn_fp32(self): @skipIfTorchDynamo() def test_autocast_torch_need_autocast_promote(self): for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote: - self._run_autocast_outofplace(op, args1, torch.float32) + self._run_autocast_outofplace(op, args1, torch.float32, device="cpu") self._run_autocast_outofplace( - op, args2, torch.float32, amp_dtype=torch.float16 + op, args2, torch.float32, device="cpu", amp_dtype=torch.float16 ) @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") diff --git a/test/test_cuda.py b/test/test_cuda.py index c6a61e1218ecd2..2195e71d6efc1a 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1,6 +1,5 @@ # Owner(s): ["module: cuda"] -import collections import contextlib import ctypes import gc @@ -29,7 +28,7 @@ segment_plot, trace_plot, ) -from torch.testing._internal.autocast_test_lists import AutocastTestLists +from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast from torch.testing._internal.common_cuda import ( _create_scaling_case, _get_torch_cuda_version, @@ -61,7 +60,6 @@ IS_WINDOWS, load_tests, NO_MULTIPROCESSING_SPAWN, - NoTest, parametrize, run_tests, serialTest, @@ -85,10 +83,6 @@ # sharding on sandcastle. This line silences flake warnings load_tests = load_tests -if not TEST_CUDA: - print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 - try: import torchvision.models # noqa: F401 from torchvision.models import resnet18 # noqa: F401 @@ -113,6 +107,7 @@ _cycles_per_ms = None +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") @torch.testing._internal.common_utils.markDynamoStrictTest class TestCuda(TestCase): _do_cuda_memory_leak_check = True @@ -121,10 +116,8 @@ class TestCuda(TestCase): def setUp(self): super().setUp() - self.autocast_lists = AutocastTestLists(torch.device("cuda:0")) def tearDown(self): - del self.autocast_lists super().tearDown() @property @@ -1558,2554 +1551,2024 @@ def _worker(t): for t in range(num_threads): self.assertEqual(results[t].sum().item(), size * size) - def _run_autocast_outofplace( - self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None - ): - # helper to cast args - def cast(val, to_type): - if isinstance(val, torch.Tensor): - return val.to(to_type) if val.is_floating_point() else val - elif isinstance(val, collections.abc.Iterable): - return type(val)(cast(v, to_type) for v in val) - else: - return val - - if add_kwargs is None: - add_kwargs = {} - fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16 - self.assertFalse(torch.is_autocast_enabled()) - with torch.autocast("cuda", dtype=fast_dtype): - self.assertTrue(torch.is_autocast_enabled()) - - out_type = out_type if out_type is not None else run_as_type - output = output_method = None - - # Try module.* variant, if requested: - if module is not None and hasattr(module, op): - output = getattr(module, op)(*args, **add_kwargs) - if isinstance(output, torch.Tensor): - self.assertTrue( - out_type == output.dtype, - f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", - ) + @slowTest + @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") + @serialTest() + def test_max_large_axis(self): + x = torch.zeros(2**32, device="cuda", dtype=torch.int8) + x[-1] = 1 + val, idx = x.max(0) + self.assertEqual(val, 1) + self.assertEqual(idx, x.shape[0] - 1) - # Try Tensor.* variant: - if hasattr(torch.Tensor, op): - output_method = getattr(args[0], op)(*args[1:], **add_kwargs) - if isinstance(output_method, torch.Tensor): - self.assertTrue( - out_type == output_method.dtype, - f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", - ) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_to_numpy(self): + self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) - self.assertTrue( - (output is not None) or (output_method is not None), - f"{op} not found as an attribute on either Tensor or the requested module {module}", - ) + def test_graph_is_current_stream_capturing(self): + self.assertFalse(torch.cuda.is_current_stream_capturing()) - # Accounts for ops that return Tensors, iterables, and other non-Tensors. - # For example, lstm_cell returns a tuple and equal returns bool. - def compare(first, second): - if isinstance(first, torch.Tensor): - return torch.equal(first, second) - elif isinstance(first, collections.abc.Iterable): - return all(compare(f, s) for f, s in zip(first, second)) - else: - return first == second + if TEST_CUDA and (not TEST_WITH_ROCM): + s = torch.cuda.Stream() + with torch.cuda.stream(s): + g = torch.cuda.CUDAGraph() + self.assertFalse(torch.cuda.is_current_stream_capturing()) + g.capture_begin() + self.assertTrue(torch.cuda.is_current_stream_capturing()) + g.capture_end() - # If both torch.* and Tensor.* variants were found, check outputs are identical - if (output is not None) and (output_method is not None): - self.assertTrue(type(output) == type(output_method)) - comparison = compare(output, output_method) - self.assertTrue( - comparison, f"torch.{op} result did not match Tensor.{op} result" - ) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_capture_simple(self): + s = torch.cuda.Stream() - # Compare numerics to Python-side "autocasting" that (we expect) does the same thing - # as the C++-side autocasting, and should be bitwise accurate. - output_to_compare = output if output is not None else output_method - with torch.autocast("cuda", enabled=False): - self.assertFalse(torch.is_autocast_enabled()) + with torch.cuda.stream(s): + a = torch.full((1000,), 1, device="cuda") + g = torch.cuda.CUDAGraph() + torch.cuda.empty_cache() + g.capture_begin() + b = a + for _ in range(10): + b = b + 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - if module is not None and hasattr(module, op): - control = getattr(module, op)( - *cast(args, run_as_type), **add_kwargs - ) - else: - control = getattr(args[0].to(run_as_type), op)( - *cast(args[1:], run_as_type), **add_kwargs - ) - self.assertTrue(type(output_to_compare) == type(control)) - comparison = compare(output_to_compare, control) - self.assertTrue(comparison, f"torch.{op} result did not match control") - self.assertTrue(torch.is_autocast_enabled()) - self.assertFalse(torch.is_autocast_enabled()) - - def args_maybe_kwargs(self, op_with_args): - if len(op_with_args) == 2: - return op_with_args[0], op_with_args[1], {} - else: - return op_with_args[0], op_with_args[1], op_with_args[2] + g.replay() - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op_with_args in self.autocast_lists.torch_fp16: - skip_test = False - op, args = op_with_args[0], op_with_args[1] - if len(op_with_args) == 3: - skip_test = op_with_args[2] # TEST_WITH_ROCM - if not skip_test: - self._run_autocast_outofplace(op, args, torch.float16) + self.assertTrue(b.sum().item() == 11000.0) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_bf16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op_with_args in self.autocast_lists.torch_fp16: - skip_test = False - op, args = op_with_args[0], op_with_args[1] - if len(op_with_args) == 3: - skip_test = op_with_args[2] # TEST_WITH_ROCM - should_error_from_cudnn = "cudnn" in op and ( - "TORCH_CUDNN_V8_API_DISABLED" in os.environ - and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"]) - or torch.cuda.get_device_capability() < (8, 0) - ) - should_error_from_not_implemented = should_error_from_cudnn - if not skip_test: - if should_error_from_not_implemented: - with self.assertRaises( - RuntimeError, - msg=str(op) + " should not be supported for bfloat16!", - ): - self._run_autocast_outofplace(op, args, torch.bfloat16) - else: - if torch.cuda.is_bf16_supported(): - self._run_autocast_outofplace(op, args, torch.bfloat16) - else: - with self.assertRaisesRegex( - RuntimeError, "Device does not support bfloat16" - ): - self._run_autocast_outofplace(op, args, torch.bfloat16) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graphsafe_set_get_rng_state(self): + # Define a function to create generator states, with optional graph registration + def create_states(generator): + """Initializes generator states and registers them with a CUDA graph if provided.""" + # Ensure the CUDA generator is initialized + torch.rand(1, device="cuda") + generator.manual_seed(0) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_fp32(self): - for op_with_args in self.autocast_lists.torch_fp32: - op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) - self._run_autocast_outofplace( - op, args, torch.float32, add_kwargs=maybe_kwargs - ) + # Save the current state of the generator + old_state = generator.graphsafe_get_state() + # Create and save a cloned state of the generator + new_state = generator.clone_state() + # Return the original generator and its two states + return generator, old_state, new_state - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_need_autocast_promote(self): - for op, args in self.autocast_lists.torch_need_autocast_promote: - self._run_autocast_outofplace(op, args, torch.float32) + def register_states_to_graph(generator_state, graph): + generator, old_state, new_state = generator_state + graph.register_generator_state(old_state) + graph.register_generator_state(new_state) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_expect_builtin_promote(self): - for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: - self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + # Define a function to perform specific RNG actions using the generator's states + def perform_random_generation_steps(generator_state): + generator, old_state, new_state = generator_state + random_values = [] - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_nn_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.autocast_lists.nn_fp16: - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._nn - ) + # Generate random numbers with the new generator state + generator.graphsafe_set_state(new_state) + random_values.append(torch.rand(5, device="cuda", generator=generator)) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_nn_bf16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.autocast_lists.nn_fp16: - if torch.cuda.is_bf16_supported(): - self._run_autocast_outofplace( - op, args, torch.bfloat16, module=torch._C._nn - ) - else: - with self.assertRaisesRegex( - RuntimeError, "Device does not support bfloat16" - ): - self._run_autocast_outofplace( - op, args, torch.bfloat16, module=torch._C._nn - ) + # Generate random numbers twice with the old generator state + generator.graphsafe_set_state(old_state) + random_values.extend( + [torch.rand(5, device="cuda", generator=generator) for _ in range(2)] + ) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_nn_fp32(self): - for op, args in self.autocast_lists.nn_fp32: - self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn) + return random_values - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_linalg_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.autocast_lists.linalg_fp16: - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._linalg - ) + # Define a function to retrieve the final offsets of the original and new generator states + def get_final_offsets_of_states(generator_state): + generator, old_state, new_state = generator_state + old_state_offset = old_state.get_offset() + new_state_offset = new_state.get_offset() + return old_state_offset, new_state_offset - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_methods_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.autocast_lists.methods_fp16: - self._run_autocast_outofplace(op, args, torch.float16, module=None) + # Set up and test a new CUDA generator + generator = torch.Generator(device="cuda") + generator_state = create_states(generator) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_methods_fp32(self): - for op, args in self.autocast_lists.methods_fp32: - self._run_autocast_outofplace(op, args, torch.float32, module=None) + # Set up and test the default CUDA generator with a CUDA Graph + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + default_generator = torch.cuda.default_generators[0] + default_generator_state = create_states(default_generator) + register_states_to_graph(default_generator_state, g) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_methods_expect_builtin_promote(self): - for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote: - self._run_autocast_outofplace( - op, args, torch.float32, module=None, out_type=out_type + # Perform random number generation within a CUDA graph + with torch.cuda.stream(s): + g.capture_begin() + graphed_random_values = perform_random_generation_steps( + default_generator_state ) + g.capture_end() - def test_autocast_banned(self): - with torch.autocast("cuda"): - for op, args, module in self.autocast_lists.banned: - with self.assertRaises(RuntimeError): - getattr(module, op)(*args) - - def test_autocast_ignored_types(self): - with torch.autocast("cuda"): - for ignore_type in (torch.double, torch.int32): - a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") - b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") - c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0") + # Synchronize the streams and replay the graph + torch.cuda.current_stream().wait_stream(s) + for _ in range(3): + random_values = perform_random_generation_steps(generator_state) + g.replay() + offset = get_final_offsets_of_states(generator_state) + graph_offset = get_final_offsets_of_states(default_generator_state) - # Tests if CastPolicy::fp16 ops ignore double and int - # Currently, no ops belonging to this policy support integer inputs. - if ignore_type is torch.double: - with self.assertRaises(RuntimeError): - torch.mm(a_ignore, c_16) - with torch.autocast("cuda", enabled=False): - type_no_autocast = torch.mm(a_ignore, b_ignore).dtype - self.assertTrue( - torch.mm(a_ignore, b_ignore).dtype is type_no_autocast - ) + # Compare the final offsets of states for both generators to ensure consistency + self.assertTrue(offset == graph_offset) + # Compare the states generated outside and inside the graph + self.assertEqual(random_values, graphed_random_values) - # Tests if CastPolicy::fp32 ops ignore double and int - with torch.autocast("cuda", enabled=False): - type_no_autocast = torch.pow(a_ignore, 2.0).dtype - self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_memory_stats_of_multiple_generators_and_graphs(self): + # Function to clear CUDA cache and collect garbage + def clear_cuda_cache(): + gc.collect() + torch.cuda.empty_cache() - # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int - with torch.autocast("cuda", enabled=False): - type_no_autocast = torch.sum(a_ignore).dtype - self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast) + # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph. + def simple_graph_task(graph): + s = torch.cuda.Stream() + with torch.cuda.stream(s): + graph.capture_begin() + torch.rand(1, device="cuda") + graph.capture_end() + torch.cuda.current_stream().wait_stream(s) + graph.replay() # Replays the captured operations - # Tests if CastPolicy::fp32_append_dtype ops ignore double and int - # Currently, no ops belonging to this policy support integer inputs. - if ignore_type is torch.double: - with torch.autocast("cuda", enabled=False): - type_no_autocast = torch.norm(a_ignore).dtype - self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast) - - def test_autocast_custom_enabled(self): - class MyMM(torch.autograd.Function): - @staticmethod - @torch.amp.custom_fwd(device_type="cuda") - def forward(ctx, a, b): - self.assertTrue(a.dtype is torch.float32) - self.assertTrue(b.dtype is torch.float32) - self.assertTrue(torch.is_autocast_enabled()) - ctx.save_for_backward(a, b) - return a.mm(b) + def get_memory_stats(): + stats = torch.cuda.memory_stats() + num_blocks = stats["active.all.current"] + total_size = stats["active_bytes.all.current"] + return num_blocks, total_size - @staticmethod - @torch.amp.custom_bwd(device_type="cuda") - def backward(ctx, grad): - self.assertTrue(torch.is_autocast_enabled()) - a, b = ctx.saved_tensors - a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad) - self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype) - return a_grad, b_grad + def test(num_graphs, num_generators): + baseline = get_memory_stats() + baseline_num_blocks, baseline_total_size = baseline - mymm = MyMM.apply + # Allocate CUDA graphs + graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)] - x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) - y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) + # Allocate and manage generator states + default_generator = torch.cuda.default_generators[0] + generators = [default_generator.graphsafe_get_state()] - dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,) - for dtype in dtypes: - with torch.cuda.amp.autocast(dtype=dtype): - output = mymm(x, y) - self.assertTrue(output.dtype is dtype) - loss = output.sum() - loss.backward() + # Starts from 1 as one state is already added + for _ in range(1, num_generators): + generators.append(default_generator.clone_state()) - def test_autocast_custom_cast_inputs(self): - class MyMM(torch.autograd.Function): - @staticmethod - @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(ctx, a, container, expect_type): - b = container[1][0] - self.assertTrue(a.dtype is expect_type) - self.assertTrue(b.dtype is expect_type) - self.assertFalse(torch.is_autocast_enabled()) - ctx.save_for_backward(a, b) - return a.mm(b) + for graph in graphs: + for generator_state in generators: + graph.register_generator_state(generator_state) + simple_graph_task(graph) - @staticmethod - @torch.amp.custom_bwd(device_type="cuda") - def backward(ctx, grad): - self.assertFalse(torch.is_autocast_enabled()) - a, b = ctx.saved_tensors - return grad.mm(b.t()), None, None + # Assert conditions after graph tasks + num_blocks, total_size = get_memory_stats() + # The allocated blocks should only be proportional to the number of generators + expected_blocks_diff = 2 * num_generators + expected_size_diff = 2 * 512 * num_generators # Each block's size is 512 - mymm = MyMM.apply + self.assertTrue( + (num_blocks - baseline_num_blocks) == expected_blocks_diff, + "Unexpected number of active blocks.", + ) + self.assertTrue( + (total_size - baseline_total_size) == expected_size_diff, + "Unexpected total memory size.", + ) - x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True) - # Puts one input tensor in a nested container. y's contained Tensor won't receive a gradient, - # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments. - # Sets requires_grad=False explicitly so we don't lie about expecting a gradient. - y = ( - 0, - { - 0: torch.randn( - (8, 8), device="cuda", dtype=torch.float16, requires_grad=False - ) - }, - ) + # Cleanup graphs and clear CUDA cache + while graphs: + graph = graphs.pop() + del graph + clear_cuda_cache() - with torch.autocast("cuda"): - output = mymm(x, y, torch.float32) - self.assertTrue(output.dtype is torch.float32) - loss = output.sum() - loss.backward() + # Assert that memory stats return to baseline after cleanup + self.assertTrue( + get_memory_stats() == baseline, + "Memory stats do not match baseline after cleanup.", + ) - # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region. - output = mymm(x, y, torch.float16) - self.assertTrue(output.dtype is torch.float16) - loss = output.sum() - loss.backward() + # Running the test function with different parameters + test(1, 1) + test(3, 2) + test(10, 20) - def test_autocast_custom_deprecated_warning(self): - with warnings.catch_warnings(record=True) as w: + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_capture_reset_recapture(self): + s = torch.cuda.Stream() - class MyMM(torch.autograd.Function): - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) - def forward(ctx, x, y): - ctx.save_for_backward(x, y) - self.assertFalse(torch.is_autocast_enabled()) - return x + y + with torch.cuda.stream(s): + a = torch.full((1000,), 1, device="cuda") + g = torch.cuda.CUDAGraph() + torch.cuda.empty_cache() + g.capture_begin() + b = a + for _ in range(10): + b = b + 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): - _, _ = ctx.saved_tensors - self.assertFalse(torch.is_autocast_enabled()) - return grad, grad + g.replay() - self.assertRegex( - str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated." - ) - self.assertRegex( - str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated." - ) + self.assertTrue(b.sum().item() == 11000.0) - mymm = MyMM.apply - x = torch.randn(3, 3, requires_grad=True) - y = torch.randn(3, 3, requires_grad=True) - with torch.amp.autocast("cuda"): - output = mymm(x, y) - loss = output.sum() - loss.backward() + g.reset() - def test_autocast_cat_jit(self): - # Reported at https://github.com/pytorch/pytorch/issues/38958 + with torch.cuda.stream(s): + g.capture_begin() + b.fill_(2.0) + for _ in range(10): + b = b + 2 + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - class Model(torch.nn.Module): - def forward(self): - a = torch.randn(1) - b = torch.randn(1) - c = torch.cat((a, b), 0) - d = torch.stack([c, c], 0) - return d + g.replay() + self.assertTrue(b.sum().item() == 22000.0) - # The JIT here doesn't really matter, we just need to call - # cat via the boxed API - model = Model() - model_jit_script = torch.jit.script(model) + g.reset() + del g - with torch.autocast("cuda", enabled=True): - model() - model_jit_script() + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_debugdump(self): + torch.cuda.empty_cache() + x = torch.randn(10240000, device="cuda") + y = torch.rand_like(x) + g = torch.cuda.CUDAGraph() + g.enable_debug_mode() + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() + s0.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s0): + g.capture_begin() + z = x + y + with torch.cuda.stream(s1): + s1.wait_stream(s0) + w = z + y + s0.wait_stream(s1) + g.capture_end() + s0.synchronize() + torch.cuda.synchronize() + with tempfile.TemporaryDirectory() as tempdir: + g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot")) - # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened) - # so they get a dedicated test. - # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100). - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_rnn(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - # seq, batch, features, hidden size - clses = ("RNN", "GRU", "LSTM") - T, B, F, H = 3, 4, 5, 6 - dtypes = (torch.float16, torch.float32) - input_layouts = ("seq_first", "batch_first", "packed") + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_error(self): + # We need to run this test in a separate thread as the error we trigger + # puts the cuda context in a bad state + script = """ +import torch - for ( - cls, - num_layers, - bias, - input_layout, - bidirectional, - try_nonpreflattened_weights, - input_dtype, - hidden_dtype, - weight_dtype, - ) in product( - clses, - (1, 2), - (True, False), - input_layouts, - (True, False), - (True, False), - dtypes, - dtypes, - dtypes, - ): - if input_layout == "seq_first": - batch_first = False - x = torch.randn((T, B, F), device="cuda", dtype=input_dtype) - elif input_layout == "batch_first": - batch_first = True - x = torch.randn((B, T, F), device="cuda", dtype=input_dtype) - elif input_layout == "packed": - batch_first = False - x = torch.nn.utils.rnn.pack_padded_sequence( - torch.randn((T, B, F), device="cuda", dtype=input_dtype), - lengths=(3, 2, 1, 3), - enforce_sorted=False, - ) - - rnn = ( - getattr(torch.nn, cls)( - F, - H, - num_layers=num_layers, - bidirectional=bidirectional, - bias=bias, - batch_first=batch_first, - ) - .cuda() - .to(dtype=weight_dtype) +g = torch.cuda.CUDAGraph() +try: + g.capture_begin() +except RuntimeError as e: + if "CUDA graphs must be captured on a non-default stream." in str(e): + exit(0) + else: + exit(1) +exit(2) +""" + try: + a = subprocess.check_output( + [sys.executable, "-c", script], + stderr=subprocess.STDOUT, + # On Windows, opening the subprocess with the default CWD makes `import torch` + # fail, so just set CWD to this script's directory + cwd=os.path.dirname(os.path.realpath(__file__)), + ) + except subprocess.CalledProcessError as e: + if e.returncode == 1: + self.assertTrue( + False, + "Error raise by starting capture without a stream is not the expected one", ) - - if try_nonpreflattened_weights: - for p in rnn.parameters(): - with torch.no_grad(): - p.set_(p.clone()) - - h = torch.randn( - (num_layers * (2 if bidirectional else 1), B, H), - device="cuda", - dtype=hidden_dtype, + elif e.returncode == 2: + self.assertTrue( + False, + "Error raised by starting capture without a stream was not caught", ) - if cls == "LSTM": - c = torch.randn( - (num_layers * (2 if bidirectional else 1), B, H), - device="cuda", - dtype=hidden_dtype, - ) - h = (h, c) - with torch.autocast("cuda"): - out, h_out = rnn(x, h) - out = out.data if input_layout == "packed" else out - self.assertEqual(out.dtype, torch.float16) - # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee - # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has - # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed. - self.assertEqual( - out.grad_fn.name(), - "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0", - ) - out.sum().backward() - grads = [p.grad.clone() for p in rnn.parameters()] + @unittest.skipIf( + (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11, + "CUDA >= 11.0 required for graphs", + ) + def test_graph_warn_if_has_zero_nodes(self): + with warnings.catch_warnings(record=True) as caught: + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + with torch.cuda.stream(s): + g.capture_begin() + g.capture_end() + self.assertTrue( + any("The CUDA Graph is empty" in str(w.message) for w in caught) + ) - rnn.zero_grad() + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @unittest.skipIf( + IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support" + ) + def test_graph_capture_oom(self): + oom_regex = ( + "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory" + ) + with self.assertRaisesRegex(RuntimeError, oom_regex): + with torch.cuda.graph(torch.cuda.CUDAGraph()): + torch.zeros(2**40, device="cuda") - if cls == "LSTM": - out_control, h_out_control = rnn.to(dtype=torch.float16)( - x.half(), (h[0].half(), h[1].half()) - ) - else: - out_control, h_out_control = rnn.to(dtype=torch.float16)( - x.half(), h.half() - ) - out_control = ( - out_control.data if input_layout == "packed" else out_control - ) - out_control.sum().backward() - grads_control = [p.grad.clone() for p in rnn.parameters()] + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @serialTest() + def test_repeat_graph_capture_cublas_workspace_memory(self): + (x, y, z) = 1024, 512, 64 + a = torch.rand((x, y), device="cuda") + b = torch.rand((y, z), device="cuda") - # Compares with default tolerances, even for FP16 execution. Barring nondeterminism, - # autocast and control results should be bitwise identical. - self.assertEqual(out, out_control) + # warmup + torch.mm(a, b) - if cls == "LSTM": - self.assertTrue( - h_out[0].dtype is torch.float16 - and h_out[1].dtype is torch.float16 - ) - self.assertEqual(h_out[0], h_out_control[0]) - self.assertEqual(h_out[1], h_out_control[1]) - else: - self.assertEqual(h_out.dtype, torch.float16) - self.assertEqual(h_out, h_out_control) - for grad, grad_control in zip(grads, grads_control): - self.assertEqual(grad.half(), grad_control) + free_bytes_before, total_bytes = torch.cuda.mem_get_info() + used_gb_before = (total_bytes - free_bytes_before) / 1e9 - def test_autocast_cache_leak(self): - # Reported at https://github.com/pytorch/pytorch/issues/48049 - # Test is used to check, if autocast recaches the same parameters - # when executed in a `torch.no_grad()` block. + for i in range(100): + torch_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(torch_graph): + torch.mm(a, b) + torch_graph.replay() - linear = torch.nn.Linear(10, 10).to("cuda") - data = torch.randn(1, 10, device="cuda") + free_bytes_after, _ = torch.cuda.mem_get_info() + used_gb_after = (total_bytes - free_bytes_after) / 1e9 - with torch.autocast("cuda"): - with torch.no_grad(): - out = linear(data) - first_iter_mem = torch.cuda.memory_allocated() - for _ in range(3): - out = linear(data) - self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) + self.assertFalse(used_gb_before + 0.1 < used_gb_after) - def test_autocast_checkpointing(self): - model = torch.nn.Sequential( - torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8) - ).cuda() - input = torch.rand( - (8, 8), device="cuda", dtype=torch.float16, requires_grad=True + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_rng_functional(self): + ops_with_kwargs = ( + (torch.nn.functional.dropout, {"p": 0.1}), + (torch.nn.functional.rrelu, {"training": True}), ) - for reentrant in (True, False): - with torch.autocast("cuda"): - output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant) - self.assertTrue(output.requires_grad) - self.assertTrue(output.dtype is torch.float16) - output.sum().backward() - - def test_cuda_autocast_deprecated_warning(self): - with self.assertWarnsRegex( - FutureWarning, - r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.", - ): - with torch.cuda.amp.autocast(): - _ = torch.ones(10) + size = 10000 - @slowTest - @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") - @serialTest() - def test_max_large_axis(self): - x = torch.zeros(2**32, device="cuda", dtype=torch.int8) - x[-1] = 1 - val, idx = x.max(0) - self.assertEqual(val, 1) - self.assertEqual(idx, x.shape[0] - 1) + def run(op, kwargs): + a = torch.randn((size,), device="cuda", dtype=torch.float) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_to_numpy(self): - self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) + # Control + torch.cuda.manual_seed(5) + eager_out = a + for _ in range(6): + eager_out = op(eager_out, **kwargs) - def test_graph_is_current_stream_capturing(self): - self.assertFalse(torch.cuda.is_current_stream_capturing()) + graph_in = a.clone() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + torch.cuda.manual_seed(5) - if TEST_CUDA and (not TEST_WITH_ROCM): - s = torch.cuda.Stream() - with torch.cuda.stream(s): g = torch.cuda.CUDAGraph() - self.assertFalse(torch.cuda.is_current_stream_capturing()) + torch.cuda.empty_cache() g.capture_begin() - self.assertTrue(torch.cuda.is_current_stream_capturing()) + graph_out = graph_in + for _ in range(2): + graph_out = op(graph_out, **kwargs) g.capture_end() + torch.cuda.current_stream().wait_stream(stream) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_capture_simple(self): - s = torch.cuda.Stream() + # Runs a graphed->eager->graphed sequence of RNG ops. + # replay() plays 2 invocations of the op, so the sequence has 6 + # invocations total, matching Control. + # replay() reads from graph_in and writes to graph_out. + g.replay() + out = op(graph_out, **kwargs) + out = op(out, **kwargs) + graph_in.copy_(out) + g.replay() - with torch.cuda.stream(s): - a = torch.full((1000,), 1, device="cuda") - g = torch.cuda.CUDAGraph() - torch.cuda.empty_cache() - g.capture_begin() - b = a - for _ in range(10): - b = b + 1 - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + # If replay() updated RNG state correctly, graph_out + # should now hold data equal to eager_out. + try: + self.assertEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e - g.replay() + # Do the same operations varying seeds + seeds = [6, 128, 9999] - self.assertTrue(b.sum().item() == 11000.0) + for seed in seeds: + torch.cuda.manual_seed(seed) + graph_in.copy_(a) + for _ in range(3): + g.replay() + + # If the random seed was not updated then the graph would + # generate the same output as in previous check. + try: + self.assertNotEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e + + # Now repeat the same operations in non-graphed mode. + torch.cuda.manual_seed(seed) + for _ in range(3): + eager_out.copy_(a) + eager_out = op(eager_out, **kwargs) + eager_out = op(eager_out, **kwargs) + + # In the end, graph_out and eager_out must be equal + # as they went under the same set of operations. + try: + self.assertEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op, kwargs in ops_with_kwargs: + run(op, kwargs) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graphsafe_set_get_rng_state(self): - # Define a function to create generator states, with optional graph registration - def create_states(generator): - """Initializes generator states and registers them with a CUDA graph if provided.""" - # Ensure the CUDA generator is initialized - torch.rand(1, device="cuda") - generator.manual_seed(0) + def test_graph_rng_distributions(self): + size = 10000 + input = torch.rand((size,), device="cuda", dtype=torch.float) + alloc = torch.empty((size,), device="cuda", dtype=torch.float) - # Save the current state of the generator - old_state = generator.graphsafe_get_state() - # Create and save a cloned state of the generator - new_state = generator.clone_state() - # Return the original generator and its two states - return generator, old_state, new_state + # Torch ops to test with sample args (tuple) and kwargs (dict) + torch_with_args = ( + ("bernoulli", (input.clone(),), {}), + # multinomial uses some uncapturable CUDA calls. + # TODO: reenable multinomial tests if/when the implementation is capturable. + # ("multinomial", (input.clone(), size, True), {}), + # ("multinomial", (input.clone(), size // 2, False), {}), + # TODO: reenable normal test, where std is a device + # tensor, when graph test failures are fixed + # ("normal", (input.clone() + 1, input.clone()), {}), + ("normal", (input.clone() + 1, 1.0), {}), + ("poisson", (input.clone(),), {}), + ("rand", (size,), {"device": "cuda", "dtype": torch.float}), + ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), + ("randn", (size,), {"device": "cuda", "dtype": torch.float}), + ) - def register_states_to_graph(generator_state, graph): - generator, old_state, new_state = generator_state - graph.register_generator_state(old_state) - graph.register_generator_state(new_state) + # Tensor methods to test with sample args (tuple) + tensor_with_args = ( + ("bernoulli_", (input.clone(),)), + ("cauchy_", ()), + ("exponential_", ()), + ("geometric_", (0.3,)), + ("log_normal_", ()), + ("normal_", ()), + ("random_", ()), + ("uniform_", ()), + ) - # Define a function to perform specific RNG actions using the generator's states - def perform_random_generation_steps(generator_state): - generator, old_state, new_state = generator_state - random_values = [] + def run(module, op, args, kwargs): + torch.cuda.manual_seed(5) - # Generate random numbers with the new generator state - generator.graphsafe_set_state(new_state) - random_values.append(torch.rand(5, device="cuda", generator=generator)) + # Each path runs a dummy op to increment the state a bit before creating controls. + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + control1 = getattr(torch, op)(*args, **kwargs) + control2 = getattr(torch, op)(*args, **kwargs) + else: + dummy = alloc.clone() + control1 = alloc.clone() + control2 = alloc.clone() + getattr(dummy, op)(*args) + getattr(control1, op)(*args) + getattr(control2, op)(*args) - # Generate random numbers twice with the old generator state - generator.graphsafe_set_state(old_state) - random_values.extend( - [torch.rand(5, device="cuda", generator=generator) for _ in range(2)] - ) + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + torch.cuda.manual_seed(5) - return random_values + g = torch.cuda.CUDAGraph() + torch.cuda.empty_cache() + if module == "torch": + g.capture_begin() + t1 = getattr(torch, op)(*args, **kwargs) + t2 = getattr(torch, op)(*args, **kwargs) + g.capture_end() + else: + t1 = alloc.clone() + t2 = alloc.clone() + g.capture_begin() + getattr(t1, op)(*args) + getattr(t2, op)(*args) + g.capture_end() + torch.cuda.current_stream().wait_stream(stream) - # Define a function to retrieve the final offsets of the original and new generator states - def get_final_offsets_of_states(generator_state): - generator, old_state, new_state = generator_state - old_state_offset = old_state.get_offset() - new_state_offset = new_state.get_offset() - return old_state_offset, new_state_offset + if not TEST_CUDAMALLOCASYNC: + # Makes sure values haven't been populated yet + # (in other words, makes sure capture didn't actually run ops). + # We can only try this with the native allocator, for which captured + # addresses are already backed by cudaMalloced memory. + # If we try it with cudaMallocAsync, CUDA won't event consider + # the captured addresses allocated until replay(), and if we + # access them before replay() we get IMAs. + try: + self.assertNotEqual(control1, t1) + self.assertNotEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e - # Set up and test a new CUDA generator - generator = torch.Generator(device="cuda") - generator_state = create_states(generator) + # Set a new seed to check if graph would use it + for seed in [6, 314, 271]: + torch.cuda.manual_seed(seed) + # Runs a dummy op prelude, as for controls, to make sure replay() + # picks up the dummy op's state increment. + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + control1 = getattr(torch, op)(*args, **kwargs) + control2 = getattr(torch, op)(*args, **kwargs) + else: + getattr(dummy, op)(*args) + getattr(control1, op)(*args) + getattr(control2, op)(*args) - # Set up and test the default CUDA generator with a CUDA Graph - g = torch.cuda.CUDAGraph() - s = torch.cuda.Stream() - default_generator = torch.cuda.default_generators[0] - default_generator_state = create_states(default_generator) - register_states_to_graph(default_generator_state, g) + torch.cuda.manual_seed(seed) + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + else: + getattr(dummy, op)(*args) - # Perform random number generation within a CUDA graph - with torch.cuda.stream(s): - g.capture_begin() - graphed_random_values = perform_random_generation_steps( - default_generator_state - ) - g.capture_end() + # see above comment on TEST_CUDAMALLOCASYNC + if not TEST_CUDAMALLOCASYNC: + t1.copy_(alloc) + t2.copy_(alloc) - # Synchronize the streams and replay the graph - torch.cuda.current_stream().wait_stream(s) - for _ in range(3): - random_values = perform_random_generation_steps(generator_state) - g.replay() - offset = get_final_offsets_of_states(generator_state) - graph_offset = get_final_offsets_of_states(default_generator_state) + # Runs RNG ops that fill t1 and t2. + g.replay() - # Compare the final offsets of states for both generators to ensure consistency - self.assertTrue(offset == graph_offset) - # Compare the states generated outside and inside the graph - self.assertEqual(random_values, graphed_random_values) + try: + self.assertEqual(control1, t1) + self.assertEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op_with_args in torch_with_args: + run("torch", *op_with_args) + + for meth_with_args in tensor_with_args: + # Adds an empty dict for kwargs, which none of the Tensor methods use + run("Tensor", *(meth_with_args + ({},))) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_memory_stats_of_multiple_generators_and_graphs(self): - # Function to clear CUDA cache and collect garbage - def clear_cuda_cache(): - gc.collect() - torch.cuda.empty_cache() + def test_graph_two_successive(self): + torch.cuda.empty_cache() - # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph. - def simple_graph_task(graph): - s = torch.cuda.Stream() - with torch.cuda.stream(s): - graph.capture_begin() - torch.rand(1, device="cuda") - graph.capture_end() - torch.cuda.current_stream().wait_stream(s) - graph.replay() # Replays the captured operations + size = 1000 + kSmallBuffer = 2097152 - def get_memory_stats(): - stats = torch.cuda.memory_stats() - num_blocks = stats["active.all.current"] - total_size = stats["active_bytes.all.current"] - return num_blocks, total_size + def func_with_temps(t, val): + x = t.clone() + val + y = t.clone() + val + return x + y - def test(num_graphs, num_generators): - baseline = get_memory_stats() - baseline_num_blocks, baseline_total_size = baseline - - # Allocate CUDA graphs - graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)] + s = torch.cuda.Stream() - # Allocate and manage generator states - default_generator = torch.cuda.default_generators[0] - generators = [default_generator.graphsafe_get_state()] + for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() - # Starts from 1 as one state is already added - for _ in range(1, num_generators): - generators.append(default_generator.clone_state()) + a = torch.ones((size,), device="cuda") - for graph in graphs: - for generator_state in generators: - graph.register_generator_state(generator_state) - simple_graph_task(graph) + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + g0_args = ( + (torch.cuda.graph_pool_handle(),) + if share_mem == "via graph_pool_handle()" + else () + ) + g0.capture_begin(*g0_args) + b = a.clone() + for _ in range(5): + b = func_with_temps(b, 1) + g0.capture_end() - # Assert conditions after graph tasks - num_blocks, total_size = get_memory_stats() - # The allocated blocks should only be proportional to the number of generators - expected_blocks_diff = 2 * num_generators - expected_size_diff = 2 * 512 * num_generators # Each block's size is 512 + g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args + g1.capture_begin(*g1_args) + for _ in range(5): + b = func_with_temps(b, 1) + g1.capture_end() + torch.cuda.current_stream().wait_stream(s) - self.assertTrue( - (num_blocks - baseline_num_blocks) == expected_blocks_diff, - "Unexpected number of active blocks.", - ) - self.assertTrue( - (total_size - baseline_total_size) == expected_size_diff, - "Unexpected total memory size.", - ) + # mixes unrelated eager ops with replays + c = a.clone() + for _ in range(2): + c = func_with_temps(c, 3) + g0.replay() + for _ in range(2): + c = func_with_temps(c, 3) + g1.replay() + for _ in range(2): + c = func_with_temps(c, 3) - # Cleanup graphs and clear CUDA cache - while graphs: - graph = graphs.pop() - del graph - clear_cuda_cache() + self.assertEqual(b.sum().item(), size * 3070) + self.assertEqual(c.sum().item(), size * 442) - # Assert that memory stats return to baseline after cleanup - self.assertTrue( - get_memory_stats() == baseline, - "Memory stats do not match baseline after cleanup.", - ) + if not TEST_CUDAMALLOCASYNC: + # These stat checks are specific to the native allocator. + if share_mem != "Don't share": + self.assertEqual( + reserved_no_sharing # noqa: F821 + - torch.cuda.memory_stats()["reserved_bytes.all.current"], + kSmallBuffer, + ) + else: + reserved_no_sharing = torch.cuda.memory_stats()[ + "reserved_bytes.all.current" + ] - # Running the test function with different parameters - test(1, 1) - test(3, 2) - test(10, 20) + del a, b, c, g0, g1 + # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. + torch.cuda.synchronize() + torch.cuda.empty_cache() + @unittest.skipIf( + (not TEST_CUDA_GRAPH) + or IS_WINDOWS + or ( # appears to still be broken on Windows as of 11.4+ + torch.version.cuda + and int(torch.version.cuda.split(".")[0]) == 11 + and int(torch.version.cuda.split(".")[1]) < 4 + ), + "Graph bindings disallow concurrent replay for CUDA < 11.4, see " + + "https://github.com/pytorch/pytorch/pull/57556", + ) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_capture_reset_recapture(self): + def test_graph_concurrent_replay(self): + torch.cuda.empty_cache() + + size = 1000000 # largeish to help expose race conditions + + def func_with_temps(t, val): + x = t.clone() + val + y = t.clone() + val + return x + y + s = torch.cuda.Stream() - with torch.cuda.stream(s): - a = torch.full((1000,), 1, device="cuda") - g = torch.cuda.CUDAGraph() - torch.cuda.empty_cache() - g.capture_begin() - b = a - for _ in range(10): - b = b + 1 - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() - g.replay() + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() - self.assertTrue(b.sum().item() == 11000.0) + a = torch.ones((size,), device="cuda") - g.reset() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + g0_args = ( + (torch.cuda.graph_pool_handle(),) + if share_mem == "via graph_pool_handle()" + else () + ) + g0.capture_begin(*g0_args) + b = a.clone() + for _ in range(5): + b = func_with_temps(b, 1) + g0.capture_end() - with torch.cuda.stream(s): - g.capture_begin() - b.fill_(2.0) - for _ in range(10): - b = b + 2 - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args + g1.capture_begin(*g1_args) + c = a.clone() + for _ in range(5): + c = func_with_temps(c, 2) + g1.capture_end() - g.replay() - self.assertTrue(b.sum().item() == 22000.0) + # To reproduce data corruption, I need g0 and g1's kernels to run concurrently. + # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead. + # The following pattern helps align device-side execution of g0 and g1's kernels. + torch.cuda.synchronize() + with torch.cuda.stream(s0): + torch.cuda._sleep(1000000) + s1.wait_stream(s0) + g0.replay() + with torch.cuda.stream(s1): + g1.replay() + torch.cuda.current_stream().wait_stream(s0) + torch.cuda.current_stream().wait_stream(s1) - g.reset() - del g + if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"): + # If we used the native allocator and shared mempools, + # we expect the concurrent replays corrupted each other. + self.assertNotEqual(b.sum().item(), size * 94) + self.assertNotEqual(c.sum().item(), size * 156) + else: + # If we EITHER + # - used the native allocator without sharing mempools, OR + # - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe + # we don't expect memory corruption. + self.assertEqual(b.sum().item(), size * 94) + self.assertEqual(c.sum().item(), size * 156) + + del a, b, c, g0, g1 + # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them. + torch.cuda.synchronize() + torch.cuda.empty_cache() @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_debugdump(self): + def test_graph_three_successive(self): torch.cuda.empty_cache() - x = torch.randn(10240000, device="cuda") - y = torch.rand_like(x) - g = torch.cuda.CUDAGraph() - g.enable_debug_mode() - s0 = torch.cuda.Stream() - s1 = torch.cuda.Stream() - s0.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s0): - g.capture_begin() - z = x + y - with torch.cuda.stream(s1): - s1.wait_stream(s0) - w = z + y - s0.wait_stream(s1) - g.capture_end() - s0.synchronize() - torch.cuda.synchronize() - with tempfile.TemporaryDirectory() as tempdir: - g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot")) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_error(self): - # We need to run this test in a separate thread as the error we trigger - # puts the cuda context in a bad state - script = """ -import torch + size = 1000 -g = torch.cuda.CUDAGraph() -try: - g.capture_begin() -except RuntimeError as e: - if "CUDA graphs must be captured on a non-default stream." in str(e): - exit(0) - else: - exit(1) -exit(2) -""" - try: - a = subprocess.check_output( - [sys.executable, "-c", script], - stderr=subprocess.STDOUT, - # On Windows, opening the subprocess with the default CWD makes `import torch` - # fail, so just set CWD to this script's directory - cwd=os.path.dirname(os.path.realpath(__file__)), - ) - except subprocess.CalledProcessError as e: - if e.returncode == 1: - self.assertTrue( - False, - "Error raise by starting capture without a stream is not the expected one", - ) - elif e.returncode == 2: - self.assertTrue( - False, - "Error raised by starting capture without a stream was not caught", - ) + s = torch.cuda.Stream() - @unittest.skipIf( - (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11, - "CUDA >= 11.0 required for graphs", - ) - def test_graph_warn_if_has_zero_nodes(self): - with warnings.catch_warnings(record=True) as caught: - g = torch.cuda.CUDAGraph() - s = torch.cuda.Stream() + for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): + a = torch.ones((size,), device="cuda") + + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() + g2 = torch.cuda.CUDAGraph() + + s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - g.capture_begin() - g.capture_end() - self.assertTrue( - any("The CUDA Graph is empty" in str(w.message) for w in caught) - ) + g0_args = ( + (torch.cuda.graph_pool_handle(),) + if share_mem == "via graph_pool_handle()" + else () + ) + g0.capture_begin(*g0_args) + b = a.clone() + c = b + 1 + d = b + 2 + g0.capture_end() - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @unittest.skipIf( - IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support" - ) - def test_graph_capture_oom(self): - oom_regex = ( - "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory" - ) - with self.assertRaisesRegex(RuntimeError, oom_regex): - with torch.cuda.graph(torch.cuda.CUDAGraph()): - torch.zeros(2**40, device="cuda") + args = (g0.pool(),) if share_mem == "via pool()" else g0_args - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @serialTest() - def test_repeat_graph_capture_cublas_workspace_memory(self): - (x, y, z) = 1024, 512, 64 - a = torch.rand((x, y), device="cuda") - b = torch.rand((y, z), device="cuda") + g1.capture_begin(*args) + e = c + 3 + del c + g1.capture_end() - # warmup - torch.mm(a, b) + g2.capture_begin(*args) + f = d + 4 + g2.capture_end() + torch.cuda.current_stream().wait_stream(s) - free_bytes_before, total_bytes = torch.cuda.mem_get_info() - used_gb_before = (total_bytes - free_bytes_before) / 1e9 + # Tests that replaying in capture order is valid + g0.replay() + g1.replay() + g2.replay() - for i in range(100): - torch_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(torch_graph): - torch.mm(a, b) - torch_graph.replay() + self.assertEqual(e.sum().item(), size * 5) + self.assertEqual(f.sum().item(), size * 7) - free_bytes_after, _ = torch.cuda.mem_get_info() - used_gb_after = (total_bytes - free_bytes_after) / 1e9 + # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool + g0.replay() + g2.replay() + g1.replay() - self.assertFalse(used_gb_before + 0.1 < used_gb_after) + expect_corruption = (not TEST_CUDAMALLOCASYNC) and ( + share_mem != "Don't share" + ) + # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f. + # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3". + self.assertEqual( + e.sum().item(), size * (7 + 3) if expect_corruption else size * 5 + ) + self.assertEqual(f.sum().item(), size * 7) + + del a, b, d, e, f, g0, g1, g2 + # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them. + torch.cuda.synchronize() + torch.cuda.empty_cache() @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC, + "CUDA >= 11.0 or ROCM >= 5.3 required for graphs", ) - def test_graph_rng_functional(self): - ops_with_kwargs = ( - (torch.nn.functional.dropout, {"p": 0.1}), - (torch.nn.functional.rrelu, {"training": True}), - ) - size = 10000 - - def run(op, kwargs): - a = torch.randn((size,), device="cuda", dtype=torch.float) - - # Control - torch.cuda.manual_seed(5) - eager_out = a - for _ in range(6): - eager_out = op(eager_out, **kwargs) + def test_graph_memory_stats_and_use_result_after_destroy_graph(self): + kSmallSize = 1048576 + kSmallBuffer = 2097152 + kLargeBuffer = 20971520 + kMinLargeAlloc = 10485760 + kRoundLarge = 2097152 - graph_in = a.clone() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - torch.cuda.manual_seed(5) + elem = 4 - g = torch.cuda.CUDAGraph() - torch.cuda.empty_cache() - g.capture_begin() - graph_out = graph_in - for _ in range(2): - graph_out = op(graph_out, **kwargs) - g.capture_end() - torch.cuda.current_stream().wait_stream(stream) + # this was annoying to write but stresses the expectations pretty rigorously + cases = ( + (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"), + (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"), + ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"), + ( + (kMinLargeAlloc - 512) // elem, + 2, + 2 * kLargeBuffer, + kLargeBuffer, + "large_pool", + ), + ( + (kMinLargeAlloc + 512) // elem, + 3, + 3 + * ( + kRoundLarge + * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge) + ), + kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge), + "large_pool", + ), + ) - # Runs a graphed->eager->graphed sequence of RNG ops. - # replay() plays 2 invocations of the op, so the sequence has 6 - # invocations total, matching Control. - # replay() reads from graph_in and writes to graph_out. - g.replay() - out = op(graph_out, **kwargs) - out = op(out, **kwargs) - graph_in.copy_(out) - g.replay() + stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.") - # If replay() updated RNG state correctly, graph_out - # should now hold data equal to eager_out. - try: - self.assertEqual(eager_out, graph_out) - except Exception as e: - raise RuntimeError("Failed on ", op) from e + gc.collect() + torch.cuda.empty_cache() - # Do the same operations varying seeds - seeds = [6, 128, 9999] + s = torch.cuda.Stream() - for seed in seeds: - torch.cuda.manual_seed(seed) - graph_in.copy_(a) - for _ in range(3): - g.replay() + for ( + numel, + delta_cudaMallocs, + delta_cudaMalloc_bytes, + delta_cudaMalloc_bytes_post_del_g, + pool_string, + ) in cases: + if pool_string == "small_pool": + delta_active_blocks = 3 # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders + delta_active_bytes = ( + numel * elem + 1024 + ) # + 1024 for CUDAGraph's rng seed and offset holders each + else: + delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder + delta_active_bytes = numel * elem - # If the random seed was not updated then the graph would - # generate the same output as in previous check. - try: - self.assertNotEqual(eager_out, graph_out) - except Exception as e: - raise RuntimeError("Failed on ", op) from e + g = torch.cuda.CUDAGraph() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + # Allocation stat estimates assume input is created on the same stream as capture_begin() + # (in other words, the same stream silo as the rng offset holder, which is not allocated from the + # capture's private pool). + a = torch.ones((numel,), device="cuda") - # Now repeat the same operations in non-graphed mode. - torch.cuda.manual_seed(seed) - for _ in range(3): - eager_out.copy_(a) - eager_out = op(eager_out, **kwargs) - eager_out = op(eager_out, **kwargs) + precapture_stats = torch.cuda.memory_stats() - # In the end, graph_out and eager_out must be equal - # as they went under the same set of operations. - try: - self.assertEqual(eager_out, graph_out) - except Exception as e: - raise RuntimeError("Failed on ", op) from e + g.capture_begin() + b = a.clone() + for _ in range(5): + b = b.clone() + 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - # We hold references to all tensors used across streams up til this sync, - # so no need to call record_stream on those tensors. - torch.cuda.synchronize() + gc.collect() - for op, kwargs in ops_with_kwargs: - run(op, kwargs) + postcapture_stats = torch.cuda.memory_stats() - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_rng_distributions(self): - size = 10000 - input = torch.rand((size,), device="cuda", dtype=torch.float) - alloc = torch.empty((size,), device="cuda", dtype=torch.float) + expecteds = ( + delta_cudaMallocs, + delta_cudaMalloc_bytes, + delta_active_blocks, + delta_active_bytes, + ) + # Double checks replay and stats before and after a call to empty_cache + for i in range(2): + for stat, expected in zip(stats_to_check, expecteds): + stat = stat + pool_string + ".current" + current = postcapture_stats[stat] - precapture_stats[stat] - # Torch ops to test with sample args (tuple) and kwargs (dict) - torch_with_args = ( - ("bernoulli", (input.clone(),), {}), - # multinomial uses some uncapturable CUDA calls. - # TODO: reenable multinomial tests if/when the implementation is capturable. - # ("multinomial", (input.clone(), size, True), {}), - # ("multinomial", (input.clone(), size // 2, False), {}), - # TODO: reenable normal test, where std is a device - # tensor, when graph test failures are fixed - # ("normal", (input.clone() + 1, input.clone()), {}), - ("normal", (input.clone() + 1, 1.0), {}), - ("poisson", (input.clone(),), {}), - ("rand", (size,), {"device": "cuda", "dtype": torch.float}), - ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), - ("randn", (size,), {"device": "cuda", "dtype": torch.float}), - ) + # There will only ever be one expandable segment in each of the small and large pools. The way the + # bookeeping is done in the allocator means that we never increment the number of segments. + if self.expandable_segments and "segment" in stat: + expected = 0 + # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an + # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is + # smaller than the page size + if ( + self.expandable_segments + and "reserved" in stat + and (numel == cases[3][0] or numel == cases[4][0]) + ): + expected = 2 * kLargeBuffer - # Tensor methods to test with sample args (tuple) - tensor_with_args = ( - ("bernoulli_", (input.clone(),)), - ("cauchy_", ()), - ("exponential_", ()), - ("geometric_", (0.3,)), - ("log_normal_", ()), - ("normal_", ()), - ("random_", ()), - ("uniform_", ()), - ) + self.assertEqual( + current, + expected, + "Pre to post capture delta of " + + stat + + f" = {current}, expected = {expected}, numel = {numel}", + ) - def run(module, op, args, kwargs): - torch.cuda.manual_seed(5) + g.replay() + self.assertEqual(b.sum().item(), 6 * numel) + if i == 0: + torch.cuda.empty_cache() - # Each path runs a dummy op to increment the state a bit before creating controls. - if module == "torch": - dummy = getattr(torch, op)(*args, **kwargs) - control1 = getattr(torch, op)(*args, **kwargs) - control2 = getattr(torch, op)(*args, **kwargs) - else: - dummy = alloc.clone() - control1 = alloc.clone() - control2 = alloc.clone() - getattr(dummy, op)(*args) - getattr(control1, op)(*args) - getattr(control2, op)(*args) + del g + gc.collect() + torch.cuda.empty_cache() + postdel_stats = torch.cuda.memory_stats() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - torch.cuda.manual_seed(5) + # Uses graph result b after graph has been deleted + self.assertEqual(b.sum().item(), 6 * numel) - g = torch.cuda.CUDAGraph() - torch.cuda.empty_cache() - if module == "torch": - g.capture_begin() - t1 = getattr(torch, op)(*args, **kwargs) - t2 = getattr(torch, op)(*args, **kwargs) - g.capture_end() - else: - t1 = alloc.clone() - t2 = alloc.clone() - g.capture_begin() - getattr(t1, op)(*args) - getattr(t2, op)(*args) - g.capture_end() - torch.cuda.current_stream().wait_stream(stream) + # b should be the only live reference remaining from the graph's private pool + expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem) + for stat, expected in zip(stats_to_check, expecteds): + stat = stat + pool_string + ".current" + current = postdel_stats[stat] - precapture_stats[stat] - if not TEST_CUDAMALLOCASYNC: - # Makes sure values haven't been populated yet - # (in other words, makes sure capture didn't actually run ops). - # We can only try this with the native allocator, for which captured - # addresses are already backed by cudaMalloced memory. - # If we try it with cudaMallocAsync, CUDA won't event consider - # the captured addresses allocated until replay(), and if we - # access them before replay() we get IMAs. - try: - self.assertNotEqual(control1, t1) - self.assertNotEqual(control2, t2) - except Exception as e: - raise RuntimeError("Failed on " + module + "." + op) from e + # There will only ever be one expandable segment in each of the small and large pools. The way the + # bookeeping is done in the allocator means that we never increment the number of segments. + if self.expandable_segments and "segment" in stat: + expected = 0 + # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an + # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is + # smaller than the page size + if ( + self.expandable_segments + and "reserved" in stat + and numel == cases[3][0] + ): + expected = 2 * kLargeBuffer + if ( + self.expandable_segments + and "reserved" in stat + and numel == cases[4][0] + ): + expected = kLargeBuffer - # Set a new seed to check if graph would use it - for seed in [6, 314, 271]: - torch.cuda.manual_seed(seed) - # Runs a dummy op prelude, as for controls, to make sure replay() - # picks up the dummy op's state increment. - if module == "torch": - dummy = getattr(torch, op)(*args, **kwargs) - control1 = getattr(torch, op)(*args, **kwargs) - control2 = getattr(torch, op)(*args, **kwargs) - else: - getattr(dummy, op)(*args) - getattr(control1, op)(*args) - getattr(control2, op)(*args) + self.assertEqual( + current, + expected, + "Pre capture to post graph delete delta of " + + stat + + f" = {current}, expected = {expected}, numel = {numel}", + ) - torch.cuda.manual_seed(seed) - if module == "torch": - dummy = getattr(torch, op)(*args, **kwargs) - else: - getattr(dummy, op)(*args) + # del a, b before the next case is essential, otherwise overwriting a and b in the next case + # can throw off its allocation/deallocation counts. + del a, b + # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. + torch.cuda.synchronize() + torch.cuda.empty_cache() - # see above comment on TEST_CUDAMALLOCASYNC - if not TEST_CUDAMALLOCASYNC: - t1.copy_(alloc) - t2.copy_(alloc) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_record_stream(self): + # Makes sure graph capture defers attempting to reclaim allocations used across streams. See + # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp + torch.cuda.empty_cache() - # Runs RNG ops that fill t1 and t2. - g.replay() + potential_problem = torch.zeros((3,), device="cuda") + a = torch.zeros((3,), device="cuda") + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + g = torch.cuda.CUDAGraph() - try: - self.assertEqual(control1, t1) - self.assertEqual(control2, t2) - except Exception as e: - raise RuntimeError("Failed on " + module + "." + op) from e + torch.cuda.synchronize() + with torch.cuda.stream(s0): + potential_problem.record_stream(s0) + torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES) + potential_problem.fill_(1.0) + del potential_problem - # We hold references to all tensors used across streams up til this sync, - # so no need to call record_stream on those tensors. - torch.cuda.synchronize() + with torch.cuda.stream(s1): + g.capture_begin() + # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc + # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life + # event, which will cause the capture to error. + b = a.clone() - for op_with_args in torch_with_args: - run("torch", *op_with_args) + # Let's also see what happens if we record_stream on a tensor during capture. + s2.wait_stream(s1) + with torch.cuda.stream(s2): + b.fill_(1.0) + b.record_stream(s2) # dummy record_stream + del b + s1.wait_stream(s2) + g.capture_end() + torch.cuda.synchronize() - for meth_with_args in tensor_with_args: - # Adds an empty dict for kwargs, which none of the Tensor methods use - run("Tensor", *(meth_with_args + ({},))) + # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event. + c = torch.zeros((3,), device="cuda") + @skipIfRocm @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_two_successive(self): + # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize + # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior + # as a memory leak unless we skip the leak check. + @skipCUDAMemoryLeakCheckIf(True) + @serialTest() + def test_graph_cudnn_dropout(self): + # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp. + # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should + # avoid syncing noncapturing streams with captured events or vice versa. torch.cuda.empty_cache() - size = 1000 - kSmallBuffer = 2097152 + model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda() + x = torch.ones(100, 192, 512, device="cuda") - def func_with_temps(t, val): - x = t.clone() + val - y = t.clone() + val - return x + y + y = model(x) + g = torch.cuda.CUDAGraph() s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + g.capture_begin() + y = model(x) + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): - g0 = torch.cuda.CUDAGraph() - g1 = torch.cuda.CUDAGraph() + g.replay() - a = torch.ones((size,), device="cuda") + y = model(x) - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - g0_args = ( - (torch.cuda.graph_pool_handle(),) - if share_mem == "via graph_pool_handle()" - else () - ) - g0.capture_begin(*g0_args) - b = a.clone() - for _ in range(5): - b = func_with_temps(b, 1) - g0.capture_end() - - g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args - g1.capture_begin(*g1_args) - for _ in range(5): - b = func_with_temps(b, 1) - g1.capture_end() - torch.cuda.current_stream().wait_stream(s) - - # mixes unrelated eager ops with replays - c = a.clone() - for _ in range(2): - c = func_with_temps(c, 3) - g0.replay() - for _ in range(2): - c = func_with_temps(c, 3) - g1.replay() - for _ in range(2): - c = func_with_temps(c, 3) - - self.assertEqual(b.sum().item(), size * 3070) - self.assertEqual(c.sum().item(), size * 442) - - if not TEST_CUDAMALLOCASYNC: - # These stat checks are specific to the native allocator. - if share_mem != "Don't share": - self.assertEqual( - reserved_no_sharing # noqa: F821 - - torch.cuda.memory_stats()["reserved_bytes.all.current"], - kSmallBuffer, - ) - else: - reserved_no_sharing = torch.cuda.memory_stats()[ - "reserved_bytes.all.current" - ] - - del a, b, c, g0, g1 - # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. - torch.cuda.synchronize() - torch.cuda.empty_cache() - - @unittest.skipIf( - (not TEST_CUDA_GRAPH) - or IS_WINDOWS - or ( # appears to still be broken on Windows as of 11.4+ - torch.version.cuda - and int(torch.version.cuda.split(".")[0]) == 11 - and int(torch.version.cuda.split(".")[1]) < 4 - ), - "Graph bindings disallow concurrent replay for CUDA < 11.4, see " - + "https://github.com/pytorch/pytorch/pull/57556", - ) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_concurrent_replay(self): - torch.cuda.empty_cache() - - size = 1000000 # largeish to help expose race conditions + @parametrize( + "with_amp,cache_enabled,allow_unused_input", + [ + subtest((False, False, True), decorators=[skipIfRocm]), + subtest((True, False, True), decorators=[skipIfRocm]), + subtest((True, True, True), decorators=[unittest.expectedFailure]), + subtest((False, False, False), decorators=[unittest.expectedFailure]), + ], + name_fn=lambda x, y, z: "{}{}{}".format( + {True: "with_amp", False: "without_amp"}[x], + {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", + {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], + ), + ) + @serialTest() + def test_graph_make_graphed_callables( + self, with_amp, cache_enabled, allow_unused_input + ): + torch.manual_seed(5) + torch.cuda.manual_seed(5) - def func_with_temps(t, val): - x = t.clone() + val - y = t.clone() + val - return x + y + N, D_in, H, D_out = 640, 4096, 2048, 1024 - s = torch.cuda.Stream() + class MLP1(torch.nn.Module): + def __init__(self, D_in: int, H: int, D_out: int): + super().__init__() + self.net_1 = torch.nn.Sequential( + torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) + ).cuda() + self.net_2 = torch.nn.Sequential( + torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) + ).cuda() - for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): - g0 = torch.cuda.CUDAGraph() - g1 = torch.cuda.CUDAGraph() + def forward(self, input_dict: dict): + x = input_dict["x"] + return self.net_2(self.net_1(x)) - s0 = torch.cuda.Stream() - s1 = torch.cuda.Stream() + class MLP2(torch.nn.Module): + def __init__(self, D_in: int, H: int, D_out: int): + super().__init__() + self.net_1 = torch.nn.Sequential( + torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) + ).cuda() + self.net_2 = torch.nn.Sequential( + torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) + ).cuda() - a = torch.ones((size,), device="cuda") + def forward(self, x): + return self.net_2(self.net_1(x)) - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - g0_args = ( - (torch.cuda.graph_pool_handle(),) - if share_mem == "via graph_pool_handle()" - else () + class ParameterlessModule(torch.nn.Module): + def forward(self, x): + idx = ( + torch.arange(x.size(0), device=x.device) + .view(-1, 1) + .repeat(1, x.size(1)) ) - g0.capture_begin(*g0_args) - b = a.clone() - for _ in range(5): - b = func_with_temps(b, 1) - g0.capture_end() - - g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args - g1.capture_begin(*g1_args) - c = a.clone() - for _ in range(5): - c = func_with_temps(c, 2) - g1.capture_end() + return {"output": torch.gather(x, 0, idx)} - # To reproduce data corruption, I need g0 and g1's kernels to run concurrently. - # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead. - # The following pattern helps align device-side execution of g0 and g1's kernels. - torch.cuda.synchronize() - with torch.cuda.stream(s0): - torch.cuda._sleep(1000000) - s1.wait_stream(s0) - g0.replay() - with torch.cuda.stream(s1): - g1.replay() - torch.cuda.current_stream().wait_stream(s0) - torch.cuda.current_stream().wait_stream(s1) + models = [] + for _ in range(2): + model_section1 = MLP1(D_in, H, H).cuda() + model_section2 = MLP2(H, H, D_out).cuda() + model_section3 = ParameterlessModule().cuda() + models.append( + torch.nn.Sequential(model_section1, model_section2, model_section3) + ) - if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"): - # If we used the native allocator and shared mempools, - # we expect the concurrent replays corrupted each other. - self.assertNotEqual(b.sum().item(), size * 94) - self.assertNotEqual(c.sum().item(), size * 156) - else: - # If we EITHER - # - used the native allocator without sharing mempools, OR - # - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe - # we don't expect memory corruption. - self.assertEqual(b.sum().item(), size * 94) - self.assertEqual(c.sum().item(), size * 156) + model_graphed = models[0] + model_control = models[1] - del a, b, c, g0, g1 - # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them. - torch.cuda.synchronize() - torch.cuda.empty_cache() + model_graphed.load_state_dict(model_control.state_dict()) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_three_successive(self): - torch.cuda.empty_cache() + opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1) + opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1) - size = 1000 + x = torch.randn(N, D_in, device="cuda") + h = torch.randn(N, H, device="cuda", requires_grad=True) + h2 = torch.randn(N, D_out, device="cuda", requires_grad=True) + unused_input = torch.randn(N, H, device="cuda", requires_grad=True) + y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True) + y = torch.randn(N, D_out, device="cuda") - s = torch.cuda.Stream() + loss_fn_control = torch.nn.functional.mse_loss + relu_control = torch.nn.functional.relu - for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): - a = torch.ones((size,), device="cuda") + # This is a good stress test. It graphs four callables: two Modules and two python functions. + with torch.amp.autocast( + device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled + ): + ( + model_graphed[0], + model_graphed[1], + model_graphed[2], + relu_graphed, + loss_fn_graphed, + ) = torch.cuda.make_graphed_callables( + ( + model_graphed[0], + model_graphed[1], + model_graphed[2], + relu_control, + loss_fn_control, + ), + ( + ({"x": x, "unused_input": unused_input},), + (h,), + (h2,), + (y_pred,), + (y_pred, y), + ), + allow_unused_input=allow_unused_input, + ) - g0 = torch.cuda.CUDAGraph() - g1 = torch.cuda.CUDAGraph() - g2 = torch.cuda.CUDAGraph() + real_inputs = [torch.rand_like(x) for _ in range(10)] + real_targets = [torch.rand_like(y) for _ in range(10)] - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - g0_args = ( - (torch.cuda.graph_pool_handle(),) - if share_mem == "via graph_pool_handle()" - else () - ) - g0.capture_begin(*g0_args) - b = a.clone() - c = b + 1 - d = b + 2 - g0.capture_end() + for m, opt, relu, loss_fn in zip( + (model_graphed, model_control), + (opt_graphed, opt_control), + (relu_graphed, relu_control), + (loss_fn_graphed, loss_fn_control), + ): + # Resets RNC states before iterations for graphed and ungraphed models, + # so dropout math should be bitwise identical for both. + torch.manual_seed(5) + torch.cuda.manual_seed(5) + for data, target in zip(real_inputs, real_targets): + opt.zero_grad(set_to_none=True) + with torch.amp.autocast( + device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled + ): + y_pred = m({"x": data, "unused_input": unused_input})["output"] + y_pred = relu(y_pred) + loss = loss_fn(y_pred, target) + loss.backward() + opt.step() - args = (g0.pool(),) if share_mem == "via pool()" else g0_args + for p, pc in zip(model_graphed.parameters(), model_control.parameters()): + self.assertEqual(p, pc) - g1.capture_begin(*args) - e = c + 3 - del c - g1.capture_end() + # We graphed the models in training mode. Eval should still run ungraphed. + model_graphed.eval() + model_control.eval() + self.assertEqual( + model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) + ) - g2.capture_begin(*args) - f = d + 4 - g2.capture_end() - torch.cuda.current_stream().wait_stream(s) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @parametrize( + "with_amp,cache_enabled,allow_unused_input", + [ + subtest((False, False, True), decorators=[skipIfRocm]), + subtest((True, False, True), decorators=[skipIfRocm]), + subtest((True, True, True), decorators=[unittest.expectedFailure]), + subtest((False, False, False), decorators=[skipIfRocm]), + ], + name_fn=lambda x, y, z: "{}{}{}".format( + {True: "with_amp", False: "without_amp"}[x], + {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", + {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], + ), + ) + @serialTest() + def test_graph_make_graphed_callables_parameterless_nograd_module( + self, with_amp, cache_enabled, allow_unused_input + ): + torch.manual_seed(5) + torch.cuda.manual_seed(5) - # Tests that replaying in capture order is valid - g0.replay() - g1.replay() - g2.replay() + N, D_in, H, D_out = 640, 4096, 2048, 1024 - self.assertEqual(e.sum().item(), size * 5) - self.assertEqual(f.sum().item(), size * 7) + class ParameterlessModule(torch.nn.Module): + def forward(self, input_dict: dict): + x = input_dict["x"] + idx = ( + torch.arange(x.size(0), device=x.device) + .view(-1, 1) + .repeat(1, x.size(1)) + ) + return {"output": torch.gather(x, 0, idx)} - # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool - g0.replay() - g2.replay() - g1.replay() + models = [] + for _ in range(2): + model_section1 = ParameterlessModule().cuda() + models.append(torch.nn.Sequential(model_section1)) - expect_corruption = (not TEST_CUDAMALLOCASYNC) and ( - share_mem != "Don't share" - ) - # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f. - # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3". - self.assertEqual( - e.sum().item(), size * (7 + 3) if expect_corruption else size * 5 + model_graphed = models[0] + model_control = models[1] + + model_graphed.load_state_dict(model_control.state_dict()) + + x = torch.randn(N, D_in, device="cuda", requires_grad=False) + unused_input = torch.randn(N, H, device="cuda", requires_grad=False) + y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False) + y = torch.randn(N, D_in, device="cuda") + + # This is a good stress test. It graphs four callables: two Modules and two python functions. + with torch.amp.autocast( + device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled + ): + model_graphed[0] = torch.cuda.make_graphed_callables( + model_graphed[0], + ({"x": x, "unused_input": unused_input},), + allow_unused_input=allow_unused_input, ) - self.assertEqual(f.sum().item(), size * 7) - del a, b, d, e, f, g0, g1, g2 - # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them. - torch.cuda.synchronize() - torch.cuda.empty_cache() + real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)] + real_targets = [torch.rand_like(y) for _ in range(10)] + + for m in (model_graphed, model_control): + # Resets RNC states before iterations for graphed and ungraphed models, + # so dropout math should be bitwise identical for both. + torch.manual_seed(5) + torch.cuda.manual_seed(5) + for data, _ in zip(real_inputs, real_targets): + with torch.amp.autocast( + device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled + ): + out = m({"x": data, "unused_input": unused_input})["output"] + + # We graphed the models in training mode. Eval should still run ungraphed. + model_graphed.eval() + model_control.eval() + self.assertEqual( + model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) + ) @unittest.skipIf( - (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC, - "CUDA >= 11.0 or ROCM >= 5.3 required for graphs", + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_memory_stats_and_use_result_after_destroy_graph(self): - kSmallSize = 1048576 - kSmallBuffer = 2097152 - kLargeBuffer = 20971520 - kMinLargeAlloc = 10485760 - kRoundLarge = 2097152 - - elem = 4 + def test_graph_make_graphed_callables_same_pool(self): + torch.manual_seed(5) + torch.cuda.manual_seed(5) + models = [] + num_models = 3 + for _ in range(num_models): + models.append( + torch.nn.Sequential( + torch.nn.Linear(32, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 128), + ).cuda() + ) + # we will reuse the same pool for all graph captures + mempool = torch.cuda.graph_pool_handle() + graphed_models = [] + for model in models: + x = torch.randn([64, 32], device="cuda") + graphed_model = deepcopy(model) + graphed_model = torch.cuda.make_graphed_callables( + graphed_model, (x,), pool=mempool + ) + graphed_models.append(graphed_model) - # this was annoying to write but stresses the expectations pretty rigorously - cases = ( - (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"), - (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"), - ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"), - ( - (kMinLargeAlloc - 512) // elem, - 2, - 2 * kLargeBuffer, - kLargeBuffer, - "large_pool", - ), - ( - (kMinLargeAlloc + 512) // elem, - 3, - 3 - * ( - kRoundLarge - * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge) - ), - kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge), - "large_pool", - ), - ) + for model, graphed_model in zip(models, graphed_models): + x = torch.randn([64, 32], device="cuda") + y = model(x) + yg = graphed_model(x) + l = y.norm() + lg = yg.norm() + l.backward() + lg.backward() - stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.") + self.assertEqual(y, yg) + self.assertEqual(l, lg) + for p, pg in zip(model.parameters(), graphed_model.parameters()): + self.assertEqual(p, pg) + self.assertEqual(p.grad, pg.grad) + self.assertNotEqual(p.data_ptr(), pg.data_ptr()) + self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr()) - gc.collect() - torch.cuda.empty_cache() + def _test_graphed_optimizer( + self, steps_warmup, steps_train, optimizer_ctor, kwargs + ): + for actually_do_graphs in (True, False): + params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [ + torch.randn((), device="cuda") + ] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] - s = torch.cuda.Stream() + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] - for ( - numel, - delta_cudaMallocs, - delta_cudaMalloc_bytes, - delta_cudaMalloc_bytes_post_del_g, - pool_string, - ) in cases: - if pool_string == "small_pool": - delta_active_blocks = 3 # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders - delta_active_bytes = ( - numel * elem + 1024 - ) # + 1024 for CUDAGraph's rng seed and offset holders each - else: - delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder - delta_active_bytes = numel * elem + # Control (capturable=False) - g = torch.cuda.CUDAGraph() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - # Allocation stat estimates assume input is created on the same stream as capture_begin() - # (in other words, the same stream silo as the rng offset holder, which is not allocated from the - # capture's private pool). - a = torch.ones((numel,), device="cuda") + opt = optimizer_ctor(params_control, capturable=False, **kwargs) - precapture_stats = torch.cuda.memory_stats() + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads[i][j] + opt.step() - g.capture_begin() - b = a.clone() - for _ in range(5): - b = b.clone() + 1 - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + # capturable=True - gc.collect() + opt = optimizer_ctor(params_graphed, capturable=True, **kwargs) - postcapture_stats = torch.cuda.memory_stats() + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads[i][j] + opt.step() - expecteds = ( - delta_cudaMallocs, - delta_cudaMalloc_bytes, - delta_active_blocks, - delta_active_bytes, - ) - # Double checks replay and stats before and after a call to empty_cache - for i in range(2): - for stat, expected in zip(stats_to_check, expecteds): - stat = stat + pool_string + ".current" - current = postcapture_stats[stat] - precapture_stats[stat] + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + opt.step() - # There will only ever be one expandable segment in each of the small and large pools. The way the - # bookeeping is done in the allocator means that we never increment the number of segments. - if self.expandable_segments and "segment" in stat: - expected = 0 - # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an - # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is - # smaller than the page size - if ( - self.expandable_segments - and "reserved" in stat - and (numel == cases[3][0] or numel == cases[4][0]) - ): - expected = 2 * kLargeBuffer - - self.assertEqual( - current, - expected, - "Pre to post capture delta of " - + stat - + f" = {current}, expected = {expected}, numel = {numel}", - ) - - g.replay() - self.assertEqual(b.sum().item(), 6 * numel) - if i == 0: - torch.cuda.empty_cache() - - del g - gc.collect() - torch.cuda.empty_cache() - postdel_stats = torch.cuda.memory_stats() - - # Uses graph result b after graph has been deleted - self.assertEqual(b.sum().item(), 6 * numel) - - # b should be the only live reference remaining from the graph's private pool - expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem) - for stat, expected in zip(stats_to_check, expecteds): - stat = stat + pool_string + ".current" - current = postdel_stats[stat] - precapture_stats[stat] - - # There will only ever be one expandable segment in each of the small and large pools. The way the - # bookeeping is done in the allocator means that we never increment the number of segments. - if self.expandable_segments and "segment" in stat: - expected = 0 - # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an - # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is - # smaller than the page size - if ( - self.expandable_segments - and "reserved" in stat - and numel == cases[3][0] - ): - expected = 2 * kLargeBuffer - if ( - self.expandable_segments - and "reserved" in stat - and numel == cases[4][0] - ): - expected = kLargeBuffer - - self.assertEqual( - current, - expected, - "Pre capture to post graph delete delta of " - + stat - + f" = {current}, expected = {expected}, numel = {numel}", - ) + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads[i + steps_warmup][j] + opt.step() - # del a, b before the next case is essential, otherwise overwriting a and b in the next case - # can throw off its allocation/deallocation counts. - del a, b - # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. - torch.cuda.synchronize() - torch.cuda.empty_cache() + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_record_stream(self): - # Makes sure graph capture defers attempting to reclaim allocations used across streams. See - # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp - torch.cuda.empty_cache() - - potential_problem = torch.zeros((3,), device="cuda") - a = torch.zeros((3,), device="cuda") - s0 = torch.cuda.Stream() - s1 = torch.cuda.Stream() - s2 = torch.cuda.Stream() - g = torch.cuda.CUDAGraph() + def test_graph_optims_with_explicitly_capturable_param_groups(self): + # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__ + n_warmup, n_replay = 3, 2 + for optimizer, second_param_group_capturable in product( + ( + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.ASGD, + torch.optim.Adamax, + torch.optim.NAdam, + torch.optim.RAdam, + torch.optim.Adadelta, + torch.optim.RMSprop, + torch.optim.Rprop, + ), + (True, False), + ): + ref_p1, param1 = ( + torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) + ) + ref_p2, param2 = ( + torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) + ) + grads1, grads2 = ( + [torch.randn_like(param1) for _ in range(n_warmup + n_replay)] + for _ in range(2) + ) + ref_grads1, ref_grads2 = ( + [t.clone() for t in tensors] for tensors in (grads1, grads2) + ) + params = [ + {"params": [param1], "capturable": True}, + {"params": [param2], "capturable": second_param_group_capturable}, + ] + opt = optimizer(params) + opt_ = optimizer( + [ + {"params": [ref_p1], "capturable": False}, + {"params": [ref_p2], "capturable": False}, + ] + ) - torch.cuda.synchronize() - with torch.cuda.stream(s0): - potential_problem.record_stream(s0) - torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES) - potential_problem.fill_(1.0) - del potential_problem + for i in range(n_warmup + n_replay): + ref_p1.grad = ref_grads1[i] + ref_p2.grad = ref_grads2[i] + opt_.step() - with torch.cuda.stream(s1): - g.capture_begin() - # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc - # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life - # event, which will cause the capture to error. - b = a.clone() + for i in range(n_warmup): + param1.grad = grads1[i] + param2.grad = grads2[i] + opt.step() - # Let's also see what happens if we record_stream on a tensor during capture. - s2.wait_stream(s1) - with torch.cuda.stream(s2): - b.fill_(1.0) - b.record_stream(s2) # dummy record_stream - del b - s1.wait_stream(s2) - g.capture_end() - torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + if not second_param_group_capturable: + with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"): + with torch.cuda.graph(g): + opt.step() + else: + with torch.cuda.graph(g): + opt.step() - # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event. - c = torch.zeros((3,), device="cuda") + for i in range(n_replay): + param1.grad.copy_(grads1[n_warmup + i]) + param2.grad.copy_(grads2[n_warmup + i]) + g.replay() + self.assertEqual(ref_p1, param1) + self.assertEqual(ref_p2, param2) - @skipIfRocm @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize - # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior - # as a memory leak unless we skip the leak check. - @skipCUDAMemoryLeakCheckIf(True) - @serialTest() - def test_graph_cudnn_dropout(self): - # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp. - # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should - # avoid syncing noncapturing streams with captured events or vice versa. - torch.cuda.empty_cache() + def test_cuda_graph_error_options(self): + def fn(): + x = torch.zeros([2000], device="cuda") + y = x + x + x + return y - model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda() - x = torch.ones(100, 192, 512, device="cuda") + mem = None - y = model(x) + def raw_malloc(): + global mem + mem = None + stream = torch.cuda.Stream() + try: + with torch.cuda.stream(stream): + mem = torch.cuda.caching_allocator_alloc(1024) + except BaseException: + if mem is None: + return + try: + torch.cuda.caching_allocator_delete(mem) + mem = None + return None + except BaseException: + pass - g = torch.cuda.CUDAGraph() - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - g.capture_begin() - y = model(x) - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + def throws_on_cuda_event(capture_error_mode): + graph = torch.cuda.CUDAGraph() + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + fn() + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + try: + with torch.cuda.graph( + graph, stream=stream, capture_error_mode=capture_error_mode + ): + out = fn() + thread = threading.Thread(target=raw_malloc) + thread.start() + thread.join() + except Exception: + if mem is not None: + torch.cuda.caching_allocator_delete(mem) + return True - g.replay() + return False - y = model(x) + self.assertFalse(throws_on_cuda_event("thread_local")) + self.assertFalse(throws_on_cuda_event("relaxed")) + + # Exception would Corrupt Process and make other tests fail + # self.assertTrue(throws_on_cuda_event("global")) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - @parametrize( - "with_amp,cache_enabled,allow_unused_input", - [ - subtest((False, False, True), decorators=[skipIfRocm]), - subtest((True, False, True), decorators=[skipIfRocm]), - subtest((True, True, True), decorators=[unittest.expectedFailure]), - subtest((False, False, False), decorators=[unittest.expectedFailure]), - ], - name_fn=lambda x, y, z: "{}{}{}".format( - {True: "with_amp", False: "without_amp"}[x], - {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", - {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], - ), - ) - @serialTest() - def test_graph_make_graphed_callables( - self, with_amp, cache_enabled, allow_unused_input - ): - torch.manual_seed(5) - torch.cuda.manual_seed(5) - - N, D_in, H, D_out = 640, 4096, 2048, 1024 + def test_cuda_graph_allocator_propagates_stream(self): + segments = torch.cuda.memory_snapshot() + existing_pools = {s["segment_pool_id"] for s in segments} + x = torch.randn(10240000, device="cuda") + y = torch.rand_like(x) + g = torch.cuda.CUDAGraph() + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() + s0.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s0): + g.capture_begin() + z = x + y + with torch.cuda.stream(s1): + s1.wait_stream(s0) + w = z + y + s0.wait_stream(s1) + with torch.cuda.stream(s0): + g.capture_end() + segments = torch.cuda.memory_snapshot() + x = [ + s["segment_pool_id"] + for s in segments + if s["segment_pool_id"] not in existing_pools + ] + self.assertEqual(len(x), 2) + self.assertEqual(x[0], x[1]) - class MLP1(torch.nn.Module): - def __init__(self, D_in: int, H: int, D_out: int): - super().__init__() - self.net_1 = torch.nn.Sequential( - torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) - ).cuda() - self.net_2 = torch.nn.Sequential( - torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) - ).cuda() + def test_batch_norm_gather_stats(self): + input = torch.randn(1, 3, 3, 3, device="cuda") + mean, invstd = torch.batch_norm_gather_stats( + input, + mean=torch.ones(2, 3, device="cuda"), + invstd=torch.ones(2, 3, device="cuda"), + running_mean=None, + running_var=None, + momentum=0.1, + eps=1e-5, + count=2, + ) + self.assertEqual(mean, torch.ones(3, device="cuda")) + self.assertEqual(invstd, torch.ones(3, device="cuda")) - def forward(self, input_dict: dict): - x = input_dict["x"] - return self.net_2(self.net_1(x)) + def test_matmul_memory_use(self): + def get_max_used(): + torch.cuda.synchronize() + val = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + return val - class MLP2(torch.nn.Module): - def __init__(self, D_in: int, H: int, D_out: int): - super().__init__() - self.net_1 = torch.nn.Sequential( - torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) - ).cuda() - self.net_2 = torch.nn.Sequential( - torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) - ).cuda() + a = torch.rand(1, 32, 32, device="cuda") + b = torch.rand(24, 32, 1, device="cuda") - def forward(self, x): - return self.net_2(self.net_1(x)) + get_max_used() - class ParameterlessModule(torch.nn.Module): - def forward(self, x): - idx = ( - torch.arange(x.size(0), device=x.device) - .view(-1, 1) - .repeat(1, x.size(1)) - ) - return {"output": torch.gather(x, 0, idx)} + torch.matmul(a, b) - models = [] - for _ in range(2): - model_section1 = MLP1(D_in, H, H).cuda() - model_section2 = MLP2(H, H, D_out).cuda() - model_section3 = ParameterlessModule().cuda() - models.append( - torch.nn.Sequential(model_section1, model_section2, model_section3) - ) + matmul_mem = get_max_used() - model_graphed = models[0] - model_control = models[1] + a = a.expand(24, 32, 32) + torch.matmul(a, b) - model_graphed.load_state_dict(model_control.state_dict()) + matmul_expand_mem = get_max_used() - opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1) - opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1) + torch.bmm(a, b) - x = torch.randn(N, D_in, device="cuda") - h = torch.randn(N, H, device="cuda", requires_grad=True) - h2 = torch.randn(N, D_out, device="cuda", requires_grad=True) - unused_input = torch.randn(N, H, device="cuda", requires_grad=True) - y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True) - y = torch.randn(N, D_out, device="cuda") + bmm_mem = get_max_used() - loss_fn_control = torch.nn.functional.mse_loss - relu_control = torch.nn.functional.relu + self.assertEqual(matmul_expand_mem, matmul_mem) + self.assertEqual(bmm_mem, matmul_mem) - # This is a good stress test. It graphs four callables: two Modules and two python functions. - with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): - ( - model_graphed[0], - model_graphed[1], - model_graphed[2], - relu_graphed, - loss_fn_graphed, - ) = torch.cuda.make_graphed_callables( - ( - model_graphed[0], - model_graphed[1], - model_graphed[2], - relu_control, - loss_fn_control, - ), - ( - ({"x": x, "unused_input": unused_input},), - (h,), - (h2,), - (y_pred,), - (y_pred, y), - ), - allow_unused_input=allow_unused_input, - ) + @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test") + def test_rocm_backward_pass_guard(self): + # The test exercises a ROCm-specific feature. - real_inputs = [torch.rand_like(x) for _ in range(10)] - real_targets = [torch.rand_like(y) for _ in range(10)] + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, constant): + self.assertFalse(torch._C._rocm_is_backward_pass()) + ctx.constant = constant + return tensor * constant - for m, opt, relu, loss_fn in zip( - (model_graphed, model_control), - (opt_graphed, opt_control), - (relu_graphed, relu_control), - (loss_fn_graphed, loss_fn_control), - ): - # Resets RNC states before iterations for graphed and ungraphed models, - # so dropout math should be bitwise identical for both. - torch.manual_seed(5) - torch.cuda.manual_seed(5) - for data, target in zip(real_inputs, real_targets): - opt.zero_grad(set_to_none=True) - with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): - y_pred = m({"x": data, "unused_input": unused_input})["output"] - y_pred = relu(y_pred) - loss = loss_fn(y_pred, target) - loss.backward() - opt.step() + @staticmethod + def backward(ctx, grad_output): + self.assertTrue(torch._C._rocm_is_backward_pass()) + return grad_output * ctx.constant, None - for p, pc in zip(model_graphed.parameters(), model_control.parameters()): - self.assertEqual(p, pc) + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) - # We graphed the models in training mode. Eval should still run ungraphed. - model_graphed.eval() - model_control.eval() - self.assertEqual( - model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) - ) + def forward(self, x): + return MyFunction.apply(x, self.a) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @parametrize( - "with_amp,cache_enabled,allow_unused_input", - [ - subtest((False, False, True), decorators=[skipIfRocm]), - subtest((True, False, True), decorators=[skipIfRocm]), - subtest((True, True, True), decorators=[unittest.expectedFailure]), - subtest((False, False, False), decorators=[skipIfRocm]), - ], - name_fn=lambda x, y, z: "{}{}{}".format( - {True: "with_amp", False: "without_amp"}[x], - {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", - {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], - ), - ) - @serialTest() - def test_graph_make_graphed_callables_parameterless_nograd_module( - self, with_amp, cache_enabled, allow_unused_input - ): - torch.manual_seed(5) - torch.cuda.manual_seed(5) + model = MyModule() + criterion = torch.nn.MSELoss(reduction="sum") + optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) - N, D_in, H, D_out = 640, 4096, 2048, 1024 + x = torch.randn(5, 5) + result = model(x) + loss = criterion(result, x) + optimizer.zero_grad() + loss.backward() + optimizer.step() - class ParameterlessModule(torch.nn.Module): - def forward(self, input_dict: dict): - x = input_dict["x"] - idx = ( - torch.arange(x.size(0), device=x.device) - .view(-1, 1) - .repeat(1, x.size(1)) - ) - return {"output": torch.gather(x, 0, idx)} + def test_matmul_device_mismatch(self): + cpu = torch.rand((10, 10)) + cuda = cpu.cuda() + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): + cpu @ cuda + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): + cuda @ cpu - models = [] - for _ in range(2): - model_section1 = ParameterlessModule().cuda() - models.append(torch.nn.Sequential(model_section1)) - - model_graphed = models[0] - model_control = models[1] - - model_graphed.load_state_dict(model_control.state_dict()) + for s, m1, m2 in product((cpu, cuda), repeat=3): + if s.device == m1.device == m2.device: + torch.addmm(s, m1, m2) + else: + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): + torch.addmm(s, m1, m2) - x = torch.randn(N, D_in, device="cuda", requires_grad=False) - unused_input = torch.randn(N, H, device="cuda", requires_grad=False) - y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False) - y = torch.randn(N, D_in, device="cuda") + @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient") + def test_lazy_init(self): + """Validate that no CUDA calls are made during `import torch` call""" - # This is a good stress test. It graphs four callables: two Modules and two python functions. - with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): - model_graphed[0] = torch.cuda.make_graphed_callables( - model_graphed[0], - ({"x": x, "unused_input": unused_input},), - allow_unused_input=allow_unused_input, + def check_output(script: str) -> str: + return ( + subprocess.check_output([sys.executable, "-c", script]) + .decode("ascii") + .strip() ) - real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)] - real_targets = [torch.rand_like(y) for _ in range(10)] - - for m in (model_graphed, model_control): - # Resets RNC states before iterations for graphed and ungraphed models, - # so dropout math should be bitwise identical for both. - torch.manual_seed(5) - torch.cuda.manual_seed(5) - for data, target in zip(real_inputs, real_targets): - with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): - out = m({"x": data, "unused_input": unused_input})["output"] - - # We graphed the models in training mode. Eval should still run ungraphed. - model_graphed.eval() - model_control.eval() - self.assertEqual( - model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) + VISIBLE_DEVICES = ( + "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES" ) - - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_make_graphed_callables_same_pool(self): - torch.manual_seed(5) - torch.cuda.manual_seed(5) - models = [] - num_models = 3 - for _ in range(num_models): - models.append( - torch.nn.Sequential( - torch.nn.Linear(32, 128), - torch.nn.ReLU(), - torch.nn.Linear(128, 128), - ).cuda() + test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())" + rc = check_output(test_script) + self.assertEqual(rc, "0") + if not TEST_WITH_ROCM: + # Check that `cuInit` was not called during the import + # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3 + # See https://github.com/pytorch/pytorch/issues/116276 for more details + libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll" + cuda_driver_api_call = ( + f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))" ) - # we will reuse the same pool for all graph captures - mempool = torch.cuda.graph_pool_handle() - graphed_models = [] - for model in models: - x = torch.randn([64, 32], device="cuda") - graphed_model = deepcopy(model) - graphed_model = torch.cuda.make_graphed_callables( - graphed_model, (x,), pool=mempool + rc = check_output( + f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})" ) - graphed_models.append(graphed_model) + self.assertEqual(rc, "3") - for model, graphed_model in zip(models, graphed_models): - x = torch.randn([64, 32], device="cuda") - y = model(x) - yg = graphed_model(x) - l = y.norm() - lg = yg.norm() - l.backward() - lg.backward() + @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") + def test_hip_device_count(self): + """Validate device_count works with both CUDA/HIP visible devices""" + test_script = """\ +import torch +import os +print(f"{torch.cuda.device_count()}") +""" + custom_envs = [ + {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, + {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"}, + {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"}, + ] - self.assertEqual(y, yg) - self.assertEqual(l, lg) - for p, pg in zip(model.parameters(), graphed_model.parameters()): - self.assertEqual(p, pg) - self.assertEqual(p.grad, pg.grad) - self.assertNotEqual(p.data_ptr(), pg.data_ptr()) - self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr()) + for env_config in custom_envs: + env = os.environ.copy() + for key, value in env_config.items(): + if value is None: + env.pop(key, None) + else: + env[key] = value + r = ( + subprocess.check_output([sys.executable, "-c", test_script], env=env) + .decode("ascii") + .strip() + ) + self.assertEqual("1", r) - def _test_graphed_optimizer( - self, steps_warmup, steps_train, optimizer_ctor, kwargs - ): - for actually_do_graphs in (True, False): - params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [ - torch.randn((), device="cuda") - ] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] + @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices") + def test_device_count_not_cached_pre_init(self): + visible_devices = ( + "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES" + ) + test_script = f"""\ +import torch +import os +r1 = torch.cuda.device_count() +os.environ['{visible_devices}'] = '0' +r2 = torch.cuda.device_count() +torch.empty(10, device='cuda') +print(f"{{r1}}, {{r2}}") +""" - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] + r = ( + subprocess.check_output([sys.executable, "-c", test_script]) + .decode("ascii") + .strip() + ) - # Control (capturable=False) + x = torch.cuda.device_count() + self.assertEqual(f"{x}, 1", r) - opt = optimizer_ctor(params_control, capturable=False, **kwargs) + @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") + def test_gds_fails_in_ci(self): + if IS_WINDOWS or TEST_WITH_ROCM: + error_msg = "is not supported on this platform" + else: + error_msg = "cuFileHandleRegister failed" + with TemporaryFileName() as f: + with self.assertRaisesRegex(RuntimeError, error_msg): + file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads[i][j] - opt.step() - # capturable=True +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") +@torch.testing._internal.common_utils.markDynamoStrictTest +class TestCudaMallocAsync(TestCase): + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + def test_memory_snapshot(self): + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history("state", stacks="python") + # make x the second block in a segment + torch.rand(2 * 311, 411, device="cuda") + unused = torch.rand(310, 410, device="cuda") + x = torch.rand(311, 411, device="cuda") - opt = optimizer_ctor(params_graphed, capturable=True, **kwargs) + # create a bunch of tensors that all will tile into the + # same segment to exercise the history merging code + # 512B is the minimum block size, + # so we allocate all the tensors to this size to make sure + # they tile evenly + tensors = [torch.rand(128, device="cuda") for _ in range(1000)] + while tensors: + del tensors[randint(0, len(tensors) - 1)] - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads[i][j] - opt.step() + # exercise the history trimming code + torch.rand(128 * 5, device="cuda") - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - opt.step() + ss = torch.cuda.memory._snapshot() + found_it = False + for seg in ss["segments"]: + self.assertTrue("frames" in seg) + for b in seg["blocks"]: + if b["requested_size"] == 311 * 411 * 4: + self.assertTrue("test_cuda" in b["frames"][0]["filename"]) + found_it = True + self.assertEqual(x.untyped_storage().data_ptr(), b["address"]) + self.assertTrue(found_it) - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads[i + steps_warmup][j] - opt.step() + if not IS_WINDOWS: + with tempfile.NamedTemporaryFile() as f: + torch.cuda.memory._save_segment_usage(f.name) + with open(f.name) as f2: + self.assertTrue("test_cuda.py" in f2.read()) + del unused + del x + torch.cuda.empty_cache() + ss = torch.cuda.memory._snapshot() + self.assertTrue( + ss["device_traces"][0][-1]["action"] + in ("segment_free", "segment_unmap") + ) - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) + finally: + torch.cuda.memory._record_memory_history(None) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_optims_with_explicitly_capturable_param_groups(self): - # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__ - n_warmup, n_replay = 3, 2 - for optimizer, second_param_group_capturable in product( - ( - torch.optim.Adam, - torch.optim.AdamW, - torch.optim.ASGD, - torch.optim.Adamax, - torch.optim.NAdam, - torch.optim.RAdam, - torch.optim.Adadelta, - torch.optim.RMSprop, - torch.optim.Rprop, - ), - (True, False), - ): - ref_p1, param1 = ( - torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) - ) - ref_p2, param2 = ( - torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) - ) - grads1, grads2 = ( - [torch.randn_like(param1) for _ in range(n_warmup + n_replay)] - for _ in range(2) - ) - ref_grads1, ref_grads2 = ( - [t.clone() for t in tensors] for tensors in (grads1, grads2) - ) - params = [ - {"params": [param1], "capturable": True}, - {"params": [param2], "capturable": second_param_group_capturable}, - ] - opt = optimizer(params) - opt_ = optimizer( - [ - {"params": [ref_p1], "capturable": False}, - {"params": [ref_p2], "capturable": False}, - ] - ) - - for i in range(n_warmup + n_replay): - ref_p1.grad = ref_grads1[i] - ref_p2.grad = ref_grads2[i] - opt_.step() - - for i in range(n_warmup): - param1.grad = grads1[i] - param2.grad = grads2[i] - opt.step() - - g = torch.cuda.CUDAGraph() - if not second_param_group_capturable: - with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"): - with torch.cuda.graph(g): - opt.step() - else: - with torch.cuda.graph(g): - opt.step() + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") + def test_direct_traceback(self): + from torch._C._profiler import gather_traceback, symbolize_tracebacks - for i in range(n_replay): - param1.grad.copy_(grads1[n_warmup + i]) - param2.grad.copy_(grads2[n_warmup + i]) - g.replay() - self.assertEqual(ref_p1, param1) - self.assertEqual(ref_p2, param2) + c = gather_traceback(True, True, True) + (r,) = symbolize_tracebacks([c]) + r = str(r) + self.assertTrue("test_cuda.py" in r) + self.assertTrue("unwind" in r) @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - def test_cuda_graph_error_options(self): - def fn(): - x = torch.zeros([2000], device="cuda") - y = x + x + x - return y - - mem = None - - def raw_malloc(): - global mem - mem = None - stream = torch.cuda.Stream() - try: - with torch.cuda.stream(stream): - mem = torch.cuda.caching_allocator_alloc(1024) - except BaseException: - if mem is None: - return - try: - torch.cuda.caching_allocator_delete(mem) - mem = None - return None - except BaseException: - pass - - def throws_on_cuda_event(capture_error_mode): - graph = torch.cuda.CUDAGraph() - torch.cuda.synchronize() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - fn() - stream.synchronize() - torch.cuda.current_stream().wait_stream(stream) - torch.cuda.synchronize() - try: - with torch.cuda.graph( - graph, stream=stream, capture_error_mode=capture_error_mode - ): - out = fn() - thread = threading.Thread(target=raw_malloc) - thread.start() - thread.join() - except Exception: - if mem is not None: - torch.cuda.caching_allocator_delete(mem) - return True + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_snapshot_with_cpp(self): + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history("state", stacks="all") + x = torch.rand(311, 411, device="cuda") - return False + ss = torch.cuda.memory._snapshot()["segments"] + found_it = False + for seg in ss: + for b in seg["blocks"]: + if b["requested_size"] == 311 * 411 * 4: + self.assertTrue("::rand" in str(b["frames"])) + found_it = True + self.assertTrue(found_it) - self.assertFalse(throws_on_cuda_event("thread_local")) - self.assertFalse(throws_on_cuda_event("relaxed")) + finally: + torch.cuda.memory._record_memory_history(None) - # Exception would Corrupt Process and make other tests fail - # self.assertTrue(throws_on_cuda_event("global")) + @skipIfRocm + def test_memory_profiler_viz(self): + with torch.profiler.profile( + with_stack=True, profile_memory=True, record_shapes=True + ) as prof: + x = torch.rand(128, 128, device="cuda") + x * x + x * x + plot = profile_plot(prof) + plot = json.dumps(_profile_to_snapshot(prof)) + self.assertTrue("test_cuda.py" in plot) + self.assertTrue("test_memory_profiler_viz" in plot) + self.assertTrue("category" in plot) @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - def test_cuda_graph_allocator_propagates_stream(self): - segments = torch.cuda.memory_snapshot() - existing_pools = {s["segment_pool_id"] for s in segments} - x = torch.randn(10240000, device="cuda") - y = torch.rand_like(x) - g = torch.cuda.CUDAGraph() - s0 = torch.cuda.Stream() - s1 = torch.cuda.Stream() - s0.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s0): - g.capture_begin() - z = x + y - with torch.cuda.stream(s1): - s1.wait_stream(s0) - w = z + y - s0.wait_stream(s1) - with torch.cuda.stream(s0): - g.capture_end() - segments = torch.cuda.memory_snapshot() - x = [ - s["segment_pool_id"] - for s in segments - if s["segment_pool_id"] not in existing_pools - ] - self.assertEqual(len(x), 2) - self.assertEqual(x[0], x[1]) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_cycles(self): + fired = False - def test_batch_norm_gather_stats(self): - input = torch.randn(1, 3, 3, 3, device="cuda") - mean, invstd = torch.batch_norm_gather_stats( - input, - mean=torch.ones(2, 3, device="cuda"), - invstd=torch.ones(2, 3, device="cuda"), - running_mean=None, - running_var=None, - momentum=0.1, - eps=1e-5, - count=2, - ) - self.assertEqual(mean, torch.ones(3, device="cuda")) - self.assertEqual(invstd, torch.ones(3, device="cuda")) + def observer(html): + nonlocal fired + fired = True + self.assertTrue("torch.Tensor" in html) + self.assertTrue("test_cuda" in html) + self.assertTrue("cell_contents" in html) - def test_matmul_memory_use(self): - def get_max_used(): - torch.cuda.synchronize() - val = torch.cuda.max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - return val + disarm = observe_tensor_cycles(observer) - a = torch.rand(1, 32, 32, device="cuda") - b = torch.rand(24, 32, 1, device="cuda") + def noop(): + pass - get_max_used() + try: - torch.matmul(a, b) + def create(): + x = torch.empty(3, 4, device="cuda") - matmul_mem = get_max_used() + def foo(p): + if p: + return foo(not p) + else: + return x - a = a.expand(24, 32, 32) - torch.matmul(a, b) + return foo - matmul_expand_mem = get_max_used() + create() + gc.collect() + # the callback has to run outside of the collect + # call so it doesn't actual fire until the next + # method call after a gc.collect + noop() + self.assertTrue(fired) + finally: + disarm() - torch.bmm(a, b) + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_plots(self): + for context, stacks in ( + ("all", "all" if IS_LINUX else "python"), + ("all", "python"), + (None, "python"), + ): + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history( + "all", context=context, stacks=stacks + ) - bmm_mem = get_max_used() - - self.assertEqual(matmul_expand_mem, matmul_mem) - self.assertEqual(bmm_mem, matmul_mem) - - @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test") - def test_rocm_backward_pass_guard(self): - # The test exercises a ROCm-specific feature. - - class MyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, tensor, constant): - self.assertFalse(torch._C._rocm_is_backward_pass()) - ctx.constant = constant - return tensor * constant - - @staticmethod - def backward(ctx, grad_output): - self.assertTrue(torch._C._rocm_is_backward_pass()) - return grad_output * ctx.constant, None + def run(): + x = torch.rand(128, 128, device="cuda") + x * x + x * x - class MyModule(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.a = torch.nn.Parameter(torch.randn(())) + run() + cpp = stacks == "all" + record_context = context is not None + ss = torch.cuda.memory._snapshot() - def forward(self, x): - return MyFunction.apply(x, self.a) + tplot = trace_plot(ss) + splot = segment_plot(ss) + text = json.dumps(ss) - model = MyModule() - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) + self.assertTrue(record_context == ("test_memory_plots" in text)) + self.assertTrue(cpp == ("::rand" in text)) + self.assertTrue(str(128 * 128 * 4) in text) - x = torch.randn(5, 5) - result = model(x) - loss = criterion(result, x) - optimizer.zero_grad() - loss.backward() - optimizer.step() + finally: + torch.cuda.memory._record_memory_history(None) - def test_matmul_device_mismatch(self): - cpu = torch.rand((10, 10)) - cuda = cpu.cuda() - with self.assertRaisesRegex( - RuntimeError, "Expected all tensors to be on the same device" - ): - cpu @ cuda - with self.assertRaisesRegex( - RuntimeError, "Expected all tensors to be on the same device" - ): - cuda @ cpu + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_plots_free_stack(self): + for context in ["alloc", "all", "state"]: + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history(context=context) + x = None - for s, m1, m2 in product((cpu, cuda), repeat=3): - if s.device == m1.device == m2.device: - torch.addmm(s, m1, m2) - else: - with self.assertRaisesRegex( - RuntimeError, "Expected all tensors to be on the same device" - ): - torch.addmm(s, m1, m2) + def thealloc(): + nonlocal x + x = torch.rand(3, 4, device="cuda") - @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient") - def test_lazy_init(self): - """Validate that no CUDA calls are made during `import torch` call""" + def thefree(): + nonlocal x + del x - def check_output(script: str) -> str: - return ( - subprocess.check_output([sys.executable, "-c", script]) - .decode("ascii") - .strip() - ) + thealloc() + thefree() + ss = json.dumps(torch.cuda.memory._snapshot()) + self.assertTrue(("thefree" in ss) == (context == "all")) + self.assertTrue(("thealloc" in ss) == (context != "state")) + finally: + torch.cuda.memory._record_memory_history(None) - VISIBLE_DEVICES = ( - "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES" - ) - test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())" - rc = check_output(test_script) - self.assertEqual(rc, "0") - if not TEST_WITH_ROCM: - # Check that `cuInit` was not called during the import - # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3 - # See https://github.com/pytorch/pytorch/issues/116276 for more details - libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll" - cuda_driver_api_call = ( - f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))" - ) - rc = check_output( - f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})" - ) - self.assertEqual(rc, "3") + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_plots_history_context(self): + try: + torch.cuda.memory.empty_cache() + x = None - @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") - def test_hip_device_count(self): - """Validate device_count works with both CUDA/HIP visible devices""" - test_script = """\ -import torch -import os -print(f"{torch.cuda.device_count()}") -""" - custom_envs = [ - {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, - {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"}, - {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"}, - ] + def should_capture1(): + nonlocal x + x = torch.rand(4, 4, device="cuda") - for env_config in custom_envs: - env = os.environ.copy() - for key, value in env_config.items(): - if value is None: - env.pop(key, None) - else: - env[key] = value - r = ( - subprocess.check_output([sys.executable, "-c", test_script], env=env) - .decode("ascii") - .strip() - ) - self.assertEqual("1", r) + def should_not_capture(): + nonlocal x + x = torch.rand(3, 4, device="cuda") - @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices") - def test_device_count_not_cached_pre_init(self): - visible_devices = ( - "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES" - ) - test_script = f"""\ -import torch -import os -r1 = torch.cuda.device_count() -os.environ['{visible_devices}'] = '0' -r2 = torch.cuda.device_count() -torch.empty(10, device='cuda') -print(f"{{r1}}, {{r2}}") -""" + def should_capture2(): + nonlocal x + x = torch.rand(4, 4, device="cuda") - r = ( - subprocess.check_output([sys.executable, "-c", test_script]) - .decode("ascii") - .strip() - ) + # Recording with context and python call stacks should capture the call stack. + torch.cuda.memory._record_memory_history(context="all", stacks="python") + should_capture1() + # Recording with context=None should not capture the call stack. + torch.cuda.memory._record_memory_history(context=None) + should_not_capture() + # Recording with context and python call stacks should capture the call stack. + torch.cuda.memory._record_memory_history(context="all", stacks="python") + should_capture2() - x = torch.cuda.device_count() - self.assertEqual(f"{x}, 1", r) + ss = json.dumps(torch.cuda.memory._snapshot()) + self.assertTrue("should_capture1" in ss) + self.assertTrue("should_not_capture" not in ss) + self.assertTrue("should_capture2" in ss) + finally: + torch.cuda.memory._record_memory_history(None) - @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") - def test_gds_fails_in_ci(self): - if IS_WINDOWS or TEST_WITH_ROCM: - error_msg = "is not supported on this platform" - else: - error_msg = "cuFileHandleRegister failed" - with TemporaryFileName() as f: - with self.assertRaisesRegex(RuntimeError, error_msg): - file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_plots_free_segment_stack(self): + for context in ["alloc", "all", "state"]: + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history(context=context) + x = torch.rand(3, 4, device="cuda") + del x + torch.cuda.memory.empty_cache() + ss = json.dumps(torch.cuda.memory._snapshot()) + self.assertTrue(("empty_cache" in ss) == (context == "all")) + finally: + torch.cuda.memory._record_memory_history(None) -@torch.testing._internal.common_utils.markDynamoStrictTest -class TestCudaMallocAsync(TestCase): @unittest.skipIf( TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - def test_memory_snapshot(self): + def test_memory_snapshot_script(self): try: torch.cuda.memory.empty_cache() torch.cuda.memory._record_memory_history("state", stacks="python") - # make x the second block in a segment - torch.rand(2 * 311, 411, device="cuda") - unused = torch.rand(310, 410, device="cuda") - x = torch.rand(311, 411, device="cuda") - # create a bunch of tensors that all will tile into the - # same segment to exercise the history merging code - # 512B is the minimum block size, - # so we allocate all the tensors to this size to make sure - # they tile evenly - tensors = [torch.rand(128, device="cuda") for _ in range(1000)] - while tensors: - del tensors[randint(0, len(tensors) - 1)] + @torch.jit.script + def foo(): + return torch.rand(311, 411, device="cuda") - # exercise the history trimming code - torch.rand(128 * 5, device="cuda") + x = foo() - ss = torch.cuda.memory._snapshot() + ss = torch.cuda.memory._snapshot()["segments"] found_it = False - for seg in ss["segments"]: - self.assertTrue("frames" in seg) + for seg in ss: for b in seg["blocks"]: if b["requested_size"] == 311 * 411 * 4: - self.assertTrue("test_cuda" in b["frames"][0]["filename"]) - found_it = True - self.assertEqual(x.untyped_storage().data_ptr(), b["address"]) - self.assertTrue(found_it) - - if not IS_WINDOWS: - with tempfile.NamedTemporaryFile() as f: - torch.cuda.memory._save_segment_usage(f.name) - with open(f.name) as f2: - self.assertTrue("test_cuda.py" in f2.read()) - del unused - del x - torch.cuda.empty_cache() - ss = torch.cuda.memory._snapshot() - self.assertTrue( - ss["device_traces"][0][-1]["action"] - in ("segment_free", "segment_unmap") - ) - - finally: - torch.cuda.memory._record_memory_history(None) - - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") - def test_direct_traceback(self): - from torch._C._profiler import gather_traceback, symbolize_tracebacks - - c = gather_traceback(True, True, True) - (r,) = symbolize_tracebacks([c]) - r = str(r) - self.assertTrue("test_cuda.py" in r) - self.assertTrue("unwind" in r) - - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_snapshot_with_cpp(self): - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history("state", stacks="all") - x = torch.rand(311, 411, device="cuda") - - ss = torch.cuda.memory._snapshot()["segments"] - found_it = False - for seg in ss: - for b in seg["blocks"]: - if b["requested_size"] == 311 * 411 * 4: - self.assertTrue("::rand" in str(b["frames"])) - found_it = True - self.assertTrue(found_it) - - finally: - torch.cuda.memory._record_memory_history(None) - - @skipIfRocm - def test_memory_profiler_viz(self): - with torch.profiler.profile( - with_stack=True, profile_memory=True, record_shapes=True - ) as prof: - x = torch.rand(128, 128, device="cuda") - x * x + x * x - plot = profile_plot(prof) - plot = json.dumps(_profile_to_snapshot(prof)) - self.assertTrue("test_cuda.py" in plot) - self.assertTrue("test_memory_profiler_viz" in plot) - self.assertTrue("category" in plot) - - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_cycles(self): - fired = False - - def observer(html): - nonlocal fired - fired = True - self.assertTrue("torch.Tensor" in html) - self.assertTrue("test_cuda" in html) - self.assertTrue("cell_contents" in html) - - disarm = observe_tensor_cycles(observer) - - def noop(): - pass - - try: - - def create(): - x = torch.empty(3, 4, device="cuda") - - def foo(p): - if p: - return foo(not p) - else: - return x - - return foo - - create() - gc.collect() - # the callback has to run outside of the collect - # call so it doesn't actual fire until the next - # method call after a gc.collect - noop() - self.assertTrue(fired) - finally: - disarm() - - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_plots(self): - for context, stacks in ( - ("all", "all" if IS_LINUX else "python"), - ("all", "python"), - (None, "python"), - ): - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history( - "all", context=context, stacks=stacks - ) - - def run(): - x = torch.rand(128, 128, device="cuda") - x * x + x * x - - run() - cpp = stacks == "all" - record_context = context is not None - ss = torch.cuda.memory._snapshot() - - tplot = trace_plot(ss) - splot = segment_plot(ss) - text = json.dumps(ss) - - self.assertTrue(record_context == ("test_memory_plots" in text)) - self.assertTrue(cpp == ("::rand" in text)) - self.assertTrue(str(128 * 128 * 4) in text) - - finally: - torch.cuda.memory._record_memory_history(None) - - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_plots_free_stack(self): - for context in ["alloc", "all", "state"]: - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history(context=context) - x = None - - def thealloc(): - nonlocal x - x = torch.rand(3, 4, device="cuda") - - def thefree(): - nonlocal x - del x - - thealloc() - thefree() - ss = json.dumps(torch.cuda.memory._snapshot()) - self.assertTrue(("thefree" in ss) == (context == "all")) - self.assertTrue(("thealloc" in ss) == (context != "state")) - finally: - torch.cuda.memory._record_memory_history(None) - - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_plots_history_context(self): - try: - torch.cuda.memory.empty_cache() - x = None - - def should_capture1(): - nonlocal x - x = torch.rand(4, 4, device="cuda") - - def should_not_capture(): - nonlocal x - x = torch.rand(3, 4, device="cuda") - - def should_capture2(): - nonlocal x - x = torch.rand(4, 4, device="cuda") - - # Recording with context and python call stacks should capture the call stack. - torch.cuda.memory._record_memory_history(context="all", stacks="python") - should_capture1() - # Recording with context=None should not capture the call stack. - torch.cuda.memory._record_memory_history(context=None) - should_not_capture() - # Recording with context and python call stacks should capture the call stack. - torch.cuda.memory._record_memory_history(context="all", stacks="python") - should_capture2() - - ss = json.dumps(torch.cuda.memory._snapshot()) - self.assertTrue("should_capture1" in ss) - self.assertTrue("should_not_capture" not in ss) - self.assertTrue("should_capture2" in ss) - finally: - torch.cuda.memory._record_memory_history(None) - - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_plots_free_segment_stack(self): - for context in ["alloc", "all", "state"]: - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history(context=context) - x = torch.rand(3, 4, device="cuda") - del x - torch.cuda.memory.empty_cache() - - ss = json.dumps(torch.cuda.memory._snapshot()) - self.assertTrue(("empty_cache" in ss) == (context == "all")) - finally: - torch.cuda.memory._record_memory_history(None) - - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - def test_memory_snapshot_script(self): - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history("state", stacks="python") - - @torch.jit.script - def foo(): - return torch.rand(311, 411, device="cuda") - - x = foo() - - ss = torch.cuda.memory._snapshot()["segments"] - found_it = False - for seg in ss: - for b in seg["blocks"]: - if b["requested_size"] == 311 * 411 * 4: - self.assertTrue(b["frames"][0]["name"] == "foo") + self.assertTrue(b["frames"][0]["name"] == "foo") found_it = True self.assertTrue(found_it) @@ -4492,7 +3955,7 @@ def reconstruct_from_tensor_metadata(metadata): return t -@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI") +@unittest.skipIf(not TEST_CUDA or TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI") @torch.testing._internal.common_utils.markDynamoStrictTest class TestBlockStateAbsorption(TestCase): @property @@ -4855,6 +4318,7 @@ def test_no_triton_on_import(self): self.assertEqual(rc, "False", "Triton was imported when importing torch!") +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") class TestMemPool(TestCase): def test_mempool_id(self): pool1 = torch.cuda.graph_pool_handle() @@ -4980,6 +4444,7 @@ def create_mempool_and_make_active(): self.assertEqual(len(set(active_pool_ids)), 4) +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") @torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests @@ -5322,6 +4787,7 @@ def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused): self.assertEqual(scaler._growth_tracker, growth_tracker) +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") class TestGDS(TestCase): def _get_tmp_dir_fs_type(self): my_path = os.path.realpath("/tmp") @@ -5356,6 +4822,526 @@ def test_gds_read_write_tensors(self): torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage()) +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") +class TestCudaAutocast(TestAutocast): + def setUp(self): + super().setUp() + self.autocast_lists = AutocastTestLists(torch.device("cuda:0")) + + def tearDown(self): + del self.autocast_lists + super().tearDown() + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = op_with_args[2] # TEST_WITH_ROCM + if not skip_test: + self._run_autocast_outofplace( + op, args, torch.float16, device="cuda", amp_dtype=torch.float16 + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_bf16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = op_with_args[2] # TEST_WITH_ROCM + should_error_from_cudnn = "cudnn" in op and ( + "TORCH_CUDNN_V8_API_DISABLED" in os.environ + and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"]) + or torch.cuda.get_device_capability() < (8, 0) + ) + should_error_from_not_implemented = should_error_from_cudnn + if not skip_test: + if should_error_from_not_implemented: + with self.assertRaises( + RuntimeError, + msg=str(op) + " should not be supported for bfloat16!", + ): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda" + ) + else: + if torch.cuda.is_bf16_supported(): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda" + ) + else: + with self.assertRaisesRegex( + RuntimeError, "Device does not support bfloat16" + ): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda" + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_fp32(self): + for op_with_args in self.autocast_lists.torch_fp32: + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + add_kwargs=maybe_kwargs, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_need_autocast_promote(self): + for op, args in self.autocast_lists.torch_need_autocast_promote: + self._run_autocast_outofplace( + op, args, torch.float32, device="cuda", amp_dtype=torch.float16 + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_expect_builtin_promote(self): + for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + out_type=out_type, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_nn_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.nn_fp16: + self._run_autocast_outofplace( + op, + args, + torch.float16, + device="cuda", + module=torch._C._nn, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_nn_bf16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.nn_fp16: + if torch.cuda.is_bf16_supported(): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda", module=torch._C._nn + ) + else: + with self.assertRaisesRegex( + RuntimeError, "Device does not support bfloat16" + ): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda", module=torch._C._nn + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_nn_fp32(self): + for op, args in self.autocast_lists.nn_fp32: + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + module=torch._C._nn, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_linalg_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.linalg_fp16: + self._run_autocast_outofplace( + op, + args, + torch.float16, + device="cuda", + module=torch._C._linalg, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_methods_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.methods_fp16: + self._run_autocast_outofplace( + op, + args, + torch.float16, + device="cuda", + module=None, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_methods_fp32(self): + for op, args in self.autocast_lists.methods_fp32: + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + module=None, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_methods_expect_builtin_promote(self): + for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote: + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + module=None, + out_type=out_type, + amp_dtype=torch.float16, + ) + + def test_autocast_banned(self): + with torch.autocast("cuda"): + for op, args, module in self.autocast_lists.banned: + with self.assertRaises(RuntimeError): + getattr(module, op)(*args) + + def test_autocast_ignored_types(self): + with torch.autocast("cuda"): + for ignore_type in (torch.double, torch.int32): + a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") + b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") + c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0") + + # Tests if CastPolicy::fp16 ops ignore double and int + # Currently, no ops belonging to this policy support integer inputs. + if ignore_type is torch.double: + with self.assertRaises(RuntimeError): + torch.mm(a_ignore, c_16) + with torch.autocast("cuda", enabled=False): + type_no_autocast = torch.mm(a_ignore, b_ignore).dtype + self.assertTrue( + torch.mm(a_ignore, b_ignore).dtype is type_no_autocast + ) + + # Tests if CastPolicy::fp32 ops ignore double and int + with torch.autocast("cuda", enabled=False): + type_no_autocast = torch.pow(a_ignore, 2.0).dtype + self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast) + + # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int + with torch.autocast("cuda", enabled=False): + type_no_autocast = torch.sum(a_ignore).dtype + self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast) + + # Tests if CastPolicy::fp32_append_dtype ops ignore double and int + # Currently, no ops belonging to this policy support integer inputs. + if ignore_type is torch.double: + with torch.autocast("cuda", enabled=False): + type_no_autocast = torch.norm(a_ignore).dtype + self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast) + + def test_autocast_custom_enabled(self): + class MyMM(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, a, b): + self.assertTrue(a.dtype is torch.float32) + self.assertTrue(b.dtype is torch.float32) + self.assertTrue(torch.is_autocast_enabled()) + ctx.save_for_backward(a, b) + return a.mm(b) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad): + self.assertTrue(torch.is_autocast_enabled()) + a, b = ctx.saved_tensors + a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad) + self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype) + return a_grad, b_grad + + mymm = MyMM.apply + + x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) + y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) + + dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,) + for dtype in dtypes: + with torch.cuda.amp.autocast(dtype=dtype): + output = mymm(x, y) + self.assertTrue(output.dtype is dtype) + loss = output.sum() + loss.backward() + + def test_autocast_custom_cast_inputs(self): + class MyMM(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) + def forward(ctx, a, container, expect_type): + b = container[1][0] + self.assertTrue(a.dtype is expect_type) + self.assertTrue(b.dtype is expect_type) + self.assertFalse(torch.is_autocast_enabled()) + ctx.save_for_backward(a, b) + return a.mm(b) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad): + self.assertFalse(torch.is_autocast_enabled()) + a, b = ctx.saved_tensors + return grad.mm(b.t()), None, None + + mymm = MyMM.apply + + x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True) + # Puts one input tensor in a nested container. y's contained Tensor won't receive a gradient, + # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments. + # Sets requires_grad=False explicitly so we don't lie about expecting a gradient. + y = ( + 0, + { + 0: torch.randn( + (8, 8), device="cuda", dtype=torch.float16, requires_grad=False + ) + }, + ) + + with torch.autocast("cuda"): + output = mymm(x, y, torch.float32) + self.assertTrue(output.dtype is torch.float32) + loss = output.sum() + loss.backward() + + # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region. + output = mymm(x, y, torch.float16) + self.assertTrue(output.dtype is torch.float16) + loss = output.sum() + loss.backward() + + def test_autocast_custom_deprecated_warning(self): + with warnings.catch_warnings(record=True) as w: + + class MyMM(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + self.assertFalse(torch.is_autocast_enabled()) + return x + y + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, grad): + _, _ = ctx.saved_tensors + self.assertFalse(torch.is_autocast_enabled()) + return grad, grad + + self.assertRegex( + str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated." + ) + self.assertRegex( + str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated." + ) + + mymm = MyMM.apply + x = torch.randn(3, 3, requires_grad=True) + y = torch.randn(3, 3, requires_grad=True) + with torch.amp.autocast("cuda"): + output = mymm(x, y) + loss = output.sum() + loss.backward() + + def test_autocast_cat_jit(self): + # Reported at https://github.com/pytorch/pytorch/issues/38958 + + class Model(torch.nn.Module): + def forward(self): + a = torch.randn(1) + b = torch.randn(1) + c = torch.cat((a, b), 0) + d = torch.stack([c, c], 0) + return d + + # The JIT here doesn't really matter, we just need to call + # cat via the boxed API + model = Model() + model_jit_script = torch.jit.script(model) + + with torch.autocast("cuda", enabled=True): + model() + model_jit_script() + + # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened) + # so they get a dedicated test. + # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100). + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_rnn(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + # seq, batch, features, hidden size + clses = ("RNN", "GRU", "LSTM") + T, B, F, H = 3, 4, 5, 6 + dtypes = (torch.float16, torch.float32) + input_layouts = ("seq_first", "batch_first", "packed") + + for ( + cls, + num_layers, + bias, + input_layout, + bidirectional, + try_nonpreflattened_weights, + input_dtype, + hidden_dtype, + weight_dtype, + ) in product( + clses, + (1, 2), + (True, False), + input_layouts, + (True, False), + (True, False), + dtypes, + dtypes, + dtypes, + ): + if input_layout == "seq_first": + batch_first = False + x = torch.randn((T, B, F), device="cuda", dtype=input_dtype) + elif input_layout == "batch_first": + batch_first = True + x = torch.randn((B, T, F), device="cuda", dtype=input_dtype) + elif input_layout == "packed": + batch_first = False + x = torch.nn.utils.rnn.pack_padded_sequence( + torch.randn((T, B, F), device="cuda", dtype=input_dtype), + lengths=(3, 2, 1, 3), + enforce_sorted=False, + ) + + rnn = ( + getattr(torch.nn, cls)( + F, + H, + num_layers=num_layers, + bidirectional=bidirectional, + bias=bias, + batch_first=batch_first, + ) + .cuda() + .to(dtype=weight_dtype) + ) + + if try_nonpreflattened_weights: + for p in rnn.parameters(): + with torch.no_grad(): + p.set_(p.clone()) + + h = torch.randn( + (num_layers * (2 if bidirectional else 1), B, H), + device="cuda", + dtype=hidden_dtype, + ) + if cls == "LSTM": + c = torch.randn( + (num_layers * (2 if bidirectional else 1), B, H), + device="cuda", + dtype=hidden_dtype, + ) + h = (h, c) + + with torch.autocast("cuda"): + out, h_out = rnn(x, h) + out = out.data if input_layout == "packed" else out + self.assertEqual(out.dtype, torch.float16) + # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee + # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has + # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed. + self.assertEqual( + out.grad_fn.name(), + "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0", + ) + out.sum().backward() + grads = [p.grad.clone() for p in rnn.parameters()] + + rnn.zero_grad() + + if cls == "LSTM": + out_control, h_out_control = rnn.to(dtype=torch.float16)( + x.half(), (h[0].half(), h[1].half()) + ) + else: + out_control, h_out_control = rnn.to(dtype=torch.float16)( + x.half(), h.half() + ) + out_control = ( + out_control.data if input_layout == "packed" else out_control + ) + out_control.sum().backward() + grads_control = [p.grad.clone() for p in rnn.parameters()] + + # Compares with default tolerances, even for FP16 execution. Barring nondeterminism, + # autocast and control results should be bitwise identical. + self.assertEqual(out, out_control) + + if cls == "LSTM": + self.assertTrue( + h_out[0].dtype is torch.float16 + and h_out[1].dtype is torch.float16 + ) + self.assertEqual(h_out[0], h_out_control[0]) + self.assertEqual(h_out[1], h_out_control[1]) + else: + self.assertEqual(h_out.dtype, torch.float16) + self.assertEqual(h_out, h_out_control) + for grad, grad_control in zip(grads, grads_control): + self.assertEqual(grad.half(), grad_control) + + def test_autocast_cache_leak(self): + # Reported at https://github.com/pytorch/pytorch/issues/48049 + # Test is used to check, if autocast recaches the same parameters + # when executed in a `torch.no_grad()` block. + + linear = torch.nn.Linear(10, 10).to("cuda") + data = torch.randn(1, 10, device="cuda") + + with torch.autocast("cuda"): + with torch.no_grad(): + out = linear(data) + first_iter_mem = torch.cuda.memory_allocated() + for _ in range(3): + out = linear(data) + self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) + + def test_autocast_checkpointing(self): + model = torch.nn.Sequential( + torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8) + ).cuda() + input = torch.rand( + (8, 8), device="cuda", dtype=torch.float16, requires_grad=True + ) + for reentrant in (True, False): + with torch.autocast("cuda"): + output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant) + self.assertTrue(output.requires_grad) + self.assertTrue(output.dtype is torch.float16) + output.sum().backward() + + def test_cuda_autocast_deprecated_warning(self): + with self.assertWarnsRegex( + FutureWarning, + r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.", + ): + with torch.cuda.amp.autocast(): + _ = torch.ones(10) + + instantiate_parametrized_tests(TestCuda) instantiate_parametrized_tests(TestCudaMallocAsync) instantiate_device_type_tests(TestCudaOptims, globals()) diff --git a/test/test_xpu.py b/test/test_xpu.py index 9dde7d8a71cf3f..e77a1e758f0c9f 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,6 +1,5 @@ # Owner(s): ["module: intel"] -import collections import subprocess import sys import tempfile @@ -8,7 +7,7 @@ import torch import torch.xpu._gpu_trace as gpu_trace -from torch.testing._internal.autocast_test_lists import AutocastTestLists +from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyXPU, @@ -371,7 +370,7 @@ def test_serialization_array_with_empty(self): instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True) -class TestXpuAutocast(TestCase): +class TestXpuAutocast(TestAutocast): # These operators are not implemented on XPU backend and we can NOT fall back # them to CPU. So we have to skip them at this moment. # TODO: remove these operators from skip list when they are implemented on XPU backend. @@ -385,89 +384,6 @@ def tearDown(self): del self.autocast_lists super().tearDown() - def _run_autocast_outofplace( - self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None - ): - # helper to cast args - def cast(val, to_type): - if isinstance(val, torch.Tensor): - return val.to(to_type) if val.is_floating_point() else val - elif isinstance(val, collections.abc.Iterable): - return type(val)(cast(v, to_type) for v in val) - else: - return val - - if add_kwargs is None: - add_kwargs = {} - fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16 - self.assertFalse(torch.is_autocast_enabled("xpu")) - with torch.amp.autocast("xpu", dtype=fast_dtype): - self.assertTrue(torch.is_autocast_enabled("xpu")) - - out_type = out_type if out_type is not None else run_as_type - output = output_method = None - - # Try module.* variant, if requested: - if module is not None and hasattr(module, op): - output = getattr(module, op)(*args, **add_kwargs) - if isinstance(output, torch.Tensor): - self.assertTrue( - out_type == output.dtype, - f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", - ) - - # Try Tensor.* variant: - if hasattr(torch.Tensor, op): - output_method = getattr(args[0], op)(*args[1:], **add_kwargs) - if isinstance(output_method, torch.Tensor): - self.assertTrue( - out_type == output_method.dtype, - f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", - ) - - self.assertTrue( - (output is not None) or (output_method is not None), - f"{op} not found as an attribute on either Tensor or the requested module {module}", - ) - - # Accounts for ops that return Tensors, iterables, and other non-Tensors. - # For example, lstm_cell returns a tuple and equal returns bool. - def compare(first, second): - if isinstance(first, torch.Tensor): - return torch.equal(first, second) - elif isinstance(first, collections.abc.Iterable): - return all(compare(f, s) for f, s in zip(first, second)) - else: - return first == second - - # If both torch.* and Tensor.* variants were found, check outputs are identical - if (output is not None) and (output_method is not None): - self.assertTrue(type(output) == type(output_method)) - comparison = compare(output, output_method) - self.assertTrue( - comparison, f"torch.{op} result did not match Tensor.{op} result" - ) - - # Compare numerics to Python-side "autocasting" that (we expect) does the same thing - # as the C++-side autocasting, and should be bitwise accurate. - output_to_compare = output if output is not None else output_method - with torch.amp.autocast("xpu", enabled=False): - self.assertFalse(torch.is_autocast_enabled("xpu")) - - if module is not None and hasattr(module, op): - control = getattr(module, op)( - *cast(args, run_as_type), **add_kwargs - ) - else: - control = getattr(args[0].to(run_as_type), op)( - *cast(args[1:], run_as_type), **add_kwargs - ) - self.assertTrue(type(output_to_compare) == type(control)) - comparison = compare(output_to_compare, control) - self.assertTrue(comparison, f"torch.{op} result did not match control") - self.assertTrue(torch.is_autocast_enabled("xpu")) - self.assertFalse(torch.is_autocast_enabled("xpu")) - def test_autocast_torch_fp16(self): for op_with_args in self.autocast_lists.torch_fp16: skip_test = False @@ -477,7 +393,9 @@ def test_autocast_torch_fp16(self): if len(op_with_args) == 3: skip_test = True # skip cudnn op if not skip_test: - self._run_autocast_outofplace(op, args, torch.float16) + self._run_autocast_outofplace( + op, args, torch.float16, device="xpu", amp_dtype=torch.float16 + ) def test_autocast_torch_bf16(self): for op_with_args in self.autocast_lists.torch_fp16: @@ -488,15 +406,24 @@ def test_autocast_torch_bf16(self): if len(op_with_args) == 3: skip_test = True # skip cudnn op if not skip_test: - self._run_autocast_outofplace(op, args, torch.bfloat16) + self._run_autocast_outofplace(op, args, torch.bfloat16, device="xpu") def test_autocast_torch_need_autocast_promote(self): for op, args in self.autocast_lists.torch_need_autocast_promote: - self._run_autocast_outofplace(op, args, torch.float32) + self._run_autocast_outofplace( + op, args, torch.float32, device="xpu", amp_dtype=torch.float16 + ) def test_autocast_torch_expect_builtin_promote(self): for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: - self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="xpu", + out_type=out_type, + amp_dtype=torch.float16, + ) def test_autocast_checkpointing(self): model = torch.nn.Sequential( diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index 8527084f4afa82..c9789f19c3dcff 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -1,7 +1,10 @@ # mypy: ignore-errors +import collections + import torch from torch.testing._internal.common_utils import TEST_WITH_ROCM +from torch.testing._internal.common_utils import TestCase class AutocastTestLists: @@ -234,6 +237,7 @@ def __init__(self, dev): torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn), ] + class AutocastCPUTestLists: # Supplies ops and arguments for test_autocast_* in test/test_cpu.py def __init__(self, dev): @@ -368,3 +372,103 @@ def __init__(self, dev): ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), ] + + +class TestAutocast(TestCase): + def args_maybe_kwargs(self, op_with_args): + if len(op_with_args) == 2: + return op_with_args[0], op_with_args[1], {} + else: + return op_with_args[0], op_with_args[1], op_with_args[2] + + def _run_autocast_outofplace( + self, + op, + args, + run_as_type, + device, + out_type=None, + module=torch, + add_kwargs=None, + amp_dtype=torch.bfloat16, + ): + # helper to cast args + def cast(val, to_type): + if isinstance(val, torch.Tensor): + return val.to(to_type) if val.is_floating_point() else val + elif isinstance(val, collections.abc.Iterable): + return type(val)(cast(v, to_type) for v in val) + else: + return val + + if add_kwargs is None: + add_kwargs = {} + + self.assertFalse(torch.is_autocast_enabled(device_type=device)) + with torch.amp.autocast(device_type=device, dtype=amp_dtype): + self.assertTrue(torch.is_autocast_enabled(device_type=device)) + + out_type = out_type if out_type is not None else run_as_type + output = output_method = None + + # Try module.* variant, if requested: + if module is not None and hasattr(module, op): + output = getattr(module, op)(*args, **add_kwargs) + if isinstance(output, torch.Tensor): + self.assertTrue( + out_type == output.dtype, + f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", + ) + # Try Tensor.* variant: + if hasattr(torch.Tensor, op): + output_method = getattr(args[0], op)(*args[1:], **add_kwargs) + if isinstance(output_method, torch.Tensor): + self.assertTrue( + out_type == output_method.dtype, + f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", + ) + + self.assertTrue( + (output is not None) or (output_method is not None), + f"{op} not found as an attribute on either Tensor or the requested module {module}", + ) + + # Accounts for ops that return Tensors, iterables, and other non-Tensors. + # For example, lstm_cell returns a tuple and equal returns bool. + def compare(first, second): + if isinstance(first, torch.Tensor): + return torch.equal(first, second) + elif isinstance(first, collections.abc.Iterable): + return all(compare(f, s) for f, s in zip(first, second)) + else: + return first == second + + # If both torch.* and Tensor.* variants were found, check outputs are identical + if (output is not None) and (output_method is not None): + self.assertTrue(type(output) == type(output_method)) + comparison = compare(output, output_method) + self.assertTrue( + comparison, f"torch.{op} result did not match Tensor.{op} result" + ) + + # Compare numerics to Python-side "autocasting" that (we expect) does the same thing + # as the C++-side autocasting, and should be bitwise accurate. + output_to_compare = output if output is not None else output_method + with torch.amp.autocast(device_type=device, enabled=False): + self.assertFalse( + torch.is_autocast_enabled(device_type=device) + ) + + if module is not None and hasattr(module, op): + control = getattr(module, op)( + *cast(args, run_as_type), **add_kwargs + ) + else: + control = getattr(args[0].to(run_as_type), op)( + *cast(args[1:], run_as_type), **add_kwargs + ) + self.assertTrue(type(output_to_compare) == type(control)) + comparison = compare(output_to_compare, control) + self.assertTrue(comparison, f"torch.{op} result did not match control") + self.assertTrue(torch.is_autocast_enabled(device_type=device)) + self.assertFalse(torch.is_autocast_enabled(device_type=device)) From ff13a3affe38ac3cafafa2e6e1519f1a94252efc Mon Sep 17 00:00:00 2001 From: FFFrog Date: Wed, 4 Sep 2024 08:55:24 +0000 Subject: [PATCH 0395/1018] Remove specific lazy initialization of PrivateUse1 (#135002) As the title stated, lazy initialization of PrivateUse1 can been removed because maybe_initialize_device have supported PrivateUse1 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135002 Approved by: https://github.com/albanD --- .../templates/python_variable_methods.cpp | 6 --- torch/csrc/Storage.cpp | 54 ++++++++++--------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index fab57c25ac9c2d..16c3b9e5efd6a6 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -999,9 +999,6 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg auto opt_memory_format = std::get<4>(parsed); auto& self_ = THPVariable_Unpack(self); torch::utils::maybe_initialize_device(device); - if (device && device->is_privateuseone()) { - at::globalContext().lazyInitPrivateUse1(); - } if (!device && !scalarType && !copy && !opt_memory_format.has_value()) { Py_INCREF(self); return self; @@ -1081,9 +1078,6 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa device = at::Device(device_type); } torch::utils::maybe_initialize_device(device); - if (device.is_privateuseone()) { - at::globalContext().lazyInitPrivateUse1(); - } return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format)); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 77520b6f1cdb1f..5e029fb0b6bd88 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -334,37 +335,38 @@ static PyObject* THPStorage_pynew( allocator = reinterpret_cast(allocator_opt.value()); } else if (device_opt.has_value()) { at::Device device = device_opt.value(); - if (device.type() == at::kCPU) { - allocator = c10::GetDefaultCPUAllocator(); + torch::utils::maybe_initialize_device(device); + + switch (device.type()) { + case at::kCPU: + allocator = c10::GetDefaultCPUAllocator(); + break; #ifdef USE_CUDA - } else if (device.type() == at::kCUDA) { - at::globalContext().lazyInitCUDA(); - allocator = c10::cuda::CUDACachingAllocator::get(); + case at::kCUDA: + allocator = c10::cuda::CUDACachingAllocator::get(); + break; #endif #ifdef USE_MPS - } else if (device.type() == at::kMPS) { - allocator = at::mps::GetMPSAllocator(); + case at::kMPS: + allocator = at::mps::GetMPSAllocator(); + break; #endif - // NOLINTBEGIN(bugprone-branch-clone) - } else if (device.type() == at::DeviceType::XPU) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::HPU) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::Meta) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::PrivateUse1) { - at::globalContext().lazyInitPrivateUse1(); - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::MAIA) { - allocator = c10::GetAllocator(device.type()); - } else { - // NOLINTEND(bugprone-branch-clone) - TORCH_CHECK( - false, - THPStorageStr, - "(): Storage device not recognized: ", - device.type()); + case at::DeviceType::XPU: + case at::DeviceType::HPU: + case at::DeviceType::Meta: + case at::DeviceType::PrivateUse1: + case at::DeviceType::MAIA: + allocator = c10::GetAllocator(device.type()); + break; + default: + // NOLINTEND(bugprone-branch-clone) + TORCH_CHECK( + false, + THPStorageStr, + "(): Storage device not recognized: ", + device.type()); } + device_guard.reset_device(device); } else { allocator = c10::GetDefaultCPUAllocator(); From b252ed7a89c643783e6436c8eaad2ab141efa025 Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Wed, 4 Sep 2024 12:57:10 +0000 Subject: [PATCH 0396/1018] Fix stale job using non-existant ARC runner (#134863) The ARC CI system has been shutdown so this job is currently using a runner that doesn't exist. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134863 Approved by: https://github.com/ZainRizvi --- .github/actionlint.yaml | 2 -- .github/workflows/stale.yml | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index f28b57dcf560c3..bc83f0c32ee785 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -3,8 +3,6 @@ self-hosted-runner: # GitHub hosted x86 Linux runners - linux.20_04.4x - linux.20_04.16x - # Repo-specific LF hosted ARC runners - - linux.large.arc # Organization-wide AWS Linux Runners - linux.large - linux.2xlarge diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 56e349dfa1b826..047e4a47ab97aa 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -21,7 +21,7 @@ on: jobs: stale: if: ${{ github.repository == 'pytorch/pytorch' }} - runs-on: linux.large.arc + runs-on: linux.large permissions: contents: read pull-requests: write From 544594bcc73f09ed696b3e9b9ce82df3bc5b57d7 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Wed, 4 Sep 2024 07:49:35 +0000 Subject: [PATCH 0397/1018] C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED and C10_DIAGNOST should be used in pairs (#135004) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135004 Approved by: https://github.com/aaronenyeshi --- aten/src/ATen/detail/AcceleratorHooksInterface.h | 2 +- aten/src/ATen/detail/MPSHooksInterface.h | 2 +- aten/src/ATen/detail/XPUHooksInterface.h | 2 +- aten/src/ATen/native/mkl/LinearAlgebra.cpp | 2 +- c10/core/SymNodeImpl.h | 3 +-- torch/csrc/profiler/stubs/itt.cpp | 10 +++------- 6 files changed, 8 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/detail/AcceleratorHooksInterface.h b/aten/src/ATen/detail/AcceleratorHooksInterface.h index 7eefdfc7269cb6..61409db3ac6807 100644 --- a/aten/src/ATen/detail/AcceleratorHooksInterface.h +++ b/aten/src/ATen/detail/AcceleratorHooksInterface.h @@ -53,4 +53,4 @@ struct TORCH_API AcceleratorHooksInterface { }; } // namespace at -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 0a9004d353424d..180ff68588edd7 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -114,4 +114,4 @@ TORCH_API const MPSHooksInterface& getMPSHooks(); } // namespace detail } // namespace at -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/detail/XPUHooksInterface.h b/aten/src/ATen/detail/XPUHooksInterface.h index 1b7b7f99e46df2..f4cd9a34b5752b 100644 --- a/aten/src/ATen/detail/XPUHooksInterface.h +++ b/aten/src/ATen/detail/XPUHooksInterface.h @@ -81,4 +81,4 @@ namespace detail { TORCH_API const XPUHooksInterface& getXPUHooks(); } // namespace detail } // namespace at -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/native/mkl/LinearAlgebra.cpp b/aten/src/ATen/native/mkl/LinearAlgebra.cpp index 582de5f8213a03..b64ccfeb03feae 100644 --- a/aten/src/ATen/native/mkl/LinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/LinearAlgebra.cpp @@ -154,4 +154,4 @@ void mkl_gemm_f16f16f32( }} // namespace at::native #endif -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 0847af2fce5340..a7ab26a24804fb 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -239,5 +239,4 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { }; } // namespace c10 - -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/torch/csrc/profiler/stubs/itt.cpp b/torch/csrc/profiler/stubs/itt.cpp index e131e19adc4f4f..89d7c32657c278 100644 --- a/torch/csrc/profiler/stubs/itt.cpp +++ b/torch/csrc/profiler/stubs/itt.cpp @@ -4,9 +4,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") -namespace torch { -namespace profiler { -namespace impl { +namespace torch::profiler::impl { namespace { struct ITTMethods : public ProfilerStubs { @@ -51,7 +49,5 @@ struct RegisterITTMethods { RegisterITTMethods reg; } // namespace -} // namespace impl -} // namespace profiler -} // namespace torch -C10_CLANG_DIAGNOSTIC_POP() +} // namespace torch::profiler::impl +C10_DIAGNOSTIC_POP() From 76ec60ef5b58c9a090954ddaeb00226fba25899b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 3 Sep 2024 18:56:31 -0700 Subject: [PATCH 0398/1018] Compute and do renamings even when ignoring fresh unbacked symbols (#134407) This is a bit twisty and I don't entirely understand the situation, but here's my best explanation. In https://github.com/pytorch/pytorch/pull/133588 I am trying to fix a problem reported by user in https://fb.workplace.com/groups/6829516587176185/permalink/7705964779531357/ The summary of this problem is that when we do collect metadata analysis in AOTAutograd, we accumulate pending unbacked symbols which are going to be discarded at the end of the trace. However, if we do a recursive make_fx inside tracing, as occurs with torch.cond, we end up seeing that there are pending unbacked symbols that aren't associated with a binding, even though it's spurious (they've leaked into the inner make_fx call from the outer AOTAutograd analysis). In #133588 I tried to just prevent adding the symbols to the pending list at all in the first place. But this itself caused some problems which were fixed in https://github.com/pytorch/pytorch/pull/124785 . The problem fixed in that PR is that when we allocate tangents that have unbacked size, something prevented them from having correct unbacked SymInts when ignore fresh unbacked SymInts was enabled. So I had patched it at the time by just not suppressing pending symbols and clearing them out some other way. I think... I was wrong in that PR? That is to say, it was OK to avoid putting the fresh unbacked symbols in the pending list; the real problem was suppressing unbacked renamings. But there doesn't seem to be a good reason to suppress these; this PR shows that it doesn't actually fail any tests if you do these anyway. Intuitively, this makes sense, because you can't trigger renamings unless you're actually adding unbacked symbols to the pending set. But I don't entirely understand all the interactions. I just know that this seems to not cause tests to fail, and it should fix the internal issue (which I need to add a UT for.) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/134407 Approved by: https://github.com/ydwu4 --- torch/fx/experimental/symbolic_shapes.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index bc6645c617a479..b50b55fa47302b 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -597,8 +597,6 @@ def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, """ if shape_env is None: return - if shape_env._ignore_fresh_unbacked_symbols_tls(): - return fs = shape_env.pending_fresh_unbacked_symbols pending = set(fs) if pending: @@ -2722,8 +2720,6 @@ def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol): assert isinstance(new_s, sympy.Symbol), new_s assert free_unbacked_symbols(new_s), new_s assert free_unbacked_symbols(orig_s), orig_s - if self._ignore_fresh_unbacked_symbols_tls(): - return dest = self.replacements.get(orig_s) assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" self._set_replacement(orig_s, new_s, "rename_unbacked_to") From 3f3c83836f7b3cbb8aeb01c9b325faaedcbeb807 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 3 Sep 2024 18:56:32 -0700 Subject: [PATCH 0399/1018] Ignore fresh unbacked when doing recursive make_fx inside HOPs (#135053) Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7705964779531357/ I'm not sure this is the right approach though... Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135053 Approved by: https://github.com/ydwu4 ghstack dependencies: #134407 --- test/export/test_export.py | 21 +++++++++++++++++++ .../collect_metadata_analysis.py | 10 ++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 887be93dbf75f4..0099eda3fa9a08 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -686,6 +686,27 @@ def false_fn(x): M()(torch.randn(7)) torch.export.export(M(), (torch.randn(7),)) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_cond_contains_unbacked_no_escape(self): + class M(torch.nn.Module): + def forward(self, a, b1, b2, c): + def true_fn(x): + return x * b1.item() + + def false_fn(x): + return x * b2.item() + + r = torch.cond(a, true_fn, false_fn, (c,)) + return r * 2 + + args = ( + torch.tensor(True), + torch.tensor([4]), + torch.tensor([4]), + torch.randn(10, requires_grad=True), + ) + torch.export.export(M(), args) + def test_state_tensors(self): class M(torch.nn.Module): # simple with register buffer def __init__(self) -> None: diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 820c7d288cf2c6..51a0aeb24ad496 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -9,6 +9,7 @@ """ import collections +import contextlib import logging from functools import wraps from typing import Callable, DefaultDict, Dict, List, Optional @@ -162,15 +163,18 @@ def inner(*flat_args): # It doesn't matter if we run this under predispatch or not because it is # only for figuring out metadata mode = FunctionalTensorMode(_allow_token_discovery=True) - with disable_above, mode: + suppress_pending = contextlib.nullcontext() + fake_mode = detect_fake_mode() + if fake_mode and (shape_env := fake_mode.shape_env): + suppress_pending = shape_env.ignore_fresh_unbacked_symbols() + with disable_above, mode, suppress_pending: # precondition: The passed in function already handles unflattening inputs + flattening outputs flat_f_args = pytree.tree_map(_to_fun, flat_args) flat_f_outs = f(*flat_f_args) # We didn't do any tracing, so we don't need to process the # unbacked symbols, they will just disappear into the ether. # Also, prevent memoization from applying. - if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env): - shape_env.pending_fresh_unbacked_symbols.clear() + if fake_mode: fake_mode.epoch += 1 fake_mode.reset_nt_tensor_id_counter() From b8bb933f138ee6f1e4d153de1fc0516869e96138 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 3 Sep 2024 13:48:02 -0700 Subject: [PATCH 0400/1018] Fix Invalid NaN comparison due to infinity-zero multiply on latest sympy (#135044) Fixes https://github.com/pytorch/pytorch/issues/133735 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135044 Approved by: https://github.com/zou3519 --- test/test_dynamic_shapes.py | 14 +++++++++++--- torch/utils/_sympy/value_ranges.py | 6 +++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 89d8686219eaed..8b1c1c109ddba4 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -191,7 +191,7 @@ def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None) ) -def create_symtype(cls, pytype, shape_env, val, duck=True): +def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs): from torch._dynamo.source import ConstantSource symbol = shape_env.create_symbol( @@ -199,13 +199,14 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"), dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC, constraint_dim=None, + **kwargs, ) return cls(SymNode(symbol, shape_env, pytype, hint=val)) # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True) -> SymInt: - return create_symtype(SymInt, int, shape_env, i, duck=duck) +def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt: + return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs) def create_symbool(shape_env, b: bool) -> SymBool: @@ -779,6 +780,13 @@ def test_guard_refine_range(self): self.assertTrue(statically_known_true(i0 >= 4)) self.assertTrue(statically_known_true(i0 > 3)) + def test_mul_int_oo_nan(self): + shape_env = ShapeEnv() + s0 = create_symint(shape_env, 5, duck=False) + s1 = create_symint(shape_env, 6, duck=False) + s2 = create_symint(shape_env, 5, duck=False) + bool(s0 * (s1 // s0) == s2) + def test_non_overlapping_and_dense(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 57d509323e01f4..9080800979e67b 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -552,9 +552,9 @@ def mul(cls, a, b): def safe_mul(a, b): # Make unknown() * wrap(0.0) == wrap(0.0) - if a == 0.0: + if a == 0.0 or a == 0: return a - elif b == 0.0: + elif b == 0.0 or b == 0: return b else: return a * b @@ -590,7 +590,7 @@ def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b: - return ValueRanges.unknown() + return ValueRanges.unknown_int() products = [] for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]): r = FloorDiv(x, y) From b2090308ed29eb33c935681f0ba4f02bba93d848 Mon Sep 17 00:00:00 2001 From: chuanqiw Date: Wed, 4 Sep 2024 14:46:36 +0000 Subject: [PATCH 0401/1018] [CD] Enable XPU nightly build on Windows (#134312) Depends on https://github.com/pytorch/builder/pull/1975 land. Works for https://github.com/pytorch/pytorch/issues/114850 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134312 Approved by: https://github.com/atalman --- .circleci/scripts/binary_windows_build.sh | 5 + .circleci/scripts/binary_windows_test.sh | 4 + .../scripts/generate_binary_build_matrix.py | 4 +- ...generated-windows-binary-wheel-nightly.yml | 1270 +++++++++++++++-- 4 files changed, 1136 insertions(+), 147 deletions(-) diff --git a/.circleci/scripts/binary_windows_build.sh b/.circleci/scripts/binary_windows_build.sh index 0f8eeb8cebe05d..d62ce7d77efe44 100644 --- a/.circleci/scripts/binary_windows_build.sh +++ b/.circleci/scripts/binary_windows_build.sh @@ -10,6 +10,11 @@ export SCCACHE_BUCKET=ossci-compiler-cache export SCCACHE_IGNORE_SERVER_IO_ERROR=1 export VC_YEAR=2019 +if [[ "$DESIRED_CUDA" == 'xpu' ]]; then + export VC_YEAR=2022 + export USE_SCCACHE=0 +fi + echo "Free space on filesystem before build:" df -h diff --git a/.circleci/scripts/binary_windows_test.sh b/.circleci/scripts/binary_windows_test.sh index bbf0efbb5e52ff..b9f801fb0c50f4 100644 --- a/.circleci/scripts/binary_windows_test.sh +++ b/.circleci/scripts/binary_windows_test.sh @@ -6,6 +6,10 @@ source "${BINARY_ENV_FILE:-/c/w/env}" export CUDA_VERSION="${DESIRED_CUDA/cu/}" export VC_YEAR=2019 +if [[ "$DESIRED_CUDA" == 'xpu' ]]; then + export VC_YEAR=2022 +fi + pushd "$BUILDER_ROOT" ./windows/internal/smoke_test.bat diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 9dc8ebed3c09cd..6b33924a0e5546 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -340,7 +340,7 @@ def generate_wheels_matrix( if os == "linux": arches += CPU_CXX11_ABI_ARCH + CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES elif os == "windows": - arches += CUDA_ARCHES + arches += CUDA_ARCHES + XPU_ARCHES elif os == "linux-aarch64": # Only want the one arch as the CPU type is different and # uses different build/test scripts @@ -462,7 +462,7 @@ def generate_wheels_matrix( ), "pytorch_extra_install_requirements": ( PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.1"] # fmt: skip - if os != "linux" + if os != "linux" and gpu_arch_type != "xpu" else "" ), } diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index c042b16be7a893..9da2ad1cb75842 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -1033,6 +1033,251 @@ jobs: conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml + wheel-py3_9-xpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v3 + if: always() + with: + name: wheel-py3_9-xpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_9-xpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_9-xpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v3 + name: Download Build Artifacts + with: + name: wheel-py3_9-xpu + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_9-xpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_9-xpu-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DESIRED_PYTHON: "3.9" + build_name: wheel-py3_9-xpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml wheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type @@ -1891,10 +2136,750 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_4-test: # Testing + wheel-py3_10-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_10-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.10" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_10-cuda12_4 + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_10-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_10-cuda12_4-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.10" + build_name: wheel-py3_10-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_10-xpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.10" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v3 + if: always() + with: + name: wheel-py3_10-xpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_10-xpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_10-xpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.10" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v3 + name: Download Build Artifacts + with: + name: wheel-py3_10-xpu + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_10-xpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_10-xpu-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DESIRED_PYTHON: "3.10" + build_name: wheel-py3_10-xpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_11-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.11" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v3 + if: always() + with: + name: wheel-py3_11-cpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_11-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_11-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.11" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_11-cpu + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_11-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_11-cpu-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DESIRED_PYTHON: "3.11" + build_name: wheel-py3_11-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_11-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.11" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v3 + if: always() + with: + name: wheel-py3_11-cuda11_8 + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_11-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_10-cuda12_4-build + - wheel-py3_11-cuda11_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 @@ -1904,11 +2889,11 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.10" + DESIRED_PYTHON: "3.11" steps: - name: Display EC2 information shell: bash @@ -1956,7 +2941,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda12_4 + name: wheel-py3_11-cuda11_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2004,29 +2989,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_4-upload: # Uploading + wheel-py3_11-cuda11_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda12_4-test + needs: wheel-py3_11-cuda11_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda12_4 + DESIRED_PYTHON: "3.11" + build_name: wheel-py3_11-cuda11_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cpu-build: + wheel-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -2037,8 +3022,9 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2122,7 +3108,7 @@ jobs: - uses: actions/upload-artifact@v3 if: always() with: - name: wheel-py3_11-cpu + name: wheel-py3_11-cuda12_1 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2139,12 +3125,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cpu-test: # Testing + wheel-py3_11-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cpu-build + - wheel-py3_11-cuda12_1-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2152,8 +3138,9 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" steps: @@ -2203,7 +3190,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cpu + name: wheel-py3_11-cuda12_1 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2251,28 +3238,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cpu-upload: # Uploading + wheel-py3_11-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cpu-test + needs: wheel-py3_11-cuda12_1-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cpu + build_name: wheel-py3_11-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda11_8-build: + wheel-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -2283,8 +3271,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2369,7 +3357,7 @@ jobs: - uses: actions/upload-artifact@v3 if: always() with: - name: wheel-py3_11-cuda11_8 + name: wheel-py3_11-cuda12_4 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2386,10 +3374,10 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda11_8-test: # Testing + wheel-py3_11-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cuda11_8-build + - wheel-py3_11-cuda12_4-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 @@ -2399,8 +3387,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2451,7 +3439,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda11_8 + name: wheel-py3_11-cuda12_4 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2499,29 +3487,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda11_8-upload: # Uploading + wheel-py3_11-cuda12_4-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda11_8-test + needs: wheel-py3_11-cuda12_4-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda11_8 + build_name: wheel-py3_11-cuda12_4 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda12_1-build: + wheel-py3_11-xpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -2532,12 +3520,10 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2618,7 +3604,7 @@ jobs: - uses: actions/upload-artifact@v3 if: always() with: - name: wheel-py3_11-cuda12_1 + name: wheel-py3_11-xpu retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2635,12 +3621,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_1-test: # Testing + wheel-py3_11-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cuda12_1-build + - wheel-py3_11-xpu-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2648,9 +3634,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" steps: @@ -2697,10 +3682,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v4.1.7 + - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: wheel-py3_11-cuda12_1 + name: wheel-py3_11-xpu path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2748,29 +3733,28 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_1-upload: # Uploading + wheel-py3_11-xpu-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda12_1-test + needs: wheel-py3_11-xpu-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda12_1 + build_name: wheel-py3_11-xpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda12_4-build: + wheel-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -2781,11 +3765,10 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.11" + DESIRED_PYTHON: "3.12" PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information @@ -2867,7 +3850,7 @@ jobs: - uses: actions/upload-artifact@v3 if: always() with: - name: wheel-py3_11-cuda12_4 + name: wheel-py3_12-cpu retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2884,12 +3867,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_4-test: # Testing + wheel-py3_12-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cuda12_4-build + - wheel-py3_12-cpu-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2897,11 +3880,10 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.11" + DESIRED_PYTHON: "3.12" steps: - name: Display EC2 information shell: bash @@ -2949,7 +3931,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda12_4 + name: wheel-py3_12-cpu path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2997,29 +3979,28 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_4-upload: # Uploading + wheel-py3_12-cpu-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda12_4-test + needs: wheel-py3_12-cpu-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda12_4 + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DESIRED_PYTHON: "3.12" + build_name: wheel-py3_12-cpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cpu-build: + wheel-py3_12-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -3030,8 +4011,9 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3115,7 +4097,7 @@ jobs: - uses: actions/upload-artifact@v3 if: always() with: - name: wheel-py3_12-cpu + name: wheel-py3_12-cuda11_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3132,12 +4114,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cpu-test: # Testing + wheel-py3_12-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cpu-build + - wheel-py3_12-cuda11_8-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3145,8 +4127,9 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" steps: @@ -3196,7 +4179,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cpu + name: wheel-py3_12-cuda11_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -3244,28 +4227,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cpu-upload: # Uploading + wheel-py3_12-cuda11_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cpu-test + needs: wheel-py3_12-cuda11_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cpu + build_name: wheel-py3_12-cuda11_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda11_8-build: + wheel-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -3276,8 +4260,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3362,7 +4346,7 @@ jobs: - uses: actions/upload-artifact@v3 if: always() with: - name: wheel-py3_12-cuda11_8 + name: wheel-py3_12-cuda12_1 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3379,10 +4363,10 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda11_8-test: # Testing + wheel-py3_12-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cuda11_8-build + - wheel-py3_12-cuda12_1-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 @@ -3392,8 +4376,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3444,7 +4428,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cuda11_8 + name: wheel-py3_12-cuda12_1 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -3492,29 +4476,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda11_8-upload: # Uploading + wheel-py3_12-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cuda11_8-test + needs: wheel-py3_12-cuda12_1-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cuda11_8 + build_name: wheel-py3_12-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda12_1-build: + wheel-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -3525,8 +4509,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3611,7 +4595,7 @@ jobs: - uses: actions/upload-artifact@v3 if: always() with: - name: wheel-py3_12-cuda12_1 + name: wheel-py3_12-cuda12_4 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3628,10 +4612,10 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_1-test: # Testing + wheel-py3_12-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cuda12_1-build + - wheel-py3_12-cuda12_4-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 @@ -3641,8 +4625,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3693,7 +4677,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cuda12_1 + name: wheel-py3_12-cuda12_4 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -3741,29 +4725,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_1-upload: # Uploading + wheel-py3_12-cuda12_4-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cuda12_1-test + needs: wheel-py3_12-cuda12_4-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cuda12_1 + build_name: wheel-py3_12-cuda12_4 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda12_4-build: + wheel-py3_12-xpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -3774,12 +4758,10 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3860,7 +4842,7 @@ jobs: - uses: actions/upload-artifact@v3 if: always() with: - name: wheel-py3_12-cuda12_4 + name: wheel-py3_12-xpu retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3877,12 +4859,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_4-test: # Testing + wheel-py3_12-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cuda12_4-build + - wheel-py3_12-xpu-build - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3890,9 +4872,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" steps: @@ -3939,10 +4920,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v4.1.7 + - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: wheel-py3_12-cuda12_4 + name: wheel-py3_12-xpu path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -3990,23 +4971,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_4-upload: # Uploading + wheel-py3_12-xpu-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cuda12_4-test + needs: wheel-py3_12-xpu-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cuda12_4 + build_name: wheel-py3_12-xpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} From 1357b5e881bebdb76f5bfaa42a80b1413bd4fe7f Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Wed, 4 Sep 2024 14:56:49 +0000 Subject: [PATCH 0402/1018] fast path for sympy gcd in floordiv (#134880) Summary: Re-implementation of https://github.com/pytorch/pytorch/pull/134150, which was reverted because of some internal tests hanging (case B). The original motivation was to get some other internal test unstuck (case A). The root cause is that sympy.gcd is both very clever as well as can blow up in some cases. This PR introduces a fast path with an appropriate fallback to sympy.gcd that ensures that both cases A and B go through. Test Plan: See the included test for specific examples. Also https://fb.workplace.com/groups/1075192433118967/posts/1491493248155548/?comment_id=1491938994777640&reply_comment_id=1492622821375924 Differential Revision: D62043315 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134880 Approved by: https://github.com/ezyang --- test/test_sympy_utils.py | 20 ++++++++++++++- torch/utils/_sympy/functions.py | 44 ++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 0eb4be644ede65..6be74e6493ba87 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -14,7 +14,7 @@ run_tests, TestCase, ) -from torch.utils._sympy.functions import FloorDiv +from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis @@ -667,6 +667,24 @@ def test_z3_proof_floordiv_eq_simplify(self): r = solver.check() self.assertEqual(r, z3.unsat) + def test_simple_floordiv_gcd(self): + x, y, z = sympy.symbols("x y z") + + # positive tests + self.assertEqual(simple_floordiv_gcd(x, x), x) + self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128) + self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128) + self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128) + self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x) + self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x) + self.assertEqual(simple_floordiv_gcd(x * y, x), x) + self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y) + self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x) + + # negative tests + self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1) + + class TestSingletonInt(TestCase): def test_basic(self): j1 = SingletonInt(1, coeff=1) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index f9602dbd63976f..4924fff6dea6be 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -95,6 +95,46 @@ def fuzzy_eq(x, y): return x == y +def simple_floordiv_gcd(p, q): + """ + Fast path for sympy.gcd, using a simple factoring strategy. + + We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0, + where n is the greatest common integer factor and e is the largest + syntactic common factor (i.e., common sub-expression) in p and q. + Then the gcd returned is n*e, cancelling which we would be left with + p1 + p2 and q0. + + Note that further factoring of p1 + p2 and q0 might be possible with + sympy.factor (which uses domain-specific theories). E.g., we are unable + to find that x*y + x + y + 1 is divisible by x + 1. More generally, + when q is of the form q1 + q2 (instead of being already factored) it + might be necessary to fall back on sympy.gcd. + """ + + def integer_coefficient(x): + integer_coefficients = [ + abs(int(arg)) + for arg in sympy.Mul.make_args(x) + if isinstance(arg, (int, sympy.Integer)) + ] + return math.prod(integer_coefficients) + + def integer_factor(expr): + integer_factors = map(integer_coefficient, sympy.Add.make_args(expr)) + return functools.reduce(math.gcd, integer_factors) + + gcd = math.gcd(integer_factor(p), integer_factor(q)) + p, q = p / gcd, q / gcd + + base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p))) + divisor_split = sympy.Mul.make_args(q) + for x in divisor_split: + if all(x in base_split for base_split in base_splits): + gcd = gcd * x + return gcd + + # It would be nice to have assertions on whether or not inputs is_integer # However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy # sometimes inconsistently reports floats an integers. @@ -202,7 +242,9 @@ def eval(cls, base, divisor): return FloorDiv(base - term, divisor) + quotient try: - gcd = sympy.gcd(base, divisor) + gcd = simple_floordiv_gcd(base, divisor) + if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add): + gcd = sympy.gcd(base, divisor) if not equal_valued(gcd, 1): return FloorDiv( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) From aaca71e614479d07132dcb828b8bae74a7f6e77e Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 4 Sep 2024 15:26:17 +0000 Subject: [PATCH 0403/1018] Fix lint after Bump actions/download-artifact update (#135109) Fixes lint after auto-generated PR: https://github.com/pytorch/pytorch/commit/367a78495f0b401bbc0525fc3d4bb3e885436848 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135109 Approved by: https://github.com/ezyang, https://github.com/huydhn --- .github/templates/common.yml.j2 | 2 +- .../workflows/generated-windows-binary-wheel-nightly.yml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index 38b90e919c3341..1bbb339129212f 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -1,7 +1,7 @@ {%- set upload_artifact_s3_action = "seemethere/upload-artifact-s3@v5" -%} {%- set download_artifact_s3_action = "seemethere/download-artifact-s3@v4" -%} {%- set upload_artifact_action = "actions/upload-artifact@v3" -%} -{%- set download_artifact_action = "actions/download-artifact@v3" -%} +{%- set download_artifact_action = "actions/download-artifact@v4.1.7" -%} {%- set timeout_minutes = 240 -%} diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 9da2ad1cb75842..37e62a6cdba539 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -1206,7 +1206,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-xpu @@ -2444,7 +2444,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_10-xpu @@ -3682,7 +3682,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_11-xpu @@ -4920,7 +4920,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_12-xpu From 495db3416534ab2d6f131807f340ed436a0f37f9 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Wed, 4 Sep 2024 16:34:10 +0000 Subject: [PATCH 0404/1018] [Profiler] Allow kwinputs to be non-string values (#134893) Summary: When we process keyword arguments in profiler today we assume that all values will be strings. This breaks HTA because it assumes that "stream" and other values similar to it will be ints. To fix this we will only put quotes around strings for ivalues. Test Plan: Add chrome trace export in unit tests and check that stream does not have quotes around it Differential Revision: D62056059 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134893 Approved by: https://github.com/sanrise, https://github.com/izaitsevfb --- test/profiler/test_profiler.py | 2 +- torch/csrc/autograd/profiler_kineto.cpp | 3 ++- torch/csrc/profiler/util.cpp | 10 +++++++--- torch/csrc/profiler/util.h | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index cf92a08d836e10..27818947ddc0bc 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -1661,7 +1661,7 @@ def test_profiler_op_event_kwargs(self): args = e["args"] self.assertTrue("stream" in args) self.assertTrue("grid" in args) - self.assertTrue(args["stream"] == "0") + self.assertTrue(args["stream"] == 0) self.assertTrue(args["grid"] == "lambda x : x + 1") self.assertTrue(args["debug"] == "None") diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index c7a7a9d049ec19..bb5a7958f286bd 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -261,7 +261,8 @@ struct AddGenericMetadata : public MetadataBase { // Add metadata for kwinputs if exist for (const auto& [key, val] : op_event.kwinputs_) { - addMetadata(key, ivalueToStr(val)); + bool isString = val.isString(); + addMetadata(key, ivalueToStr(val, isString)); } // Add extra metadata if any for (const auto& [key, val] : op_event.extra_meta_) { diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 6e582b25c5ebb9..d03ffa5c27ac20 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -293,15 +293,19 @@ std::string strListToStr(const std::vector& types) { return "[" + rc + "]"; } } -std::string ivalueToStr(const c10::IValue& val) { +std::string ivalueToStr(const c10::IValue& val, bool isString) { std::stringstream ss; if (val.isNone()) { return "\"None\""; } else { ss.str(""); - ss << "\""; + if (isString) { + ss << "\""; + } ss << val; - ss << "\""; + if (isString) { + ss << "\""; + } std::string mystr = ss.str(); // A double quote can cause issues with the chrome tracing so force diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index f68d0e7b1aadb8..ba149e740b221f 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -92,7 +92,7 @@ TORCH_API std::string shapesToStr( TORCH_API std::string strListToStr(const std::vector& types); TORCH_API std::string inputOpIdsToStr( const std::list>& input_op_ids); -TORCH_API std::string ivalueToStr(const c10::IValue& val); +TORCH_API std::string ivalueToStr(const c10::IValue& val, bool isString); TORCH_API std::string ivalueListToStr(const std::vector& list); TORCH_API std::vector inputTypes(const at::RecordFunction& fn); From b50322508306e9efcb04b784d43376ac64606dde Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 3 Sep 2024 21:39:08 -0700 Subject: [PATCH 0405/1018] Redesign custom op functionlaization for better re-inplace (#134409) - The new implementation (auto_functionalized_v2) is enabled by default but can be disable using an inductor flag. - In export mode the old implementation is used. **Motiviation** Previous functionalization fails to re-inplace arguments when they are view over other tensors. see issue https://github.com/pytorch/pytorch/issues/131192 The new functionalization is easier to re-inplace for views. **A) Functionalizations pass** consider a program: ``` func(t) x = t[0] y = t[1] foo(x, y) # custom operator with x, y mutable return (x, y, t) ``` - To functionalize `foo` we generate a function that operates on the base tensors of the inputs; (x.base() and y.base()) and record how to regenerates the views out of the base for argument x by recording ```ViewInfo=(x.base(), x.size(), x.stride, x,storage_offset())``` - Due to some limitations on the torch.export arguments format, we have to generate alot of arguments, but this is something we can simplify in the future, for the example above we get the following function. ``` auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0 , _y_base_index = 0,_y_size = (), _y_stride = (), _y_storage_offset = 1 , _all_bases = [arg0_1]) ``` - In the code above: - _all_bases[t]: refers to a unique set of bases for all foo arguments. - for each argument x we have _x_base_index, _x_size, _x_stride, _x_storage_offset that can be used to (1) regenerate x from _all_bases[_x_base_index] or a copy of a the base. - the output of auto_functionalized is foo output , followed by x tensors one for each base in _all_bases, that is a copy of the base tensor after observing the mutations of the all the arguments that are views of that base. - for each use of a base in _all_bases or a view of it , that are after the call to foo, replace it with a view of the new output for the function above after functionalization we get : ``` def forward(self, arg0_1: "f32[2][1]cpu"): auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) getitem_1: "f32[2][1]cpu" = auto_functionalized[1]; auto_functionalized = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None # No stacktrace found for following nodes select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0) select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None return (select_2, select_3) ``` **B) Semantics of auto_functionalize** The new semantics of auto_functionalize is as the following: 1. For each base in all_bases, copy the base and create all_bases copies. (if a base is inplaced we do not need to copy it) 2. For each arg, regenerate the arg from the copy of its base using the view information above. 3. return the original foo output followed by the new bases. **C) Re-inplace pass** since auto_functionalize not copy the bases, what we actually inplace is the bases. (run just like before but on the beses instead of args). 1. For each base b in _all_bases check if there is any use of base (or its aliases/views) after auto_functionalize (before its overwritten with a copy) if there is not any, then inplace it (avoid copying it in step 1 above). Pull Request resolved: https://github.com/pytorch/pytorch/pull/134409 Approved by: https://github.com/zou3519 --- test/inductor/test_auto_functionalize.py | 719 +++++++++++++++++- test/inductor/test_inplacing_pass.py | 272 ++++++- .../dispatch_and_compile_graph.py | 1 - torch/_functorch/partitioners.py | 9 +- torch/_higher_order_ops/auto_functionalize.py | 554 ++++++++++++-- torch/_inductor/config.py | 5 + torch/_inductor/fx_passes/post_grad.py | 33 +- torch/_inductor/fx_passes/reinplace.py | 74 +- torch/_subclasses/functional_tensor.py | 23 +- .../_remove_auto_functionalized_pass.py | 9 +- 10 files changed, 1517 insertions(+), 182 deletions(-) diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index a007c3451896e8..f705a4fe8a8eb2 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -4,6 +4,7 @@ import torch import torch._dynamo.testing +import torch._inductor.config as inductor_config import torch._inductor.test_case import torch.onnx.operators import torch.utils._pytree as pytree @@ -126,6 +127,7 @@ def test_can_auto_functionalize(self): for schema in expected_true: with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define("mylib::a", schema, lib=lib) + self.assertTrue( can_auto_functionalize(torch.ops.mylib.a.default), msg=schema ) @@ -139,7 +141,8 @@ def test_can_auto_functionalize(self): ) self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) - def test_auto_functionalize(self): + @torch._inductor.config.patch(enable_auto_functionalized_v2=False) + def test_auto_functionalize_old(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo", @@ -162,9 +165,7 @@ def f(x, y, z, n): z = torch.randn(3) n = torch.randn(3) orig_args = (x, y, z, n) - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - log_stream, ctx = logs_to_string( "torch._inductor.compile_fx", "post_grad_graphs" ) @@ -192,7 +193,8 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 f(*eager_args) self.assertEqual(compiled_args, eager_args) - def test_auto_functionalize_with_returns(self): + @torch._inductor.config.patch(enable_auto_functionalized_v2=False) + def test_auto_functionalize_with_returns_old(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo", @@ -237,14 +239,12 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", \ -arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ -arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None getitem_4: "f32[3][1]cpu" = foo_default[0] getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None - return (getitem_4, getitem_5)""", + return (getitem_4, getitem_5)""", # noqa: B950 + ignore_comments=True, ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) @@ -253,37 +253,41 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 self.assertEqual(compiled_out, eager_out) def test_auto_functionalize_on_view(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - torch.library.define( - "mylib::foo", - "(Tensor(a!) x) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) + for value in [True, False]: + with torch.library._scoped_library( + "mylib", "FRAGMENT" + ) as lib, inductor_config.patch({"enable_auto_functionalized_v2": value}): + torch.library.define( + "mylib::foo", + "(Tensor(a!) x) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x): - x_np = x.detach().numpy() # view - np.sin(x_np, out=x_np) - return + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x): + x_np = x.detach().numpy() # view + np.sin(x_np, out=x_np) + return - x = torch.randn(3) - expected = x.sin() - torch.ops.mylib.foo(x) - assert torch.allclose(x, expected) + x = torch.randn(3) + expected = x.sin() + torch.ops.mylib.foo(x) + assert torch.allclose(x, expected) - @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) - def f(x): - x = x.clone() - y = x[:] - torch.ops.mylib.foo(y) - return x + @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) + def f(x): + x = x.clone() + y = x[:] + torch.ops.mylib.foo(y) + return x - y = f(x) - self.assertEqual(y, x.sin()) + y = f(x) + self.assertEqual(y, x.sin()) - def test_auto_functionalize_optional(self): + @torch._inductor.config.patch(enable_auto_functionalized_v2=False) + def test_auto_functionalize_optional_old(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo", @@ -308,14 +312,12 @@ def f(x, y, z, n): z = torch.randn(3) n = torch.randn(3) orig_args = (x, y, z, n) - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) log_stream, ctx = logs_to_string( "torch._inductor.compile_fx", "post_grad_graphs" ) with ctx(): torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) - if torch._dynamo.config.assume_static_by_default: post_grad_graphs = "\n".join( log_stream.getvalue().strip().split("\n")[3:] @@ -356,6 +358,647 @@ def f(x): x = torch.zeros(100, dtype=torch.int64) f(x) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_v2(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + + def f(x, y, z, n): + torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", dynamic=_dynamic, fullgraph=True)( + *compiled_args + ) + + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + def run_aot_eager(self, f, orig_args, _dynamic=False): + aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + + log_stream, ctx = logs_to_string( + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + ) + + result = None + with ctx(): + result = torch.compile( + f, backend="aot_eager", fullgraph=True, dynamic=_dynamic + )(*aot_eager_args) + + graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + return [aot_eager_args, result, graph] + + def run_inductor(self, f, orig_args, _dynamic=False): + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + result = None + with ctx(): + result = torch.compile( + f, backend="inductor", fullgraph=True, dynamic=_dynamic + )(*compiled_args) + + graph = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() + + return [compiled_args, result, graph] + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_with_returns_v2(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + return y[0] + w, y[1] + n + + @torch.library.impl_abstract("mylib::foo", lib=lib) + def foo_abstract(x, y, z, w, n): + return y[0] + w, y[1] + n + + def f(x, y, z, n): + return torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( + *compiled_args + ) + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None + getitem_4: "f32[3][1]cpu" = foo_default[0] + getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None + + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None + return (getitem_4, getitem_5)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + eager_out = f(*eager_args) + self.assertEqual(compiled_args, eager_args) + self.assertEqual(compiled_out, eager_out) + + # foo takes two inputs that are not views. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_extra1(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x, y): + torch.ops.mylib.foo(x, y) + return x + y + + orig_args = (torch.randn(2), torch.randn(2)) + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager( + f, orig_args, _dynamic + ) + [inductor_args, result2, graph_inductor] = self.run_inductor( + f, orig_args, _dynamic + ) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1]) + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None + return (add,)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None + return (add,)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None + add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1) + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None + return (add,)""", + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None + add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1) + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None + return (add,)""", + ignore_comments=True, + ignore_empty_lines=True, + ) + + # foo takes two views on the same input, function does not have return. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_extra2(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + a = x[0] + b = x[1] + torch.ops.mylib.foo(a, b) + return + + orig_args = [torch.randn(2)] + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager( + f, orig_args, _dynamic + ) + [inductor_args, result2, graph_inductor] = self.run_inductor( + f, orig_args, _dynamic + ) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1]) + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0) + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1) + foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1) + foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # foo takes two views on the same input, function returns both views and the input + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_extra3(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + a = x[0] + b = x[1] + torch.ops.mylib.foo(a, b) + return (a, b, x) + + orig_args = [torch.randn(2)] + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) + [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None + select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0) + select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None + return (select_2, select_3)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1) + foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None + select_2: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) + select_3: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None + return (select_2, select_3)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # foo takes a mutable list with views in addition to other args. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_extra4(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!)[] y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y[0].sin_() + + def f(x, y, z): + a = x[0] + b = z[0] + torch.ops.mylib.foo(a, [b, y]) + + orig_args = [torch.randn(2), torch.randn(2), torch.randn(2)] + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) + [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args[2], eager_args[2]) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_length = 2, _y_0_base_index = 1, _y_0_size = (), _y_0_stride = (), _y_0_storage_offset = 0, _y_1_base_index = 2, _all_bases = [arg0_1, arg1_1, arg2_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2] + getitem_3: "f32[2][1]cpu" = auto_functionalized_v2[3]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None + copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_3); arg2_1 = getitem_3 = copy__2 = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"): + as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0) + foo_default = torch.ops.mylib.foo.default(as_strided_default, [as_strided_default_1, arg2_1]); as_strided_default = as_strided_default_1 = foo_default = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None + copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__2 = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_optional_v2(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + if x is not None: + x.add_(y[0] + w) + if z is not None: + z.add_(y[1] + n) + + def f(x, y, z, n): + torch.ops.mylib.foo(x, y, z, 2, n) + + x = None + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=False) + def test_inference_mode1_v2(self): + with torch.inference_mode(): + self.test_auto_functionalize_extra1() + + # In inference mode we do not support inplacing views yet. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_inference_mode2_v2(self): + with torch.inference_mode(), torch.library._scoped_library( + "mylib", "FRAGMENT" + ) as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + a = x[0] + b = x[1] + torch.ops.mylib.foo(a, b) + return + + orig_args = [torch.randn(2)] + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) + [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) + select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1) + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]); select = select_1 = None + getitem_1: "f32[][]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[][]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0); getitem_1 = None + select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1); select_scatter = getitem_2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1); arg0_1 = select_scatter_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) + select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1) + clone_default: "f32[][]cpu" = torch.ops.aten.clone.default(select); select = None + clone_default_1: "f32[][]cpu" = torch.ops.aten.clone.default(select_1); select_1 = None + foo_default = torch.ops.mylib.foo.default(clone_default, clone_default_1); foo_default = None + select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, clone_default, 0, 0); clone_default = None + select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, clone_default_1, 0, 1); select_scatter_default = clone_default_1 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1); arg0_1 = select_scatter_default_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_dynamic_v2(self): + self.test_auto_functionalize_v2(_dynamic=True) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_dynamic2_v2(self): + self.test_auto_functionalize_extra1(_dynamic=True) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_dynamic3_v2(self): + self.test_auto_functionalize_extra2(_dynamic=True) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_inplacing_pass.py b/test/inductor/test_inplacing_pass.py index 6164b4e975ee87..309e793667161f 100644 --- a/test/inductor/test_inplacing_pass.py +++ b/test/inductor/test_inplacing_pass.py @@ -3,10 +3,14 @@ from typing import List import torch +import torch._inductor.config as inductor_config from functorch import make_fx from torch import Tensor from torch._dynamo.utils import counters -from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized, + auto_functionalized_v2, +) from torch._inductor.fx_passes.reinplace import reinplace_inplaceable_ops_core from torch._inductor.test_case import run_tests, TestCase as InductorTestCase from torch.testing._internal.common_utils import ( @@ -70,6 +74,11 @@ def sin_triton(x, out): return +@torch.library.custom_op("test_view::boo", mutates_args={"x"}) +def boo(x: torch.Tensor) -> None: + x.sin_() + + class TestReinplacingPassCorrectness(InductorTestCase): def setUp(self): counters.clear() @@ -124,7 +133,7 @@ def f(x, y): self._test(f) - def test_counters(self): + def test_counters_functionalize_old(self): counters.clear() def f(x): @@ -143,21 +152,176 @@ def f(x): # IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE self.assertEqual(num_reinplacing_failures(), 1) + def test_counters_functionalize_v2(self): + counters.clear() + + def f(x): + out = torch.empty_like(x) + _, new_out = auto_functionalized_v2( + sin._opoverload, + x=x, + _result_base_index=0, + _result_size=(3,), + _result_stride=(1,), + _result_storage_offset=0, + _all_bases=[out], + ) + y = out * new_out + return new_out, y + + x = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x) + reinplace_inplaceable_ops_core(gm.graph) + + # We shouldn't have been able to reinplace `out` because it was used after + # auto_functionalized. Note that this usually doesn't happen in practice; + # we're artificially creating this example to test the counter. + # IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE + self.assertEqual(num_reinplacing_failures(), 1) + + def get_not_inplaced_count(self, graph): + counter = 0 + auto_functionalized_found = False + for node in graph.nodes: + if (node.target == torch.ops.higher_order.auto_functionalized) or ( + node.target == torch.ops.higher_order.auto_functionalized_v2 + ): + auto_functionalized_found = True + counter += len(node.meta["only_clone_these_tensors"]) + assert auto_functionalized_found + return counter + + def test_view_inplaced_functionalize_v2(self): + def f(arg0_1): + select = torch.ops.aten.select.int(arg0_1, 0, 0) + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(3,), + _x_stride=(1,), + _x_storage_offset=0, + _all_bases=[arg0_1], + ) + getitem_1 = auto_functionalized[1] + copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) + return () + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 0) + + # introduce a view another_view that is used `after` the copy + def test_view_inplaced2_functionalize_v2(self): + def f(arg0_1): + select = torch.ops.aten.select.int(arg0_1, 0, 0) + another_view = arg0_1[2] + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(3,), + _x_stride=(1,), + _x_storage_offset=0, + _all_bases=[arg0_1], + ) + getitem_1 = auto_functionalized[1] + copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) + return another_view + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 0) + + # introduce a view another_view that is used `before` the copy + def test_views_not_inplaced_functionalize_v2(self): + def f(arg0_1): + select = torch.ops.aten.select.int(arg0_1, 0, 0) + another_view = arg0_1[2] + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(3,), + _x_stride=(1,), + _x_storage_offset=0, + _all_bases=[arg0_1], + ) + getitem_1 = auto_functionalized[1] + use_another_view = another_view * 10 + copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) + return use_another_view + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) + + # a view over input without copy node, inplace not allowed + def test_views_not_inplaced2_functionalize_v2(self): + def f(arg0_1): + select = torch.ops.aten.select.int(arg0_1, 0, 0) + another_view = arg0_1[2] + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(3,), + _x_stride=(1,), + _x_storage_offset=0, + _all_bases=[arg0_1], + ) + getitem_1 = auto_functionalized[1] + return + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) + + # no copy nodes, view over local, with a use for another view + def test_views_not_inplaced3_functionalize_v2(self): + def f(arg0_1): + a = torch.ones(10) + another_view = a[2] + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(), + _x_stride=(), + _x_storage_offset=0, + _all_bases=[a], + ) + getitem_1 = auto_functionalized[1] + return another_view + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) + def test_multi_output_intermediate(self): for requires_grad in [False, True]: - counters.clear() - - def f(x): - out1 = torch.empty_like(x) - out2 = torch.empty_like(x) - sin_cos(x, out1, out2) - return out1, out2, x**2 - - x = torch.randn(3, device=device, requires_grad=requires_grad) - res1, res2, _ = torch.compile(f)(x) - self.assertEqual(res1, x.sin()) - self.assertEqual(res2, x.cos()) - self.assertEqual(num_reinplacing_failures(), 0) + for enable_v2 in [False, True]: + with inductor_config.patch( + {"enable_auto_functionalized_v2": enable_v2} + ): + counters.clear() + + def f(x): + out1 = torch.empty_like(x) + out2 = torch.empty_like(x) + sin_cos(x, out1, out2) + return out1, out2, x**2 + + x = torch.randn(3, device=device, requires_grad=requires_grad) + res1, res2, _ = torch.compile(f)(x) + self.assertEqual(res1, x.sin()) + self.assertEqual(res2, x.cos()) + self.assertEqual(num_reinplacing_failures(), 0) def test_multiple_mutations(self): counters.clear() @@ -190,31 +354,59 @@ def f(x): self.assertEqual(result, x.sin().sin().sin()) self.assertEqual(num_reinplacing_failures(), 0) - def test_lists(self): - @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"}) - def mutate_op(y: List[Tensor]) -> None: - y[0].add_(2) - y[1].add_(3) - - @torch.compile(fullgraph=True, dynamic=False, backend="inductor") - def f(b): - mutate_op([b[0], b[1]]) - - x1 = torch.tensor([0.3, 0.4], device=device) - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - torch.compile(f, backend="inductor", fullgraph=True)(x1) - post_grad_graphs = "\n".join( - log_stream.getvalue().strip().split("\n")[3:] - ).strip() - - # Can't reinplace on views yet (1 for the "entire list" failing to reinplace) - self.assertEqual(num_reinplacing_failures(), 1) - - # Both list inputs failed to reinplace. So we should have emitted clones for them. - self.assertEqual(post_grad_graphs.count("aten.clone"), 2) + def test_lists_functionalize_v2(self): + with inductor_config.patch({"enable_auto_functionalized_v2": True}): + + @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"}) + def mutate_op(y: List[Tensor]) -> None: + y[0].add_(2) + y[1].add_(3) + + @torch.compile(fullgraph=True, dynamic=False, backend="inductor") + def f(b): + mutate_op([b[0], b[1]]) + + x1 = torch.tensor([0.3, 0.4], device=device) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(x1) + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + # We can inplace the base y. no clones emitted. + self.assertEqual(num_reinplacing_failures(), 0) + self.assertEqual(post_grad_graphs.count("aten.clone"), 0) + + def test_lists_old_functionalize(self): + with inductor_config.patch({"enable_auto_functionalized_v2": False}): + + @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"}) + def mutate_op(y: List[Tensor]) -> None: + y[0].add_(2) + y[1].add_(3) + + @torch.compile(fullgraph=True, dynamic=False, backend="inductor") + def f(b): + mutate_op([b[0], b[1]]) + + x1 = torch.tensor([0.3, 0.4], device=device) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(x1) + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + # Can't reinplace on views yet (1 for the "entire list" failing to reinplace) + self.assertEqual(num_reinplacing_failures(), 1) + + # Both list inputs failed to reinplace. So we should have emitted clones for them. + self.assertEqual(post_grad_graphs.count("aten.clone"), 2) @parametrize( "factory_op", diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index 2013eca1f6fb86..a12d42db7475e0 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -184,7 +184,6 @@ def _map_assigned_buffer_to_proxy(_mod, name, buffer): # As long as we opted to remove input mutations, then # there should be *NO* mutating ops in the graph at this point. copy_count = assert_functional_graph(fw_module.graph) - fw_module.graph.eliminate_dead_code() fw_module.recompile() diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index df6ca4e0e998a0..b0f609dc75615c 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -804,16 +804,17 @@ def solve_min_cut( if node.op == "call_function" and hasattr(node.target, "_overloadpacket") } ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} - print("Ops banned from rematerialization: ", ops_ignored) + print("Ops banned from re-materialization: ", ops_ignored) print() def can_fuse_into_auto_functionalized(a, b): if b.target != torch.ops.higher_order.auto_functionalized: return False mutable_op = b.args[0] - mutable_arg_names = ( - torch._higher_order_ops.auto_functionalize.get_mutable_arg_names(mutable_op) - ) + ( + mutable_arg_names, + _, + ) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op) for name in mutable_arg_names: arg = b.kwargs[name] if a is arg: diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index bd84ec7f2eb600..d0a25564af6fd6 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.utils._pytree as pytree @@ -17,6 +18,142 @@ ) +@dataclass +class ViewInfo: + base_index: int + size: Optional[Sequence[Union[int, torch.SymInt]]] = None + stride: Optional[Sequence[Union[int, torch.SymInt]]] = None + storage_offset: Optional[int] = None + # When is_view is false, the tensor is the base, and + # size, stride and storage_offset are all None. + is_view: bool = True + + def regenerate_view(self, bases_list: List[Tensor]): + if not self.is_view: + return bases_list[self.base_index] + + assert self.stride is not None + assert self.size is not None + assert self.storage_offset is not None + + return torch.as_strided( + bases_list[self.base_index], + self.size, + self.stride, + self.storage_offset, + ) + + +def write_view_information_to_args( + mutable_arg_names: List[str], + mutable_arg_types: List[torch.Type], + kwargs: Dict[str, Any], + arg_to_base_index: Dict[str, Any], +): + """ + This function writes the view information into kwargs. It reads mutable_args from kwargs. + and uses arg_to_base_index and tensor information to write ViewInfo into kwargs. + mutable_arg_names: mutable custom operator arg names. + mutable_arg_types: mutable custom operator arg types. + kwargs: the original custom operator args. + arg_to_base_index: maps mutable_arg_name to int | [int] that refers to the base tensor that + corresponds to the input tensor + """ + + def write_single_view(prefix: str, tensor: Tensor, base_index: int): + assert f"{prefix}_base_index" not in kwargs + assert f"{prefix}_size" not in kwargs + assert f"{prefix}_stride" not in kwargs + assert f"{prefix}_storage_offset" not in kwargs + + if tensor is None: + kwargs[f"{prefix}_base_index"] = None + elif tensor._base is None: + # if the tensor is the base (not view), for simplicity we do not serialize view meta. + kwargs[f"{prefix}_base_index"] = base_index + else: + kwargs[f"{prefix}_base_index"] = base_index + kwargs[f"{prefix}_size"] = tensor.size() + kwargs[f"{prefix}_stride"] = tensor.stride() + kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset() + + for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): + arg = kwargs[arg_name] + if isinstance(arg_type, torch.ListType): + if arg is None: + kwargs[f"_{arg_name}_length"] = None + + kwargs[f"_{arg_name}_length"] = len(arg) + for i, elem in enumerate(arg): + write_single_view( + f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i] + ) + + elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): + write_single_view( + f"_{arg_name}", + kwargs[arg_name], + arg_to_base_index.get(arg_name, None), + ) + else: + raise RuntimeError(f"Unsupported type {arg_type}") + + +# Returns a dict of arg_name -> ViewInfo | [ViewInfo] +def read_view_information_from_args( + mutable_arg_names: List[str], + mutable_arg_types: List[torch.Type], + kwargs: Dict[str, Any], + all_bases: List[Tensor], +): + """ + This reads the view information added by `write_view_information_to_args` from kwargs, pop them, + and returns a dict arg_name -> ViewInfo | [ViewInfo](if the input is list). that maps each mutable arg + to its view information. + mutable_arg_names: mutable custom operator arg names. + mutable_arg_types: mutable custom operator arg types. + kwargs : args of auto_functionalize(custom_op, kwargs) + """ + + def get_arg(name): + return kwargs.pop(name) + + def read_single_view(prefix): + base_index = get_arg(f"{prefix}_base_index") + if base_index is None: + return None + elif f"{prefix}_size" not in kwargs: + assert f"{prefix}_stride" not in kwargs + assert f"{prefix}_storage_offset" not in kwargs + + # This means that the argument is the base tensor + return ViewInfo(base_index, all_bases[base_index], is_view=False) + + else: + size = get_arg(f"{prefix}_size") + stride = get_arg(f"{prefix}_stride") + storage_offset = get_arg(f"{prefix}_storage_offset") + return ViewInfo(base_index, size, stride, storage_offset, is_view=True) + + args_view_info: Dict[str, Any] = {} + for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): + if isinstance(arg_type, torch.ListType): + length = get_arg(f"_{arg_name}_length") + if length is None: + # The whole list is None. + args_view_info[arg_name] = None + else: + args_view_info[arg_name] = [ + read_single_view(f"_{arg_name}_{i}") for i in range(length) + ] + + elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): + args_view_info[arg_name] = read_single_view(f"_{arg_name}") + else: + raise RuntimeError(f"Unsupported type {arg_type}") + return args_view_info + + # NOTE: [auto-functionalizing custom ops] # Users may wish to torch.compile custom ops that mutate their inputs. # torch.compile will automatically support this op without anyone needing @@ -34,6 +171,9 @@ # This HOP effectively runs the functional version of the op when # called: it clones inputs that will be mutated, runs the op, and # then returns (output, Tensors with the new values) +# +# auto_functionalize_v2 is an improved version of auto_functionalize that better handle +# re-inplacing views. class AutoFunctionalized(HigherOrderOperator): @@ -71,6 +211,38 @@ def __call__( auto_functionalized = AutoFunctionalized() auto_functionalized.__module__ = "torch.ops.higher_order" +auto_functionalized.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) + + +class AutoFunctionalizedV2(HigherOrderOperator): + """auto_functionalized_v2(_mutable_op, **kwargs) + + This HOP runs a "functional" version of _mutable_op. + Unlike AutoFunctionalized, this version is improved to better handle + view tensors. This version is only used in non export mode. + """ + + def __init__(self) -> None: + super().__init__("auto_functionalized_v2") + + def __call__( + self, + /, + _mutable_op: OpOverload, + **kwargs: Any, + ) -> Tuple[Any, Tuple[Tensor, ...]]: + assert can_auto_functionalize(_mutable_op) + assert isinstance(kwargs, dict) + return super().__call__(_mutable_op, **kwargs) + + +auto_functionalized_v2 = AutoFunctionalizedV2() +auto_functionalized_v2.__module__ = "torch.ops.higher_order" + +auto_functionalized_v2.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized_v2.fallthrough(DispatchKey.AutogradCUDA) + def can_auto_functionalize(op: OperatorBase) -> bool: if not isinstance(op, OpOverload): @@ -120,85 +292,23 @@ def can_auto_functionalize(op: OperatorBase) -> bool: return True -@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) -def auto_functionalized_dense( - _mutable_op: OpOverload, - _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, - **kwargs: Any, -) -> Tuple[Any, Tuple[Tensor, ...]]: - new_kwargs = dict(**kwargs) - result = [] - - _mutable_args_names = get_mutable_arg_names(_mutable_op) - for name in _mutable_args_names: - if ( - _only_clone_these_tensors is not None - and name not in _only_clone_these_tensors - ): - new_kwargs[name] = kwargs[name] - else: - new_kwargs[name] = ( - [clone_preserve_strides(x) for x in kwargs[name]] - if kwargs[name] is not None and isinstance(kwargs[name], list) - else clone_preserve_strides(kwargs[name]) - if kwargs[name] is not None - else None - ) - result.append(new_kwargs[name]) - out = _mutable_op(**new_kwargs) - - if isinstance(out, tuple): - return (*out, *result) # type: ignore[return-value] - else: - return (out, *result) # type: ignore[return-value] - - -@auto_functionalized.py_impl(FakeTensorMode) -def auto_functionalized_fake( - mode, - _mutable_op: OpOverload, - **kwargs: Any, -) -> Tuple[Any, Tuple[Tensor, ...]]: - with mode: - result = auto_functionalized_dense(_mutable_op, **kwargs) - return result - - -@auto_functionalized.py_impl(ProxyTorchDispatchMode) -def auto_functionalized_proxy( - mode, - _mutable_op: OpOverload, - **kwargs: Any, -) -> Tuple[Any, Tuple[Tensor, ...]]: - with disable_proxy_modes_tracing(): - out = auto_functionalized(_mutable_op, **kwargs) - - proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) - out_proxy = mode.tracer.create_proxy( - "call_function", - auto_functionalized, - (_mutable_op,), - proxy_kwargs, - ) - result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) - return result - - -auto_functionalized.fallthrough(DispatchKey.AutogradCPU) -auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) - - -def get_mutable_arg_names(op: OpOverload) -> List[str]: +def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]: """ Returns the list of argument names that get mutated according to the - schema. + schema and their types. """ mutable_args_names = [ arg.name for arg in op._schema.arguments if arg.alias_info is not None and arg.alias_info.is_write ] - return mutable_args_names + + mutable_args_types = [ + arg.type + for arg in op._schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + return mutable_args_names, mutable_args_types def do_auto_functionalize( @@ -245,7 +355,7 @@ def do_auto_functionalize( ) # List of the name of args that get mutated (according to the schema) - mutable_args_names = get_mutable_arg_names(op) + mutable_args_names, _ = get_mutable_args(op) unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[ : -len(mutable_args_names) @@ -290,9 +400,307 @@ def sync_update(o, orig_arg): return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] +def do_auto_functionalize_v2( + op: OpOverload, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + # All of the (args, kwargs), but all as kwargs. The names for the + # args come from the schema. This makes it easier for us to work with them. + normalized_kwargs = {} + + schema = op._schema + for idx, arg in enumerate(schema.arguments): + # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema + if arg.name in kwargs: + normalized_kwargs[arg.name] = kwargs[arg.name] + elif idx < len(args): + # if its out of bounds we don't need to do anything + # as it means the the optional arg was passed with its default + # value + normalized_kwargs[arg.name] = args[idx] + else: + normalized_kwargs[arg.name] = arg.default_value + + # List of the name of args that get mutated (according to the schema) + mutable_args_names, mutable_args_types = get_mutable_args(op) + + # A list of all bases of mutable args without duplication + all_bases = [] + all_bases_addresses: list[int] = [] + + # Map arg_name to the index of its base in all_bases. + arg_to_base_index: Dict[str, Any] = {} + + def update_dict(tensor, arg_name, index=None): + base = tensor if tensor._base is None else tensor._base + + def set_result(base_index): + if index is None: + arg_to_base_index[arg_name] = base_index + else: + arg_to_base_index[arg_name][index] = base_index + + if not all_bases_addresses.__contains__(base._cdata): + all_bases_addresses.append(base._cdata) + all_bases.append(base) + set_result(len(all_bases) - 1) + else: + set_result(all_bases_addresses.index(base._cdata)) + + for arg_name in mutable_args_names: + arg = normalized_kwargs[arg_name] + if arg is None: + continue + + if isinstance(arg, list): + arg_to_base_index[arg_name] = {} + for i, tensor in enumerate(arg): + if tensor is None: + arg_to_base_index[arg_name].append(None) + continue + + update_dict(tensor, arg_name, i) + + else: + update_dict(arg, arg_name) + + # add view_meta for each args into unwrapped_kwargs. + write_view_information_to_args( + mutable_args_names, + mutable_args_types, + normalized_kwargs, + arg_to_base_index, + ) + + # remove mutated args from the kwargs (its a function of _all_bases now) + for arg_name in mutable_args_names: + del normalized_kwargs[arg_name] # type: ignore[arg-type] + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: + warnings.warn( + "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " + "Please consider using a different name for this argument to avoid potential issues." + ) + all_basis_unwrapped = ctx.unwrap_tensors(all_bases) + + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized_v2( + op, **dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped) # type: ignore[arg-type] + ) + + unwrapped_actual_out: Union[Any, Tuple[Any]] = ( + unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)] + ) + + unwrapped_mutable_out = ( + [] if len(all_bases) == 0 else unwrapped_outs[-len(all_bases) :] + ) + + if len(op._schema.returns) == 0: + assert unwrapped_actual_out[0] is None + unwrapped_actual_out = None + elif len(op._schema.returns) == 1: + assert len(unwrapped_actual_out) == 1 + unwrapped_actual_out = unwrapped_actual_out[0] + else: + assert len(unwrapped_actual_out) == len(op._schema.returns) + + for orig_arg, unwrapped_out in zip(all_bases, unwrapped_mutable_out): + # Can be None if input was `Tensor(a!)?` + if unwrapped_out is None: + continue + + # We only handle Tensor or List[Tensor] here for now. + def sync_update(o, orig_arg): + ctx.replace(orig_arg, o) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + if isinstance(unwrapped_out, torch.Tensor): + sync_update(unwrapped_out, orig_arg) + elif isinstance(unwrapped_out, list) and all( + isinstance(o, torch.Tensor) for o in unwrapped_out + ): + assert len(orig_arg) == len(unwrapped_out) + for orig_a, o in zip(orig_arg, unwrapped_out): + sync_update(o, orig_a) + else: + raise RuntimeError( + f"unsupported type for auto-functionalization: {unwrapped_out}" + ) + + return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] + + +# auto_functionalize functions +@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_dense( + _mutable_op: OpOverload, + _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + new_kwargs = dict(**kwargs) + result = [] + + _mutable_args_names, _ = get_mutable_args(_mutable_op) + for name in _mutable_args_names: + if ( + _only_clone_these_tensors is not None + and name not in _only_clone_these_tensors + ): + new_kwargs[name] = kwargs[name] + else: + new_kwargs[name] = ( + [clone_preserve_strides(x) for x in kwargs[name]] + if kwargs[name] is not None and isinstance(kwargs[name], list) + else clone_preserve_strides(kwargs[name]) + if kwargs[name] is not None + else None + ) + result.append(new_kwargs[name]) + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *result) # type: ignore[return-value] + else: + return (out, *result) # type: ignore[return-value] + + +@auto_functionalized.py_impl(FakeTensorMode) +def auto_functionalized_fake( + mode, + _mutable_op: OpOverload, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_dense(_mutable_op, **kwargs) + return result + + +@auto_functionalized.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_proxy( + mode, + _mutable_op: OpOverload, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + with disable_proxy_modes_tracing(): + out = auto_functionalized(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + @auto_functionalized.py_functionalize_impl def auto_functionalized_func(ctx, _mutable_op, **kwargs): unwrapped_kwargs = ctx.unwrap_tensors(kwargs) with ctx.redispatch_to_next(): result = auto_functionalized(_mutable_op, **unwrapped_kwargs) return ctx.wrap_tensors(result) + + +# auto_functionalized_v2 functions +@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_v2_dense( + _mutable_op: OpOverload, + _only_clone_these_bases: Optional[Tuple[int, ...]] = None, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + all_bases: List[Tensor] = kwargs.pop("_all_bases", []) + mutable_args_names, mutable_args_types = get_mutable_args(_mutable_op) + args_view_info = read_view_information_from_args( + mutable_args_names, mutable_args_types, kwargs, all_bases + ) + + if _only_clone_these_bases is None: + _only_clone_these_bases = tuple(range(len(all_bases))) + + def maybe_copy(i, t): + if t is None: + return None + if i in _only_clone_these_bases: + return t.clone() + else: + return t + + all_bases_new = [maybe_copy(i, t) for i, t in enumerate(all_bases)] + + # create new args + new_kwargs = dict(**kwargs) + + # re-generate all inputs from all_bases_new using args_view_info and add them to new_kwargs. + for arg_name in mutable_args_names: + if args_view_info[arg_name] is None: + new_kwargs[arg_name] = None + elif isinstance(args_view_info[arg_name], list): + new_kwargs[arg_name] = [] + for i, elem in enumerate(args_view_info[arg_name]): + if elem is None: + new_kwargs[arg_name].append(None) + else: + view_info = args_view_info[arg_name][i] + new_kwargs[arg_name].append( + view_info.regenerate_view(all_bases_new) + ) + else: + new_kwargs[arg_name] = args_view_info[arg_name].regenerate_view( + all_bases_new + ) + + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *all_bases_new) # type: ignore[return-value] + else: + return (out, *all_bases_new) # type: ignore[return-value] + + +@auto_functionalized_v2.py_impl(FakeTensorMode) +def auto_functionalized_v2_fake( + mode, + _mutable_op: OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_v2_dense(_mutable_op, **kwargs) + return result + + +@auto_functionalized_v2.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_v2_proxy( + mode, + _mutable_op: OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + with disable_proxy_modes_tracing(): + out = auto_functionalized_v2(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized_v2, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +@auto_functionalized_v2.py_functionalize_impl +def auto_functionalized_v2_func(ctx, _mutable_op, **kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + result = auto_functionalized_v2(_mutable_op, **unwrapped_kwargs) + return ctx.wrap_tensors(result) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c5429761abf452..2332bebf2c7fda 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -25,6 +25,11 @@ def autotune_remote_cache_default() -> Optional[bool]: return None +# Enable auto_functionalized_v2 (enabled by default) +enable_auto_functionalized_v2 = ( + os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1" +) + # add some debug printouts debug = False diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 1d4029b852320f..4e73931fcef23b 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -811,7 +811,7 @@ def decompose_auto_functionalized(graph): CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), pass_dict=graph_pass, ) - def replacement(match: Match, *args, **kwargs): + def _(match: Match, *args, **kwargs): from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense only_clone_these_tensors = tuple( @@ -849,11 +849,42 @@ def decomp(*flat_args): match.replace_by_example(decomp, flat_args, run_functional_passes=False) + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized_v2_dense, + ) + + only_clone_these_bases = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + graph_pass.apply(graph) + for node in graph.find_nodes( op="call_function", target=torch.ops.higher_order.auto_functionalized ): raise AssertionError("auto_functionalized was not removed") + + for node in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized_v2 + ): + raise AssertionError("auto_functionalized_v2 was not removed") + for node in graph.find_nodes( op="call_function", target=torch.ops.higher_order.triton_kernel_wrapper_functional, diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 2618c7591b605b..59706134f85fea 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -467,6 +467,7 @@ def can_inplace(node, mutated_arg): if get_node_storage(mutated_arg) is None: return False shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] + if mutated_arg.op in ("placeholder", "get_attr"): # Get the first copy_ node that mutates the mutated_arg. copy_node = copy_nodes.get(mutated_arg, None) @@ -482,6 +483,9 @@ def can_inplace(node, mutated_arg): return True elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes): + # This should never happen in auto_functionalize_v2 non-inference mode, + # since all mutated_arg are bases. + # If mutated arg is view of any of the inputs of the graph, # do not allow for inplacing. # This would require more sophisticated algorithm to handle @@ -491,9 +495,30 @@ def can_inplace(node, mutated_arg): node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg ) + def log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + possibly_missed_reinplacing_opportunities, + ): + log.info( + "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " + "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " + "memory usage and performance.", + node_name, + old_tensors_to_clone, + tensors_to_clone, + possibly_missed_reinplacing_opportunities, + ) + torch._dynamo.utils.counters["inductor"][ + "possibly_missed_reinplacing_opportunities" + ] += len(possibly_missed_reinplacing_opportunities) + replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {} - def reinplace_and_refine_tensors_to_clone(old_tensors_to_clone, kwargs, node_name): + def reinplace_and_refine_tensors_to_clone( + old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False + ): tensors_to_clone: List[str] = [] storage_of_reinplaced_args = set() possibly_missed_reinplacing_opportunities = [] @@ -507,6 +532,7 @@ def tensor_with_same_storage_already_reinplaced(arg): for arg in old_tensors_to_clone: assert arg in kwargs + mutated_arg = kwargs[arg] # Let's say we have: @@ -523,12 +549,18 @@ def tensor_with_same_storage_already_reinplaced(arg): mutated_arg ) if should_attempt_reinplace and can_inplace(node, mutated_arg): + # In general, we probably do not need those optimizations. copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) if copy_node is not None: replace_dict[copy_node] = copy_node.args[0] - for user in node.users: - if user.target == operator.getitem and user.args[1] == arg: - replace_dict[user] = mutated_arg + if not auto_functionalize_v2: + for user in node.users: + # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to + # output atindex size(out)+i. + # This used to compare string with integers before for auto_functionalize_v2. Not sure + # if it was needed for inplaceable_triton_ops? + if user.target == operator.getitem and user.args[1] == arg: + replace_dict[user] = mutated_arg if isinstance(mutated_arg, (list, tuple)): for a in mutated_arg: @@ -540,18 +572,12 @@ def tensor_with_same_storage_already_reinplaced(arg): possibly_missed_reinplacing_opportunities.append(arg) tensors_to_clone.append(arg) - log.info( - "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " - "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " - "memory usage and performance.", + log_inplace_results( node_name, old_tensors_to_clone, tensors_to_clone, possibly_missed_reinplacing_opportunities, ) - torch._dynamo.utils.counters["inductor"][ - "possibly_missed_reinplacing_opportunities" - ] += len(possibly_missed_reinplacing_opportunities) return tensors_to_clone for node in graph.nodes: @@ -565,17 +591,37 @@ def tensor_with_same_storage_already_reinplaced(arg): if copy_node is not None: replace_dict[copy_node] = copy_node.args[0] node.target = inplaceable_op.inplace_op + elif node.target == torch.ops.higher_order.auto_functionalized_v2: + _mutable_op = node.args[0] + kwargs = node.kwargs + + all_bases = kwargs["_all_bases"] + bases_to_clone = range(len(all_bases)) + base_tensors_dct = dict(enumerate(all_bases)) + new_bases_to_clone: List[int] = reinplace_and_refine_tensors_to_clone( + bases_to_clone, + base_tensors_dct, + node.target, + auto_functionalize_v2=True, + ) + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = new_bases_to_clone elif node.target == torch.ops.higher_order.auto_functionalized: _mutable_op = node.args[0] - from torch._higher_order_ops.auto_functionalize import get_mutable_arg_names + from torch._higher_order_ops.auto_functionalize import get_mutable_args - tensors_to_clone = get_mutable_arg_names(_mutable_op) + tensors_to_clone, _ = get_mutable_args(_mutable_op) # Don't try to reinplace Optional[Tensor] args that are None. tensors_to_clone = [ t for t in tensors_to_clone if node.kwargs[t] is not None ] tensors_to_clone = reinplace_and_refine_tensors_to_clone( - tensors_to_clone, node.kwargs, _mutable_op._name + tensors_to_clone, + node.kwargs, + _mutable_op._name, + auto_functionalize_v2=False, ) # Stash the metadata. There is a pass later on where we decompose diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 1fb64fbcf7bad9..e9cbb0759b4c10 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -2,9 +2,10 @@ import contextlib import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, Union +from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union import torch +import torch._inductor.config as inductor_config import torch.utils._pytree as pytree from torch._C import _functionalization_reapply_views_tls as _reapply_views from torch._ops import _get_dispatch_mode_pre_dispatch @@ -434,6 +435,7 @@ def unwrap(x): from torch._higher_order_ops.auto_functionalize import ( can_auto_functionalize, do_auto_functionalize, + do_auto_functionalize_v2, ) if can_auto_functionalize( @@ -444,7 +446,10 @@ def unwrap(x): # it doesn't matter what mode we use here because # the implementation of do_auto_functionalize doesn't # interact with FunctionalTensorMode at all - return do_auto_functionalize(func, args, kwargs) + if self.export or not inductor_config.enable_auto_functionalized_v2: + return do_auto_functionalize(func, args, kwargs) + else: + return do_auto_functionalize_v2(func, args, kwargs) from torch._higher_order_ops.effects import handle_effects, has_effects @@ -611,7 +616,7 @@ def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: @abstractmethod def unwrap_tensors( self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + ) -> Any: pass @abstractmethod @@ -654,8 +659,8 @@ def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: ) def unwrap_tensors( - self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]] + ) -> Any: return torch.utils._pytree.tree_map_only( FunctionalTensor, FunctionalTensor.from_functional, args ) @@ -748,9 +753,11 @@ def unwrap_tensors( def functionalize(self, inner_f: Callable) -> Callable: return torch.func.functionalize( inner_f, - remove="mutations_and_views" - if self.interpreter.functionalize_add_back_views() - else "mutations", + remove=( + "mutations_and_views" + if self.interpreter.functionalize_add_back_views() + else "mutations" + ), ) def redispatch_to_next(self) -> ContextManager: diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py index 76e254119458b6..683e89c3d14910 100644 --- a/torch/export/_remove_auto_functionalized_pass.py +++ b/torch/export/_remove_auto_functionalized_pass.py @@ -9,7 +9,7 @@ import torch from torch._higher_order_ops.auto_functionalize import ( auto_functionalized, - get_mutable_arg_names, + auto_functionalized_v2, ) from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized from torch.export import ExportedProgram @@ -36,10 +36,13 @@ def unsafe_remove_auto_functionalized_pass( if not isinstance(module, torch.fx.GraphModule): continue for node in ep.graph.nodes: - if node.op == "call_function" and node.target is auto_functionalized: + if ( + node.op == "call_function" and node.target is auto_functionalized + ) or ( + node.op == "call_function" and node.target is auto_functionalized_v2 + ): func = node.args[0] assert isinstance(func, torch._ops.OpOverload) - mutable_args_names = get_mutable_arg_names(func) # re-inplace everything node.meta["only_clone_these_tensors"] = [] decompose_auto_functionalized(ep.graph) From 4a8130adba534a46f8d8be562da8822f5657756c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 4 Sep 2024 17:18:21 +0000 Subject: [PATCH 0406/1018] Revert "Ignore fresh unbacked when doing recursive make_fx inside HOPs (#135053)" This reverts commit a178a053ad2c8e42d1b684ed38385b9646ec3b74. Reverted https://github.com/pytorch/pytorch/pull/135053 on behalf of https://github.com/ezyang due to need to back out https://github.com/pytorch/pytorch/pull/133585 ([comment](https://github.com/pytorch/pytorch/pull/134407#issuecomment-2329597388)) --- test/export/test_export.py | 21 ------------------- .../collect_metadata_analysis.py | 10 +++------ 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 0099eda3fa9a08..887be93dbf75f4 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -686,27 +686,6 @@ def false_fn(x): M()(torch.randn(7)) torch.export.export(M(), (torch.randn(7),)) - @torch._dynamo.config.patch(capture_scalar_outputs=True) - def test_cond_contains_unbacked_no_escape(self): - class M(torch.nn.Module): - def forward(self, a, b1, b2, c): - def true_fn(x): - return x * b1.item() - - def false_fn(x): - return x * b2.item() - - r = torch.cond(a, true_fn, false_fn, (c,)) - return r * 2 - - args = ( - torch.tensor(True), - torch.tensor([4]), - torch.tensor([4]), - torch.randn(10, requires_grad=True), - ) - torch.export.export(M(), args) - def test_state_tensors(self): class M(torch.nn.Module): # simple with register buffer def __init__(self) -> None: diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 51a0aeb24ad496..820c7d288cf2c6 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -9,7 +9,6 @@ """ import collections -import contextlib import logging from functools import wraps from typing import Callable, DefaultDict, Dict, List, Optional @@ -163,18 +162,15 @@ def inner(*flat_args): # It doesn't matter if we run this under predispatch or not because it is # only for figuring out metadata mode = FunctionalTensorMode(_allow_token_discovery=True) - suppress_pending = contextlib.nullcontext() - fake_mode = detect_fake_mode() - if fake_mode and (shape_env := fake_mode.shape_env): - suppress_pending = shape_env.ignore_fresh_unbacked_symbols() - with disable_above, mode, suppress_pending: + with disable_above, mode: # precondition: The passed in function already handles unflattening inputs + flattening outputs flat_f_args = pytree.tree_map(_to_fun, flat_args) flat_f_outs = f(*flat_f_args) # We didn't do any tracing, so we don't need to process the # unbacked symbols, they will just disappear into the ether. # Also, prevent memoization from applying. - if fake_mode: + if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env): + shape_env.pending_fresh_unbacked_symbols.clear() fake_mode.epoch += 1 fake_mode.reset_nt_tensor_id_counter() From 31190d5737474d2a8cbc45757878fcde346d4ec0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 4 Sep 2024 17:18:21 +0000 Subject: [PATCH 0407/1018] Revert "Compute and do renamings even when ignoring fresh unbacked symbols (#134407)" This reverts commit 46cb2af7d822681298370bab9d49b3cba5546dd5. Reverted https://github.com/pytorch/pytorch/pull/134407 on behalf of https://github.com/ezyang due to need to back out https://github.com/pytorch/pytorch/pull/133585 ([comment](https://github.com/pytorch/pytorch/pull/134407#issuecomment-2329597388)) --- torch/fx/experimental/symbolic_shapes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index b50b55fa47302b..bc6645c617a479 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -597,6 +597,8 @@ def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, """ if shape_env is None: return + if shape_env._ignore_fresh_unbacked_symbols_tls(): + return fs = shape_env.pending_fresh_unbacked_symbols pending = set(fs) if pending: @@ -2720,6 +2722,8 @@ def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol): assert isinstance(new_s, sympy.Symbol), new_s assert free_unbacked_symbols(new_s), new_s assert free_unbacked_symbols(orig_s), orig_s + if self._ignore_fresh_unbacked_symbols_tls(): + return dest = self.replacements.get(orig_s) assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" self._set_replacement(orig_s, new_s, "rename_unbacked_to") From bfdc2ddad5a5ced461b47f491d211aa3ec704709 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 4 Sep 2024 17:21:32 +0000 Subject: [PATCH 0408/1018] Revert "Fix set_unbacked_bindings when list of Tensors is returned (#133585)" This reverts commit 2a49296d7563150d67bb00bd4c97bc5aafaa77df. Reverted https://github.com/pytorch/pytorch/pull/133585 on behalf of https://github.com/ezyang due to fails torchrec tests ([comment](https://github.com/pytorch/pytorch/pull/133585#issuecomment-2329602983)) --- test/export/test_export.py | 34 --------------------------- torch/fx/experimental/proxy_tensor.py | 8 +++---- 2 files changed, 3 insertions(+), 39 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 887be93dbf75f4..5688bf6509e324 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -652,40 +652,6 @@ def forward(self, x, c): foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False ) - def test_unbacked_to_cond(self): - class M(torch.nn.Module): - def forward(self, a): - az = a.nonzero() - - def true_fn(x): - return (x + 1).sum() - - def false_fn(x): - return (x + 3).sum() - - r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) - return r * 2 - - M()(torch.randn(7)) - torch.export.export(M(), (torch.randn(7),)) - - def test_unbacked_to_cond_passthrough(self): - class M(torch.nn.Module): - def forward(self, a): - az = a.nonzero() - - def true_fn(x): - return x + 1 - - def false_fn(x): - return x + 3 - - r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) - return r * 2 - - M()(torch.randn(7)) - torch.export.export(M(), (torch.randn(7),)) - def test_state_tensors(self): class M(torch.nn.Module): # simple with register buffer def __init__(self) -> None: diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 12c4f632118570..7a3ece886506c3 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -596,6 +596,8 @@ def track_tensor_tree( constant: Optional[_NestedTensors], tracer: _ProxyTracer, ) -> T: + _set_unbacked_bindings(inner_res, proxy_res) + def wrap_with_proxy( e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors] ) -> None: @@ -604,13 +606,11 @@ def wrap_with_proxy( assert constant is None or isinstance(constant, Tensor) track_tensor(e, proxy, tracer=tracer, constant=constant) set_meta(proxy, e) - _set_unbacked_bindings(e, proxy) elif isinstance(e, py_sym_types): assert isinstance(proxy, Proxy) # NB: eagerly set meta here, so that the numbering is in order set_meta(proxy, e) set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) - _set_unbacked_bindings(e, proxy) elif isinstance(e, _AnyScriptObject): assert isinstance(proxy, Proxy) set_proxy_slot(e, tracer, proxy) @@ -2188,8 +2188,6 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: """A helper function for setting up unbacked_bindings on the destination FX graph.""" from .symbolic_shapes import compute_unbacked_bindings - log.debug("_set_unbacked_bindings %s", out_proxy) - # Can't use detect_fake_mode here, # # python test/distributed/_tensor/test_dtensor_compile.py -k @@ -2200,5 +2198,5 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) if fake_mode and fake_mode.shape_env: if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): - assert isinstance(out_proxy, Proxy), out_proxy + assert isinstance(out_proxy, Proxy) out_proxy.node.meta["unbacked_bindings"] = symbol_to_path From 126c46725809918a3a053c32404fceb4e1eb6846 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 4 Sep 2024 02:51:25 +0000 Subject: [PATCH 0409/1018] [inductor] Pass to fix device on index(..., [iota]) (#134748) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This pattern was preventing cudagraphs from kicking in on torch_multimodal_clip, resulting in `1.6529 → 3.3471` speedup. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134748 Approved by: https://github.com/shunting314 --- test/inductor/test_cuda_repro.py | 27 +++++++++++- torch/_inductor/fx_passes/joint_graph.py | 54 ++++++++++++++++++++++++ torch/_inductor/utils.py | 19 +++++++++ 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 0af5c2ca966711..3112b3e8965154 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -15,7 +15,11 @@ from torch._inductor import config from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.runtime.hints import DeviceProperties -from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code +from torch._inductor.utils import ( + run_and_get_code, + run_and_get_graph_lowering, + run_fw_bw_and_get_code, +) from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( @@ -1262,6 +1266,27 @@ def forward(self, x): out2 = m(input_tensor) self.assertEqual(out, out2, atol=1e-3, rtol=1e-3) + @config.patch("triton.cudagraphs", True) + def test_cpu_index(self): + @torch.compile(fullgraph=True) + def fn(x): + return x[torch.arange(32)] + + result, (graph,) = run_and_get_graph_lowering( + fn, torch.randn(64, device="cuda") + ) + self.assertEqual(graph.disable_cudagraphs_reason, None) + self.assertEqual(graph.device_types, {"cuda"}) + + inp = torch.randn(64, device="cuda", requires_grad=True) + result, (graph,) = run_and_get_graph_lowering(fn, inp) + self.assertEqual(graph.disable_cudagraphs_reason, None) + self.assertEqual(graph.device_types, {"cuda"}) + + result, (graph,) = run_and_get_graph_lowering(lambda: result.sum().backward()) + self.assertEqual(graph.disable_cudagraphs_reason, None) + self.assertEqual(graph.device_types, {"cuda"}) + def test_reflection_pad_loop_order(self): def fn(x, y): a = torch.nn.functional.pad(x, (5, 5, 5, 5), mode="reflect") diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 416125a8f90ffe..86c42d696190d2 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -14,6 +14,7 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.multiprocessing.reductions import StorageWeakRef +from ...utils._ordered_set import OrderedSet from .. import config from ..pattern_matcher import ( CallFunction, @@ -475,6 +476,59 @@ def joint_graph_passes(graph: torch.fx.GraphModule): return graph +@register_graph_pattern( + CallFunction( + torch.ops.prims.iota.default, + KeywordArg("length"), + start=KeywordArg("start"), + step=KeywordArg("step"), + dtype=KeywordArg("dtype"), + device=KeywordArg("device"), + requires_grad=KeywordArg("requires_grad"), + ), + pass_dict=patterns, +) +def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad): + """ + Eager supports: + + aten.index(cuda_tensor, torch.arange(..., device="cpu")) + + But this results in an implicit host-device-copy and breaks cudagraphs. + Rewrite the arange to use CUDA. + """ + (node,) = match.nodes + user_devices: OrderedSet[torch.device] = OrderedSet() + for user in node.users: + if ( + user.op == "call_function" + and user.target in (aten.index.Tensor, aten.index_put.default) + and hasattr(user.meta.get("val"), "device") + ): + user_devices.add(user.meta["val"].device) # type: ignore[union-attr] + else: + return # bail out + + if len(user_devices) == 1 and "val" in node.meta: + (user_device,) = user_devices + if device.type != user_device.type: + repl = match.graph.call_function( + torch.ops.prims.iota.default, + (length,), + { + "start": start, + "step": step, + "dtype": dtype, + "device": user_device, + "requires_grad": requires_grad, + }, + ) + repl.meta.update(node.meta) + repl.meta["val"] = repl.meta["val"].to(user_device) + node.replace_all_uses_with(repl) + match.erase_nodes(match.graph) + + @register_graph_pattern( CallFunction( torch.ops.prims.convert_element_type.default, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3c76bbe432da92..5c76dd19822a2d 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1361,6 +1361,25 @@ def run_and_get_triton_code(fn, *args, **kwargs): return source_codes[0] +def run_and_get_graph_lowering(fn, *args, **kwargs): + from torch._inductor.codecache import CompiledFxGraph + from torch._inductor.graph import GraphLowering + + real_init = CompiledFxGraph.__init__ + graph_lowerings = [] + + def fake_init(*args, **kwargs): + real_init(*args, **kwargs) + graph = args[2] + assert isinstance(graph, GraphLowering) + graph_lowerings.append(graph) + + with mock.patch.object(CompiledFxGraph, "__init__", fake_init): + result = fn(*args, **kwargs) + + return result, graph_lowerings + + @contextlib.contextmanager def override_lowering(aten_op, override_fn): """ From 846969310c4778f0c117ba650535c8c9f82b6ab4 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 4 Sep 2024 02:51:26 +0000 Subject: [PATCH 0410/1018] [inductor] Allow cudagraphs with unused CPU inputs (#134749) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This pattern was preventing cudagraphs from kicking in on torch_multimodal_clip, resulting in `1.6529 → 3.3471` speedup. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134749 Approved by: https://github.com/shunting314 ghstack dependencies: #134748 --- test/inductor/test_cuda_repro.py | 14 ++++++++++++++ torch/_inductor/graph.py | 5 +++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 3112b3e8965154..29653ffcda3976 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1243,6 +1243,20 @@ def test_non_contiguous_unaligned_input_indices(self): idxs = remove_unaligned_input_idxs(inputs, [0, 2]) self.assertEqual(idxs, [0]) + @config.patch("triton.cudagraphs", True) + def test_unused_cpu_input_cudagraphs(self): + def fn(x, y): + return x.sin().sin().sin().sin().cos() + 1 + + fx_graph = torch.fx.symbolic_trace(fn) + inp = [torch.randn(64, device="cuda"), torch.randn(64, device="cpu")] + compiled_fn, (graph,) = run_and_get_graph_lowering( + torch._inductor.compile, fx_graph, inp + ) + self.assertEqual(graph.disable_cudagraphs_reason, None) + self.assertEqual(graph.device_types, {"cuda"}) + self.assertEqual(compiled_fn(*inp), fn(*inp)) + def test_epilogue_fusion_with_view(self): class ToyModel(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index c9da83e7584e62..3448ea6eb9620a 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -792,8 +792,8 @@ def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: name = self.qualify_name(f"buf{len(self.buffers)}") self.buffers.append(buffer) self.name_to_buffer[name] = buffer - # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 if ( + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements()) and buffer.get_device() is not None ): @@ -958,7 +958,8 @@ def placeholder( ) self.graph_inputs[target] = tensor self.graph_inputs_original[target] = tensor.data.data - self.add_device_info(example.device) + if self.current_node.users: # cudagraphs should work with an unused CPU input + self.add_device_info(example.device) # Note: [Input Alignment handling in Inductor] # Alignment matters for generating efficient code. Some operations, From 59962a595bae981b33736c74a1160f5acc42118a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 4 Sep 2024 02:51:26 +0000 Subject: [PATCH 0411/1018] [inductor] Refactor simplify erase_nodes() (#134822) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134822 Approved by: https://github.com/shunting314 ghstack dependencies: #134748, #134749 --- torch/_inductor/fx_passes/joint_graph.py | 7 +++---- torch/_inductor/fx_passes/mkldnn_fusion.py | 2 +- torch/_inductor/pattern_matcher.py | 10 ++++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 86c42d696190d2..d526f1bb6a64de 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -526,7 +526,7 @@ def fix_iota_device(match: Match, length, start, step, dtype, device, requires_g repl.meta.update(node.meta) repl.meta["val"] = repl.meta["val"].to(user_device) node.replace_all_uses_with(repl) - match.erase_nodes(match.graph) + match.erase_nodes() @register_graph_pattern( @@ -552,7 +552,7 @@ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtyp ) repl.meta.update(node.meta) node.replace_all_uses_with(repl) - match.erase_nodes(graph) + match.erase_nodes() @register_graph_pattern( @@ -561,12 +561,11 @@ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtyp ) def pointless_view(match: Match, arg, size): """Remove no-op view""" - graph = match.graph node = match.output_node() arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] if size == arg_size: node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] - match.erase_nodes(graph) + match.erase_nodes() # When softmax is used with temperature or other scaling, we get the pattern diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 2ed93c2cd95d48..156760a68e7e69 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -835,7 +835,7 @@ def linear_bias_pattern(match, *args): ) repl.meta.update(add_node.meta) add_node.replace_all_uses_with(repl) - match.erase_nodes(graph) + match.erase_nodes() def _is_packable_mkldnn_rnn_layer(match): lstm_node = match.output_node() diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index c31d7aa9b668ff..e43d37fd37b1a7 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -200,7 +200,8 @@ def bundle(self) -> Match: def __repr__(self) -> str: return f"Match(..., {self.args}, {self.kwargs})" - def erase_nodes(self, graph: torch.fx.Graph) -> None: + def erase_nodes(self) -> None: + graph = self.graph for n in reversed(self.nodes): if not n._erased and not n.users: graph.erase_node(n) @@ -1009,7 +1010,7 @@ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> Non replacement.meta.update(node.meta) node.replace_all_uses_with(replacement) assert match.nodes[-1] is node - match.erase_nodes(graph) + match.erase_nodes() @dataclasses.dataclass @@ -1036,9 +1037,6 @@ def replace_with_graph( replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], args: Sequence[torch.fx.Node], ) -> None: - output_nodes = match.output_nodes() - first_node = output_nodes[0] - class Replacer(torch.fx.Interpreter): call_method = None # type: ignore[assignment] call_module = None # type: ignore[assignment] @@ -1180,7 +1178,7 @@ def replace( assert len(output_nodes) == 1 replace(output_nodes[0], replacement) - match.erase_nodes(graph) + match.erase_nodes() def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: assert match.replacement_graph is not None From 95b6de927a4d9c2702013058cdb5226e1afff6f6 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Wed, 4 Sep 2024 17:47:14 +0000 Subject: [PATCH 0412/1018] Add ExecuTorch warning to mobile_optimizer (#134697) Preview: https://docs-preview.pytorch.org/pytorch/pytorch/134697/mobile_optimizer.html Pull Request resolved: https://github.com/pytorch/pytorch/pull/134697 Approved by: https://github.com/ali-khosh, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- docs/source/mobile_optimizer.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/mobile_optimizer.rst b/docs/source/mobile_optimizer.rst index 2624a5cb0bed9a..88eb87fb92f4d3 100644 --- a/docs/source/mobile_optimizer.rst +++ b/docs/source/mobile_optimizer.rst @@ -2,7 +2,11 @@ torch.utils.mobile_optimizer =================================== .. warning:: - This API is in beta and may change in the near future. + PyTorch Mobile is no longer actively supported. Please check out + `ExecuTorch `__, PyTorch's + all-new on-device inference library. You can also review + documentation on `XNNPACK `__ + and `Vulkan `__ delegates. Torch mobile supports ``torch.utils.mobile_optimizer.optimize_for_mobile`` utility to run a list of optimization pass with modules in eval mode. The method takes the following parameters: a torch.jit.ScriptModule object, a blocklisting optimization set, a preserved method list, and a backend. From 2a6fd78a368d01d4835e23875a35b9e85285dfa4 Mon Sep 17 00:00:00 2001 From: rzou Date: Fri, 30 Aug 2024 12:10:15 -0700 Subject: [PATCH 0413/1018] Fix tensor.data access under inference_mode and compile (#134878) Fixes https://github.com/pytorch/pytorch/issues/134798 In the regular Tensor case, when you call Tensor.data, there's a check for if inference mode is active. If it is active, then we don't set the version counter. We replicate this check for Tensor Subclasses (the bug was we were trying to set the version counter on a FakeTensor in inference_mode). Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/134878 Approved by: https://github.com/bdhirsh --- c10/core/TensorImpl.cpp | 4 +++- test/dynamo/test_misc.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 130292aaa70d6a..40bf133d2587ee 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -509,7 +509,9 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach_core( r = (pyobj_slot_.load_pyobj_interpreter())->detach(this); } if (r) { - r->set_version_counter(std::forward(version_counter)); + if (!r->is_inference()) { + r->set_version_counter(std::forward(version_counter)); + } r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); return r; } diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 0057383430ff14..aaf2a846860a47 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -2000,6 +2000,17 @@ def fn(cfg, x, y): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) + def test_data_access_in_inference_mode(self): + @torch.compile(fullgraph=True) + def f(x): + y = x.data + return y + + with torch.inference_mode(): + x = torch.randn(3) + y = f(x) + self.assertEqual(y, x) + def test_dataclass_fields(self): @dataclasses.dataclass class MyDataClass: From 9648118b5f8736b221871a2a2cf85f3c33b5b11f Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Tue, 3 Sep 2024 16:03:46 -0700 Subject: [PATCH 0414/1018] [inductor] fix compile time regression due the (disabled) loop ordering after fusion (#135071) It's a bit surprised that the code added in Scheduler.fusable_read_and_write would increase compilation time. Here are some number I get from a H100 on BertForMaskedLM: - without the fix, cold start compilation time is around 82s - with the fix, cold start compilation time is around 76s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135071 Approved by: https://github.com/jansel --- torch/_inductor/scheduler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 360e4809371443..7f9f363fa1da04 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3076,8 +3076,10 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: if read_name in self.mutation_renames: read_name = self.mutation_renames[read_name] - if read.num_vars != write.num_vars: - # merge loops + if config.loop_ordering_after_fusion and read.num_vars != write.num_vars: + # Need merge loops if we do loop ordering after fusion since + # we have not merged the loops yet when creating the scheduler + # nodes. read = read.normalize() write = write.normalize() From 0eff47bdcdfccce049c2d6a165017707fce21637 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 4 Sep 2024 18:41:21 +0000 Subject: [PATCH 0415/1018] Revert "restore CSE'd node metadata in runtime asserts pass (#134516)" This reverts commit 1dfb1052395d908ed6e67288c9357e16022da272. Reverted https://github.com/pytorch/pytorch/pull/134516 on behalf of https://github.com/pianpwk due to breaking NestedTensor test ([comment](https://github.com/pytorch/pytorch/pull/134516#issuecomment-2329738450)) --- test/distributed/test_dynamo_distributed.py | 1 + torch/fx/passes/runtime_assert.py | 102 +++++--------------- 2 files changed, 24 insertions(+), 79 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 3ed2bf258e1cb4..134370e41703c3 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -423,6 +423,7 @@ def forward(self, x, y): opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512), torch.tensor([12, 13])) + @unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534" @config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True) def test_unbacked_symbol_splitting_no_binding(self): class Model(nn.Module): diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index ba77d6ad809e22..e2a7c1bee3db1d 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import functools import logging import operator import sys @@ -44,18 +43,6 @@ def _get_example_value(node: fx.Node) -> Optional[str]: return None -def _get_example_value_key(node: fx.Node) -> Optional[str]: - """ - actually just run this once at start of pass, based on first node, and constantly use that. - """ - if "example_value" in node.meta: - return "example_value" - elif "val" in node.meta: - return "val" - else: - return None - - def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]: val = _get_example_value(node) if isinstance(val, py_sym_types): @@ -105,7 +92,6 @@ def insert_deferred_runtime_asserts( # Import sympy locally import sympy - from torch._export.passes._node_metadata_hook import _set_node_metadata_hook from torch.fx.experimental.symbolic_shapes import ( _has_uninterpretable_sympy_function, CallMethodKey, @@ -159,30 +145,6 @@ def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: ) ) - # Figure out what key to use, val or example_value - val_key = "val" - for node in graph.nodes: - if "example_value" in node.meta: - val_key = "example_value" - break - elif "val" in node.meta: - break - - def _node_metadata_hook( - node: torch.fx.Node, - stack_trace: Optional[str] = None, - nn_module_stack: Optional[Dict[str, Any]] = None, - ) -> None: - fake_args = [ - _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg - for arg in node.args - ] - node.meta[val_key] = node.target(*fake_args) # type: ignore[operator] - if stack_trace is not None: - node.meta["stack_trace"] = stack_trace - if nn_module_stack is not None: - node.meta["nn_module_stack"] = nn_module_stack - # Track asserts/checks we've added added_asserts: Set[sympy.Expr] = set() constrained_unbacked_symbols: Set[sympy.Symbol] = set() @@ -285,8 +247,7 @@ def match_symbol(symint, cb): and isinstance(s := symint.node.expr, sympy.Symbol) and s not in expr_to_proxy ): - with _set_node_metadata_hook(gm, _node_metadata_hook): - expr_to_proxy[s] = fx.Proxy(cb()) + expr_to_proxy[s] = fx.Proxy(cb()) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) match_symbol(example_value, lambda: node) @@ -370,15 +331,7 @@ def match_symbol(symint, cb): if _is_intermediate_tensor_sym_call( node ): # reify from input shapes - with _set_node_metadata_hook( - gm, - functools.partial( - _node_metadata_hook, - stack_trace=node.meta.get("stack_trace"), - nn_module_stack=node.meta.get("nn_module_stack"), - ), - ): - expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] + expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] # won't try DCE-ing tensor compute here hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] node.replace_all_uses_with(hash_node) @@ -483,8 +436,7 @@ def go(node, keypath): raise AssertionError(f"unrecognized keypath {keypath}") if s not in expr_to_proxy: - with _set_node_metadata_hook(gm, _node_metadata_hook): - expr_to_proxy[s] = fx.Proxy(go(node, keypath)) + expr_to_proxy[s] = fx.Proxy(go(node, keypath)) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) for i0 in defs: @@ -567,34 +519,26 @@ def convert(s): # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts # raises AOTAutograd errors on cast_symbool_to_symint_guardless - with _set_node_metadata_hook( - gm, - functools.partial( - _node_metadata_hook, - stack_trace=node.meta.get("stack_trace"), - nn_module_stack=node.meta.get("nn_module_stack"), - ), - ): - if (min_val := convert(vr.lower)) is not None: - ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - ge, - f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", - ), - ) - added_asserts.add(i0 >= min_val) - if (max_val := convert(vr.upper)) is not None: - le = _sympy_interp(expr_to_proxy, i0 <= max_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - le, - f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", - ), - ) - added_asserts.add(i0 <= max_val) + if (min_val := convert(vr.lower)) is not None: + ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + ge, + f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", + ), + ) + added_asserts.add(i0 >= min_val) + if (max_val := convert(vr.upper)) is not None: + le = _sympy_interp(expr_to_proxy, i0 <= max_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + le, + f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", + ), + ) + added_asserts.add(i0 <= max_val) constrained_unbacked_symbols.add(i0) add_runtime_asserts(ras) From a4e298e33d3bda7ffb71d2c4dfe9eb45d8d74137 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Wed, 4 Sep 2024 18:41:48 +0000 Subject: [PATCH 0416/1018] [Profiler] Fix Raw Metadata Iterator (#135096) Summary: D62008788 added an extra parameter to the RawTensorMetadata struct. For some reason this causes some corrupted accesses in other tests as described in T200685032. Once this is removed the tests pass. Going forward we need to document how to add parameters to this portion of the code as the AppendOnlyLists seem to be very rigid. Test Plan: Ran all the tests locally and they all passed. Differential Revision: D62171089 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135096 Approved by: https://github.com/aaronenyeshi --- torch/csrc/profiler/collection.cpp | 7 +++---- torch/csrc/profiler/collection.h | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index a4f53a5e10f99f..67cfaab3b81957 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -33,14 +33,13 @@ RawTensorMetadataBase::RawTensorMetadataBase(const at::Tensor& t) : data_{t.has_storage() ? t.storage().data() : nullptr}, dtype_{t.scalar_type()}, layout_{t.layout()}, - size_dim_{static_cast(t.sizes().size())}, - stride_dim_{static_cast(t.strides().size())} { + size_dim_{static_cast(t.sizes().size())} { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( t.sizes().size() <= std::numeric_limits::max(), "Cannot profile Tensors of size > uint32 max. Got dim: ", t.sizes().size()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - t.sizes().size() != t.strides().size(), + t.sizes().size() == t.strides().size(), "Tensor has mismatching sizes and strides. Sizes: ", t.sizes().size(), " Strides: ", @@ -205,7 +204,7 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { sizes.push_back(*tensor_size_strides_it++); } if (raw_metadata.layout_ == at::kStrided) { - for (C10_UNUSED const auto _ : c10::irange(raw_metadata.stride_dim_)) { + for (C10_UNUSED const auto _ : c10::irange(raw_metadata.size_dim_)) { if (tensor_size_strides_it.exhausted()) { LOG(WARNING) << "Expected Tensor Strides mismatch with raw Tensor metadata. Reported shapes may be inaccurate!"; diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index e2bb603c387cc5..abaa9a845082b6 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -48,7 +48,6 @@ struct TORCH_API RawTensorMetadataBase { c10::ScalarType dtype_{c10::ScalarType::Undefined}; c10::Layout layout_{c10::Layout::Strided}; uint32_t size_dim_{0}; - uint32_t stride_dim_{0}; }; // Collected during profiling. From 1ee68642b578dc56fafbe4a5a5f329f5084f2333 Mon Sep 17 00:00:00 2001 From: Saurabh Mishra Date: Wed, 4 Sep 2024 18:49:34 +0000 Subject: [PATCH 0417/1018] [AIInfra][DCP] All gather keys checkpoint utils bug fix (#135045) Summary: All gather keys checkpoint utils bug fix. Dist. get_world_size should have the process group passed in to avoid inconsistent world size in case the process group has changed. This is common in the tests. Test Plan: UTs Reviewed By: Saiteja64 Differential Revision: D61578832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135045 Approved by: https://github.com/MeetVadakkanchery, https://github.com/LucasLLC --- torch/distributed/checkpoint/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 5bc171a00dd2d9..cb90dd11912159 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -44,7 +44,7 @@ def _all_gather_keys( ) -> List[Any]: """Gathers all keys, and returns them sorted.""" keys = list(local_dict.keys()) - gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item] + gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item] dist.all_gather_object(gathered_keys, keys, group=group) return sorted(set(itertools.chain.from_iterable(gathered_keys))) From 10c1988c3c261c654457dba096dac8474afd3637 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 4 Sep 2024 19:44:29 +0000 Subject: [PATCH 0418/1018] Revert "Add support for 32KB multi_tensor_apply kernel arguments (#134373)" This reverts commit 08184aa85cf183198ebdf2fd7a49fe7bc4842c13. Reverted https://github.com/pytorch/pytorch/pull/134373 on behalf of https://github.com/drisspg due to See https://github.com/pytorch/pytorch/issues/135126 for more details ([comment](https://github.com/pytorch/pytorch/pull/134373#issuecomment-2329839011)) --- aten/src/ATen/native/cuda/AmpKernels.cu | 33 ++- .../ATen/native/cuda/ForeachBinaryOpList.cu | 88 +++---- .../ATen/native/cuda/ForeachBinaryOpScalar.cu | 42 ++-- .../native/cuda/ForeachBinaryOpScalarList.cu | 44 ++-- .../cuda/ForeachBinaryOpScalarTensor.cu | 46 ++-- aten/src/ATen/native/cuda/ForeachFunctors.cuh | 117 +++------ .../ATen/native/cuda/ForeachPointwiseOp.cu | 84 +++---- aten/src/ATen/native/cuda/ForeachReduceOp.cu | 81 ++---- aten/src/ATen/native/cuda/ForeachTernaryOp.cu | 80 +++--- aten/src/ATen/native/cuda/ForeachUnaryOp.cu | 55 ++-- aten/src/ATen/native/cuda/FusedSgdKernel.cu | 117 ++++----- .../src/ATen/native/cuda/MultiTensorApply.cpp | 25 -- .../src/ATen/native/cuda/MultiTensorApply.cuh | 235 +++++------------- .../native/cuda/fused_adam_amsgrad_impl.cu | 66 ++--- aten/src/ATen/native/cuda/fused_adam_impl.cu | 66 ++--- .../src/ATen/native/cuda/fused_adam_utils.cuh | 10 +- .../native/cuda/fused_adamw_amsgrad_impl.cu | 66 ++--- aten/src/ATen/native/cuda/fused_adamw_impl.cu | 66 ++--- build_variables.bzl | 1 - 19 files changed, 470 insertions(+), 852 deletions(-) delete mode 100644 aten/src/ATen/native/cuda/MultiTensorApply.cpp diff --git a/aten/src/ATen/native/cuda/AmpKernels.cu b/aten/src/ATen/native/cuda/AmpKernels.cu index 07cb84f95924ec..8c161ca627229d 100644 --- a/aten/src/ATen/native/cuda/AmpKernels.cu +++ b/aten/src/ATen/native/cuda/AmpKernels.cu @@ -157,24 +157,21 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads, using opmath_t = at::opmath_type; // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly. - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>(tensor_lists, - UnaryOpFunctor(), - [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { - // There is a slight asymmetry here with the TensorIterator kernel above. - // MTA Functors ensure val comes in as opmath_t rather than scalar_t. - if (!isfinite_ensure_cuda_math(val)) { - *found_inf_ptr = 1.f; - } - // Every thread accesses inv_scale, but it will hit in cache. - const auto inv_scale_val = *inv_scale_ptr; - return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); - }); - }); + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { + // There is a slight asymmetry here with the TensorIterator kernel above. + // MTA Functors ensure val comes in as opmath_t rather than scalar_t. + if (!isfinite_ensure_cuda_math(val)) { + *found_inf_ptr = 1.f; + } + // Every thread accesses inv_scale, but it will hit in cache. + const auto inv_scale_val = *inv_scale_ptr; + return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); + }); }); } diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu index 90c180cba3119d..533aa38c04cf5e 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu @@ -41,18 +41,15 @@ std::vector foreach_tensor_list_op( tensor_lists.emplace_back(std::move(vec_res)); using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<3>( - tensor_lists, - BinaryOpListAlphaFunctor< - T, - /* depth */ 3, - /* r_args_depth */ 2, - /* res_arg_index */ 2, - large_kernel_arg>(), - Op(), - alpha.to()); - }); + multi_tensor_apply<3>( + tensor_lists, + BinaryOpListAlphaFunctor< + T, + /* depth */ 3, + /* r_args_depth */ 2, + /* res_arg_index */ 2>(), + Op(), + alpha.to()); return tensor_lists[2]; } @@ -67,18 +64,15 @@ void foreach_tensor_list_op_( tensor_lists.emplace_back(tensors2.vec()); using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - BinaryOpListAlphaFunctor< - T, - /* depth */ 2, - /* r_args_depth */ 2, - /* res_arg_index */ 0, - large_kernel_arg>(), - Op(), - alpha.to()); - }); + multi_tensor_apply<2>( + tensor_lists, + BinaryOpListAlphaFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 2, + /* res_arg_index */ 0>(), + Op(), + alpha.to()); increment_version(tensors1); } @@ -337,15 +331,13 @@ template < typename src_t, int depth, int r_args_depth, - int res_arg_index, - bool large_kernel_arg> + int res_arg_index> struct CopyFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1); template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -428,17 +420,14 @@ void foreach_tensor_copy_list_kernel_cuda_( using opmath_t = at::opmath_type; AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] { if constexpr (std::is_same_v) { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - UnaryOpFunctor< - scalar_t, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1, - large_kernel_arg>(), - Copy()); - }); + multi_tensor_apply<2>( + tensor_lists, + UnaryOpFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Copy()); } else { // Ref: // https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301 @@ -446,18 +435,15 @@ void foreach_tensor_copy_list_kernel_cuda_( TORCH_WARN_ONCE( "Casting complex values to real discards the imaginary part"); } - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - CopyFunctor< - scalar_t, - src_t, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1, - large_kernel_arg>(), - Copy()); - }); + multi_tensor_apply<2>( + tensor_lists, + CopyFunctor< + scalar_t, + src_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Copy()); } }); }); diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu index 1b0045daec09c2..80d748dd3579b5 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu @@ -36,18 +36,15 @@ std::vector foreach_binary_op( tensor_lists.emplace_back(std::move(vec_res)); using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - BinaryOpScalarFunctor< - T, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1, - large_kernel_arg>(), - Op(), - scalar.to()); - }); + multi_tensor_apply<2>( + tensor_lists, + BinaryOpScalarFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Op(), + scalar.to()); return tensor_lists[1]; } @@ -57,18 +54,15 @@ void foreach_binary_op_(TensorList tensors, const Scalar& scalar) { tensor_lists.emplace_back(tensors.vec()); using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>( - tensor_lists, - BinaryOpScalarFunctor< - T, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0, - large_kernel_arg>(), - Op(), - scalar.to()); - }); + multi_tensor_apply<1>( + tensor_lists, + BinaryOpScalarFunctor< + T, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0>(), + Op(), + scalar.to()); increment_version(tensors); } diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu index c423006719294d..dcb93188b5e690 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu @@ -36,19 +36,16 @@ std::vector foreach_binary_op( tensor_lists.emplace_back(vec_res); using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2, opmath_t>( - tensor_lists, - scalars, - BinaryOpScalarListFunctor< - T, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1, - large_kernel_arg>(), - - Op()); - }); + multi_tensor_apply<2, opmath_t>( + tensor_lists, + scalars, + BinaryOpScalarListFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + + Op()); return tensor_lists[1]; } @@ -58,18 +55,15 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef scalars) { tensor_lists.emplace_back(tensors.vec()); using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1, opmath_t>( - tensor_lists, - scalars, - BinaryOpScalarListFunctor< - T, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0, - large_kernel_arg>(), - Op()); - }); + multi_tensor_apply<1, opmath_t>( + tensor_lists, + scalars, + BinaryOpScalarListFunctor< + T, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0>(), + Op()); increment_version(tensors); } diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu index 4163b30b2889eb..ad5eeee5ebec41 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarTensor.cu @@ -46,19 +46,16 @@ std::vector foreach_binary_op( tensor_lists.emplace_back(std::move(vec_res)); using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - BinaryOpScalarTensorFunctor< - T, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1, - large_kernel_arg>(), - Op(), - scalar.data_ptr(), - alpha.to()); - }); + multi_tensor_apply<2>( + tensor_lists, + BinaryOpScalarTensorFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Op(), + scalar.data_ptr(), + alpha.to()); return tensor_lists[1]; } @@ -84,19 +81,16 @@ void foreach_binary_op_( tensor_lists.emplace_back(tensors.vec()); using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>( - tensor_lists, - BinaryOpScalarTensorFunctor< - T, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0, - large_kernel_arg>(), - Op(), - scalar.data_ptr(), - alpha.to()); - }); + multi_tensor_apply<1>( + tensor_lists, + BinaryOpScalarTensorFunctor< + T, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0>(), + Op(), + scalar.data_ptr(), + alpha.to()); increment_version(tensors); } diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index fb8d8492b039e0..55e4fd7a598907 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -18,10 +18,10 @@ inline void increment_version(TensorList tensors) { } // Initializes args and checks if all args are aligned -template +template __device__ bool init_args( T** args, - TensorListMetadata& tl, + TensorListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { @@ -38,10 +38,10 @@ __device__ bool init_args( } // Initializes args and checks if all args are aligned -template +template __device__ bool init_args( T** args, - TensorListScalarListMetadata& tl, + TensorListScalarListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { @@ -57,10 +57,10 @@ __device__ bool init_args( return all_aligned; } -template +template __device__ bool init_args( T** args, - FusedOptimizerTensorListMetadata& tl, + FusedOptimizerTensorListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { @@ -203,19 +203,13 @@ __device__ __forceinline__ void pointwise_op_scalar( // // Binary Functors // -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct BinaryOpScalarFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, opmath_t scalar) { const int tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -233,19 +227,13 @@ struct BinaryOpScalarFunctor { } }; -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct BinaryOpScalarListFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListScalarListMetadata& tl, + TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -263,19 +251,13 @@ struct BinaryOpScalarListFunctor { } }; -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct BinaryOpListAlphaFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, opmath_t alpha) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -321,19 +303,13 @@ struct BinaryOpListAlphaFunctor { } }; -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct BinaryOpScalarTensorFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, T* scalar, opmath_t alpha) { @@ -385,17 +361,11 @@ struct BinaryOpScalarTensorFunctor { // Unary Functors // -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct ZeroFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata<1, large_kernel_arg>& tl) { + TensorListMetadata<1>& tl) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; @@ -431,19 +401,13 @@ struct ZeroFunctor { } }; -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct UnaryOpFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -489,19 +453,13 @@ struct UnaryOpFunctor { // Pointwise Functors // -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct PointwiseOpScalarFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, opmath_t scalar) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -519,19 +477,13 @@ struct PointwiseOpScalarFunctor { } }; -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct PointwiseOpScalarListFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListScalarListMetadata& tl, + TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -549,14 +501,13 @@ struct PointwiseOpScalarListFunctor { } }; -template +template struct PointwiseOpListFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -601,19 +552,13 @@ struct PointwiseOpListFunctor { } }; -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct TernaryOpListFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op) { static_assert(depth == 3 || depth == 4, ""); static_assert(depth >= r_args_depth, ""); @@ -661,19 +606,13 @@ struct TernaryOpListFunctor { } }; -template < - typename T, - int depth, - int r_args_depth, - int res_arg_index, - bool large_kernel_arg> +template struct TernaryOpScalarFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, Op op, opmath_t alpha) { static_assert(depth == 2 || depth == 3, ""); diff --git a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu index bad152c322878b..7a3276c44750a6 100644 --- a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu +++ b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu @@ -46,18 +46,15 @@ std::vector foreach_pointwise_op( "foreach_pointwise_op_cuda", [&]() { using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<4>( - tensor_lists, - PointwiseOpScalarFunctor< - scalar_t, - /* depth */ 4, - /* r_args_depth */ 3, - /* res_arg_index */ 3, - large_kernel_arg>(), - Op(), - scalar.to()); - }); + multi_tensor_apply<4>( + tensor_lists, + PointwiseOpScalarFunctor< + scalar_t, + /* depth */ 4, + /* r_args_depth */ 3, + /* res_arg_index */ 3>(), + Op(), + scalar.to()); }); return tensor_lists[3]; @@ -81,18 +78,15 @@ void foreach_pointwise_op_( "foreach_pointwise_op__cuda", [&]() { using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<3>( - tensor_lists, - PointwiseOpScalarFunctor< - scalar_t, - /* depth */ 3, - /* r_args_depth */ 3, - /* res_arg_index */ 0, - large_kernel_arg>(), - Op(), - scalar.to()); - }); + multi_tensor_apply<3>( + tensor_lists, + PointwiseOpScalarFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 3, + /* res_arg_index */ 0>(), + Op(), + scalar.to()); }); increment_version(input); } @@ -116,18 +110,15 @@ void foreach_pointwise_op_( "foreach_pointwise_op__cuda", [&]() { using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<3, opmath_t>( - tensor_lists, - scalars, - PointwiseOpScalarListFunctor< - scalar_t, - /* depth */ 3, - /* r_args_depth */ 3, - /* res_arg_index */ 0, - large_kernel_arg>(), - Op()); - }); + multi_tensor_apply<3, opmath_t>( + tensor_lists, + scalars, + PointwiseOpScalarListFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 3, + /* res_arg_index */ 0>(), + Op()); }); increment_version(input); } @@ -158,18 +149,15 @@ std::vector foreach_pointwise_op( "foreach_pointwise_op_cuda", [&]() { using opmath_t = at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<4, opmath_t>( - tensor_lists, - scalars, - PointwiseOpScalarListFunctor< - scalar_t, - /* depth */ 4, - /* r_args_depth */ 3, - /* res_arg_index */ 3, - large_kernel_arg>(), - Op()); - }); + multi_tensor_apply<4, opmath_t>( + tensor_lists, + scalars, + PointwiseOpScalarListFunctor< + scalar_t, + /* depth */ 4, + /* r_args_depth */ 3, + /* res_arg_index */ 3>(), + Op()); }); return tensor_lists[3]; diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 1ce18b71aed3e4..61793fd5f9e08c 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -50,13 +50,11 @@ template < typename T, int depth = 1, int r_args_depth = 1, - int res_arg_index = 0, - bool large_kernel_arg = false> + int res_arg_index = 0> struct LpMaxFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, T* output_per_tensor_ptr, const int max_chunks_per_tensor) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -180,13 +178,11 @@ std::vector foreach_tensor_max_cuda(TensorList tensors) { tensor_lists[0][0].scalar_type(), "foreach_tensor_max_cuda_scalar_type", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>( - tensor_lists, - LpMaxFunctor(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); - }); + multi_tensor_apply<1>( + tensor_lists, + LpMaxFunctor(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); C10_CUDA_KERNEL_LAUNCH_CHECK(); const at::cuda::OptionalCUDAGuard device_guard( @@ -243,14 +239,12 @@ template < typename out_t, int depth = 1, int r_args_depth = 1, - int res_arg_index = 0, - bool large_kernel_arg = false> + int res_arg_index = 0> struct LpNormFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; using out_opmath_t = typename at::opmath_type; __device__ __forceinline__ void operator()( int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, out_opmath_t* output_per_tensor_ptr, const int max_chunks_per_tensor) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -482,50 +476,23 @@ std::vector foreach_tensor_norm_cuda( output_dtype, "foreach_tensor_norm_cuda_out_dtype", [&]() { using out_opmath_t = typename at::opmath_type; if (p == static_cast(1)) { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>( - tensor_lists, - LpNormFunctor< - scalar_t, - NormType::L1, - out_t, - 1, - 1, - 0, - large_kernel_arg>(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); - }); + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); } else if (p == static_cast(2)) { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>( - tensor_lists, - LpNormFunctor< - scalar_t, - NormType::L2, - out_t, - 1, - 1, - 0, - large_kernel_arg>(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); - }); + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); } else if (p == std::numeric_limits::infinity()) { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>( - tensor_lists, - LpNormFunctor< - scalar_t, - NormType::LInf, - out_t, - 1, - 1, - 0, - large_kernel_arg>(), - output_per_tensor.mutable_data_ptr(), - max_chunks_per_tensor); - }); + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); } C10_CUDA_KERNEL_LAUNCH_CHECK(); const at::cuda::OptionalCUDAGuard device_guard( diff --git a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu index e73862d3be6692..e13f2015f1d819 100644 --- a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu @@ -46,17 +46,14 @@ std::vector foreach_tensor_lerp_ternary_cuda( "foreach_tensor_lerp_ternary_cuda", [&]() { using opmath_t = typename at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<4>( - tensor_lists, - TernaryOpListFunctor< - scalar_t, - /* depth */ 4, - /* r_args_depth */ 3, - /* res_arg_index */ 3, - large_kernel_arg>(), - LerpFunctor()); - }); + multi_tensor_apply<4>( + tensor_lists, + TernaryOpListFunctor< + scalar_t, + /* depth */ 4, + /* r_args_depth */ 3, + /* res_arg_index */ 3>(), + LerpFunctor()); }); return tensor_lists[3]; @@ -80,17 +77,14 @@ void foreach_tensor_lerp_ternary_cuda_( "foreach_tensor_lerp_ternary_cuda_", [&]() { using opmath_t = typename at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<3>( - tensor_lists, - TernaryOpListFunctor< - scalar_t, - /* depth */ 3, - /* r_args_depth */ 3, - /* res_arg_index */ 0, - large_kernel_arg>(), - LerpFunctor()); - }); + multi_tensor_apply<3>( + tensor_lists, + TernaryOpListFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 3, + /* res_arg_index */ 0>(), + LerpFunctor()); }); increment_version(tensors1); } @@ -119,18 +113,15 @@ std::vector foreach_tensor_lerp_list_cuda( "foreach_tensor_lerp_scalar_cuda", [&]() { using opmath_t = typename at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<3>( - tensor_lists, - TernaryOpScalarFunctor< - scalar_t, - /* depth */ 3, - /* r_args_depth */ 2, - /* res_arg_index */ 2, - large_kernel_arg>(), - LerpFunctor(), - weight.to()); - }); + multi_tensor_apply<3>( + tensor_lists, + TernaryOpScalarFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 2, + /* res_arg_index */ 2>(), + LerpFunctor(), + weight.to()); }); return tensor_lists[2]; @@ -154,18 +145,15 @@ void foreach_tensor_lerp_list_cuda_( "foreach_tensor_lerp_scalar_cuda_", [&]() { using opmath_t = typename at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - TernaryOpScalarFunctor< - scalar_t, - /* depth */ 2, - /* r_args_depth */ 2, - /* res_arg_index */ 0, - large_kernel_arg>(), - LerpFunctor(), - weight.to()); - }); + multi_tensor_apply<2>( + tensor_lists, + TernaryOpScalarFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 2, + /* res_arg_index */ 0>(), + LerpFunctor(), + weight.to()); }); } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index fa333b4327a943..1a969cfdbdcc43 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -56,17 +56,14 @@ std::vector foreach_unary_op(TensorList tensors) { tensor_lists.emplace_back(std::move(vec_res)); using opmath_t = typename at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - UnaryOpFunctor< - scalar_t, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1, - large_kernel_arg>(), - Op()); - }); + multi_tensor_apply<2>( + tensor_lists, + UnaryOpFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Op()); return tensor_lists[1]; } @@ -76,17 +73,14 @@ void foreach_unary_op_(TensorList tensors) { std::vector> tensor_lists; tensor_lists.emplace_back(tensors.vec()); using opmath_t = typename at::opmath_type; - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>( - tensor_lists, - UnaryOpFunctor< - scalar_t, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0, - large_kernel_arg>(), - Op()); - }); + multi_tensor_apply<1>( + tensor_lists, + UnaryOpFunctor< + scalar_t, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0>(), + Op()); increment_version(tensors); } @@ -401,16 +395,13 @@ void foreach_tensor_zero_cuda_(TensorList tensors) { tensors[0].scalar_type(), "foreach_zero_cuda_", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<1>( - tensor_lists, - ZeroFunctor< - scalar_t, - /* depth */ 1, - /* r_args_depth */ 1, - /* res_arg_index */ 0, - large_kernel_arg>()); - }); + multi_tensor_apply<1>( + tensor_lists, + ZeroFunctor< + scalar_t, + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0>()); }); } diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu index b8ab3af9bd7d9f..beca6f75563b7b 100644 --- a/aten/src/ATen/native/cuda/FusedSgdKernel.cu +++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu @@ -56,15 +56,14 @@ C10_DEVICE __forceinline__ void sgd_math( } } -template +template struct FusedSgdMathFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; static_assert( depth == 2 || depth == 3, "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0"); C10_DEVICE __forceinline__ void operator()( const int chunk_size, - TensorListMetadata& tl, + TensorListMetadata& tl, const double weight_decay, const double momentum, const float* lr_ptr, @@ -173,21 +172,19 @@ void _fused_sgd_with_momentum_kernel_cuda_( params[0].scalar_type(), "fused_sgd_with_momentum_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<3>( - tensor_lists, - FusedSgdMathFunctor(), - weight_decay, - momentum, - lr_ptr, - lr, - dampening, - nesterov, - maximize, - is_first_step, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply<3>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr_ptr, + lr, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale_ptr, + found_inf_ptr); }); } @@ -249,21 +246,19 @@ void _fused_sgd_with_momentum_kernel_cuda_( params[0].scalar_type(), "fused_sgd_with_momentum_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<3>( - tensor_lists, - FusedSgdMathFunctor(), - weight_decay, - momentum, - lr.data_ptr(), - 1.0, - dampening, - nesterov, - maximize, - is_first_step, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply<3>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr.data_ptr(), + 1.0, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale_ptr, + found_inf_ptr); }); } @@ -317,21 +312,19 @@ void _fused_sgd_kernel_cuda_( params[0].scalar_type(), "fused_sgd_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - FusedSgdMathFunctor(), - weight_decay, - momentum, - lr_ptr, - lr, - dampening, - nesterov, - maximize, - /* is_first_step */ false, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply<2>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr_ptr, + lr, + dampening, + nesterov, + maximize, + /* is_first_step */ false, + grad_scale_ptr, + found_inf_ptr); }); } @@ -411,21 +404,19 @@ void _fused_sgd_kernel_cuda_( params[0].scalar_type(), "fused_sgd_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply<2>( - tensor_lists, - FusedSgdMathFunctor(), - weight_decay, - momentum, - lr.data_ptr(), - 1.0, - dampening, - nesterov, - maximize, - /* is_first_step */ false, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply<2>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr.data_ptr(), + 1.0, + dampening, + nesterov, + maximize, + /* is_first_step */ false, + grad_scale_ptr, + found_inf_ptr); }); } diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cpp b/aten/src/ATen/native/cuda/MultiTensorApply.cpp deleted file mode 100644 index 208adb0d63fcfa..00000000000000 --- a/aten/src/ATen/native/cuda/MultiTensorApply.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#include - -namespace at::native { - -bool supports_large_kernel_arg() { -#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && CUDART_VERSION >= 12010 - static std::optional supports_large_kernel_arg_ = std::nullopt; - if (!supports_large_kernel_arg_.has_value()) { - int driver_ver = 0; - AT_CUDA_CHECK(cudaDriverGetVersion(&driver_ver)); - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - supports_large_kernel_arg_ = (driver_ver >= 12010) && prop->major >= 7; - } - const bool is_capturing = at::cuda::currentStreamCaptureStatusMayInitCtx() != - at::cuda::CaptureStatus::None; - return !is_capturing && *supports_large_kernel_arg_; -#else - return false; -#endif -} - -} // namespace at::native diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh index 9b9c3ee7cc93f0..17f14444abd14a 100644 --- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh +++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh @@ -8,105 +8,20 @@ namespace at::native { -// NOTE: [32KB kernel argument size support] -// 32KB kernel argument size support has three requirements: -// - CUDART_VERSION >= 12010 -// - Driver version >= 530 -// - GPU arch >= VOLTA -// -// Due to minor version compatibility, it possible for binaries built with -// CUDART_VERSION >= 12010 to run with driver version < 530. Since driver -// version can only be checked at runtime, if CUDART_VERSION >= 12010, we have -// to build both 4KB and 32KB kernels and determine the appropriate kernel to -// dispatch at runtime. -// -// - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated. -// -// - If CUDART_VERSION >= 12010: -// - Host code: -// - We always instantiate the launching stub for both 4KB and 32KB kernels. -// - Device code: -// - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB -// kernels. -// - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty -// 32KB kernel (formal parameter space overflowed). Thus, we only -// instantiate a declaration for 32KB kernels. This is valid as long as the -// declaration-only kernel is not launched. -// -// - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and -// GPU arch >= VOLTA. -// -// - TODO(yifu): once there's a CUDART version that is not compatible with any -// driver version below 530, we can determine at compile time to not compile -// the kernels for 4KB kernel argument size. -// -// https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/ -bool supports_large_kernel_arg(); - namespace { static constexpr int64_t kILP = 4; static constexpr int64_t kChunkSize = 65536; static constexpr int64_t kBlockSize = 512; -// MSVC has a problem with constexpr and can't handle passing them to templates -// as arguments. We need to replace it with const static. -// https://github.com/state-spaces/mamba/issues/12#issuecomment-1848835662 -#if !defined(_WIN32) -#define SWITCH_TYPE constexpr bool -#else -#define SWITCH_TYPE const static bool -#endif - -#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && \ - CUDART_VERSION >= 12010 -#define DISPATCH_MULTI_TENSOR_APPLY(...) \ - if (at::native::supports_large_kernel_arg()) { \ - SWITCH_TYPE large_kernel_arg C10_UNUSED = true; \ - __VA_ARGS__(); \ - } else { \ - SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \ - __VA_ARGS__(); \ - } -#else -#define DISPATCH_MULTI_TENSOR_APPLY(...) \ - do { \ - SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \ - __VA_ARGS__(); \ - } while (0); -#endif - -template -struct DepthToMaxConfig; - -template <> -struct DepthToMaxConfig { - static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; - static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; - static constexpr int depth_to_max_tensors_scalarlist[5] = - {96, 64, 48, 36, 30}; - static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { - 72, - 60}; - using TensorIdxType = unsigned char; -}; - -template <> -struct DepthToMaxConfig { - // TODO(yifu): These values are not yet optimally tuned. I simply multiplied - // the values tuned for 4KB kernel argument size limit by 7 (the kernel - // argument size limit increased by 8x but we need to change the type of - // block_to_tensor from unsigned char to uint16_t to support larger number of - // tensors). - static constexpr int depth_to_max_tensors[5] = {770, 448, 336, 252, 210}; - static constexpr int depth_to_max_blocks[5] = {2240, 2240, 2240, 2240, 2240}; - static constexpr int depth_to_max_tensors_scalarlist[5] = - {672, 448, 336, 252, 210}; - static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { - 504, - 420}; - using TensorIdxType = uint16_t; -}; +// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` +// TensorListMetadata has to be < 4KB - the limit for kernel launch argument +static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; +static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; +static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { + 72, + 60}; template __device__ __forceinline__ bool is_aligned(T* p) { @@ -123,101 +38,73 @@ __device__ __forceinline__ void load_store( ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template +template struct TensorListMetadata { - using Conf = DepthToMaxConfig; - const void* addresses[n][Conf::depth_to_max_tensors[n - 1]]; - int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]]; - typename Conf::TensorIdxType - block_to_tensor[Conf::depth_to_max_blocks[n - 1]]; - int block_to_chunk[Conf::depth_to_max_blocks[n - 1]]; + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; int start_tensor_this_launch; }; -template +template struct TensorListScalarListMetadata { - using Conf = DepthToMaxConfig; - const void* addresses[n][Conf::depth_to_max_tensors_scalarlist[n - 1]]; - int64_t numel_for_tensor[Conf::depth_to_max_tensors_scalarlist[n - 1]]; - scalar_vals_t scalar_vals[Conf::depth_to_max_tensors_scalarlist[n - 1]]; - typename Conf::TensorIdxType - block_to_tensor[Conf::depth_to_max_blocks[n - 1]]; - int block_to_chunk[Conf::depth_to_max_blocks[n - 1]]; + const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]]; + scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; }; // note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of // 4kb with `c10::complex` -template -struct TensorListScalarListMetadata, 1, large_kernel_arg> { - using Conf = DepthToMaxConfig; - const void* - addresses[1][Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]]; - int64_t numel_for_tensor - [Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]]; +template <> +struct TensorListScalarListMetadata, 1> { + const void* addresses[1] + [depth_to_max_tensors_scalarlist_of_complex_double[0]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]]; c10::complex - scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]]; - typename Conf::TensorIdxType - block_to_tensor[Conf::depth_to_max_blocks[1 - 1]]; - int block_to_chunk[Conf::depth_to_max_blocks[1 - 1]]; + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]]; + int block_to_chunk[depth_to_max_blocks[1 - 1]]; }; -template -struct TensorListScalarListMetadata, 2, large_kernel_arg> { - using Conf = DepthToMaxConfig; - const void* - addresses[2][Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]]; - int64_t numel_for_tensor - [Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]]; +template <> +struct TensorListScalarListMetadata, 2> { + const void* addresses[2] + [depth_to_max_tensors_scalarlist_of_complex_double[1]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]]; c10::complex - scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]]; - typename Conf::TensorIdxType - block_to_tensor[Conf::depth_to_max_blocks[2 - 1]]; - int block_to_chunk[Conf::depth_to_max_blocks[2 - 1]]; + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]]; + int block_to_chunk[depth_to_max_blocks[2 - 1]]; }; // NOTE(crcrpar): This is a conservative resolution to handle `state_steps` // whose each element is `at::Tensor` of 1 element representing the number of // `step`s called so far. -template +template struct FusedOptimizerTensorListMetadata { - using Conf = DepthToMaxConfig; - const void* addresses[n][Conf::depth_to_max_tensors[n - 1]]; - int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]]; - const void* - state_steps_addresses[Conf::depth_to_max_tensors_scalarlist[n - 1]]; - typename Conf::TensorIdxType - block_to_tensor[Conf::depth_to_max_blocks[n - 1]]; - int block_to_chunk[Conf::depth_to_max_blocks[n - 1]]; + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; int start_tensor_this_launch; }; -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 template C10_LAUNCH_BOUNDS_1(kBlockSize) -__global__ typename std::enable_if::type - multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) { +__global__ void multi_tensor_apply_kernel( + T tensorListMeta, + U callable, + ArgTypes... args) { // Hand the chunk information to the user-supplied functor to process however // it likes. callable(kChunkSize, tensorListMeta, args...); } -#else -// When compiling device code with __CUDA_ARCH__ < 700, we only instantiate a -// declaration for the 32KB kernels. -// For details see: [32KB kernel argument size support] -#pragma nv_diag_suppress 114 // Function was referenced but not defined -template -C10_LAUNCH_BOUNDS_1(kBlockSize) -__global__ typename std::enable_if::type - multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args); -#pragma nv_diag_default 114 // Function was referenced but not defined -#endif - -template -C10_LAUNCH_BOUNDS_1(kBlockSize) -__global__ typename std::enable_if::type - multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) { - callable(kChunkSize, tensorListMeta, args...); -} } // namespace @@ -246,10 +133,7 @@ void multi_tensor_apply( "Number of tensor lists has to match the depth."); const size_t n_tensors = tensor_lists[0].size(); using scalar_vals_t = typename T::opmath_t; - TensorListScalarListMetadata - tensorListMeta; - - using Conf = DepthToMaxConfig; + TensorListScalarListMetadata tensorListMeta; int loc_block_info = 0; int loc_tensor_info = 0; @@ -283,11 +167,10 @@ void multi_tensor_apply( // a tensor is not considered full unless all its chunks have been // processed const bool tensors_full = - (loc_tensor_info == - Conf::depth_to_max_tensors_scalarlist[depth - 1] && + (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] && chunk == chunks - 1); const bool blocks_full = - (loc_block_info == Conf::depth_to_max_blocks[depth - 1]); + (loc_block_info == depth_to_max_blocks[depth - 1]); if (tensors_full || blocks_full) { multi_tensor_apply_kernel<<< @@ -340,11 +223,9 @@ void multi_tensor_apply( tensor_lists.size() == depth, "Number of tensor lists has to match the depth."); const size_t n_tensors = tensor_lists[0].size(); - TensorListMetadata tensorListMeta; + TensorListMetadata tensorListMeta; tensorListMeta.start_tensor_this_launch = 0; - using Conf = DepthToMaxConfig; - int loc_block_info = 0; int loc_tensor_info = 0; for (size_t t = 0; t < n_tensors; t++) { @@ -369,10 +250,10 @@ void multi_tensor_apply( loc_block_info++; const bool tensors_full = - (loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] && + (loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks - 1); const bool blocks_full = - (loc_block_info == Conf::depth_to_max_blocks[depth - 1]); + (loc_block_info == depth_to_max_blocks[depth - 1]); if (tensors_full || blocks_full) { multi_tensor_apply_kernel<<< @@ -423,10 +304,7 @@ void multi_tensor_apply_for_fused_optimizer( tensor_lists.size() == depth, "Number of tensor lists has to match the depth"); const auto num_tensors = tensor_lists[0].size(); - FusedOptimizerTensorListMetadata - tensorListMeta; - - using Conf = DepthToMaxConfig; + FusedOptimizerTensorListMetadata tensorListMeta; int loc_block_info = 0; int loc_tensor_info = 0; @@ -455,10 +333,9 @@ void multi_tensor_apply_for_fused_optimizer( loc_block_info++; const auto tensor_full = - (loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] && + (loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks - 1); - const auto blocks_full = - loc_block_info == Conf::depth_to_max_blocks[depth - 1]; + const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1]; if (tensor_full || blocks_full) { multi_tensor_apply_kernel<<< diff --git a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu index 0c6318d606b9aa..cef07de1b41f99 100644 --- a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu @@ -42,26 +42,19 @@ void _fused_adam_amsgrad_cuda_impl_( params[0].scalar_type(), "fused_adam_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor< - scalar_t, - 5, - ADAM_MODE::ORIGINAL, - true, - large_kernel_arg>(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); }); } @@ -100,26 +93,19 @@ void _fused_adam_amsgrad_cuda_impl_( params[0].scalar_type(), "fused_adam_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor< - scalar_t, - 5, - ADAM_MODE::ORIGINAL, - true, - large_kernel_arg>(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); }); } diff --git a/aten/src/ATen/native/cuda/fused_adam_impl.cu b/aten/src/ATen/native/cuda/fused_adam_impl.cu index 43a4cbfa2f791d..2c1f5ce0d6d572 100644 --- a/aten/src/ATen/native/cuda/fused_adam_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adam_impl.cu @@ -37,26 +37,19 @@ void _fused_adam_cuda_impl_( params[0].scalar_type(), "fused_adam_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor< - scalar_t, - 4, - ADAM_MODE::ORIGINAL, - false, - large_kernel_arg>(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); }); } @@ -90,26 +83,19 @@ void _fused_adam_cuda_impl_( params[0].scalar_type(), "fused_adam_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor< - scalar_t, - 4, - ADAM_MODE::ORIGINAL, - false, - large_kernel_arg>(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); }); } diff --git a/aten/src/ATen/native/cuda/fused_adam_utils.cuh b/aten/src/ATen/native/cuda/fused_adam_utils.cuh index 43627fc32148b3..182195969ed9a3 100644 --- a/aten/src/ATen/native/cuda/fused_adam_utils.cuh +++ b/aten/src/ATen/native/cuda/fused_adam_utils.cuh @@ -102,21 +102,15 @@ C10_DEVICE inline void adam_math( // parameter updates accordingly. To be functionally on par with `torch.optim` // optimizers and `_multi_tensor` ones, the kernel below writes out gradients // only when `grad_scale_ptr != nullptr. -template < - typename scalar_type, - int depth, - ADAM_MODE adam_mode, - bool amsgrad, - bool large_kernel_arg> +template struct FusedAdamMathFunctor { - static constexpr bool use_large_kernel_arg = large_kernel_arg; static_assert( depth == 4 || depth == 5, "depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); using opmath_t = at::opmath_type; C10_DEVICE __forceinline__ void operator()( int chunk_size, - FusedOptimizerTensorListMetadata& tl, + FusedOptimizerTensorListMetadata& tl, const float* lr_ptr, const double& lr, const double& beta1, diff --git a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu index 3b3b7322b769c4..8a22b57a47e8b0 100644 --- a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu @@ -43,26 +43,19 @@ void _fused_adamw_amsgrad_cuda_impl_( params[0].scalar_type(), "fused_adamw_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor< - scalar_t, - 5, - ADAM_MODE::ADAMW, - true, - large_kernel_arg>(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); }); } @@ -101,26 +94,19 @@ void _fused_adamw_amsgrad_cuda_impl_( params[0].scalar_type(), "fused_adamw_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor< - scalar_t, - 5, - ADAM_MODE::ADAMW, - true, - large_kernel_arg>(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); }); } diff --git a/aten/src/ATen/native/cuda/fused_adamw_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_impl.cu index ff657687b5f920..b0f9dc6db6aff6 100644 --- a/aten/src/ATen/native/cuda/fused_adamw_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adamw_impl.cu @@ -38,26 +38,19 @@ void _fused_adamw_cuda_impl_( params[0].scalar_type(), "fused_adamw_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor< - scalar_t, - 4, - ADAM_MODE::ADAMW, - false, - large_kernel_arg>(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); }); } @@ -91,26 +84,19 @@ void _fused_adamw_cuda_impl_( params[0].scalar_type(), "fused_adamw_kernel_cuda", [&]() { - DISPATCH_MULTI_TENSOR_APPLY([&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor< - scalar_t, - 4, - ADAM_MODE::ADAMW, - false, - large_kernel_arg>(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); }); } diff --git a/build_variables.bzl b/build_variables.bzl index ff1e6083218c78..306e71b61f1d0f 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1459,7 +1459,6 @@ aten_cuda_cu_source_list = [ "aten/src/ATen/native/cuda/Equal.cpp", "aten/src/ATen/native/cuda/GridSampler.cpp", "aten/src/ATen/native/cuda/IndexKernel.cpp", - "aten/src/ATen/native/cuda/MultiTensorApply.cpp", "aten/src/ATen/native/cuda/ReduceOps.cpp", "aten/src/ATen/native/cuda/ScanKernels.cpp", "aten/src/ATen/native/cuda/Sort.cpp", From d751d41c943715513712bcf858c1b3e7f0de61ed Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 4 Sep 2024 20:09:41 +0000 Subject: [PATCH 0419/1018] [Flex Attention] update __getitem__ without tree_map_only to support compile (#134627) Adds a helper function for getting the block mask for a specific row index during decoding. We need this change to avoid the pytree + torch.compile issue #134731. Tested in gpt-fast [pr](https://github.com/pytorch-labs/gpt-fast/pull/196). Pull Request resolved: https://github.com/pytorch/pytorch/pull/134627 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 57 ++++++++++++++++++++++++++++ torch/nn/attention/flex_attention.py | 57 +++++++++++++++++++++++++--- 2 files changed, 109 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 2776b8b89bfa72..7e3df0c75ac62a 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1913,6 +1913,63 @@ def causal_mask(b, h, q, kv): self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity()) self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity()) + @supported_platform + def test_getitem(self): + offset = torch.zeros(8, device="cuda") + + def causal_mask(b, h, q, kv): + return (q + (offset[b] * 128)) >= kv + + block_mask = create_block_mask(causal_mask, 4, 2, 512, 512) + assert block_mask.kv_num_blocks.shape == (4, 2, 4) + assert block_mask.kv_indices.shape == (4, 2, 4, 4) + + # Index on batch dimension + new_block_mask = block_mask[0] + assert new_block_mask.kv_num_blocks.shape == (2, 4) + assert new_block_mask.kv_indices.shape == (2, 4, 4) + + # Index on batch and head dimension + new_block_mask = block_mask[0, 1] + assert new_block_mask.kv_num_blocks.shape == (4,) + assert new_block_mask.kv_indices.shape == (4, 4) + + # slicing on batch and head dimension + new_block_mask = block_mask[0:2, 1:2] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 4) + assert new_block_mask.kv_indices.shape == (2, 1, 4, 4) + + # slicing on batch, head, and query dimension + new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 1) + assert new_block_mask.kv_indices.shape == (2, 1, 1, 4) + + # slicing on batch, head, and query dimension + q_index = torch.tensor([0], dtype=torch.int32) + new_block_mask = block_mask[:, :, q_index] + + self.assertEqual(new_block_mask.kv_num_blocks.ndim, 3) + self.assertEqual(new_block_mask.kv_indices.ndim, 4) + torch.testing.assert_close( + new_block_mask.kv_num_blocks, + block_mask.kv_num_blocks[:, :, q_index], + ) + torch.testing.assert_close( + new_block_mask.kv_indices, block_mask.kv_indices[:, :, q_index, :] + ) + + if block_mask.full_kv_num_blocks is not None: + assert new_block_mask.full_kv_num_blocks is not None + assert new_block_mask.full_kv_indices is not None + torch.testing.assert_close( + new_block_mask.full_kv_num_blocks, + block_mask.full_kv_num_blocks[:, :, q_index], + ) + torch.testing.assert_close( + new_block_mask.full_kv_indices, + block_mask.full_kv_indices[:, :, q_index, :], + ) + @supported_platform def test_block_mask_device_change(self): offset = torch.zeros(8, device="cuda") diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 4f99e3d903caca..7b0a3354642a7f 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -391,12 +391,59 @@ def __str__(self): return s def __getitem__(self, index) -> "BlockMask": - mapped_attributes = tree_map_only( - torch.Tensor, - lambda x: x[index], - self.as_tuple(flatten=False), + """ + Returns a new BlockMask instance by getting the mask for the given index position. + + Args: + index: Index to apply to all attributes. + + Example Usage: + .. code-block:: python + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device="cuda") + assert block_mask.kv_num_blocks.shape == (4,2,4) + assert block_mask.kv_indices.shape == (4,2,4,4) + + # Index on batch dimension + new_block_mask = block_mask[0] + assert new_block_mask.kv_num_blocks.shape == (2,4) + assert new_block_mask.kv_indices.shape == (2,4,4) + + # Index on batch and head dimension + new_block_mask = block_mask[0, 1] + assert new_block_mask.kv_num_blocks.shape == (4,) + assert new_block_mask.kv_indices.shape == (4,4) + + # slicing on batch and head dimension + new_block_mask = block_mask[0:2, 1:2] + assert new_block_mask.kv_num_blocks.shape == (2,1,4) + assert new_block_mask.kv_indices.shape == (2,1,4,4) + + # slicing on batch, head, and query dimension + new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)] + assert new_block_mask.kv_num_blocks.shape == (2,1,1) + assert new_block_mask.kv_indices.shape == (2,1,1,4) + """ + new_kv_num_blocks = self.kv_num_blocks[index] + new_kv_indices = self.kv_indices[index] + if self.full_kv_num_blocks is not None: + assert self.full_kv_indices is not None + new_full_kv_num_blocks = self.full_kv_num_blocks[index] + new_full_kv_indices = self.full_kv_indices[index] + else: + new_full_kv_num_blocks = None + new_full_kv_indices = None + return BlockMask.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + BLOCK_SIZE=self.BLOCK_SIZE, + mask_mod=None, ) - return BlockMask(*mapped_attributes) def __repr__(self): def shape_or_none(x: Optional[torch.Tensor]): From a90aea3cd2407778a21e2efd4c71103676a83a21 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 4 Sep 2024 09:59:40 -0700 Subject: [PATCH 0420/1018] [dynamo][backend match] Optimize backend match for common case (#135121) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135121 Approved by: https://github.com/williamwen42 ghstack dependencies: #135039 --- torch/csrc/dynamo/extra_state.cpp | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 9142470315cd9b..abc6951ab554c6 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -83,6 +83,19 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code) { return extra_state; } +bool backend_match(PyObject* saved_backend, PyObject* backend) { + // Pointer equality check for common case + if (saved_backend != backend) { + // The Py_TYPE check should not be required but there is a pre-existing + // issue where backend is possibly deallocated (or nullptr) and causes + // segfaults. Check test - test_inplace_custom_op_intermediate + return ( + Py_TYPE(saved_backend) == Py_TYPE(backend) && + PyObject_RichCompareBool(saved_backend, backend, Py_EQ)); + } + return true; +} + PyObject* lookup( ExtraState* extra_state, PyObject* f_locals, @@ -93,12 +106,8 @@ PyObject* lookup( for (CacheEntry& cache_entry : extra_state->cache_entry_list) { // Check backend. Py_False means run only mode. - // The Py_TYPE check should not be required but there is a pre-existing - // issue where backend is possibly deallocated (or nullptr) and causes - // segfaults. Check test - test_inplace_custom_op_intermediate - bool valid = backend == Py_False || - (Py_TYPE(cache_entry.backend) == Py_TYPE(backend) && - PyObject_RichCompareBool(cache_entry.backend, backend, Py_EQ)); + bool valid = + backend == Py_False || backend_match(cache_entry.backend, backend); if (valid) { try { From 2b18ad68ceb8e9aa6a46173de1e1e79e8b6529bf Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 4 Sep 2024 21:20:01 +0000 Subject: [PATCH 0421/1018] [PT] Add out variant for avg_pool1d and adaptive_avg_pool1d (#135051) Test Plan: CI Differential Revision: D62148410 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135051 Approved by: https://github.com/SS-JIA --- aten/src/ATen/native/native_functions.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1f31c5cd198e6a..683b38e5f20ab4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -537,9 +537,11 @@ - func: avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor tags: core + autogen: avg_pool1d.out - func: adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor tags: core + autogen: adaptive_avg_pool1d.out # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) From 5cbb818c425897a2cf8626590a51976abc43da2a Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 4 Sep 2024 11:29:09 -0700 Subject: [PATCH 0422/1018] Make Dynamo inline through torch._library.custom_ops.autograd (#135066) Fixes https://github.com/pytorch/pytorch/issues/135057 The bug was: in the situation that Dynamo graph breaks in the forward and Compiled Autograd uses Dynamo to introspect the backward, we end up running into a "Unsupported: inlining through SKIPFILES" error. The solution is to mark the entirety of this module as inlineable. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/135066 Approved by: https://github.com/bdhirsh, https://github.com/williamwen42, https://github.com/yanboliang --- test/inductor/test_compiled_autograd.py | 20 ++++++++++++++++++++ torch/_dynamo/trace_rules.py | 1 + 2 files changed, 21 insertions(+) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 1a30d141bbf344..239698d8e29f29 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -172,6 +172,26 @@ def fn(): self.check_output_and_recompiles(fn) + def test_graph_break_custom_op(self): + @torch.library.custom_op("mylib::sin", mutates_args={}) + def sin(x: torch.Tensor) -> torch.Tensor: + return x.sin() + + def setup_context(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + def backward(ctx, grad): + (x,) = ctx.saved_tensors + return grad * x.cos() + + sin.register_autograd(backward, setup_context=setup_context) + + x = torch.randn(3, requires_grad=True) + y = sin(x.clone()).sum() + with compiled_autograd.enable(compiler_fn): + y.backward() + def test_tensor_grad_hook1(self): def fn(): for _ in range(3): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 0740970457de7b..5eee636438fed2 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3229,6 +3229,7 @@ def _module_dir(m: types.ModuleType): "torch._higher_order_ops.strict_mode", "torch._higher_order_ops.while_loop", "torch._inductor.test_operators", + "torch._library.autograd", "torch._library.custom_ops", "torch._prims", "torch._refs", From e9cc3444b8bd5313b4fb8b4dc5980c7216705ee4 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 4 Sep 2024 21:53:43 +0000 Subject: [PATCH 0423/1018] [BE][optim] Make pyright recognize exported symbols (#135043) Follows pattern introduced by https://github.com/pytorch/pytorch/pull/80955 which [pyright](https://github.com/microsoft/pyright) prefers over `__all__` symbol, see https://github.com/microsoft/pylance-release/issues/2953#issuecomment-1168956296 Fixes https://github.com/pytorch/pytorch/issues/134985 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135043 Approved by: https://github.com/janeyx99 --- torch/optim/__init__.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index ee53d1c5fd765c..7354092dda4e02 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -6,22 +6,22 @@ future. """ -from torch.optim import lr_scheduler, swa_utils -from torch.optim._adafactor import Adafactor -from torch.optim.adadelta import Adadelta -from torch.optim.adagrad import Adagrad -from torch.optim.adam import Adam -from torch.optim.adamax import Adamax -from torch.optim.adamw import AdamW -from torch.optim.asgd import ASGD -from torch.optim.lbfgs import LBFGS -from torch.optim.nadam import NAdam -from torch.optim.optimizer import Optimizer -from torch.optim.radam import RAdam -from torch.optim.rmsprop import RMSprop -from torch.optim.rprop import Rprop -from torch.optim.sgd import SGD -from torch.optim.sparse_adam import SparseAdam +from torch.optim import lr_scheduler as lr_scheduler, swa_utils as swa_utils +from torch.optim._adafactor import Adafactor as Adafactor +from torch.optim.adadelta import Adadelta as Adadelta +from torch.optim.adagrad import Adagrad as Adagrad +from torch.optim.adam import Adam as Adam +from torch.optim.adamax import Adamax as Adamax +from torch.optim.adamw import AdamW as AdamW +from torch.optim.asgd import ASGD as ASGD +from torch.optim.lbfgs import LBFGS as LBFGS +from torch.optim.nadam import NAdam as NAdam +from torch.optim.optimizer import Optimizer as Optimizer +from torch.optim.radam import RAdam as RAdam +from torch.optim.rmsprop import RMSprop as RMSprop +from torch.optim.rprop import Rprop as Rprop +from torch.optim.sgd import SGD as SGD +from torch.optim.sparse_adam import SparseAdam as SparseAdam Adafactor.__module__ = "torch.optim" From bbe6d79eaf7195ef3699f922b0bdc7fd3eba0f4e Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:27:50 -0700 Subject: [PATCH 0424/1018] [TorchElastic] add warning when users try to pass a "use_libuv" argument to create_c10d_store (#135062) **Summary** Extend the warning message to be more self-explained Pull Request resolved: https://github.com/pytorch/pytorch/pull/135062 Approved by: https://github.com/shuqiangzhang --- torch/distributed/elastic/utils/distributed.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 4d64fccb698942..2e216da72c2547 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -38,7 +38,11 @@ def create_c10d_store( use_libuv: Optional[bool] = None, ): if use_libuv is not None: - logger.warning("argument use_libuv is deprecated and ignored.") + logger.warning( + "argument use_libuv is deprecated and ignored. Set USE_LIBUV environment " + 'variable to "0" to disable libuv, or "1" to enable it. If the env var ' + "is not set, libuv will be used by default." + ) # check os.environ for use_libuv use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option From 78870bb9b1432800651de1b6f51e4eed28ce654d Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sun, 1 Sep 2024 13:30:12 -0700 Subject: [PATCH 0425/1018] [Dynamo] Support for proxying frozen dataclasses (#134846) Fixes https://github.com/pytorch/pytorch/issues/133858 Details: Previously Dynamo would treat dataclasses as UserDefinedVariables. This was non-desirable if we would like to proxy the value into the graph, which is needed for TensorSubclassMetadata. To rectify this, frozen dataclasses are now able to be proxied similarly to NamedTuples. We require the object to be frozen, because if arbitrary mutation were allowed, we would need to replay those mutations in the graph after construction of the object. For tracing construction of the variable, the generated `__init__` for the dataclass uses `object.__setattr__` because frozen dataclasses throw errors on the usual `__setattr__` invocation. With this treatment, no special handling is needed in dynamo for frozen dataclass construction. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134846 Approved by: https://github.com/bdhirsh, https://github.com/anijain2305 --- test/dynamo/test_misc.py | 108 ++++++++++++++++++++++++ torch/_dynamo/side_effects.py | 5 +- torch/_dynamo/testing.py | 7 ++ torch/_dynamo/utils.py | 17 +++- torch/_dynamo/variables/builder.py | 6 ++ torch/_dynamo/variables/dicts.py | 10 --- torch/_dynamo/variables/user_defined.py | 84 ++++++++++++++++++ 7 files changed, 225 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index aaf2a846860a47..980e962e5d0524 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -44,6 +44,7 @@ CompileCounter, CompileCounterWithBackend, expectedFailureDynamic, + requiresPy310, same, skipIfNotPy311, unsupported, @@ -9765,6 +9766,113 @@ def forward(self, x): } self.assertEqual(expected_fqn, gm.meta["dynamo_flat_name_to_original_fqn"]) + def test_proxy_frozen_dataclass(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + + @allow_in_graph + def inner_fn(dc): + return dc.x + dc.y + + def fn(x, y): + dc = TestDataClass(x, y) + return inner_fn(dc) + + fn_opt = torch.compile(fullgraph=True)(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + self.assertEqual(actual, expected) + + def test_reconstruct_frozen_dataclass(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + + def fn(x, y): + dc = TestDataClass(x, y) + torch._dynamo.graph_break() + return dc.x + dc.y + + fn_opt = torch.compile()(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + def test_frozen_dataclass_default_value(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + z: int = dataclasses.field(default=5) + a: int = 6 + + @allow_in_graph + def inner_fn(dc): + return dc.x + dc.y + dc.z + dc.a + + def fn(x, y): + dc = TestDataClass(x, y) + return inner_fn(dc) + + fn_opt = torch.compile(fullgraph=True)(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + self.assertEqual(actual, expected) + + def test_frozen_dataclass_default_factory(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + z: int = dataclasses.field(default_factory=list) + a: int = dataclasses.field(default_factory=lambda: [5]) + + @allow_in_graph + def inner_fn(dc): + return dc.x + dc.y + dc.a[0] + + def fn(x, y): + dc = TestDataClass(x, y) + return inner_fn(dc) + + fn_opt = torch.compile(fullgraph=True)(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + self.assertEqual(actual, expected) + + @requiresPy310 + def test_frozen_dataclass_kw_only(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + z: int = dataclasses.field(kw_only=True) + a: int = dataclasses.field(kw_only=True) + + @allow_in_graph + def inner_fn(dc): + return dc.x + dc.y + dc.a + dc.z + + def fn(x, y): + dc = TestDataClass(x, y, z=5, a=2) + return inner_fn(dc) + + fn_opt = torch.compile(fullgraph=True)(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + self.assertEqual(actual, expected) + def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 0f8ad871425dae..fc1bd976ff57df 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -17,13 +17,14 @@ from .codegen import PyCodegen from .exc import unimplemented from .source import GlobalSource, LocalSource, Source -from .utils import nn_module_new, object_new +from .utils import is_frozen_dataclass, nn_module_new, object_new from .variables.base import ( is_side_effect_safe, MutableLocalBase, MutableLocalSource, VariableTracker, ) +from .variables.user_defined import FrozenDataClassVariable class MutableSideEffects(MutableLocalBase): @@ -285,6 +286,8 @@ def track_object_new_from_user_defined_class( variable_cls = variables.UnspecializedNNModuleVariable elif issubclass(user_cls, MutableMapping): variable_cls = variables.MutableMappingVariable + elif is_frozen_dataclass(user_cls): + variable_cls = FrozenDataClassVariable else: variable_cls = variables.UserDefinedObjectVariable diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 6c5f47067547e9..4922b521bada34 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -372,6 +372,13 @@ def skipIfPy312(fn): return fn +def requiresPy310(fn): + if sys.version_info >= (3, 10): + return fn + else: + unittest.skip(fn) + + # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py # and test/dynamo/test_dynamic_shapes.py def expectedFailureDynamic(fn): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 2b9d4b96081471..fc062f57026357 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -30,6 +30,7 @@ import warnings import weakref from contextlib import contextmanager +from dataclasses import is_dataclass from functools import lru_cache from types import MethodWrapperType from typing import ( @@ -2330,9 +2331,13 @@ def import_submodule(mod: types.ModuleType): def object_has_getattribute(value: Any): + return class_has_getattribute(type(value)) + + +def class_has_getattribute(cls: type): try: if isinstance( - inspect.getattr_static(type(value), "__getattribute__"), + inspect.getattr_static(cls, "__getattribute__"), types.FunctionType, ): return True @@ -2978,6 +2983,16 @@ def to_fake_tensor(t, fake_mode): ) +# NB: this works for both classes and instances +def is_frozen_dataclass(value): + return ( + not object_has_getattribute(value) + and not class_has_getattribute(value) + and is_dataclass(value) + and value.__dataclass_params__.frozen + ) + + def get_first_attr(obj, *attrs): """ Return the first available attribute or throw an exception if none is present. diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f2e981276fcc86..e1eb82ac778ac7 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -95,6 +95,7 @@ get_fake_value, get_locals_to_steal, get_static_address_type, + is_frozen_dataclass, is_function_or_wrapper, is_lru_cache_wrapped_function, is_namedtuple, @@ -206,6 +207,7 @@ TorchFunctionModeVariable, ) from .user_defined import ( + FrozenDataClassVariable, KeyedJaggedTensorVariable, MutableMappingVariable, SourcelessGraphModuleVariable, @@ -1161,6 +1163,10 @@ def build_key_value(i, k, v): elif issubclass(type(value), MutableMapping): self.install_guards(GuardBuilder.TYPE_MATCH) return MutableMappingVariable(value, source=self.source) + elif is_frozen_dataclass(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = FrozenDataClassVariable.create(self.tx, value, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) else: return self.wrap_user_defined(value) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index f9b74bb2f202f4..f6eed6f7bc3c62 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -671,16 +671,6 @@ def _call_hasattr_customobj( ) -class DataClassVariable(ConstDictVariable): - """ - This class doesn't appear to be used anywhere. - It used to be used to deal with transformers.file_utils.ModelOutput - from huggingface. - - Keeping since we wish to support dataclasses in general in the future - """ - - class CustomizedDictVariable(ConstDictVariable): @staticmethod def is_matching_cls_hf(cls): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 362d81d82aea04..8806627e71eb8d 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -2,6 +2,7 @@ import collections import contextlib +import dataclasses import enum import functools import inspect @@ -39,6 +40,7 @@ check_constant_args, get_custom_getattr, has_torch_function, + is_frozen_dataclass, is_namedtuple_cls, is_utils_checkpoint, is_wrapper_or_member_descriptor, @@ -452,6 +454,40 @@ def call_function( assert all(x is not None for x in items) return variables.NamedTupleVariable(items, self.value) + elif is_frozen_dataclass(self.value) and self.is_standard_new(): + from .builder import SourcelessBuilder + + fields = dataclasses.fields(self.value) + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + default_kwargs = {} + for field, var_tracker in zip(fields, items): + if var_tracker is None: + if field.name in kwargs: + var_tracker = kwargs[field.name] + else: + if not field.init: + continue + + if field.default is not dataclasses.MISSING: + var_tracker = SourcelessBuilder.create(tx, field.default) + elif field.default_factory is not dataclasses.MISSING: + factory_fn = SourcelessBuilder.create( + tx, field.default_factory + ) + var_tracker = factory_fn.call_function(tx, [], {}) + else: + # if we are subclass, the constructor could possibly + # be missing args + continue + + default_kwargs[field.name] = var_tracker + kwargs.update(default_kwargs) + + var = tx.output.side_effects.track_object_new_from_user_defined_class(self) + var.call_method(tx, "__init__", args, kwargs) + return var elif ( self.is_standard_new() and SideEffects.cls_supports_mutation_side_effects(self.value) @@ -1175,6 +1211,54 @@ def odict_getitem(self, tx: "InstructionTranslator", key): )(collections.OrderedDict.__getitem__(self.value, key.as_python_constant())) +class FrozenDataClassVariable(UserDefinedObjectVariable): + @staticmethod + def create(tx, value, source): + from dataclasses import fields + + assert is_frozen_dataclass(value) + + from .builder import VariableBuilder + + field_map = {} + for field in fields(value): + if hasattr(value, field.name): + field_map[field.name] = VariableBuilder( + tx, AttrSource(source, field.name) + )(getattr(value, field.name)) + + return FrozenDataClassVariable(value, fields=field_map, source=source) + + def __init__(self, value, fields=None, **kwargs) -> None: + super().__init__(value, **kwargs) + if fields is None: + fields = {} + self.fields = fields + + def as_proxy(self): + from dataclasses import fields + + args = [] + kwargs = {} + for field in fields(self.value): + proxy = self.fields[field.name].as_proxy() + if hasattr(field, "kw_only") and field.kw_only: + kwargs[field.name] = proxy + else: + args.append(proxy) + + return self.python_type()(*args, **kwargs) + + # NB: This is called during __init__ for a frozen dataclass + # use this to accumulate the most up-to-date field values + def method_setattr_standard(self, tx: "InstructionTranslator", name, value): + self.fields[name.as_python_constant()] = value + return super().method_setattr_standard(tx, name, value) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value_type.__name__})" + + class SourcelessGraphModuleVariable(UserDefinedObjectVariable): def __init__( self, From 3842b3e03ce301465ffe34bd3e4ca825c8d3acb1 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 4 Sep 2024 11:31:30 -0700 Subject: [PATCH 0426/1018] [dynamo][super] Corner case where the class is not present in the __mro__ (#135129) I could not come up with a testcase. This was seen in https://github.com/pytorch/pytorch/issues/93633 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135129 Approved by: https://github.com/yanboliang ghstack dependencies: #135039, #135121 --- torch/_dynamo/variables/misc.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2979bd58e312fe..527474d95969c0 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -105,7 +105,13 @@ def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): resolved_class = None resolved_attr = None search_mro = type_to_use.__mro__ - start_index = search_mro.index(search_type) + 1 + + try: + start_index = search_mro.index(search_type) + 1 + except ValueError: + # Corner case where the typevar is not in the mro of the objvar + # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844 + return getattr(super(search_type, type_to_use), name), None # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812 # super has its getattro implementation. The key point is that instead of calling getattr, it checks the # attribute in the class __dict__ From 1965367e0b2f773c51fd08d3a54ba7266a0d7b19 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Wed, 4 Sep 2024 05:22:25 -0700 Subject: [PATCH 0427/1018] [Inductor] Apply loop split optimization in codegen_node (#132389) This PR applies loop split optimization in codegen_node to avoid non-contiguous load. When the vector is loaded in a non-contiguous manner due to a division in the index, we eliminate the division by splitting the loop to avoid non-contiguous load. Example: ``` import torch import torch.nn as nn class GNReLU(torch.nn.Module): def __init__(self, num_groups, num_channels): super(GNReLU, self).__init__() self.gn = nn.GroupNorm(num_groups, num_channels) def forward(self, x): return torch.nn.functional.relu(self.gn(x)) input = torch.randn(2, 960, 96, 96).to(memory_format=torch.channels_last) m = GNReLU(32, 960).eval() compiled_m = torch.compile(m) with torch.no_grad(): compiled_m(input) ``` Generated code: - Before: ``` cpp_fused_native_group_norm_relu_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], ''' #include "/tmp/torchinductor_jiayisun/vu/cvuckxaygqfovv2zu2byqhcmiejbke7mdhf2rpgpr5mlscdev2hg.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(56) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast(0L); x0(2L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(32L); x1+=static_cast(1L)) { { Welford tmp_acc0 = Welford(); Welford> tmp_acc0_vec = Welford>(); Welford> masked_tmp_acc0_vec = Welford>(); static WeightRecp> wrecps0(static_cast(17280L)); for(long x2=static_cast(0L); x2(9216L); x2+=static_cast(1L)) { for(long x3=static_cast(0L); x3(16L); x3+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0); } for(long x3=static_cast(16L); x3(30L); x3+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14); masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0); } } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec)); tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.mean); out_ptr1[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.m2); } } } } { #pragma omp for collapse(2) for(long x0=static_cast(0L); x0(2L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(9216L); x1+=static_cast(1L)) { for(long x2=static_cast(0L); x2(960L); x2+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x2 + (960L*x1) + (8847360L*x0)), 16); auto tmp1 = [&] { __at_align__ std::array tmpbuf; #pragma GCC unroll 16 for (long x2_inner = 0; x2_inner < 16; x2_inner++) { tmpbuf[x2_inner] = out_ptr0[static_cast((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))]; } return at::vec::Vectorized::loadu(tmpbuf.data(), 16); } () ; auto tmp3 = [&] { __at_align__ std::array tmpbuf; #pragma GCC unroll 16 for (long x2_inner = 0; x2_inner < 16; x2_inner++) { tmpbuf[x2_inner] = out_ptr1[static_cast((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))]; } return at::vec::Vectorized::loadu(tmpbuf.data(), 16); } () ; auto tmp12 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x2), 16); auto tmp14 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x2), 16); auto tmp2 = tmp0 - tmp1; auto tmp4 = static_cast(276480.0); auto tmp5 = at::vec::Vectorized(tmp4); auto tmp6 = tmp3 / tmp5; auto tmp7 = static_cast(1e-05); auto tmp8 = at::vec::Vectorized(tmp7); auto tmp9 = tmp6 + tmp8; auto tmp10 = tmp9.rsqrt(); auto tmp11 = tmp2 * tmp10; auto tmp13 = tmp11 * tmp12; auto tmp15 = tmp13 + tmp14; auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0)); tmp16.store(out_ptr2 + static_cast(x2 + (960L*x1) + (8847360L*x0))); } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg2_1, = args args.clear() assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960)) buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32) buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32) buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32) cpp_fused_native_group_norm_relu_0(arg2_1, _frozen_param3, _frozen_param2, buf0, buf1, buf3) del arg2_1 return (buf3, ) ``` - After: ``` cpp_fused_native_group_norm_relu_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], ''' #include "/tmp/torchinductor_jiayisun/vu/cvuckxaygqfovv2zu2byqhcmiejbke7mdhf2rpgpr5mlscdev2hg.h" extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(56) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast(0L); x0(2L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(32L); x1+=static_cast(1L)) { { Welford tmp_acc0 = Welford(); Welford> tmp_acc0_vec = Welford>(); Welford> masked_tmp_acc0_vec = Welford>(); static WeightRecp> wrecps0(static_cast(17280L)); for(long x2=static_cast(0L); x2(9216L); x2+=static_cast(1L)) { for(long x3=static_cast(0L); x3(16L); x3+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0); } for(long x3=static_cast(16L); x3(30L); x3+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14); masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0); } } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec)); tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.mean); out_ptr1[static_cast(x1 + (32L*x0))] = static_cast(tmp_acc0.m2); } } } } { #pragma omp for collapse(2) for(long x0=static_cast(0L); x0(2L); x0+=static_cast(1L)) { for(long x1=static_cast(0L); x1(9216L); x1+=static_cast(1L)) { #pragma GCC ivdep for(long x2=static_cast(0L); x2(32L); x2+=static_cast(1L)) { for(long x3=static_cast(0L); x3(16L); x3+=static_cast(16L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 16); auto tmp1 = out_ptr0[static_cast(x2 + (32L*x0))]; auto tmp4 = out_ptr1[static_cast(x2 + (32L*x0))]; auto tmp12 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x3 + (30L*x2)), 16); auto tmp14 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x3 + (30L*x2)), 16); auto tmp2 = at::vec::Vectorized(tmp1); auto tmp3 = tmp0 - tmp2; auto tmp5 = static_cast(276480.0); auto tmp6 = tmp4 / tmp5; auto tmp7 = static_cast(1e-05); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); auto tmp9 = 1 / std::sqrt(tmp8); auto tmp10 = at::vec::Vectorized(tmp9); auto tmp11 = tmp3 * tmp10; auto tmp13 = tmp11 * tmp12; auto tmp15 = tmp13 + tmp14; auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0)); tmp16.store(out_ptr2 + static_cast(x3 + (30L*x2) + (960L*x1) + (8847360L*x0))); } for(long x3=static_cast(16L); x3(30L); x3+=static_cast(14L)) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 14); auto tmp1 = out_ptr0[static_cast(x2 + (32L*x0))]; auto tmp4 = out_ptr1[static_cast(x2 + (32L*x0))]; auto tmp12 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(x3 + (30L*x2)), 14); auto tmp14 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(x3 + (30L*x2)), 14); auto tmp2 = at::vec::Vectorized(tmp1); auto tmp3 = tmp0 - tmp2; auto tmp5 = static_cast(276480.0); auto tmp6 = tmp4 / tmp5; auto tmp7 = static_cast(1e-05); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); auto tmp9 = 1 / std::sqrt(tmp8); auto tmp10 = at::vec::Vectorized(tmp9); auto tmp11 = tmp3 * tmp10; auto tmp13 = tmp11 * tmp12; auto tmp15 = tmp13 + tmp14; auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0)); tmp16.store(out_ptr2 + static_cast(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 14); } } } } } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg2_1, = args args.clear() assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960)) buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32) buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32) buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32) cpp_fused_native_group_norm_relu_0(arg2_1, _frozen_param3, _frozen_param2, buf0, buf1, buf3) del arg2_1 return (buf3, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/132389 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel Co-authored-by: Jiong Gong --- test/inductor/test_cpu_repro.py | 12 ++++ torch/_inductor/codegen/cpp.py | 116 +++++++++++++++++++++++++++++++- torch/_inductor/ir.py | 13 ++++ torch/_inductor/scheduler.py | 13 +++- 4 files changed, 148 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 7dbe35f766f9a7..d1f5a82cdf6c74 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3435,6 +3435,18 @@ def forward(self, x): # 2 generated kernels (one for var_mean, the other for result) check_metrics_vec_kernel_count(2) + # check loop split optimization + if fmt == torch.channels_last: + torch._dynamo.reset() + metrics.reset() + with torch.no_grad(): + opt_mod = torch.compile(mod) + _, code = run_and_get_cpp_code(opt_mod, x) + # check that there are no non_contiguous loads + FileCheck().check_count("__at_align__ std::array", 0, exactly=True).run( + code + ) + def test_int_div_vec(self): def fn(x, y, mode): return torch.div(x, y, rounding_mode=mode) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 6663de7639e8e1..999df248fc6110 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -313,9 +313,12 @@ def visit_modular_indexing(divisor, modulus): @functools.lru_cache -def stride_at_vec_range(index: sympy.Expr, var: sympy.Symbol, vec_length: int): - index_vec_simplified = simplify_index_in_vec_range(index, var, vec_length) - return stride_at(index_vec_simplified, var) +def stride_at_vec_range( + index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None +): + if vec_length: + index = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index, var) class OuterLoopFusedSchedulerNode(FusedSchedulerNode): @@ -4085,6 +4088,112 @@ def can_fuse_vertical(self, node1, node2): self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) + def try_loop_split(self, nodes: List[SchedulerNode]): + """ + Apply loop split optimization. + When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop + to avoid non-contiguous loads, subject to the following conditions: + 1. No reduction and no mudular index for all nodes. + 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, + we can get the dimension that needs to be split, and the split dimension is contiguous + in all other indexing_exprs. + + For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: + {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, + we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to + {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to + {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. + """ + + # No reduction and no mudular + if any( + len(node.group[1][1]) != 0 + or any( + expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() + ) + for node in nodes + ): + return nodes + + split_var = None + split_number = None + divide_index_name = None + num_div = 0 + match_div = False + matched_node = None + + for node in nodes: + assert isinstance(node.node, ir.ComputedBuffer) + _, original_body, _ = node.node.get_default_sizes_body() + for name, expr in original_body.indexing_exprs.items(): + num_div += expr.count(FloorDiv) + if num_div > 1: + return nodes + if expr.count(FloorDiv) == 1: + div_expr = expr.find(FloorDiv).pop() + split_var = div_expr.args[0] + split_number = div_expr.args[1] + divide_index_name = name + if ( + isinstance(split_number, sympy.core.numbers.Integer) + and isinstance(split_var, sympy.core.symbol.Symbol) + and divide_index_name is not None + and all( + stride_at_vec_range(expr, split_var) == 1 + for name, expr in original_body.indexing_exprs.items() + if name != divide_index_name + ) + ): + match_div = True + matched_node = node + + # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. + if not match_div: + return nodes + + extra_indexing_constraints = None + + def loop_split(sizes, body, vars): + index_size, reduce_size = sizes + index_vars, reduce_vars = vars + split_idx = index_vars.index(split_var) + new_index_size = index_size.copy() + new_index_size[split_idx] = index_size[split_idx] // split_number + new_index_size.insert(split_idx + 1, split_number) + (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( + new_index_size, reduce_size, prefix="y" + ) + iter_vars = new_index_vars.copy() + divisor_var = iter_vars.pop(split_idx + 1) + iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var + body = ir.LoopBody( + body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars + ) + nonlocal extra_indexing_constraints + if not extra_indexing_constraints: + extra_indexing_constraints = ( + body.var_ranges, + list(body.indexing_exprs.values()), + ) + return ( + (new_index_size, reduce_size), + body, + (new_index_vars, reduce_vars), + ) + + # Here decide the final loop order + for node in nodes: + if node == matched_node: + node.recompute_size_and_body(recompute_sizes_body_func=loop_split) + for node in nodes: + if node != matched_node: + node.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=loop_split, + ) + + return nodes + def codegen_outer_loop_node( self, node: OuterLoopFusedSchedulerNode, @@ -4287,6 +4396,7 @@ def codegen_node( self.codegen_outer_loop_node(node) else: nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + nodes = self.try_loop_split(nodes) cpp_kernel_proxy = CppKernelProxy(kernel_group) cpp_kernel_proxy.codegen_nodes(nodes) kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 18cdac11ac8077..b9f08fedfe847c 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3717,6 +3717,7 @@ def get_default_sizes_body(self): def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ): """ This is a main place where we do loop transformations in a @@ -3732,6 +3733,8 @@ def simplify_and_reorder( to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) on CPU by preventing indexing simplifications and obtaining index/reduce ranges for the scheduler node compatible with other nodes. + Optional argument recompute_sizes_body_func can be used to recompute sizes and body + on the default body. This can be useful to append additional loop transformations. """ ( (index_size, reduce_size), @@ -3739,6 +3742,15 @@ def simplify_and_reorder( (index_vars, reduce_vars), ) = self.get_default_sizes_body() + if recompute_sizes_body_func: + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = recompute_sizes_body_func( + (index_size, reduce_size), body, (index_vars, reduce_vars) + ) + index_formulas = [*body.indexing_exprs.values()] if extra_indexing_constraints is not None: assert ( @@ -3915,6 +3927,7 @@ def should_allocate(self): def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ): return ( ( diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 7f9f363fa1da04..179dde1386ebba 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -828,10 +828,12 @@ def __init__( def _compute_attrs( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) self._sizes, self._body = self.node.simplify_and_reorder( - extra_indexing_constraints=extra_indexing_constraints + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, ) group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn @@ -856,9 +858,14 @@ def _compute_attrs( ) def recompute_size_and_body( - self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: - self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) + self._compute_attrs( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) def refresh_dependencies(self, normalize: bool) -> None: # Fake dependencies are added manually. They can not be analyzed from From ed2ff23c880d1f97b1ba89993475b1506d9a645f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 4 Sep 2024 11:42:11 -0700 Subject: [PATCH 0428/1018] [dynamo] Search for _torchdynamo_inline only for functions (#135130) Issue seen in https://github.com/pytorch/pytorch/issues/93633 Fixes https://github.com/pytorch/pytorch/issues/93633 Unable to create a testcase Pull Request resolved: https://github.com/pytorch/pytorch/pull/135130 Approved by: https://github.com/williamwen42, https://github.com/yanboliang ghstack dependencies: #135039, #135121, #135129 --- torch/_dynamo/variables/builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index e1eb82ac778ac7..c54c38f9bf1f83 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -549,7 +549,8 @@ class Autotuner: # Note - There are some nested values where types mismatch! # We want to get those out and wrap those. - value = inspect.getattr_static(value, "_torchdynamo_inline", value) + if is_function_or_wrapper(value): + value = inspect.getattr_static(value, "_torchdynamo_inline", value) # Everything else (NB: order matters!) if is_traceable_wrapper_subclass(value) or istype( From 0407bf08af8453bb4946fe992c4abafb56970dad Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 4 Sep 2024 20:13:48 +0000 Subject: [PATCH 0429/1018] [PP] Go back to export instead of _export (#134299) Reverts https://github.com/pytorch/pytorch/pull/130998 because FakeTensor + real device suffice to work around the autocast issue in HF. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134299 Approved by: https://github.com/lessw2020 --- torch/distributed/pipelining/_IR.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 7b6e39cfbe44bb..036ad6cf8621eb 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -13,7 +13,6 @@ import torch.fx as fx from torch.distributed import ProcessGroup from torch.export import ExportedProgram -from torch.export._trace import _export from torch.export.unflatten import ( _assign_attr, _AttrKind, @@ -1005,14 +1004,11 @@ def _trace_with_export( ) -> ExportedProgram: logger.info("Tracing model ...") try: - with torch.no_grad(): - ep = _export( - mod, - example_args, - example_kwargs, - strict=True, - pre_dispatch=False, - ) + ep = torch.export.export( + mod, + example_args, + example_kwargs, + ) except Exception as e: raise RuntimeError( "It seems that we cannot capture your model as a full graph. " From 9ed4b638355691647134915a8f53df0c87a65b5a Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 4 Sep 2024 23:26:35 +0000 Subject: [PATCH 0430/1018] [export] Fix indentation (#135128) Summary: as title Test Plan: CI Differential Revision: D62195680 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135128 Approved by: https://github.com/tugsbayasgalan --- torch/_export/__init__.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index fea70b061a73d3..91d893e05cb89c 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -184,20 +184,20 @@ def print_export_warning(): range_constraints=range_constraints, ) - error_message = \ - """ - Calling train() or eval() is not supported for exported models. - Alternatively, you may override these methods to do custom user behavior as follows: + error_message = \ + """ + Calling train() or eval() is not supported for exported models. + Alternatively, you may override these methods to do custom user behavior as follows: - def _my_train(self, mode: bool = True): - ... + def _my_train(self, mode: bool = True): + ... - def _my_eval(self): - ... + def _my_eval(self): + ... - model.train = types.MethodType(_my_train, model) - model.eval = types.MethodType(_my_eval, model) - """ + model.train = types.MethodType(_my_train, model) + model.eval = types.MethodType(_my_eval, model) + """ def _train(self, mode: bool = True): raise NotImplementedError(error_message) From a0562c173af008b826ce7340afa90352ef5e24ed Mon Sep 17 00:00:00 2001 From: eqy Date: Wed, 4 Sep 2024 23:37:16 +0000 Subject: [PATCH 0431/1018] [CUDA][complex][TF32] Update `test_noncontiguous_samples` tolerances for `complex64` (#134526) Recent cuDNN heuristics change surfaces same TF32 issue as `float32` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134526 Approved by: https://github.com/ezyang --- torch/testing/_internal/common_methods_invocations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3b81da0a55872c..869270a600bc90 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15083,7 +15083,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ), # TF32 DecorateInfo( - toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3)}), + toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3), + torch.complex64: tol(atol=5e-3, rtol=1e-3)}), 'TestCommon', 'test_noncontiguous_samples', ), DecorateInfo( From 053387e8301c7ac78a135985e4fb50bd516a55f9 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Wed, 4 Sep 2024 23:42:11 +0000 Subject: [PATCH 0432/1018] [ONNX] Support output_names in dynamic_axes when dynamo=True (#135134) Previous to this PR, if output_names shows in dynamic_axes, it errors when we turn it to dynamic_shapes of torch.export, as we only recognized input_names. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135134 Approved by: https://github.com/justinchuby --- test/onnx/exporter/test_api.py | 22 ++++++++++++++++++++++ torch/onnx/_internal/exporter/_compat.py | 14 +++++++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 907db3d96a0e18..6f2a1f942313dd 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -72,6 +72,28 @@ def test_dynamic_axes_supports_partial_dynamic_shapes(self): }, ) + def test_dynamic_axes_supports_output_names(self): + self.assert_export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "b": [0, 1, 2], + }, + ) + onnx_program = torch.onnx.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + input_names=["x", "b"], + output_names=["x_out", "b_out"], + dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]}, + dynamo=True, + ) + assert onnx_program is not None + onnx_testing.assert_onnx_program(onnx_program) + def test_saved_f_exists_after_export(self): with common_utils.TemporaryFileName(suffix=".onnx") as path: _ = torch.onnx.export( diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 7d38c2ce1415e9..72d2a411c980cd 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -29,7 +29,9 @@ def _signature(model) -> inspect.Signature: def _from_dynamic_axes_to_dynamic_shapes( model, + *, dynamic_axes=None, + output_names: set[str], input_names: Sequence[str] | None = None, ) -> dict[str, Any] | None: """ @@ -69,10 +71,13 @@ def _from_dynamic_axes_to_dynamic_shapes( # for the exported program dynamic_shapes_to_exported_program = {} for input_name, axes in dynamic_axes.items(): - # input_name can be either from inptu_names or from the model inputs + if input_name in output_names: + # User specified an output name as a dynamic axis, so we skip it + continue + # input_name can be either from input_names or from the model inputs if input_name not in input_names_to_model_inputs: raise ValueError( - f"dynamix axis: {input_name} is not found in the input names: {input_names}" + f"dynamic axis: {input_name} is not found in the input names: {input_names}" ) model_input_name = input_names_to_model_inputs[input_name] if isinstance(axes, dict): @@ -147,7 +152,10 @@ def export_compat( args, kwargs = _get_torch_export_args(args, kwargs) if dynamic_shapes is None and dynamic_axes is not None: dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( - model, dynamic_axes, input_names + model, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=set(output_names or ()), ) try: From 42c66b70cb77002b1e0bb47719ff4e3d931cb3f4 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Wed, 4 Sep 2024 21:01:52 +0000 Subject: [PATCH 0433/1018] Fix unintentional deduplication of returned tensors (#134726) When CSE was used, returned tensors that had gone through identical processing steps but were distinct from a data perspective were pruned out of the graph. This commit protects tensors which are directly output from being pruned, and adds a test for this behavior. Closes #88813 and #114344 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134726 Approved by: https://github.com/amjames, https://github.com/zou3519, https://github.com/bdhirsh --- test/dynamo/test_repros.py | 56 +++++++++++++++++++++++++++++++ torch/_functorch/compile_utils.py | 33 ++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index f2a96131952f6e..19e9f7d9617240 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5854,6 +5854,62 @@ def f(x): out = f(x) self.assertEqual(torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]), out[0]) + # https://github.com/pytorch/pytorch/issues/88813 + def test_return_value_duplication_tensor(self) -> None: + def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return val * 2, val * 2 + + x = torch.randn(2, requires_grad=True) + + expect = fn(x) + self.assertNotEqual( + expect[0].untyped_storage().data_ptr(), + expect[1].untyped_storage().data_ptr(), + ) + + actual = torch.compile(fn, backend="aot_eager")(x) + self.assertNotEqual( + actual[0].untyped_storage().data_ptr(), + actual[1].untyped_storage().data_ptr(), + ) + + # https://github.com/pytorch/pytorch/issues/114344 + def test_return_value_duplication_mixed_grad(self) -> None: + def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + out0 = val + 1 + out1 = val + 1 + return out0, out1 + + x = torch.randn(2, requires_grad=True) + + with torch.enable_grad(): + expect = fn(x) + actual = torch.compile(fn, backend="aot_eager")(x) + + self.assertEqual(expect[0].requires_grad, actual[0].requires_grad) + self.assertEqual(expect[1].requires_grad, actual[1].requires_grad) + + # https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371 + def test_return_value_duplication_scalar(self) -> None: + def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x, y = val * 2, val * 2 + return x[0], y[0] + + x = torch.randn(2, requires_grad=True) + + expect = fn(x) + self.assertNotEqual( + expect[0].untyped_storage().data_ptr(), + expect[1].untyped_storage().data_ptr(), + ) + + actual = torch.compile(fn, backend="aot_eager")(x) + self.assertNotEqual( + actual[0].untyped_storage().data_ptr(), + actual[1].untyped_storage().data_ptr(), + ) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index ac7e0d5fdffd34..3bf61f1af3bf3c 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -5,6 +5,7 @@ import torch import torch.fx as fx +from torch.multiprocessing.reductions import StorageWeakRef from torch.utils import _pytree as pytree from torch.utils._pytree import tree_flatten @@ -50,6 +51,37 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph): ) compute_mutation_region_ids(fx_g) # type: ignore[arg-type] + + # Make a set of separate storages returned from the output, which will be preserved + # when pruning. This prevents us from deduplicating returned tensors which have + # experienced identical operations, but are separate data structures in eager mode. + output_node: fx.Node = list(fx_g.nodes)[-1] + assert output_node.op == "output" + + def checkable_node(node: fx.Node) -> bool: + """We can evaluate only nodes that represent tensors with defined storage.""" + if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor): + return False + + try: + node.meta["val"].untyped_storage() + except NotImplementedError: + return False + + return True + + output_storages = { + StorageWeakRef(n.meta["val"].untyped_storage()) + for n in output_node.all_input_nodes + if checkable_node(n) + } + nodes_that_alias_outputs = { + n + for n in fx_g.nodes + if checkable_node(n) + and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages + } + for n in fx_g.nodes: # The placeholder, output, and get_attr nodes are copied to the new graph without change # do not CSE away random operations @@ -62,6 +94,7 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph): # Also, aten.empty is almost always fusible into its consumer, # so it's not worth CSEing. or get_aten_target(n) is aten.empty + or n in nodes_that_alias_outputs ): new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node From 47241dc64c855a1666ecb66fb574a752a6ba50f5 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 4 Sep 2024 13:01:16 -0700 Subject: [PATCH 0434/1018] [PP] Fix zero bubble composability with DP (#134052) Moved all the backward functions (`stage_backward_input`, `stage_backward_weight`, `stage_backward`) under the same `backward_maybe_with_nosync` function which controls the logic of the data parallel wrappers. FSDP was not working with zero bubble PP because there will be twice as many "backward" calls and we update the weight gradients after `autograd.grad` is called. As a result, we need to manually call the FSDP `post_backward_hook()` after the weights have the correct gradients. Fixes the tests: `python test/distributed/_composable/test_composability/test_pp_composability.py ComposabilityTest.test_manual_with_data_parallel_dp_type_FSDP_ScheduleClass0_use_new_runtime_False` `python test/distributed/_composable/test_composability/test_pp_composability.py ComposabilityTest.test_manual_with_data_parallel_dp_type_DDP_ScheduleClass0_use_new_runtime_False` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134052 Approved by: https://github.com/kwen2501 --- .../test_pp_composability.py | 8 +- test/distributed/pipelining/test_backward.py | 6 +- torch/distributed/pipelining/_backward.py | 38 +++--- torch/distributed/pipelining/stage.py | 109 +++++++++++++----- 4 files changed, 112 insertions(+), 49 deletions(-) diff --git a/test/distributed/_composable/test_composability/test_pp_composability.py b/test/distributed/_composable/test_composability/test_pp_composability.py index 328f677be7fcfe..e173bb34e3eaa2 100644 --- a/test/distributed/_composable/test_composability/test_pp_composability.py +++ b/test/distributed/_composable/test_composability/test_pp_composability.py @@ -17,6 +17,7 @@ ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, ScheduleLoopedBFS, ) from torch.nn.parallel import DistributedDataParallel as DDP @@ -86,6 +87,7 @@ def device(self): ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleFlexibleInterleaved1F1B, + ScheduleInterleavedZeroBubble, ], ) @parametrize("use_new_runtime", [False, True]) @@ -233,7 +235,9 @@ def build_stage(stage_idx, num_stages): name = ".".join(parts) ref_p = ref_parameters[name] self.assertTrue(isinstance(p.grad, DTensor)) - self.assertEqual(ref_p.grad, p.grad.full_tensor()) + torch.testing.assert_close( + ref_p.grad, p.grad.full_tensor(), rtol=1e-5, atol=5e-5 + ) elif dp_type == "DDP": for partial_model, offset in zip(partial_models, offsets): for name, p in partial_model.named_parameters(): @@ -241,7 +245,7 @@ def build_stage(stage_idx, num_stages): parts[0] = str(int(parts[0]) + offset) name = ".".join(parts) ref_p = ref_parameters[name] - self.assertEqual(ref_p.grad, p.grad) + torch.testing.assert_close(ref_p.grad, p.grad, rtol=1e-5, atol=5e-5) torch.distributed.destroy_process_group() diff --git a/test/distributed/pipelining/test_backward.py b/test/distributed/pipelining/test_backward.py index b99e303ca9a61b..328eddcce50695 100644 --- a/test/distributed/pipelining/test_backward.py +++ b/test/distributed/pipelining/test_backward.py @@ -77,7 +77,7 @@ def test_stage_backward_input(self): dinputs, param_groups = stage_backward_input( stage_outputs=(loss,), output_grads=None, - stage_inputs=[x], + input_values=[x], weights=mod.parameters(), ) @@ -112,7 +112,7 @@ def test_stage_backward_weight(self): dinputs, param_groups = stage_backward_input( stage_outputs=(loss,), output_grads=None, - stage_inputs=[x], + input_values=[x], weights=mod.parameters(), ) @@ -160,7 +160,7 @@ def test_stage_backward_weight_multiple_iters(self): dinputs, param_groups = stage_backward_input( stage_outputs=(loss,), output_grads=None, - stage_inputs=[x], + input_values=[x], weights=mod.parameters(), ) diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index 5aaa6180721d97..a4c5516e83da01 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -159,7 +159,7 @@ def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, A def stage_backward_input( stage_outputs: List[torch.Tensor], output_grads: Optional[List[torch.Tensor]], - stage_inputs: List[torch.Tensor], + input_values: List[torch.Tensor], weights: Iterator[Parameter], ): """ @@ -169,7 +169,7 @@ def stage_backward_input( filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs)) ) stage_input_grad_fns: List[Node] = list( - filter(None, map(_get_grad_fn_or_grad_acc, stage_inputs)) + filter(None, map(_get_grad_fn_or_grad_acc, input_values)) ) weight_grad_fns: List[Node] = list( filter(None, map(_get_grad_fn_or_grad_acc, weights)) @@ -197,7 +197,7 @@ def hook(grad_inputs): intermediate.register_prehook(get_hook(param_group, i)) # Stage 0 inputs do not require grads? Should we skip in that case? - if all(tensor.requires_grad for tensor in stage_inputs): + if all(tensor.requires_grad for tensor in input_values): if output_grads is None: # In case this is the loss and there are no output_grads, then we just use 1s output_grads = [ @@ -206,13 +206,13 @@ def hook(grad_inputs): dinputs = torch.autograd.grad( stage_outputs, - inputs=stage_inputs, + inputs=input_values, grad_outputs=output_grads, retain_graph=True, ) # update the gradients for inputs - for i, inp in enumerate(stage_inputs): + for i, inp in enumerate(input_values): if inp.grad is None: inp.grad = dinputs[i] else: @@ -225,7 +225,14 @@ def hook(grad_inputs): def stage_backward_weight( weights: Iterator[Parameter], param_groups: List[Dict[str, Any]] ): - all_dweights = dict() + # map weights to param_group_weights + grad_acc_to_weight = {} + weight_grads = [] + for index, weight in enumerate(weights): + grad_acc = _get_grad_fn_or_grad_acc(weight) + grad_acc_to_weight[grad_acc] = weight, index + weight_grads.append(weight.grad) + for param_group in param_groups: # TODO: Handle case where intermediate can have multiple outputs intermediate_edges = tuple( @@ -242,19 +249,14 @@ def stage_backward_weight( weights_edges, grad_outputs=sum(param_group["grads"], tuple()), ) - for w, dw in zip(param_group["params"], dweights): - all_dweights[w] = dw + for grad_acc, dw in zip(param_group["params"], dweights): + weight, index = grad_acc_to_weight[grad_acc] + if weight.grad is None: + weight.grad = dw + else: + weight.grad += dw # return grads in the original order weights were provided in - out = [] - for w in weights: - grad_acc = _get_grad_fn_or_grad_acc(w) - dweight = all_dweights[grad_acc] - out.append(dweight) - if w.grad is None: - w.grad = dweight - else: - w.grad += dweight - return out + return weight_grads def stage_backward( diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 4761ba8dcb7787..3c4abfb7b0b352 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -10,7 +10,7 @@ import torch.fx as fx import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensor -from torch.distributed._composable.fsdp.fully_shard import FSDPModule +from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard from torch.fx.node import map_aggregate from torch.nn.parallel import DistributedDataParallel @@ -468,14 +468,40 @@ def forward_maybe_with_nosync(self, *args, **kwargs): out_val = self.submod(*args, **kwargs) return out_val - def backward_maybe_with_nosync(self, bwd_kwargs: Dict): + def backward_maybe_with_nosync(self, backward_type, bwd_kwargs: Dict): """ Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but there are additional state-variables and performance considerations depending on the data parallelism used. This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. """ - last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] + full_backward = bwd_kwargs["full_backward"] + if full_backward: + last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] + else: + # For backwards are split into weight and input, we will see twice as many bwd_chunks + last_backward = self._seen_bwd_chunks == 2 * self.chunks - 1 # type: ignore[operator] + + def perform_backward(backward_type): + if backward_type == "full": + return lambda: stage_backward( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + ) + elif backward_type == "input": + return lambda: stage_backward_input( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + self.submod.parameters(), + ) + elif backward_type == "weight": + return lambda: stage_backward_weight( + self.submod.parameters(), bwd_kwargs["param_groups"] + ) + else: + raise RuntimeError(f"Unknown backward type: {backward_type}") # If submod is wrapped by DDP if isinstance(self.submod, DistributedDataParallel): @@ -489,21 +515,41 @@ def backward_maybe_with_nosync(self, bwd_kwargs: Dict): ) ) ) - grads_input = stage_backward(**bwd_kwargs) + result = perform_backward(backward_type)() else: with self.submod.no_sync(): # type: ignore[operator] - grads_input = stage_backward(**bwd_kwargs) + result = perform_backward(backward_type)() # If submod is a FSDP module elif isinstance(self.submod, FSDPModule): - self.submod.set_is_last_backward(last_backward) - self.submod.set_requires_gradient_sync(last_backward) - grads_input = stage_backward(**bwd_kwargs) + self.submod.set_is_last_backward(False) + self.submod.set_reshard_after_backward(False) + self.submod.set_requires_gradient_sync(False) + result = perform_backward(backward_type)() + if last_backward: + # Manually call post backward for FSDP + def run_post_backward(fsdp_module: FSDPModule) -> None: + fsdp_module.set_is_last_backward(True) + fsdp_module.set_reshard_after_backward(True) + fsdp_module.set_requires_gradient_sync(True) + fsdp_state = fully_shard.state(fsdp_module) + for state in fsdp_state._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.post_backward() + + run_post_backward(self.submod) else: # Non-DP submodule, regular backward - grads_input = stage_backward(**bwd_kwargs) + result = perform_backward(backward_type)() self._seen_bwd_chunks += 1 - return grads_input + + if isinstance(result, tuple) and len(result) == 2: + # for stage_backward_input() + grads, param_groups = result + else: + grads, param_groups = result, None + + return grads, param_groups def forward_one_chunk( self, @@ -611,41 +657,43 @@ def backward_one_chunk( "input_values": input_values, } + # Save full_backward + bwd_kwargs["full_backward"] = full_backward + # Custom backward function if self.dw_builder: # TODO: We may want to change our semantics so we are allowed to ignore # the 'dw_builder' and call full_backward directly when it is a full_backward op. - self.grads_input = self.backward_maybe_with_nosync(bwd_kwargs) + self.grads_input, _ = self.backward_maybe_with_nosync("full", bwd_kwargs) if full_backward: self.dw_builder()() else: self.dw_runner[bwd_chunk_id] = self.dw_builder() else: if full_backward: - self.grads_input = self.backward_maybe_with_nosync(bwd_kwargs) + self.grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs + ) else: # perform the partial backwards for the inputs with a custom backward function # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors if isinstance(bwd_kwargs["stage_output"], torch.Tensor): bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],) - dinputs, param_groups = stage_backward_input( - bwd_kwargs["stage_output"], - bwd_kwargs["output_grads"], - bwd_kwargs["input_values"], - self.submod.parameters(), + grads_input, param_groups = self.backward_maybe_with_nosync( + "input", bwd_kwargs ) + # TODO: we dont need to save this, add to dw_runner? self.backward_state[bwd_chunk_id] = ( - dinputs, input_values, param_groups, bwd_kwargs["stage_output"], bwd_kwargs["output_grads"], ) - self.grads_input = dinputs - # save a dw_runner with the `stage_backward_weight` function - self.dw_runner[bwd_chunk_id] = stage_backward_weight + self.grads_input = grads_input + # Save a placeholder for the dw_runner + self.dw_runner[bwd_chunk_id] = lambda: None logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) def backward_weight_one_chunk(self, bwd_chunk_id: int): @@ -658,23 +706,32 @@ def backward_weight_one_chunk(self, bwd_chunk_id: int): self.dw_runner.pop(bwd_chunk_id)() else: ( - dinputs, input_values, param_groups, stage_output, output_grads, ) = self.backward_state.pop(bwd_chunk_id) + if self.stage_index != 0: - dweights = self.dw_runner.pop(bwd_chunk_id)( - self.submod.parameters(), param_groups - ) + bwd_kwargs = { + "stage_output": stage_output, + "param_groups": param_groups, + "full_backward": False, + } + weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs) else: # TODO: figure out a better way to do this: # if inputs does not require gradient, # then the parameter group will not be fully captured during stage_backward_input # in this case, we need call grad directly on the parameters # To solve: make input fn do the intersect compute and then finish it off during W - torch.autograd.backward(stage_output, grad_tensors=output_grads) + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": input_values, + "full_backward": False, + } + self.backward_maybe_with_nosync("full", bwd_kwargs) def _validate_fwd_input(self, args, kwargs): """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" From 381bd7c779b34f271d6eb46e8fb177670fffdabb Mon Sep 17 00:00:00 2001 From: chuanqiw Date: Thu, 5 Sep 2024 00:02:46 +0000 Subject: [PATCH 0435/1018] [CI] Build pytorch wheel with Torch XPU Operators on Windows (#133151) # Description This pipeline enables the CI build on Windows with PR labeled with ciflow/xpu. This will build torch binary with Torch XPU Operators on Windows using Vision Studio BuildTools 2022. # Changes 1. Install xpu batch file (install_xpu.bat) - Check if build machine has oneAPI in environment, and if the version of it is latest. If not, install the latest public released oneAPI in the machine. 2. GHA callable pipeline (_win-build.yml) - Set vc_year and use_xpu as parameter to set build wheel environment. 3. GHA workflow (xpu.yml) - Add a new windows build job and pass parameters to it. 4. Build wheels script (.ci/pytorch/win-test-helpers/build_pytorch.bat) - Prepare environment for building, e.g. install oneAPI bundle. # Note 1. For building wheels on Intel GPU, you need Vision Studio BuildTools version >= 2022 2. This pipeline requires to use Vision Studio BuildTools 2022 to build wheels. For now, we specify "windows.4xlarge.nonephemeral" as build machine label in the yaml file. We will request to add self-hosted runners with Intel GPU and Vision Studio BuildTools 2022 installed soon. Work for #114850 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133151 Approved by: https://github.com/chuanqi129, https://github.com/atalman Co-authored-by: chuanqiw --- .../win-test-helpers/build_pytorch.bat | 16 ++++ .../installation-helpers/install_xpu.bat | 91 +++++++++++++++++++ .github/workflows/_win-build.yml | 13 ++- .github/workflows/xpu.yml | 9 ++ 4 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 .ci/pytorch/win-test-helpers/installation-helpers/install_xpu.bat diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 824acc09c5854a..92078d22326396 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -24,6 +24,12 @@ call %INSTALLER_DIR%\install_sccache.bat if errorlevel 1 goto fail if not errorlevel 0 goto fail +if "%USE_XPU%"=="1" ( + :: Install xpu support packages + call %INSTALLER_DIR%\install_xpu.bat + if errorlevel 1 exit /b 1 +) + :: Miniconda has been installed as part of the Windows AMI with all the dependencies. :: We just need to activate it here call %INSTALLER_DIR%\activate_miniconda3.bat @@ -43,6 +49,16 @@ if "%VC_VERSION%" == "" ( ) if errorlevel 1 goto fail if not errorlevel 0 goto fail + +if "%USE_XPU%"=="1" ( + :: Activate xpu environment - VS env is required for xpu + call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" + if errorlevel 1 exit /b 1 + :: Reduce build time. Only have MTL self-hosted runner now + SET TORCH_XPU_ARCH_LIST=xe-lpg + SET USE_KINETO=0 +) + @echo on popd diff --git a/.ci/pytorch/win-test-helpers/installation-helpers/install_xpu.bat b/.ci/pytorch/win-test-helpers/installation-helpers/install_xpu.bat new file mode 100644 index 00000000000000..b9fd597929dded --- /dev/null +++ b/.ci/pytorch/win-test-helpers/installation-helpers/install_xpu.bat @@ -0,0 +1,91 @@ +@echo on +REM Description: Install Intel Support Packages on Windows +REM BKM reference: https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpu/2-5.html + +set XPU_INSTALL_MODE=%~1 +if "%XPU_INSTALL_MODE%"=="" goto xpu_bundle_install_start +if "%XPU_INSTALL_MODE%"=="bundle" goto xpu_bundle_install_start +if "%XPU_INSTALL_MODE%"=="driver" goto xpu_driver_install_start +if "%XPU_INSTALL_MODE%"=="all" goto xpu_driver_install_start + +:arg_error + +echo Illegal XPU installation mode. The value can be "bundle"/"driver"/"all" +echo If keep the value as space, will use default "bundle" mode +exit /b 1 + +:xpu_driver_install_start +:: TODO Need more testing for driver installation +set XPU_DRIVER_LINK=https://downloadmirror.intel.com/830975/gfx_win_101.5972.exe +curl -o xpu_driver.exe --retry 3 --retry-all-errors -k %XPU_DRIVER_LINK% +echo "XPU Driver installing..." +start /wait "Intel XPU Driver Installer" "xpu_driver.exe" +if errorlevel 1 exit /b 1 +del xpu_driver.exe +if "%XPU_INSTALL_MODE%"=="driver" goto xpu_install_end + +:xpu_bundle_install_start + +set XPU_BUNDLE_PARENT_DIR=C:\Program Files (x86)\Intel\oneAPI +set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d1a91e2-e8b8-40a5-8c7f-5db768a6a60c/w_intel-for-pytorch-gpu-dev_p_0.5.3.37_offline.exe +set XPU_PTI_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d1a91e2-e8b8-40a5-8c7f-5db768a6a60c/w_intel-pti-dev_p_0.9.0.37_offline.exe +set XPU_BUNDLE_VERSION=0.5.3+31 +set XPU_PTI_VERSION=0.9.0+36 +set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.intel-for-pytorch-gpu-dev.product +set XPU_PTI_PRODUCT_NAME=intel.oneapi.win.intel-pti-dev.product +set XPU_BUNDLE_INSTALLED=0 +set XPU_PTI_INSTALLED=0 +set XPU_BUNDLE_UNINSTALL=0 +set XPU_PTI_UNINSTALL=0 + +:: Check if XPU bundle is target version or already installed +if exist "%XPU_BUNDLE_PARENT_DIR%\Installer\installer.exe" goto xpu_bundle_ver_check +goto xpu_bundle_install + +:xpu_bundle_ver_check + +"%XPU_BUNDLE_PARENT_DIR%\Installer\installer.exe" --list-products > xpu_bundle_installed_ver.log + +for /f "tokens=1,2" %%a in (xpu_bundle_installed_ver.log) do ( + if "%%a"=="%XPU_BUNDLE_PRODUCT_NAME%" ( + echo %%a Installed Version: %%b + set XPU_BUNDLE_INSTALLED=1 + if not "%XPU_BUNDLE_VERSION%"=="%%b" ( + start /wait "Installer Title" "%XPU_BUNDLE_PARENT_DIR%\Installer\installer.exe" --action=remove --eula=accept --silent --product-id %XPU_BUNDLE_PRODUCT_NAME% --product-ver %%b --log-dir uninstall_bundle + set XPU_BUNDLE_UNINSTALL=1 + ) + ) + if "%%a"=="%XPU_PTI_PRODUCT_NAME%" ( + echo %%a Installed Version: %%b + set XPU_PTI_INSTALLED=1 + if not "%XPU_PTI_VERSION%"=="%%b" ( + start /wait "Installer Title" "%XPU_BUNDLE_PARENT_DIR%\Installer\installer.exe" --action=remove --eula=accept --silent --product-id %XPU_PTI_PRODUCT_NAME% --product-ver %%b --log-dir uninstall_bundle + set XPU_PTI_UNINSTALL=1 + ) + ) +) +if errorlevel 1 exit /b 1 +if exist xpu_bundle_installed_ver.log del xpu_bundle_installed_ver.log +if "%XPU_BUNDLE_INSTALLED%"=="0" goto xpu_bundle_install +if "%XPU_BUNDLE_UNINSTALL%"=="1" goto xpu_bundle_install +if "%XPU_PTI_INSTALLED%"=="0" goto xpu_pti_install +if "%XPU_PTI_UNINSTALL%"=="1" goto xpu_pti_install +goto xpu_install_end + +:xpu_bundle_install + +curl -o xpu_bundle.exe --retry 3 --retry-all-errors -k %XPU_BUNDLE_URL% +echo "XPU Bundle installing..." +start /wait "Intel Pytorch Bundle Installer" "xpu_bundle.exe" --action=install --eula=accept --silent --log-dir install_bundle +if errorlevel 1 exit /b 1 +del xpu_bundle.exe + +:xpu_pti_install + +curl -o xpu_pti.exe --retry 3 --retry-all-errors -k %XPU_PTI_URL% +echo "XPU PTI installing..." +start /wait "Intel PTI Installer" "xpu_pti.exe" --action=install --eula=accept --silent --log-dir install_bundle +if errorlevel 1 exit /b 1 +del xpu_pti.exe + +:xpu_install_end diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index 6ee529f38b5c86..f0a1bb003de76d 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -11,6 +11,16 @@ on: required: true type: string description: What CUDA version to build with, "cpu" for none. + use-xpu: + required: false + type: boolean + default: false + description: If set, build with XPU support. + vc-year: + required: false + type: string + default: "2019" + description: The Visual Studio year to use for building. build-with-debug: required: false type: boolean @@ -141,7 +151,7 @@ jobs: SCCACHE_REGION: us-east-1 VC_PRODUCT: "BuildTools" VC_VERSION: "" - VC_YEAR: "2019" + VC_YEAR: "${{ inputs.vc-year }}" ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" AWS_DEFAULT_REGION: us-east-1 PR_NUMBER: ${{ github.event.pull_request.number }} @@ -149,6 +159,7 @@ jobs: DEBUG: ${{ inputs.build-with-debug && '1' || '0' }} TORCH_CUDA_ARCH_LIST: "8.6" USE_CUDA: ${{ inputs.cuda-version != 'cpu' && '1' || '0' }} + USE_XPU: ${{ inputs.use-xpu == true && '1' || '0' }} OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} run: | .ci/pytorch/win-build.sh diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index 81fe225c5fc48a..17fd3e4dfc6b71 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -49,3 +49,12 @@ jobs: build-environment: linux-jammy-xpu-py3.9 docker-image: ${{ needs.linux-jammy-xpu-py3_9-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-xpu-py3_9-build.outputs.test-matrix }} + + windows-xpu-build: + name: win-vs2022-xpu-py3 + uses: ./.github/workflows/_win-build.yml + with: + build-environment: win-vs2022-xpu-py3 + cuda-version: cpu + use-xpu: true + vc-year: '2022' From 034c0d8c13eb7caa11ada9c306688c7ec3eac189 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Thu, 5 Sep 2024 00:07:31 +0000 Subject: [PATCH 0436/1018] [CUDA][AMP] Fix autocast_dtype (#133938) Fixes #132715 The failure in #132715 is due to `autocast_dtype` being a thread-local variable. It causes inconsistencies between `get_autocast_dtype()` among different threads. To be exact, what is happening in the following: The amp dtype is set to `bfloat16` on main thread. The `backward` call runs on a side thread, so `at::autocast::prioritize` fails because `lower_precision_fp` defaults to `float16`: https://github.com/pytorch/pytorch/blob/6f738d643462690670f9aedf7832cd4cbf480e80/aten/src/ATen/autocast_mode.h#L221-L225 This PR makes `autocast_dtype` thread-global so it consistent among all threads of forward and backward passes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133938 Approved by: https://github.com/soulitzer --- aten/src/ATen/ThreadLocalState.cpp | 16 ++++++++++++++-- aten/src/ATen/ThreadLocalState.h | 7 +++++++ test/test_autocast.py | 26 ++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index c22f07866f7124..f1ec1a37bf82a6 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -1,6 +1,7 @@ #include -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) +#include #include #endif @@ -18,7 +19,13 @@ ThreadLocalState::ThreadLocalState() torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()), saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()), - saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {} + saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) { +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) + for(uint8_t i=0; i(i)); + } +#endif +} void ThreadLocalState::set_grad_mode(bool enabled) { autograd_tls_.set_grad_mode(enabled); @@ -54,6 +61,11 @@ void ThreadLocalState::setThreadLocalState( at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_); at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_); +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) + for(uint8_t i=0; i(i), state.autocast_dtypes_[i]); + } +#endif } } // namespace at diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 8419499c3a563c..721ea9957513bf 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -78,6 +78,13 @@ class TORCH_API ThreadLocalState { // TLS for arbitrary python objects that is registered via hooks at::impl::ThreadLocalPythonObjects saved_objects_; +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ + !defined(BUILD_LITE_INTERPRETER) + // TLS for autocast dtypes + std::array + autocast_dtypes_; +#endif + friend class ThreadLocalStateGuard; }; diff --git a/test/test_autocast.py b/test/test_autocast.py index 8e702f92296c13..0866c5c865bd91 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -273,6 +273,32 @@ def test_cache_disabled(self): finally: torch._C._set_cached_tensors_enabled(False) + # index_put under AMP follows a cast policy called "promote", + # https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230 + # That means: + # (1) double precision is ignored, + # (2) if any argument is float, then all arguments are promoted to float, + # (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype. + # Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution, + # and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward + # pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied. + # For more info see https://github.com/pytorch/pytorch/issues/132715. + def test_autocast_prioritize(self): + device = "cuda" + dtype = torch.bfloat16 + + with torch.autocast(device_type=device, enabled=True, dtype=dtype): + t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True) + index = torch.randint( + low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device + ) + val = torch.randn(1, dtype=dtype, device=device) + + res = torch.index_put(t, [index], val) + + loss = res.mean() + loss.backward() + class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self): From dae51dd0304a29b1d95942394328540d8f8d36ae Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 4 Sep 2024 15:12:00 -0700 Subject: [PATCH 0437/1018] Fix: use clone_preserve_strides in auto_functionalized_v2 (#135142) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135142 Approved by: https://github.com/zou3519 ghstack dependencies: #134409 --- test/inductor/test_auto_functionalize.py | 41 ++++++++++++++++--- torch/_higher_order_ops/auto_functionalize.py | 2 +- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index f705a4fe8a8eb2..b1e3ab1f8ad777 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -976,11 +976,15 @@ def forward(self, arg0_1: "f32[2][1]cpu"): def forward(self, arg0_1: "f32[2][1]cpu"): select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1) - clone_default: "f32[][]cpu" = torch.ops.aten.clone.default(select); select = None - clone_default_1: "f32[][]cpu" = torch.ops.aten.clone.default(select_1); select_1 = None - foo_default = torch.ops.mylib.foo.default(clone_default, clone_default_1); foo_default = None - select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, clone_default, 0, 0); clone_default = None - select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, clone_default_1, 0, 1); select_scatter_default = clone_default_1 = None + as_strided_default: "f32[1][1]cpu" = torch.ops.aten.as_strided.default(select, [1], [1], 0); select = None + clone_default: "f32[1][1]cpu" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default, [], [], 0); clone_default = None + as_strided_default_2: "f32[2][1]cpu" = torch.ops.aten.as_strided.default(select_1, [2], [1], 0); select_1 = None + clone_default_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_default_2); as_strided_default_2 = None + as_strided_default_3: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default_1, [], [], 1); clone_default_1 = None + foo_default = torch.ops.mylib.foo.default(as_strided_default_1, as_strided_default_3); foo_default = None + select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, as_strided_default_1, 0, 0); as_strided_default_1 = None + select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, as_strided_default_3, 0, 1); select_scatter_default = as_strided_default_3 = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1); arg0_1 = select_scatter_default_1 = copy_ = None return ()""", # noqa: B950 ignore_comments=True, @@ -999,6 +1003,33 @@ def test_dynamic2_v2(self): def test_dynamic3_v2(self): self.test_auto_functionalize_extra2(_dynamic=True) + # foo takes two views on the same input, function does not have return. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_graph_input_is_view(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x): + pass + + @torch.compile(fullgraph=True, dynamic=False, backend="aot_eager") + def f(x): + a = x[0] + torch.ops.mylib.foo(a) + return + + x = torch.tensor([[1, 2], [3, 4]]) + # This would fail if auto_functionalized_v2 uses clone and not clone_preserve_strides + # to clone not-inplaced args. + f(x[1]) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index d0a25564af6fd6..8466d42ff6f435 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -631,7 +631,7 @@ def maybe_copy(i, t): if t is None: return None if i in _only_clone_these_bases: - return t.clone() + return clone_preserve_strides(t) else: return t From d087e295d75d78deac8290246e614f539c9d09f5 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 5 Sep 2024 00:43:31 +0000 Subject: [PATCH 0438/1018] Fix binary builds artifact download (#135139) By upgrading upload-artifacts action to v4.4.0 As artifact store layout is different between v3 and v4 actions and artifacts uploaded by v3 can not be downloaded by v4 Should fix`Unable to download artifact(s): Artifact not found for name: libtorch-cpu-shared-with-deps-release`, which could be seen for example [here](https://github.com/pytorch/pytorch/actions/runs/10707740040/job/29690137218#step:7:29) I.e. fix regression introduced by https://github.com/pytorch/pytorch/pull/135068 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135139 Approved by: https://github.com/atalman, https://github.com/huydhn --- .github/templates/common.yml.j2 | 2 +- .../macos_binary_build_workflow.yml.j2 | 2 +- .github/workflows/_binary-build-linux.yml | 2 +- ...rated-macos-arm64-binary-conda-nightly.yml | 8 ++-- ...rm64-binary-libtorch-cxx11-abi-nightly.yml | 2 +- ...rated-macos-arm64-binary-wheel-nightly.yml | 8 ++-- ...generated-windows-binary-conda-nightly.yml | 32 +++++++-------- ...ted-windows-binary-libtorch-debug-main.yml | 2 +- ...-windows-binary-libtorch-debug-nightly.yml | 8 ++-- ...d-windows-binary-libtorch-release-main.yml | 2 +- ...indows-binary-libtorch-release-nightly.yml | 8 ++-- ...generated-windows-binary-wheel-nightly.yml | 40 +++++++++---------- 12 files changed, 58 insertions(+), 58 deletions(-) diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index 1bbb339129212f..8db7da9456a6b9 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -1,6 +1,6 @@ {%- set upload_artifact_s3_action = "seemethere/upload-artifact-s3@v5" -%} {%- set download_artifact_s3_action = "seemethere/download-artifact-s3@v4" -%} -{%- set upload_artifact_action = "actions/upload-artifact@v3" -%} +{%- set upload_artifact_action = "actions/upload-artifact@v4.4.0" -%} {%- set download_artifact_action = "actions/download-artifact@v4.1.7" -%} {%- set timeout_minutes = 240 -%} diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index af417054ab190c..272073dd902e6d 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -101,7 +101,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: !{{ config["build_name"] }} diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index deec5e35dda219..509312c30bdfea 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -283,7 +283,7 @@ jobs: # Ensure the working directory gets chowned back to the current user docker run --rm -v "${RUNNER_TEMP}/artifacts:/v" -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' }} with: name: ${{ inputs.build_name }} diff --git a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml index 0eabbcde433606..39579b08403dae 100644 --- a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml @@ -117,7 +117,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cpu @@ -232,7 +232,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cpu @@ -347,7 +347,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cpu @@ -462,7 +462,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cpu diff --git a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml index a428191d4a85e9..0d0043cbe99352 100644 --- a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml @@ -121,7 +121,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-cxx11-abi diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 2900a33d5c6f0a..0a3716c7019b27 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -118,7 +118,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cpu @@ -234,7 +234,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_10-cpu @@ -350,7 +350,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_11-cpu @@ -466,7 +466,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cpu diff --git a/.github/workflows/generated-windows-binary-conda-nightly.yml b/.github/workflows/generated-windows-binary-conda-nightly.yml index 5d3925fb98eaf9..bcadb5d0fc4507 100644 --- a/.github/workflows/generated-windows-binary-conda-nightly.yml +++ b/.github/workflows/generated-windows-binary-conda-nightly.yml @@ -132,7 +132,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cpu @@ -378,7 +378,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cuda11_8 @@ -626,7 +626,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cuda12_1 @@ -874,7 +874,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cuda12_4 @@ -1121,7 +1121,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cpu @@ -1367,7 +1367,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cuda11_8 @@ -1615,7 +1615,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cuda12_1 @@ -1863,7 +1863,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cuda12_4 @@ -2110,7 +2110,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cpu @@ -2356,7 +2356,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cuda11_8 @@ -2604,7 +2604,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cuda12_1 @@ -2852,7 +2852,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cuda12_4 @@ -3099,7 +3099,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cpu @@ -3345,7 +3345,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cuda11_8 @@ -3593,7 +3593,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cuda12_1 @@ -3841,7 +3841,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cuda12_4 diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml index 58a10b5c0ad6d6..e11b4b9bb6961c 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml @@ -129,7 +129,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-debug diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index b9f75fc865ad17..44178c422a1b80 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -136,7 +136,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-debug @@ -394,7 +394,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda11_8-shared-with-deps-debug @@ -654,7 +654,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda12_1-shared-with-deps-debug @@ -914,7 +914,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda12_4-shared-with-deps-debug diff --git a/.github/workflows/generated-windows-binary-libtorch-release-main.yml b/.github/workflows/generated-windows-binary-libtorch-release-main.yml index 2897998822d65d..26c35089427569 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-main.yml @@ -129,7 +129,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-release diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index 885992be575a25..388c8cb5ed9379 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -136,7 +136,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-release @@ -394,7 +394,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda11_8-shared-with-deps-release @@ -654,7 +654,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda12_1-shared-with-deps-release @@ -914,7 +914,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda12_4-shared-with-deps-release diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 37e62a6cdba539..316329b46870fa 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -133,7 +133,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cpu @@ -380,7 +380,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cuda11_8 @@ -629,7 +629,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cuda12_1 @@ -878,7 +878,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cuda12_4 @@ -1125,7 +1125,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-xpu @@ -1371,7 +1371,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_10-cpu @@ -1618,7 +1618,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_10-cuda11_8 @@ -1867,7 +1867,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_10-cuda12_1 @@ -2116,7 +2116,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_10-cuda12_4 @@ -2363,7 +2363,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_10-xpu @@ -2609,7 +2609,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_11-cpu @@ -2856,7 +2856,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_11-cuda11_8 @@ -3105,7 +3105,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_11-cuda12_1 @@ -3354,7 +3354,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_11-cuda12_4 @@ -3601,7 +3601,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_11-xpu @@ -3847,7 +3847,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cpu @@ -4094,7 +4094,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cuda11_8 @@ -4343,7 +4343,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cuda12_1 @@ -4592,7 +4592,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cuda12_4 @@ -4839,7 +4839,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-xpu From 105d8bb71257cffd424dd7e3abe1d23d406fc89d Mon Sep 17 00:00:00 2001 From: fduwjj Date: Wed, 4 Sep 2024 15:20:04 -0700 Subject: [PATCH 0439/1018] [FR] Add version based logic to FR script and make traces print can be filtered (#135154) This PR makes version passing around the version, so that we can have different behaviors for different versions of FR dump. This PR also adds the logic of filtering to certain PG(desc) and ranks to show their traces. Some minor refactors to make the name more accurate and util function working. image Pull Request resolved: https://github.com/pytorch/pytorch/pull/135154 Approved by: https://github.com/wconstab --- tools/flight_recorder/components/builder.py | 28 ++++++++----- .../components/config_manager.py | 27 +++++++++++- tools/flight_recorder/components/loader.py | 9 ++-- tools/flight_recorder/components/utils.py | 42 ++++++++++++++----- tools/flight_recorder/fr_trace.py | 4 +- 5 files changed, 82 insertions(+), 28 deletions(-) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index b797d405afe754..d9eaa4c27368ab 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -20,15 +20,16 @@ Traceback, ) from tools.flight_recorder.components.utils import ( + align_trace_from_beginning, check_no_missing_dump_files, check_size_alltoall, check_version, find_coalesced_group, format_frames, + get_version_detail, just_print_entries, match_coalesced_groups, match_one_event, - sort_trace_from_beginning, ) @@ -140,6 +141,7 @@ def build_collectives( _groups: Dict[str, Group], _memberships: Dict[str, Set[Any]], _pg_guids: Dict[Tuple[str, int], str], + version: str, ) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]: """ groups, memberships are the non-flat dicts that are indexable @@ -169,6 +171,7 @@ def build_collectives( ] } """ + major_v, minor_v = get_version_detail(version) tracebacks: List[Traceback] = [] collectives: List[Collective] = [] @@ -325,8 +328,10 @@ def build_collectives( fail_check, input_numel, output_numel = check_size_alltoall( alltoall_cases ) - # We don't log the input/output sizes for alltoall so we don't consider the size mismatch as an error for now. - fail_check = False + if major_v <= 2 and minor_v <= 3: + # We don't log the input/output sizes for alltoall before v2.4, + # so we don't consider the size mismatch as an error for now. + fail_check = False if fail_check: # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. mismatch[pg_name] += 1 @@ -413,19 +418,22 @@ def build_collectives( return tracebacks, collectives, nccl_calls -def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Database: +def build_db( + details: Dict[str, Dict[str, Any]], args: argparse.Namespace, version: str +) -> Database: # temporary state used for building database entries = {} pg_config = {} - version = {} + version_by_ranks = {} for dump in details.values(): rank = dump["rank"] entries[rank] = dump["entries"] - version[rank] = dump["version"] + version_by_ranks[rank] = dump["version"] pg_config[rank] = dump["pg_config"] - check_version(version) - entries = sort_trace_from_beginning(entries) + # Ensure version is consistent across all ranks. + check_version(version_by_ranks, version) + entries = align_trace_from_beginning(entries) # flattened database groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( @@ -436,11 +444,11 @@ def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Da check_no_missing_dump_files(entries, memberships) if args.just_print_entries: - just_print_entries(entries, _groups, _memberships, _pg_guids) + just_print_entries(entries, _groups, _memberships, _pg_guids, args) sys.exit(0) tracebacks, collectives, nccl_calls = build_collectives( - entries, _groups, _memberships, _pg_guids + entries, _groups, _memberships, _pg_guids, version ) print("built collectives, nccl_calls") if args.verbose: diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index 41155fe7c777d0..1f4511d2979192 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -19,7 +19,21 @@ def __init__(self: "JobConfig"): ) self.parser.add_argument( - "-d", "--dir", help="Directory with flight recorder dumps" + "-d", "--dir", required=True, help="Directory with flight recorder dumps" + ) + self.parser.add_argument( + "--selected-ranks", + default=None, + nargs="+", + type=int, + help="List of ranks we want to show traces for.", + ) + self.parser.add_argument( + "--pg-filters", + default=None, + nargs="+", + type=str, + help="List of filter strings", ) self.parser.add_argument("-o", "--output", default=None) self.parser.add_argument( @@ -34,4 +48,13 @@ def __init__(self: "JobConfig"): def parse_args( self: "JobConfig", args: Optional[Sequence[str]] ) -> argparse.Namespace: - return self.parser.parse_args(args) + args = self.parser.parse_args(args) + if args.selected_ranks is not None: + assert ( + args.just_print_entries + ), "Not support selecting ranks without printing entries" + if args.pg_filters is not None: + assert ( + args.just_print_entries + ), "Not support selecting pg filters without printing entries" + return args diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index 854c6034906e20..19987ddc74d004 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -8,7 +8,7 @@ import os import pickle import time -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Tuple, Union def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]: @@ -35,13 +35,16 @@ def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any] } -def read_dir(prefix: str, folder: str) -> Dict[str, Dict[str, Any]]: +def read_dir(prefix: str, folder: str) -> Tuple[Dict[str, Dict[str, Any]], str]: gc.disable() details = {} t0 = time.time() + version = "" for root, _, files in os.walk(folder): for f in files: details[f] = read_dump(prefix, os.path.join(root, f)) + if not version: + version = str(details[f]["version"]) tb = time.time() print(f"loaded {len(files)} files in {tb - t0}s") - return details + return details, version diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 5725359291e69e..87e3fc6a1c9664 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import argparse import math from typing import Any, Dict, List, Set, Tuple @@ -214,21 +215,34 @@ def just_print_entries( _groups: Dict[str, Group], _memberships: Dict[str, Set[Any]], _pg_guids: Dict[Tuple[str, int], str], + args: argparse.Namespace, ) -> None: rows = [] ranks = sorted(all_entries.keys()) - headers = [f"Rank {rank}" for rank in ranks] + headers = [ + f"Rank {rank}" + for rank in ranks + if args.selected_ranks is None or rank in args.selected_ranks + ] progress = True while progress: progress = False row = [] for rank in ranks: + if args.selected_ranks is not None and rank not in args.selected_ranks: + continue if len(all_entries[rank]) == 0: row.append("") else: entry = all_entries[rank].pop(0) pg_name = _pg_guids[(entry["process_group"][0], rank)] - row.append(str(Op(entry, _memberships, pg_name))) + if ( + args.pg_filters is None + or entry["process_group"][1] in args.pg_filters + ): + row.append(str(Op(entry, _memberships, pg_name))) + else: + row.append("") progress = True if progress: rows.append(row) @@ -248,23 +262,29 @@ def check_no_missing_dump_files( ), f"Missing dump files from ranks {all_ranks - dumps_ranks}" -def check_version(versions: Dict[str, Any]) -> None: - for rank, version in versions.items(): # noqa: PERF102 - major, minor = map(int, version.split(".")) - # assert major == 2, f"Rank {rank} unsupported version {version}" - # assert minor >= 0, f"Rank {rank} unsupported version {version}" +def check_version(version_by_ranks: Dict[str, str], version: str) -> None: + for rank, v in version_by_ranks.items(): + assert ( + v == version + ), f"Rank {rank} has different version {v} from the given version {version}" + + +def get_version_detail(version: str) -> Tuple[int, int]: + version = version.split(".") + assert len(version) == 2, f"Invalid version {version}" + major, minor = map(int, version) + return major, minor -def sort_trace_from_beginning( +def align_trace_from_beginning( entries: Dict[int, List[Dict[str, Any]]] ) -> Dict[int, List[Dict[str, Any]]]: """ - Sorts the trace entries by record ID for entries. + Align the trace entries by record ID for entries. This function takes a dictionary of rank names to lists of trace entries as input. Each trace entry is a dictionary containing information about a collective operation, including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer). - The function first sorts the entries in each rank by their `record_id` values. - Then, it finds the largest starting point across all ranks by taking the maximum + The function finds the largest starting point across all ranks by taking the maximum `record_id` value of the first entry in each rank. Finally, it filters out any entries with `record_id` values less than the maximum starting point. The function returns the updated dictionary of sorted and filtered trace entries. diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py index aa5797c3bec870..298852d9bb4b5d 100644 --- a/tools/flight_recorder/fr_trace.py +++ b/tools/flight_recorder/fr_trace.py @@ -40,8 +40,8 @@ def main(args: Optional[Sequence[str]] = None) -> None: config = JobConfig() args = config.parse_args(args) - details = read_dir(args.prefix, args.dir) - db = build_db(details, args) + details, version = read_dir(args.prefix, args.dir) + db = build_db(details, args, version) if args.output: with open(args.output, "wb") as f: pickle.dump((types, db), f) From 6355ed0287244dd76b1968d71f9851d0b9600847 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 4 Sep 2024 12:17:55 -0700 Subject: [PATCH 0440/1018] [dynamo] Retire CompileProfiler (#135133) Fixes confusion in https://github.com/pytorch/pytorch/issues/113443 We have TORCH_LOGS that supersedes CompileProfiler Pull Request resolved: https://github.com/pytorch/pytorch/pull/135133 Approved by: https://github.com/ezyang ghstack dependencies: #135039, #135121, #135129, #135130 --- docs/source/torch.compiler_faq.rst | 11 --- .../source/torch.compiler_troubleshooting.rst | 19 ---- test/distributed/test_dynamo_distributed.py | 18 ++-- test/dynamo/test_misc.py | 45 +-------- torch/_dynamo/utils.py | 99 ------------------- 5 files changed, 8 insertions(+), 184 deletions(-) diff --git a/docs/source/torch.compiler_faq.rst b/docs/source/torch.compiler_faq.rst index a5883ce015bef2..b7ff8bdd1aab26 100644 --- a/docs/source/torch.compiler_faq.rst +++ b/docs/source/torch.compiler_faq.rst @@ -136,17 +136,6 @@ Why is compilation slow? as long (as many iterations) as you were running when you ran into trouble, and the profiler will accumulate statistics over this duration. -.. code-block:: python - - from torch._dynamo.utils import CompileProfiler - - def my_model(): - ... - - with CompileProfiler() as prof: - profiler_model = torch.compile(my_model, backend=prof) - profiler_model() - print(prof.report()) Why are you recompiling in production? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/torch.compiler_troubleshooting.rst b/docs/source/torch.compiler_troubleshooting.rst index 7158149c09e190..05560789f07b45 100644 --- a/docs/source/torch.compiler_troubleshooting.rst +++ b/docs/source/torch.compiler_troubleshooting.rst @@ -663,13 +663,6 @@ recompile that function (or part) up to hitting the cache limit, you will first need to determine which guard is failing and what part of your program is triggering it. -The `compile profiler `__ automates the -process of setting TorchDynamo’s cache limit to 1 and running your -program under an observation-only 'compiler' that records the causes of -any guard failures. You should be sure to run your program for at least -as long (as many iterations) as you were running when you ran into -trouble, and the profiler will accumulate statistics over this duration. - If your program exhibits a bounded amount of dynamism, you may be able to tune the TorchDynamo cache limit to allow for each variation to be compiled and cached, but if the cache limit is too high you may find the @@ -685,18 +678,6 @@ support rank-dynamism. In the meantime, setting a specific cache limit can be used in coordination with bucketing techniques to achieve an acceptable number of recompilations for some dynamic models. -.. code-block:: python - - from torch._dynamo.utils import CompileProfiler - - def my_model(): - ... - - with CompileProfiler() as prof: - profiler_model = torch.compile(my_model, backend=prof) - profiler_model() - print(prof.report()) - Accuracy Debugging ~~~~~~~~~~~~~~~~~~ diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 134370e41703c3..71f9065e7d4638 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -616,21 +616,17 @@ def test_fsdp_aot_eager(self): def test_fsdp_setattr(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) + from torch._dynamo.utils import counters + + counters.clear() m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) - prof = torch._dynamo.utils.CompileProfiler() - fsdp_m = torch.compile(fsdp_m, backend=prof, fullgraph=False) + fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - FileCheck().check("Torchdynamo Profiler Report").check( - "Graph Breaks" - ).check_not( - "setattr(FSDPManagedNNModuleVariable(MutatingModel), state, ...)" - ).check_not( - "setattr(FSDPManagedNNModuleVariable(FullyShardedDataParallel), _is_root, ...)" - ).run( - prof.report() - ) + self.assertEqual(len(counters["graph_break"]), 1) + first_graph_break = list(counters["graph_break"].keys())[0] # noqa: RUF015 + self.assertTrue("setattr" not in first_graph_break) @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 980e962e5d0524..e0581ffc34b45b 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -49,7 +49,7 @@ skipIfNotPy311, unsupported, ) -from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault +from torch._dynamo.utils import counters, ifdynstaticdefault from torch._inductor.utils import run_and_get_code from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize @@ -7887,49 +7887,6 @@ def fn(a, b): fn(torch.rand(2, 3), torch.rand(2, 3)) fn(torch.rand(2, 3), (1, 2, 3)) - @expectedFailureDynamic - @torch._dynamo.config.patch(automatic_dynamic_shapes=False) - def test_compile_profiler(self): - class Model(torch.nn.Module): - def forward(self, input): - return input + input - - model = Model() - prof = CompileProfiler() - compiled = torch.compile(model, backend=prof) - base_checker = ( - lambda: FileCheck() - .check("Torchdynamo Profiler Report") - .check("Graph Breaks") - .check("No graph breaks detected.") - .check("Recompilation") - ) - input = torch.rand((2, 3, 4)) - _ = compiled(input) - base_checker().check("No recompilation detected.").run(prof.report()) - - new_shape_input = torch.rand((3, 3, 4)) - _ = compiled(new_shape_input) - - # Not an exhaustive test of dynamic shapes behavior, but some sanity - if torch._dynamo.config.assume_static_by_default: - base_checker().check("Recompile Reasons").check("'forward'").check( - "cache_size_limit to 1" - ).run(prof.report()) - else: - base_checker().check("No recompilation detected.").run(prof.report()) - - new_shape_input = torch.rand((4, 3, 4)) - _ = compiled(new_shape_input) - - base_checker().check("Recompile Reasons").check("'forward'").check( - "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3" - ).check( - "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4" - ).run( - prof.report() - ) - def test_guards_strip_function_call(self): from torch._dynamo.guards import strip_function_call diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index fc062f57026357..627ada1ff88a1d 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -21,7 +21,6 @@ import os import re import sys -import textwrap import threading import time import types @@ -1911,104 +1910,6 @@ def disable_cache_limit(): seen_code_map = ExactWeakKeyDictionary() -class CompileProfiler: - """Utility for profiling how and what dynamo would compile. - - Can be used for - * diagnosing recompilation issues - * determining an appropriate compile cache limit - * (TODO)confirming which functions got compiled/skipped - """ - - def __init__(self): - self.frame_count = 0 - self.op_count = 0 - self.backend_ctx_ctor = disable_cache_limit - - def __call__(self, gm: torch.fx.GraphModule, example_inputs): - self.frame_count += 1 - for node in gm.graph.nodes: - if "call" in node.op: - self.op_count += 1 - return gm.forward - - # no-op __enter__ and __exit__ to preserve BC - def __enter__(self): - return self - - def __exit__(self, typ, val, traceback): - pass - - def get_metrics(self): - return {"guard_failures": guard_failures} - - def report(self): - metrics = self.get_metrics() - gf = metrics["guard_failures"] - - def num_recompiles(code): - return len(gf[code]) - - def recompile_reasons(code): - return "\n".join([str(x) for x in gf[code]]) - - summarized_gf = [ - [format_func_info(code), num_recompiles(code), recompile_reasons(code)] - for code in gf - ] - - def graph_break_report(): - if "graph_break" in counters: - graph_breaks = counters["graph_break"] - return tabulate( - [[msg, graph_breaks[msg]] for msg in graph_breaks], - headers=["Graph Break Reason", "Count"], - ) - - def recompilation_report(): - if len(gf): - max_recompiles = max(num_recompiles(code) for code in gf) - recomp_table = tabulate( - summarized_gf, - headers=["Function", "Recompiles", "Recompile Reasons"], - ) - return recomp_table + textwrap.dedent( - f""" - - Set torch._dynamo.config.cache_size_limit to {max_recompiles} to avoid being cache limited. - """ - ) - - report = textwrap.dedent( - """ - Torchdynamo Profiler Report - =========================== - - Graph Breaks - ------------ - Graph breaks happen when torchdynamo encounters code it can't safely trace. - If you want to find out why breaks are happening, check below for each break reason - You may gain additional insight by passing `fullgraph=True` to torch.compile, - to stop at the first break. - - """ - ) - report += graph_break_report() or "No graph breaks detected." - report += textwrap.dedent( - """ - - Recompilation - ------------- - These subgraphs were recompiled more than once due to guard failures - Guard failures indicate some condition assumed to be static by the tracer changed, - making it unsafe to reuse the compiled program. - - """ - ) - report += recompilation_report() or "No recompilation detected.\n" - return report - - # return same dir unless user changes config between calls @functools.lru_cache(None) def _get_debug_dir(root_dir): From 286ff30ddf992cc4a1e6d7639e1e1dd351460b4e Mon Sep 17 00:00:00 2001 From: FFFrog Date: Wed, 4 Sep 2024 14:47:24 +0000 Subject: [PATCH 0441/1018] Remove redundant code (#134955) Remove GetPrivateUse1HooksInterface Pull Request resolved: https://github.com/pytorch/pytorch/pull/134955 Approved by: https://github.com/Skylion007 --- aten/src/ATen/Context.h | 6 +++--- aten/src/ATen/EmptyTensor.cpp | 2 +- aten/src/ATen/detail/PrivateUse1HooksInterface.cpp | 7 ------- aten/src/ATen/detail/PrivateUse1HooksInterface.h | 14 +++++++------- aten/src/ATen/native/Resize.cpp | 2 +- .../cpp_extensions/open_registration_extension.cpp | 2 +- 6 files changed, 13 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 370400b4fa036c..d46abc2e211a9f 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -52,7 +52,7 @@ class TORCH_API Context { } else if (device_type == at::kIPU) { return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index()); } else if (device_type == at::kPrivateUse1) { - return at::GetPrivateUse1HooksInterface()->getDefaultGenerator( + return at::detail::getPrivateUse1Hooks().getDefaultGenerator( device.index()); } else { AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); @@ -91,7 +91,7 @@ class TORCH_API Context { } else if (device_type == at::kXPU) { return at::detail::getXPUHooks().getDeviceFromPtr(data); } else if (device_type == at::kPrivateUse1) { - return at::GetPrivateUse1HooksInterface()->getDeviceFromPtr(data); + return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data); } else { AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); } @@ -182,7 +182,7 @@ class TORCH_API Context { void lazyInitPrivateUse1() { c10::call_once(thp_init, [&] { if (isPrivateUse1HooksRegistered()) { - at::GetPrivateUse1HooksInterface()->initPrivateUse1(); + at::detail::getPrivateUse1Hooks().initPrivateUse1(); } }); } diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index a5e096ba2b79c7..a9d37cff78ca31 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -21,7 +21,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { } else if (at::globalContext().hasXPU()) { return at::detail::getXPUHooks().getPinnedMemoryAllocator(); } else if(at::isPrivateUse1HooksRegistered()) { - return at::GetPrivateUse1HooksInterface()->getPinnedMemoryAllocator(); + return at::detail::getPrivateUse1Hooks().getPinnedMemoryAllocator(); } else { TORCH_CHECK(false, "Need to provide pin_memory allocator to use pin memory.") } diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp b/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp index ff267a41506bb2..258d05f87e6d2f 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp @@ -11,13 +11,6 @@ TORCH_API void RegisterPrivateUse1HooksInterface(at::PrivateUse1HooksInterface* privateuse1_hooks = hook_; } -TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface() { - TORCH_CHECK( - privateuse1_hooks != nullptr, - "Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first."); - return privateuse1_hooks; -} - TORCH_API bool isPrivateUse1HooksRegistered() { return privateuse1_hooks != nullptr; } diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.h b/aten/src/ATen/detail/PrivateUse1HooksInterface.h index 62fcb75d2a5ef2..e321f484deeace 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.h +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.h @@ -12,7 +12,7 @@ namespace at { struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { ~PrivateUse1HooksInterface() override = default; virtual const at::Generator& getDefaultGenerator( - c10::DeviceIndex device_index) { + c10::DeviceIndex device_index) const { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`."); @@ -24,24 +24,26 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`."); } - bool isPinnedPtr(const void* data) const override { + virtual bool isPinnedPtr(const void* data) const override { return false; } - Allocator* getPinnedMemoryAllocator() const override { + virtual Allocator* getPinnedMemoryAllocator() const override { TORCH_CHECK( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`."); } - bool hasPrimaryContext(DeviceIndex device_index) const override { + virtual bool hasPrimaryContext(DeviceIndex device_index) const override { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`."); } virtual void initPrivateUse1() const {} - virtual void resizePrivateUse1Bytes(const c10::Storage &storage, size_t newsize) const { + virtual void resizePrivateUse1Bytes( + const c10::Storage& storage, + size_t newsize) const { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`."); @@ -53,8 +55,6 @@ struct TORCH_API PrivateUse1HooksArgs {}; TORCH_API void RegisterPrivateUse1HooksInterface( at::PrivateUse1HooksInterface* hook_); -TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface(); - TORCH_API bool isPrivateUse1HooksRegistered(); namespace detail { diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index f4d895c2c58584..e4d63de7f4d6d8 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -282,7 +282,7 @@ void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& newsize) { } else if (device_type == at::kMeta) { at::native::resize_bytes_meta(storage.unsafeGetStorageImpl(), newsize); } else if (device_type == at::kPrivateUse1) { - at::GetPrivateUse1HooksInterface()->resizePrivateUse1Bytes( + at::detail::getPrivateUse1Hooks().resizePrivateUse1Bytes( storage, newsize.expect_int()); } else if (device_type == at::kXPU || device_type == at::kHPU) { ptrdiff_t size_bytes_i = newsize.expect_int(); diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index 0b435ddb928b44..f857aecc657cc8 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -586,7 +586,7 @@ struct FooHooksArgs : public at::PrivateUse1HooksArgs {}; struct FooHooksInterface : public at::PrivateUse1HooksInterface { FooHooksInterface(FooHooksArgs) {} ~FooHooksInterface() override = default; - const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) override { + const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override { static auto device_gen = make_generator_privateuse1(device_index); return device_gen; } From dd166205514de6e8d2d53f3dcd72a85a8cc60bdc Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 4 Sep 2024 15:10:53 -0700 Subject: [PATCH 0442/1018] [FlexAttention] Fix mismatched backward strides for eager impl (#135152) # Fixes: The first repro from: https://github.com/pytorch/pytorch/issues/134888 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135152 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 28 +++++++++++++++++++++++ torch/_higher_order_ops/flex_attention.py | 11 ++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 7e3df0c75ac62a..0ba2d1ca08152f 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1375,6 +1375,34 @@ def test_aot_eager_gradcheck(self, score_mod): ) ) + @supported_platform + def test_eager_backward_strides(self): + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + self.qkv_proj = torch.nn.Linear(256, 256 * 3) + self.n_head = 256 // 64 + self.d_attn = 256 + + def forward(self, x): + n_batch, n_ctx, _ = x.shape + q, k, v = self.qkv_proj(x).split( + [self.d_attn, self.d_attn, self.d_attn], dim=2 + ) + q = q.reshape(n_batch, n_ctx, self.n_head, -1) + k = k.reshape(n_batch, n_ctx, self.n_head, -1) + v = v.reshape(n_batch, n_ctx, self.n_head, -1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + x = torch.nn.attention.flex_attention.flex_attention(q, k, v) + return x + + model = Repro().cuda() + x = torch.randn((1, 512, 256), device="cuda", requires_grad=True) + out = torch.compile(model, backend="aot_eager")(x) + out.backward(torch.ones_like(out)) + @supported_platform def test_differentiable_logsumexp_gradcheck(self): make_tensor = functools.partial( diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 4890fa499048c3..cd1c9d4fd92c5e 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -686,6 +686,11 @@ def sdpa_dense_backward( score_mod_other_buffers: Tuple, mask_mod_other_buffers: Tuple, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Get outputs before calling repeat interleave + actual_grad_query = torch.empty_like(query) + actual_grad_key = torch.empty_like(key) + actual_grad_value = torch.empty_like(value) + G = query.size(1) // key.size(1) key = torch.repeat_interleave(key, G, dim=1) value = torch.repeat_interleave(value, G, dim=1) @@ -767,7 +772,11 @@ def sdpa_dense_backward( grad_key = torch.sum(grad_key, 2, keepdim=False) grad_value = torch.sum(grad_value, 2, keepdim=False) - return grad_query.contiguous(), grad_key.contiguous(), grad_value.contiguous() + actual_grad_query.copy_(grad_query) + actual_grad_key.copy_(grad_key) + actual_grad_value.copy_(grad_value) + + return actual_grad_query, actual_grad_key, actual_grad_value def trace_flex_attention_backward( From a8afac0d3dad5c6da36904fc47230be9cbf0b41b Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov Date: Thu, 5 Sep 2024 01:17:52 +0000 Subject: [PATCH 0443/1018] Update current scripts used for setting up s390x runners (#129866) Update current scripts used for setting up s390x runners Just a documentation update. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129866 Approved by: https://github.com/malfet, https://github.com/huydhn --- .github/scripts/s390x-ci/README.md | 39 +++++++-- .../actions-runner.Dockerfile | 33 ++++---- .../actions-runner@.service | 6 +- .../fs/usr/bin/actions-runner | 42 +++++++++- .../self-hosted-builder/helpers/app_token.sh | 84 +++++++++++++++++++ .../helpers/gh_token_generator.sh | 10 +++ 6 files changed, 191 insertions(+), 23 deletions(-) create mode 100755 .github/scripts/s390x-ci/self-hosted-builder/helpers/app_token.sh create mode 100755 .github/scripts/s390x-ci/self-hosted-builder/helpers/gh_token_generator.sh diff --git a/.github/scripts/s390x-ci/README.md b/.github/scripts/s390x-ci/README.md index f62b02e24aa3e9..94e4a85be43ceb 100644 --- a/.github/scripts/s390x-ci/README.md +++ b/.github/scripts/s390x-ci/README.md @@ -3,7 +3,7 @@ ## Install prerequisites. ``` -$ sudo dnf install docker +$ sudo dnf install podman podman-docker jq ``` ## Add services. @@ -27,23 +27,48 @@ $ sudo systemctl enable --now qemu-user-static ## Rebuild the image -In order to build or update the `iiilinuxibmcom/actions-runner` image, e.g. to get the -latest OS security fixes, use the following commands: +First build s390x builder image `docker.io/pytorch/manylinuxs390x-builder`, +using following commands: + +``` +$ cd ~ +$ git clone https://github.com/pytorch/pytorch +$ cd pytorch +$ git submodule update --init --recursive +$ GPU_ARCH_TYPE=cpu-s390x "$(pwd)/.ci/docker/manywheel/build.sh" manylinuxs390x-builder +$ docker image tag localhost/pytorch/manylinuxs390x-builder docker.io/pytorch/manylinuxs390x-builder:cpu-s390x +$ docker image save -o ~/manywheel-s390x.tar docker.io/pytorch/manylinuxs390x-builder:cpu-s390x +``` + +Next step is to build `actions-runner` image using: ``` $ cd self-hosted-builder $ sudo docker build \ - --build-arg repo=/ \ - --build-arg token=<***> \ --pull \ -f actions-runner.Dockerfile \ - -t iiilinuxibmcom/actions-runner \ + -t iiilinuxibmcom/actions-runner. \ . ``` -If it fails, ensure that selinux doesn't prevent it from working. +If there are failures, ensure that selinux doesn't prevent it from working. In worst case, selinux can be disabled with `setenforce 0`. +Now prepare all necessary files for runner registration: + +``` +$ sudo mkdir -p /etc/actions-runner/ +$ sudo chmod 700 /etc/actions-runner/ +$ sudo /bin/cp /etc/actions-runner//key_private.pem +$ sudo echo | sudo tee /etc/actions-runner//appid.env +$ sudo echo | sudo tee /etc/actions-runner//installid.env +$ sudo echo NAME= | sudo tee /etc/actions-runner//env +$ sudo echo ORG= | sudo tee -a /etc/actions-runner//env +$ cd self-hosted-builder +$ sudo /bin/cp helpers/*.sh /usr/local/bin/ +$ sudo chmod 755 /usr/local/bin/app_token.sh /usr/local/bin/gh_token_generator.sh +``` + ## Autostart the runner. ``` diff --git a/.github/scripts/s390x-ci/self-hosted-builder/actions-runner.Dockerfile b/.github/scripts/s390x-ci/self-hosted-builder/actions-runner.Dockerfile index 416a6d8e50df5e..ee1db829fe66ce 100644 --- a/.github/scripts/s390x-ci/self-hosted-builder/actions-runner.Dockerfile +++ b/.github/scripts/s390x-ci/self-hosted-builder/actions-runner.Dockerfile @@ -1,12 +1,12 @@ # Self-Hosted IBM Z Github Actions Runner. # Temporary image: amd64 dependencies. -FROM docker.io/amd64/ubuntu:22.04 as ld-prefix +FROM docker.io/amd64/ubuntu:23.10 as ld-prefix ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get -y install ca-certificates libicu70 libssl3 +RUN apt-get update && apt-get -y install ca-certificates libicu72 libssl3 # Main image. -FROM docker.io/s390x/ubuntu:22.04 +FROM docker.io/s390x/ubuntu:23.10 # Packages for pytorch building and testing. ENV DEBIAN_FRONTEND=noninteractive @@ -16,6 +16,7 @@ RUN apt-get update && apt-get -y install \ gcc \ git \ jq \ + zip \ libxml2-dev \ libxslt-dev \ ninja-build \ @@ -43,24 +44,28 @@ COPY fs/ / RUN chmod +x /usr/bin/actions-runner /usr/bin/entrypoint +# install podman +RUN apt -y install podman podman-docker + # amd64 Github Actions Runner. RUN useradd -m actions-runner USER actions-runner WORKDIR /home/actions-runner -RUN curl -L https://github.com/actions/runner/releases/download/v2.309.0/actions-runner-linux-x64-2.309.0.tar.gz | tar -xz -# repository -ARG repo +# set up python virtual environment which is later used by runner. +# build workflows use "python -m pip install ...", +# and it doesn't work for non-root user +RUN virtualenv --system-site-packages venv -# repository token -ARG token +# copy prebuilt manywheel docker image for builds and tests +# build command is: +# GPU_ARCH_TYPE=cpu-s390x "$(pwd)/manywheel/build_docker.sh" +# and save command is: +# docker image save -o manywheel-s390x.tar pytorch/manylinuxs390x-builder:cpu-s390x +# +COPY --chown=actions-runner:actions-runner manywheel-s390x.tar /home/actions-runner/manywheel-s390x.tar -RUN ./config.sh \ - --unattended \ - --url "https://github.com/${repo}" \ - --token "${token}" \ - --no-default-labels \ - --labels self-hosted,linux.s390x +RUN curl -L https://github.com/actions/runner/releases/download/v2.317.0/actions-runner-linux-x64-2.317.0.tar.gz | tar -xz ENTRYPOINT ["/usr/bin/entrypoint"] CMD ["/usr/bin/actions-runner"] diff --git a/.github/scripts/s390x-ci/self-hosted-builder/actions-runner@.service b/.github/scripts/s390x-ci/self-hosted-builder/actions-runner@.service index 158be9ccb6c1d0..323b00edc178bc 100644 --- a/.github/scripts/s390x-ci/self-hosted-builder/actions-runner@.service +++ b/.github/scripts/s390x-ci/self-hosted-builder/actions-runner@.service @@ -8,12 +8,16 @@ StartLimitIntervalSec=0 Type=simple Restart=always ExecStartPre=-/usr/bin/docker rm --force actions-runner.%i +ExecStartPre=-/usr/local/bin/gh_token_generator.sh /etc/actions-runner/%i/appid.env /etc/actions-runner/%i/installid.env /etc/actions-runner/%i/key_private.pem /etc/actions-runner/%i/ghtoken.env ExecStart=/usr/bin/docker run \ + --env-file=/etc/actions-runner/%i/env \ + --env-file=/etc/actions-runner/%i/ghtoken.env \ --init \ --interactive \ --name=actions-runner.%i \ --rm \ - iiilinuxibmcom/actions-runner + --privileged \ + iiilinuxibmcom/actions-runner.%i ExecStop=/bin/sh -c "docker exec actions-runner.%i kill -INT -- -1" ExecStop=/bin/sh -c "docker wait actions-runner.%i" ExecStop=/bin/sh -c "docker rm actions-runner.%i" diff --git a/.github/scripts/s390x-ci/self-hosted-builder/fs/usr/bin/actions-runner b/.github/scripts/s390x-ci/self-hosted-builder/fs/usr/bin/actions-runner index 760784b21c3966..6d129d8656944b 100644 --- a/.github/scripts/s390x-ci/self-hosted-builder/fs/usr/bin/actions-runner +++ b/.github/scripts/s390x-ci/self-hosted-builder/fs/usr/bin/actions-runner @@ -2,5 +2,45 @@ set -e -u +# first import docker image +if [ -f ./manywheel-s390x.tar ] ; then + docker image load --input manywheel-s390x.tar + docker image tag docker.io/pytorch/manylinuxs390x-builder:cpu-s390x docker.io/pytorch/manylinuxs390x-builder:cpu-s390x-main + rm -f manywheel-s390x.tar +fi + +token_file=registration-token.json + +# Generate registration token +curl \ + -X POST \ + -H "Accept: application/vnd.github.v3+json" \ + -H "Authorization: Bearer ${ACCESS_TOKEN}" \ + "https://api.github.com/orgs/${ORG}/actions/runners/registration-token" \ + -o "$token_file" + +unset ACCESS_TOKEN + +# register runner as ephemeral runner +# it does one job, stops and unregisters +registration_token=$(jq --raw-output .token "$token_file") + +./config.sh \ + --unattended \ + --ephemeral \ + --url "https://github.com/${ORG}" \ + --token "${registration_token}" \ + --name "${NAME}" \ + --no-default-labels \ + --labels self-hosted,linux.s390x + +unset registration_token +rm -f "$token_file" + +# enter into python virtual environment. +# build workflows use "python -m pip install ...", +# and it doesn't work for non-root user +source venv/bin/activate + # Run one job. -./run.sh --once +./run.sh diff --git a/.github/scripts/s390x-ci/self-hosted-builder/helpers/app_token.sh b/.github/scripts/s390x-ci/self-hosted-builder/helpers/app_token.sh new file mode 100755 index 00000000000000..eb483c197772ad --- /dev/null +++ b/.github/scripts/s390x-ci/self-hosted-builder/helpers/app_token.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +# +# Request an ACCESS_TOKEN to be used by a GitHub APP +# Environment variable that need to be set up: +# * APP_ID, the GitHub's app ID +# * INSTALL_ID, the Github's app's installation ID +# * APP_PRIVATE_KEY, the content of GitHub app's private key in PEM format. +# +# https://github.com/orgs/community/discussions/24743#discussioncomment-3245300 +# + +set -o pipefail + +_GITHUB_HOST=${GITHUB_HOST:="github.com"} + +# If URL is not github.com then use the enterprise api endpoint +if [[ ${GITHUB_HOST} = "github.com" ]]; then + URI="https://api.${_GITHUB_HOST}" +else + URI="https://${_GITHUB_HOST}/api/v3" +fi + +API_VERSION=v3 +API_HEADER="Accept: application/vnd.github.${API_VERSION}+json" +CONTENT_LENGTH_HEADER="Content-Length: 0" +APP_INSTALLATIONS_URI="${URI}/app/installations" + + +# JWT parameters based off +# https://docs.github.com/en/developers/apps/building-github-apps/authenticating-with-github-apps#authenticating-as-a-github-app +# +# JWT token issuance and expiration parameters +JWT_IAT_DRIFT=60 +JWT_EXP_DELTA=600 + +JWT_JOSE_HEADER='{ + "alg": "RS256", + "typ": "JWT" +}' + + +build_jwt_payload() { + now=$(date +%s) + iat=$((now - JWT_IAT_DRIFT)) + jq -c \ + --arg iat_str "${iat}" \ + --arg exp_delta_str "${JWT_EXP_DELTA}" \ + --arg app_id_str "${APP_ID}" \ + ' + ($iat_str | tonumber) as $iat + | ($exp_delta_str | tonumber) as $exp_delta + | ($app_id_str | tonumber) as $app_id + | .iat = $iat + | .exp = ($iat + $exp_delta) + | .iss = $app_id + ' <<< "{}" | tr -d '\n' +} + +base64url() { + base64 | tr '+/' '-_' | tr -d '=\n' +} + +rs256_sign() { + openssl dgst -binary -sha256 -sign <(echo "$1") +} + +request_access_token() { + jwt_payload=$(build_jwt_payload) + encoded_jwt_parts=$(base64url <<<"${JWT_JOSE_HEADER}").$(base64url <<<"${jwt_payload}") + encoded_mac=$(echo -n "$encoded_jwt_parts" | rs256_sign "${APP_PRIVATE_KEY}" | base64url) + generated_jwt="${encoded_jwt_parts}.${encoded_mac}" + + auth_header="Authorization: Bearer ${generated_jwt}" + + app_installations_response=$(curl -sX POST \ + -H "${auth_header}" \ + -H "${API_HEADER}" \ + --header "X-GitHub-Api-Version: 2022-11-28" \ + --url "https://api.github.com/app/installations/${INSTALL_ID}/access_tokens" \ + ) + echo "$app_installations_response" | jq --raw-output '.token' +} + +request_access_token diff --git a/.github/scripts/s390x-ci/self-hosted-builder/helpers/gh_token_generator.sh b/.github/scripts/s390x-ci/self-hosted-builder/helpers/gh_token_generator.sh new file mode 100755 index 00000000000000..8f16974423dd77 --- /dev/null +++ b/.github/scripts/s390x-ci/self-hosted-builder/helpers/gh_token_generator.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +SCRIPT_DIR=$(dirname "$0") +APP_ID=$1 +INSTALL_ID=$2 +APP_PRIVATE_KEY=$3 +DST_FILE="$4" + +ACCESS_TOKEN="$(APP_ID="$(<"${APP_ID}")" INSTALL_ID="$(<"${INSTALL_ID}")" APP_PRIVATE_KEY="$(<"${APP_PRIVATE_KEY}")" "${SCRIPT_DIR}/app_token.sh")" +echo "ACCESS_TOKEN=${ACCESS_TOKEN}" > "${DST_FILE}" From 8290e9e84afe854a0be61a7ed263a497319fd1c2 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 4 Sep 2024 10:08:58 -0700 Subject: [PATCH 0444/1018] [FSDP] casting input args with dataclass(frozen=True) (#135067) resolve: https://github.com/pytorch/pytorch/pull/135029 when enabling mixed precision, FSDP cast input args to desired dtype by calling `_apply_to_tensors`. When input args has `dataclass(frozen=True)`, we hit following runtime error, because of using `setattr` in `_apply_to_tensors` `dataclasses.FrozenInstanceError: cannot assign to field 'some_key'`. The fix is to use dataclasses api `dataclasses.replace` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135067 Approved by: https://github.com/awgu --- test/distributed/fsdp/test_utils.py | 13 +++++++++++-- torch/distributed/utils.py | 8 ++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index e36834617aa072..adc338dcf9a97f 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -55,7 +55,13 @@ def get_a_tensor(): return t @dataclass - class SomeDataClass: + class NonFrozenDataClass: + some_key: str + some_float: float + some_tensor: List[torch.Tensor] + + @dataclass(frozen=True) + class FrozenDataClass: some_key: str some_float: float some_tensor: List[torch.Tensor] @@ -65,7 +71,10 @@ class SomeDataClass: data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) data.insert(0, {"x", get_a_tensor(), get_a_tensor()}) data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2})) - data.append({"abc": SomeDataClass("some_key", 1.0, [get_a_tensor()])}) + data.append( + {"non_frozen_ds": NonFrozenDataClass("some_key", 1.0, [get_a_tensor()])} + ) + data.append({"frozen_ds": FrozenDataClass("some_key", 1.0, [get_a_tensor()])}) od = OrderedDict() od["k"] = "value" data.append(od) diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index a4502a4b52e48e..faacd059f899e1 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -232,10 +232,10 @@ def apply(x): return fn(x) elif hasattr(x, "__dataclass_fields__"): dc = dataclasses.replace(x) - for f in dataclasses.fields(dc): - name = f.name - setattr(dc, name, apply(getattr(dc, name))) - return dc + changes = { + f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) + } + return dataclasses.replace(dc, **changes) elif isinstance(x, OrderedDict): od = x.__class__() for key, value in x.items(): From 4ee3c61c767511bb43861b8dd46899b7f0068d74 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Wed, 4 Sep 2024 15:17:24 -0700 Subject: [PATCH 0445/1018] Fix typo in comment (#135111) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135111 Approved by: https://github.com/aorenste, https://github.com/oulgen --- torch/_functorch/_aot_autograd/autograd_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 2db43bc865f0bf..17b619519c158d 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -256,7 +256,7 @@ def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGra # [Note: AOTAutogradCache and FXGraphCache Guard interactions] # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. - # he invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly + # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly # the same as the ones it passes to inductor, for both the forward and backward passes. # (This does not mean that the tensor values passed in are the same: only that their symints are). # That is, AOTAutograd and Inductor never create new guards based on symints with different sources From 6df32b089dc7377b06d9eecb58064cb4772fe8b8 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Wed, 4 Sep 2024 14:11:37 +0000 Subject: [PATCH 0446/1018] [Intel GPU]device guard codegen for XPU (#133980) This PR is a supplement to #130082. The previous PR #130082 fulfill the basic functionality of codegen, while we found it fails to handle the device sameness check in lots of uts. Current PR is aimed to facilitate the XPU device guard code generation. With current PR, the code snippet in `RegisterXPU.cpp` is as follows, where we can see the device guard is successfully generated. ```c++ namespace { at::Tensor & wrapper_XPU_Tensor_float_out_normal_out(const at::Tensor & mean, double std, ::std::optional generator, at::Tensor & out) { std::optional common_device = std::nullopt; (void)common_device; // Suppress unused variable warning c10::impl::check_and_update_common_device(common_device, out, "wrapper_XPU_Tensor_float_out_normal_out", "out"); c10::impl::check_and_update_common_device(common_device, mean, "wrapper_XPU_Tensor_float_out_normal_out", "mean"); const OptionalDeviceGuard device_guard(device_of(out)); return at::native::normal_out(mean, std, generator, out); } } // anonymous namespace ``` Nevertheless, without current change, the generated code is ```c++ namespace { at::Tensor & wrapper_XPU_Tensor_float_out_normal_out(const at::Tensor & mean, double std, ::std::optional generator, at::Tensor & out) { // No device check // DeviceGuard omitted return at::native::normal_out(mean, std, generator, out); } } // anonymous namespace ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133980 Approved by: https://github.com/EikanWang, https://github.com/malfet --- torchgen/gen.py | 3 ++- torchgen/model.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/torchgen/gen.py b/torchgen/gen.py index 27ff0c48caa3cc..e5870a24fc6684 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -59,6 +59,7 @@ is_cuda_dispatch_key, is_generic_dispatch_key, is_ufunc_dispatch_key, + is_xpu_dispatch_key, Location, NativeFunction, NativeFunctionsGroup, @@ -184,7 +185,7 @@ def parse_native_yaml_struct( use_out_as_primary=True, external=False, # Only cuda-like devices in tree require device guards - device_guard=is_cuda_dispatch_key(k), + device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k), index=v, ) return ParsedYaml(rs, indices) diff --git a/torchgen/model.py b/torchgen/model.py index 7459587e31d64a..956949343101ad 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -323,6 +323,18 @@ def is_cuda_dispatch_key(dk: DispatchKey) -> bool: } +# XPU specific dispatcy keys +def is_xpu_dispatch_key(dk: DispatchKey) -> bool: + return dk in { + DispatchKey.XPU, + DispatchKey.QuantizedXPU, + DispatchKey.SparseXPU, + DispatchKey.SparseCsrXPU, + DispatchKey.NestedTensorXPU, + DispatchKey.AutogradXPU, + } + + # Structured kernel generation is only supported for certain key types; # otherwise use old-style def is_structured_dispatch_key(dk: DispatchKey) -> bool: From 456ec8e4bfa4bc57cc1e94fde5bcd7c51ad476e0 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 4 Sep 2024 17:28:11 -0700 Subject: [PATCH 0447/1018] [Traceable FSDP2] Skip _backward_prefetch under compile, and rely on compiler pass to have prefetching (#135163) Before this PR, when traceable FSDP2 + AC is run, an error would be thrown: ``` File "/data/users/willfeng/pytorch/torch/_dynamo/variables/builtin.py", line 1449, in call_getitem return args[0].call_method(tx, "__getitem__", args[1:], kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 435, in call_method return super().call_method(tx, name, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 392, in call_method return super().call_method(tx, name, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 131, in call_method return self.getitem_const(tx, value) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/willfeng/pytorch/torch/_dynamo/variables/lists.py", line 106, in getitem_const return self.items[index] Error: Index out of bound from user code: File ".5", line 105, in forward aot0_trace_wrapped = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(aot0_tangents_1, bw_state = aot0_primals_34); aot0_tangents_1 = None File "/data/users/willfeng/pytorch/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 74, in self_invoke return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs) File "/data/users/willfeng/pytorch/torch/_dynamo/external_utils.py", line 132, in call_hook_from_backward_state return getattr(bw_state, hook_name)(*args, **kwargs) File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_state.py", line 271, in _pre_backward self._fsdp_param_group.pre_backward(default_prefetch) File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 332, in pre_backward self._backward_prefetch() File "/data/users/willfeng/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 417, in _backward_prefetch target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] ``` Since it's okay to rely on the compiler to recover the "prefetching" pattern, we will skip this `_backward_prefetch()` code path during tracing to avoid the error, and have a compiler pass (in future PR) to achieve the equivalent prefetching overlap. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135163 Approved by: https://github.com/awgu --- torch/distributed/_composable/fsdp/_fsdp_param_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 3c2f639607e14f..3cb4a31c28d2f3 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -327,7 +327,7 @@ def pre_backward(self, default_prefetch: bool, *unused: Any): self._training_state = TrainingState.PRE_BACKWARD self.unshard() # no-op if prefetched self.wait_for_unshard() - if default_prefetch: + if default_prefetch and not ca.compiled_autograd_enabled: self._backward_prefetch() def post_backward(self, *unused: Any): From 0740d3a57d8f764dfd76d8deaf74446c969c469f Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Thu, 5 Sep 2024 04:21:48 +0000 Subject: [PATCH 0448/1018] [executorch hash update] update the pinned executorch hash (#135162) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135162 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index e7bb12db3dc0a3..639bfcdd0420e2 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -d519b4d3a1ffdc81b45e2b1d4733423ce0577813 +cd1c833b079adb324871dcbbe75b43d42ffc0ade From b21110d10a785bb2fef49ae8f3dcd26342135a19 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 5 Sep 2024 05:39:44 +0000 Subject: [PATCH 0449/1018] [export] dynamic_shapes serialization, load/dump (#134718) Adds utility functions `_dump_dynamic_shapes` and `_load_dynamic_shapes`. - `_dump_dynamic_shapes`: dynamic shapes spec -> serialized format: - takes in the `dynamic_shapes` pytree object you'd feed into `export()`, and dumps into serialized format - `_load_dynamic_shapes`: serialized format -> dynamic shapes spec - takes the serialized format, and produces a `dynamic_shapes` object you feed into `export()` For example with dumping: ``` dx = Dim("dx", min=4, max=16) dy = dx + 1 inputs = ( [ torch.randn(4, 4), torch.randn(5, 4), ], torch.randn(4), torch.randn(4, 4), "hello", ) dynamic_shapes = { "a": [ (dx, 4), (dy, 4), ], "b": (Dim.AUTO,), "c": None, "d": None, } out = _dump_dynamic_shapes(dynamic_shapes, inputs) ``` would generate the following output: ``` DynamicShapesSpec( dynamic_shapes=( [ ['dx', 4], ['dx + 1', 4], ], ['_DimHint.STATIC'], ['_DimHint.STATIC', '_DimHint.STATIC'], None, ), dims={ 'dx': RootDim( min=4, max=16, derived=['dx + 1'], ), }, ) ``` The serialized format contains 2 keys, `dynamic_shapes` and `dims.` - `dynamic_shapes` is the pytree structure matching the input to `export()`, with strings in place of Dim names and enums, and ints/Nones otherwise. Each tensor is represented with a list of shapes, non-tensors with Nones. - `dims` contain min/max range and derived dims info for each root dim. The test cases show some roundtrippability guarantees for these functions. Definitely taking naming suggestions for them :) Follow up: utility function to extract serializable format from ExportedProgram. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134718 Approved by: https://github.com/avikchaudhuri --- test/export/test_export.py | 263 ++++++++++++++++++++- torch/_export/serde/dynamic_shapes.py | 321 ++++++++++++++++++++++++++ 2 files changed, 577 insertions(+), 7 deletions(-) create mode 100644 torch/_export/serde/dynamic_shapes.py diff --git a/test/export/test_export.py b/test/export/test_export.py index 5688bf6509e324..3a5f10c88ea944 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -293,19 +293,50 @@ def _test_export_same_as_eager(self, f, args, kwargs=None): # ) def _check_dynamic_shapes_specs_and_shapes( - self, model, inputs, specs, passing_shapes, failing_shapes + self, model, inputs, specs, passing_shapes, failing_shapes, test_serdes=False ): + from torch._export.serde.dynamic_shapes import ( + _dump_dynamic_shapes, + _load_dynamic_shapes, + ) + from torch.utils._pytree import tree_map + + def _construct_inputs(shapes): + def _is_tensor_leaf(x): + return isinstance(x, tuple) and all(isinstance(y, int) for y in x) + + return tree_map( + lambda x: torch.randn(*x) if _is_tensor_leaf(x) else x, + shapes, + is_leaf=_is_tensor_leaf, + ) + # exports with a list of equivalent dynamic shapes specs, # then tests for pass/fail on list of shapes for _specs in specs: ep = export(model, inputs, dynamic_shapes=_specs) - for shapes in passing_shapes: - test_inputs = (torch.randn(*shape) for shape in shapes) - ep.module()(*test_inputs) - for shapes in failing_shapes: - test_inputs = (torch.randn(*shape) for shape in shapes) - with self.assertRaises(RuntimeError): + eps = [ep] + if test_serdes: + # test dynamic shapes serialization + # test that behavior remains the same when exporting with ser/des specs: + # serialize + deserialize original specs, and export. + ep_serdes = export( + model, + inputs, + dynamic_shapes=_load_dynamic_shapes( + _dump_dynamic_shapes(_specs, inputs) + ), + ) + eps.append(ep_serdes) + + for ep in eps: + for shapes in passing_shapes: + test_inputs = _construct_inputs(shapes) ep.module()(*test_inputs) + for shapes in failing_shapes: + test_inputs = _construct_inputs(shapes) + with self.assertRaises(RuntimeError): + ep.module()(*test_inputs) def test_basic(self): class Module(torch.nn.Module): @@ -6893,6 +6924,7 @@ def forward(self, x, y, z): ((4, 4), (4, 4), (4, 3)), ((4, 4), (5, 4), (4, 5)), ], + test_serdes=True, ) # static s1 self._check_dynamic_shapes_specs_and_shapes( @@ -6913,6 +6945,7 @@ def forward(self, x, y, z): ((1, 1), (1, 1), (1, 1)), ((0, 9), (0, 9), (0, 9)), ], + test_serdes=True, ) # fully static self._check_dynamic_shapes_specs_and_shapes( @@ -6928,6 +6961,7 @@ def forward(self, x, y, z): ((1, 3), (1, 3), (1, 3)), ((0, 9), (0, 9), (0, 9)), ], + test_serdes=True, ) def test_automatic_dynamic_shapes_constant_relation(self): @@ -6955,6 +6989,7 @@ def forward(self, x, y): failing_shapes=[ ((10,), (13,)), ], + test_serdes=True, ) # static s1 should specialize s0 self._check_dynamic_shapes_specs_and_shapes( @@ -6971,6 +7006,7 @@ def forward(self, x, y): ((3,), (7,)), ((2,), (6,)), ], + test_serdes=True, ) def test_automatic_dynamic_shapes_linear_relation(self): @@ -7003,6 +7039,7 @@ def forward(self, x, y): ((34,), (8,)), ((22,), (5,)), ], + test_serdes=False, ) # static s1 shouldn't actually specialize s0 (guard: (s0 + 2) // 4 == 5) self._check_dynamic_shapes_specs_and_shapes( @@ -7021,6 +7058,7 @@ def forward(self, x, y): failing_shapes=[ ((33,), (8,)), ], + test_serdes=False, ) # but static s0 will definitely specialize s1 (guard: (21 + 2) // 4 == s1 -> 5 == s1) self._check_dynamic_shapes_specs_and_shapes( @@ -7035,7 +7073,218 @@ def forward(self, x, y): failing_shapes=[ ((22,), (5,)), ], + test_serdes=True, + ) + + def test_dynamic_shapes_serdes_generic(self): + from torch._export.serde.dynamic_shapes import ( + _dump_dynamic_shapes, + _load_dynamic_shapes, + ) + + class Foo(torch.nn.Module): + def forward(self, a, b, c, d): + if d == "hello": + x = a[0] + a[1][1:] + b = torch.cat([b, b], dim=0).reshape([-1, 1]) + return x + b, c * 2 + + # test de/serialization on some generic specs + dz = Dim("dz", min=4, max=16) + dx = 2 * dz + dy = dx + 1 + inputs = ( + [ + torch.randn(8, 4), + torch.randn(9, 4), + ], + torch.randn(4), + torch.randn(4, 4), + "hello", ) + dynamic_shapes = { + "a": [ + (dx, 4), + (dy, 4), + ], + "b": (dz,), + "c": None, + "d": None, + } + ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes) + self._check_dynamic_shapes_specs_and_shapes( + Foo(), + inputs, + [dynamic_shapes], + [ + ([(16, 4), (17, 4)], (8,), (4, 4), "hello"), + ([(24, 4), (25, 4)], (12,), (4, 4), "hello"), + ], + [ + ([(16, 4), (17, 4)], (8,), (5, 5), "hello"), + ], + test_serdes=True, + ) + self.assertExpectedInline( + _dump_dynamic_shapes(dynamic_shapes, inputs), + """DynamicShapesSpec(dynamic_shapes=([['2*dz', 4], ['2*dz + 1', 4]], ['dz'], ['_DimHint.STATIC', '_DimHint.STATIC'], None), dims={'dz': RootDim(min=4, max=16, derived=['2*dz', '2*dz + 1'])})""", + ) + self.assertExpectedInline( + _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True), + """{'dynamic_shapes': ([['2*dz', 4], ['2*dz + 1', 4]], ['dz'], ['_DimHint.STATIC', '_DimHint.STATIC'], None), 'dims': {'dz': {'min': 4, 'max': 16, 'derived': ['2*dz', '2*dz + 1']}}}""", + ) + ((dx, _), (dy, _)), (dz,), (_, _), _ = _load_dynamic_shapes( + _dump_dynamic_shapes(dynamic_shapes, inputs) + ) + self.assertEqual(dx.root, dz) + self.assertEqual(dy.root, dz) + + def test_dynamic_shapes_serdes_various(self): + # serialization for dataclass inputs, Dim.AUTO/STATIC, and kwargs + from torch._export.serde.dynamic_shapes import ( + _dump_dynamic_shapes, + _load_dynamic_shapes, + ) + + auto, static = Dim.AUTO, Dim.STATIC + + @dataclass + class Input: + a: Tensor + b: Tensor + + register_dataclass_as_pytree_node( + Input, + serialized_type_name="test_dynamic_shapes_serdes_various.Input", + ) + + class Foo(torch.nn.Module): + def forward(self, x, y, z): + return x - torch.randn(4), y.a + y.b + z[1:] + + args = (torch.randn(4, 4),) + kwargs = { + "y": Input(a=torch.randn(8, 8), b=torch.randn(8, 8)), + "z": torch.randn(9, 8), + } + dynamic_shapes = { + "x": (auto, static), + "y": [(auto, auto), (auto, auto)], + "z": (auto, 8), + } + + # dump dynamic_shapes + self.assertExpectedInline( + _dump_dynamic_shapes(dynamic_shapes, args, kwargs), + """DynamicShapesSpec(dynamic_shapes=(['_DimHint.AUTO', '_DimHint.STATIC'], [['_DimHint.AUTO', '_DimHint.AUTO'], ['_DimHint.AUTO', '_DimHint.AUTO']], ['_DimHint.AUTO', 8]), dims={})""", + ) + self.assertExpectedInline( + _dump_dynamic_shapes(dynamic_shapes, args, kwargs, to_dict=True), + """{'dynamic_shapes': (['_DimHint.AUTO', '_DimHint.STATIC'], [['_DimHint.AUTO', '_DimHint.AUTO'], ['_DimHint.AUTO', '_DimHint.AUTO']], ['_DimHint.AUTO', 8]), 'dims': {}}""", + ) + + def test_dynamic_shapes_serdes_user_errors(self): + # check error messages for dynamic shapes de/serialization + from torch._export.serde.dynamic_shapes import ( + _dump_dynamic_shapes, + _load_dynamic_shapes, + DynamicShapesSpec, + RootDim, + ) + from torch._export.serde.serialize import _dataclass_to_dict + + # this stuff should be well tested in `test_mismatched_dynamic_shapes` + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs[0]['k']` " + "is a , but `dynamic_shapes[0]['k']` is a " + ), + ): + dynamic_shapes = {"x": {"k": (Dim("dx"), Dim("dy"))}} + _dump_dynamic_shapes(dynamic_shapes, ({"k": [torch.randn(4, 4)]},)) + + # loading with from_dict=True/False + spec = DynamicShapesSpec( + dynamic_shapes=[["dx"]], + dims={"dx": RootDim(min=4, max=16, derived=[])}, + ) + spec_dict = _dataclass_to_dict(spec) + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "With from_dict=True, expected `spec` to be a dict, " + "got " + ), + ): + _load_dynamic_shapes(spec, from_dict=True) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape("Expected `spec` to be a DynamicShapesSpec, got "), + ): + _load_dynamic_shapes(spec_dict, from_dict=False) + + self.assertExpectedInline( + _load_dynamic_shapes(spec, from_dict=False), + """[[]]""", + ) + + # check incorrect info in dims + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "Expected dims in `spec['dims']` to map `min` to an int, got dx: None" + ), + ): + spec = { + "dynamic_shapes": [["dx"]], + "dims": { + "dx": { + "min": None, + "max": 4, + "derived": [], + }, + }, + } + _load_dynamic_shapes(spec, from_dict=True) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " + "got dx which is not in dict_keys(['dy'])" + ), + ): + spec = { + "dynamic_shapes": [["dx"]], + "dims": { + "dy": { + "min": 2, + "max": 4, + "derived": [], + }, + }, + } + _load_dynamic_shapes(spec, from_dict=True) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "Expected derived expressions to be linear expressions, got dx**2 + 4" + ), + ): + spec = { + "dynamic_shapes": [["dx"]], + "dims": { + "dx": { + "min": 2, + "max": 4, + "derived": ["dx**2 + 4"], + }, + }, + } + _load_dynamic_shapes(spec, from_dict=True) @testing.expectedFailureNonStrict @testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked? diff --git a/torch/_export/serde/dynamic_shapes.py b/torch/_export/serde/dynamic_shapes.py new file mode 100644 index 00000000000000..f24822d9b07d68 --- /dev/null +++ b/torch/_export/serde/dynamic_shapes.py @@ -0,0 +1,321 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch._dynamo.exc import UserError, UserErrorType +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _DerivedDim, + _Dim, + _DimHint, + _tree_map_with_path, + Dim, +) +from torch.utils._pytree import tree_map + +from .serialize import _dataclass_to_dict + + +@dataclasses.dataclass +class RootDim: + """ + This represents a _Dim object. + """ + + min: int + max: Union[int, None] + derived: List[str] + + +@dataclasses.dataclass +class DynamicShapesSpec: + """ + This stores a dynamic_shapes spec for de/serialization. + """ + + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] + dims: Dict[str, RootDim] + + +def _postprocess_serialized_shapes( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + dims: Dict[str, Dict[str, Union[int, List[str], None]]], + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, Dict[str, Any]]: + """ + Sorts dims and dumps to dictionary format. + """ + from torch.utils._sympy.numbers import int_oo + + dims = { + k: RootDim( + min=v["min"], # type: ignore[arg-type] + max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type] + derived=sorted(v["derived"]), # type: ignore[arg-type] + ) + for k, v in sorted(dims.items()) + } + spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims) + if to_dict: + return _dataclass_to_dict(spec) + else: + return spec + + +def _dump_dynamic_shapes( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, Dict[str, Any]]: + """ + Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. + Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". + Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones). + + dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export(): + - Each tensor input is represented with a list of values, non-tensor inputs with None. + - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings. + - static dimensions are represented with ints. + + dims: A dictionary mapping each symbol name to the min/max range and derived dim names. + + For example: + ``` + dx = Dim("dx", min=4, max=16) + dy = dx + 1 + + inputs = ( + [ + torch.randn(4, 4), + torch.randn(5, 4), + ], + torch.randn(4), + torch.randn(4, 4), + "hello", + ) + dynamic_shapes = { + "a": [ + (dx, 4), + (dy, 4), + ], + "b": (Dim.STATIC,), + "c": None, + "d": None, + } + out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True) + ``` + would generate the following output: + ``` + { + 'dynamic_shapes': ( + [ + ['dx', 4], + ['dx + 1', 4], + ], + ['_DimHint.STATIC'], + ['_DimHint.STATIC', '_DimHint.STATIC'], + None, + ), + 'dims': { + 'dx': { + 'min': 4, + 'max': 16, + 'derived': ['dx + 1'], + }, + }, + } + ``` + """ + dims: Dict[str, Dict[str, Any]] = {} + + def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] + """ + Helps standardize the dynamic_shapes tree structure we serialize, + returning lists for each tensor shape, handling tensor-level Nones. + """ + if not isinstance(tensor, torch.Tensor): + return None + if shape is None: + return [Dim.STATIC] * len(tensor.shape) # type: ignore[attr-defined] + + out = [] + if isinstance(shape, dict): + for i, s in enumerate(tensor.shape): + out.append(s if shape.get(i) is None else shape.get(i)) + else: + assert isinstance(shape, (tuple, list)) + for i, s in enumerate(tensor.shape): + out.append(s if shape[i] is None else shape[i]) + return out + + def _track_dim_from_dims( + val: Union[None, int, _DimHint, _Dim] + ) -> Union[None, int, str]: + """ + Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. + """ + if val is None or isinstance(val, int): # non-tensor input or static + return val + if isinstance(val, _DimHint): # store enum as string + return val.__class__.__name__ + "." + val.name + + assert isinstance(val, _Dim) + + # track root dim + root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined] + if root.__name__ not in dims: + dims[root.__name__] = { + "min": root.min, + "max": root.max, + "derived": set(), + } + + # track derived dims + if isinstance(val, _DerivedDim): + dims[root.__name__]["derived"].add(val.__name__) + + return val.__name__ + + if dynamic_shapes is None: + return {"dynamic_shapes": None, "dims": {}} + + # convert to tuple of specs, for each arg/kwarg + kwargs = kwargs or {} + if isinstance(dynamic_shapes, dict): + dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment] + dynamic_shapes = tuple(dynamic_shapes) + combined_args = tuple(args) + tuple(kwargs.values()) + + # run same check when we're processing shapes for export - is this too lazy? + _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type] + + tree_shapes = _tree_map_with_path( + _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs" + ) + serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes) + return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict) + + +def _load_dynamic_shapes( + spec: Union[DynamicShapesSpec, Dict[str, Any]], + from_dict: Optional[bool] = False, +) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]: + """ + Utility function for dynamic shapes serialization. + Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). + """ + import sympy + + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + if from_dict: + if not isinstance(spec, dict): + raise UserError( + UserErrorType.INVALID_INPUT, + f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}", + ) + if sorted(spec.keys()) != ["dims", "dynamic_shapes"]: + raise UserError( + UserErrorType.INVALID_INPUT, + "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, " + f"instead found {spec.keys()}", + ) + dims = {} + for k, v in spec["dims"].items(): + if not isinstance(k, str): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}", + ) + if sorted(v.keys()) != ["derived", "max", "min"]: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, " + f"instead found {v.keys()}", + ) + if not isinstance(v["min"], int): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}", + ) + if not isinstance(v["max"], int) or v["max"] is None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}", + ) + if not isinstance(v["derived"], list) or any( + not isinstance(d, str) for d in v["derived"] + ): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, " + f"got {k}: {v['derived']}", + ) + dims[k] = RootDim(**v) + dynamic_shapes = spec["dynamic_shapes"] + else: + if not isinstance(spec, DynamicShapesSpec): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}", + ) + dims = spec.dims + dynamic_shapes = spec.dynamic_shapes + + if dynamic_shapes is None: + return None + + dim_cache = {} + for name, info in dims.items(): + symbol = sympy.sympify(name) + if not isinstance(symbol, sympy.Symbol): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be symbols, got {name}", + ) + dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim + for _expr in info.derived: + expr = sympy.sympify(_expr) + if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions in to have {name} as the only free symbol, got {expr}", + ) + if not _is_supported_equivalence(expr): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions to be linear expressions, got {expr}", + ) + modulus, remainder = sympy.polys.polytools.div(expr, symbol) + ddim = dim_cache[name] + if modulus != 1: + ddim = int(modulus) * ddim + if remainder != 0: + ddim = ddim + int(remainder) + dim_cache[_expr] = ddim # cache derived dims + + def deserialize_shape( + val: Union[None, int, str] + ) -> Union[None, int, _Dim, _DimHint]: + if val is None or isinstance(val, int): + return val + elif val == "_DimHint.AUTO": + return _DimHint.AUTO + elif val == "_DimHint.STATIC": + return _DimHint.STATIC + if not isinstance(val, str): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, " + f" or derived expressions, got {val}", + ) + if val not in dim_cache: + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " + f"got {val} which is not in {dims.keys()}", + ) + return dim_cache[val] + + return tree_map(deserialize_shape, dynamic_shapes) From 5e55d56c7c05e19a2ff19606638458187b7f3993 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Wed, 4 Sep 2024 19:49:37 -0700 Subject: [PATCH 0450/1018] [Inductor] support masked vectorization for the tail_loop for dynamic shapes (#131745) Pull Request resolved: https://github.com/pytorch/pytorch/pull/131745 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel --- .../src/ATen/cpu/vec/vec512/vec512_bfloat16.h | 120 +++++++--------- aten/src/ATen/cpu/vec/vec512/vec512_float.h | 39 ++++-- aten/src/ATen/cpu/vec/vec_base.h | 9 +- aten/src/ATen/cpu/vec/vec_mask.h | 11 ++ test/inductor/test_torchinductor.py | 4 +- torch/_inductor/codegen/cpp.py | 132 +++++++++++------- 6 files changed, 186 insertions(+), 129 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h index d2006fcf1c4755..dcdb682c56208d 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -1348,41 +1348,20 @@ static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { // Code referred to FBGEMM: // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6 template<> -inline void transpose_mxn( - const BFloat16* src, - int64_t ld_src, - BFloat16* dst, - int64_t ld_dst) { - // Load from memory - __m512i r[32]; -#ifndef __msvc_cl__ -#pragma unroll(32) -#endif - for (int i = 0; i < 32; ++i) { - r[i] = _mm512_loadu_si512(reinterpret_cast(src + i* ld_src)); - } - - __m512i d[32]; - _transpose_mxn_half_32_32(r, d); - - // Store to dst -#ifndef __msvc_cl__ -#pragma unroll(32) -#endif - for (int i = 0; i < 32; ++i) { - _mm512_storeu_si512(dst + i* ld_dst, d[i]); - } -} - -template ::value && ((M < 32 && M != 16) || (N < 32 && N != 16)), int> = 0> -inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) { +inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst, int M, int N) { // load from src - __mmask32 src_mask = (1 << N) - 1; + TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); __m512i r[32]; int i; - for (i = 0; i < M; ++i) { - r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } } for (; i < 32; ++i) { r[i] = _mm512_setzero_si512(); @@ -1392,48 +1371,39 @@ inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, in _transpose_mxn_half_32_32(r, d); // store to dst - __mmask32 dst_mask = (1 << M) - 1; - for (i = 0; i < N; ++i) { - _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } } } -template<> -inline void transpose_mxn( - const Half* src, - int64_t ld_src, - Half* dst, - int64_t ld_dst) { - // Load from memory - __m512i r[32]; -#ifndef __msvc_cl__ -#pragma unroll(32) -#endif - for (int i = 0; i < 32; ++i) { - r[i] = _mm512_loadu_si512(reinterpret_cast(src + i* ld_src)); - } - - __m512i d[32]; - _transpose_mxn_half_32_32(r, d); - - // Store to dst -#ifndef __msvc_cl__ -#pragma unroll(32) -#endif - for (int i = 0; i < 32; ++i) { - _mm512_storeu_si512(dst + i* ld_dst, d[i]); - } +template ::value && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0> +inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); } -template ::value && ((M < 32 && M != 16) || (N < 32 && N != 16)), int> = 0> -inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) { +template<> +inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst, int M, int N) { + TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); // load from src - __mmask32 src_mask = (1 << N) - 1; __m512i r[32]; int i; - for (i = 0; i < M; ++i) { - r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } } for (; i < 32; ++i) { r[i] = _mm512_setzero_si512(); @@ -1443,12 +1413,24 @@ inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld _transpose_mxn_half_32_32(r, d); // store to dst - __mmask32 dst_mask = (1 << M) - 1; - for (i = 0; i < N; ++i) { - _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } } } +template ::value && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0> +inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + template <> class Vectorized: public Vectorized16 { public: diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h index 289f927a10689b..804e9f404297ab 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h @@ -582,15 +582,21 @@ Vectorized inline fmsub(const Vectorized& a, const Vectorized::value && M <= 16 && N <= 16, int> = 0> -inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { +template <> +inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) { + TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn expects M, N <= 16."); // load from src to registers - __mmask16 src_mask = (1 << N) - 1; __m512 input[16]; int i; - for (i = 0; i < M; ++i) { - input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); + if (N == 16) { + for (i = 0; i < M; ++i) { + input[i] = _mm512_loadu_ps(&src[i * ld_src]); + } + } else { + __mmask16 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); + } } for (; i < 16; ++i) { // Not really needed but to avoid uninitialized variable warning. @@ -640,16 +646,31 @@ inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd); } - // store from registers to dst - __mmask16 dst_mask = (1 << M) - 1; for (i = 0; i < N; ++i) { if (i < 8) { input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88); } else { input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd); } - _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); } + + // store from registers to dst + if (M == 16) { + for (i = 0; i < N; ++i) { + _mm512_storeu_ps(&dst[i * ld_dst], input[i]); + } + } else { + __mmask16 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); + } + } +} + +template ::value && M <= 16 && N <= 16, int> = 0> +inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); } #endif diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index 86aaba490a8635..c4d8b1ccf5de49 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -1137,8 +1137,8 @@ inline Vectorized flip(const Vectorized & data) { // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading // dimension of `src` and `ld_dst` is the leading dimension of `dst`. -template -inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { +template +inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { dst[j*ld_dst + i] = src[i*ld_src + j]; @@ -1146,6 +1146,11 @@ inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) } } +template +inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + }} // namespace at::vec::CPU_CAPABILITY // additional headers for more operations that depend on vec_base diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index 2783417845492c..a39ffa3090b8eb 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -169,6 +169,17 @@ class VecMask { return result; } + static VecMask set( + const VecMask& a, + const VecMask& b, + int64_t count = size()) { + VectorizedN result = VectorizedN::set( + VectorizedN(a), + VectorizedN(b), + count); + return result; + } + void store(bool* b, int count = size()) { constexpr int L = (VectorizedN::size() + Vectorized::size() - 1)/ Vectorized::size(); auto res = this->to(); diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 698f3bfde5d9ae..ffd4dcb0ebc3e0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -665,7 +665,7 @@ def _run_and_assert_no_indirect_indexing( ) if has_wrapping is not None: test_case.assertTrue( - ("where" in code or "?" in code) is has_wrapping, + ("where" in code or ") ? (" in code) is has_wrapping, msg=f"Wanted {has_wrapping=} but got\n{code}", ) test_case.assertTrue( @@ -1517,7 +1517,7 @@ def test( pass # no device asserts in halide elif self.device == "cpu": _, code = run_and_get_cpp_code(fn_opt, *inps) - self.assertTrue(("?" in code or "blendv" in code) is has_wrapping) + self.assertTrue((") ? (" in code or "blendv" in code) is has_wrapping) self.assertTrue(("TORCH_CHECK" in code) is has_assert) # Assert that we always vectorize the kernel regardless of wrapping / checks self.assertTrue(("loadu" in code) is vectorize) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 999df248fc6110..944cd972f63f8e 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1586,22 +1586,26 @@ def frexp(x): code.writeline(f"{mantissa_t} {mantissa};") code.writeline("[&]()") with code.indent(): - code.writeline(f"__at_align__ std::array<{cdtype}, {size}> tmpbuf;") - code.writeline(f"{x}.store(tmpbuf.data(), {size});") - code.writeline(f"__at_align__ std::array tmpbuf_exponent;") code.writeline( - f"__at_align__ std::array<{cdtype}, {size}> tmpbuf_mantissa;" + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;" ) - code.writeline(f"for (int i = 0; i < {size}; i++)") + code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});") + code.writeline( + f"__at_align__ std::array tmpbuf_exponent;" + ) + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;" + ) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") with code.indent(): code.writeline( "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" ) code.writeline( - f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {size});" + f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" ) code.writeline( - f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {size});" + f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" ) code.writeline("();") V.kernel.compute.splice(code) @@ -1635,15 +1639,19 @@ def inner(*args, **kwargs): assert arg.is_vec assert arg.dtype == vec_dtype code.writeline( - f"__at_align__ std::array<{cdtype}, {size}> tmpbuf{argidx};" + f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};" + ) + code.writeline( + f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});" ) - code.writeline(f"{arg}.store(tmpbuf{argidx}.data(), {size});") scalar_args.append(f"tmpbuf{argidx}[i]") else: scalar_args.append(arg) - code.writeline(f"__at_align__ std::array<{octype}, {size}> tmpbuf_out;") + code.writeline( + f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;" + ) res = scalar_func(*scalar_args) - code.writeline(f"for (int i = 0; i < {size}; i++)") + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") with code.indent(): code.writeline(f"tmpbuf_out[i] = {res};") if output_mask: @@ -1651,7 +1659,7 @@ def inner(*args, **kwargs): load_args = "tmpbuf_out.data()" load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" else: - load_args = f"tmpbuf_out.data(), {size}" + load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" if n_vec == 1: load_fn = f"at::vec::Vectorized<{cdtype}>::loadu" else: @@ -2279,7 +2287,7 @@ def _get_vec_load_line( line = ( f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" if load_mask_str - else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {self.num_elems})" + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})" ) return line @@ -2320,6 +2328,12 @@ def get_result_size(dtype: torch.dtype) -> int: else: return self.num_elems + def get_tiling_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: assert vec_var.is_vec code = BracesBuffer() @@ -2330,10 +2344,11 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: if vec_dtype == torch.bool: vec_dtype = torch.float result_size = get_result_size(vec_dtype) + tiling_size = get_tiling_size(vec_dtype) code.writeline( - f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {result_size}> tmpbuf;" + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;" ) - line = f"{vec_var}.store(tmpbuf.data(), {result_size});" + line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});" code.writeline(line) code.writeline("return tmpbuf;") code.writeline("()") @@ -2345,12 +2360,15 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: code.writeline("[&]") with code.indent(): result_size = get_result_size(dtype) + tiling_size = get_tiling_size(dtype) result_declare = ( - f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {result_size}> tmpbuf;" + f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;" ) code.writeline(result_declare) if store_value: - code.writeline(f"{store_value}.store(tmpbuf.data(), {result_size});") + code.writeline( + f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});" + ) itervar_inner = sympy_index_symbol( f"{self.itervars[self.tiling_idx]}_inner" ) @@ -2376,12 +2394,12 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: else: load_mask = f"{self._load_mask} != 0" if cpp_builder.is_gcc(): - code.writeline(f"#pragma GCC unroll {self.num_elems}") + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") else: - code.writeline(f"#pragma unroll {self.num_elems}") + code.writeline(f"#pragma unroll {self.tiling_factor}") code.writeline( f"for (long {itervar_inner} = 0; " - + f"{itervar_inner} < {self.num_elems}; " + + f"{itervar_inner} < {cexpr_index(self.num_elems)}; " + f"{itervar_inner}++)" ) with code.indent(), contextlib.ExitStack() as stack: @@ -2463,7 +2481,9 @@ def _get_store_line( if dtype == torch.float and self.tail_size is None: code.writeline(f"{value}.store({var_expr});") else: - code.writeline(f"{value}.store({var_expr}, {self.num_elems});") + code.writeline( + f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});" + ) else: self._load_or_store_non_contiguous( var, index, dtype, buffer=code, store_value=value, accu_store=accu_store @@ -2801,7 +2821,7 @@ def reduction_combine_vec( is_bool = src_dtype == torch.bool if reduction_type == "max": if self.tail_size: - return f"max_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return ( f"{var} | {next_value}" @@ -2810,7 +2830,7 @@ def reduction_combine_vec( ) elif reduction_type == "min": if self.tail_size: - return f"min_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return ( f"{var} & {next_value}" @@ -2819,29 +2839,29 @@ def reduction_combine_vec( ) elif reduction_type == "sum": if self.tail_size: - return f"sum_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: conjunction = "|" if is_bool else "+" return f"{var} {conjunction} {next_value}" elif reduction_type == "prod": if self.tail_size: - return f"prod_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return f"{var} * {next_value}" elif reduction_type == "xor_sum": if self.tail_size: - return f"xor_sum_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return f"{var} ^ {next_value}" elif reduction_type == "welford_reduce": if use_weight_recps: if self.tail_size: - return f"welford_combine({var}, {next_value}, {self.tail_size}, &{self.weight_recps_val})" + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{self.weight_recps_val})" else: return f"welford_combine({var}, {next_value}, &{self.weight_recps_val})" else: if self.tail_size: - return f"welford_combine({var}, {next_value}, {self.tail_size})" + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return f"welford_combine({var}, {next_value})" elif reduction_type == "welford_combine": @@ -2852,7 +2872,7 @@ def reduction_combine_vec( # When combining intermediate accumulators we have a Welford struct mean, m2, weight = reduction_project(reduction_type, next_value) if self.tail_size: - return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {self.tail_size})" + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})" else: return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" elif reduction_type in ("argmin", "argmax"): @@ -2869,7 +2889,7 @@ def reduction_combine_vec( if self.tail_size: return ( f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" - f"({var}, {next_value}{arg_extra}, {self.tail_size})" + f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})" ) else: return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})" @@ -2907,6 +2927,11 @@ def indirect_assert(self, var, lower, upper, mask=None): mask = f"{self._get_mask_type(var.dtype)}({mask})" # We need not check when the mask is False cond = f"({cond}) | ~({mask})" + if self.tail_size: + cond = ( + f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)" + f", ({cond}), {cexpr_index(self.tail_size)})" + ) cond = f"({cond}).all_masked()" return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' @@ -3010,20 +3035,29 @@ def gen_transposed_tile_load_store(self, name, var, index, is_store): src = f"{var} + {cexpr_index(index)}" dst = "__place_holder__" ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" - ld_dst = f"{self.num_elems}" + ld_dst = f"{cexpr_index(self.num_elems)}" if is_store: src, dst = dst, src ld_src, ld_dst = ld_dst, ld_src need_define = True if self.inner_is_tiling_idx ^ is_store: + M, N = self.inner_num_elems, self.outer_num_elems + else: + M, N = ( + self.outer_num_elems, + self.inner_num_elems, + ) + if (isinstance(M, sympy.Expr) and not M.is_number) or ( + isinstance(N, sympy.Expr) and not N.is_number + ): load_or_store = ( - f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{self.inner_num_elems},{self.outer_num_elems}>" - f"({src}, {ld_src}, {dst}, {ld_dst});" + f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]}>" + f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});" ) else: load_or_store = ( - f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{self.outer_num_elems},{self.inner_num_elems}>" + f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)}>" f"({src}, {ld_src}, {dst}, {ld_dst});" ) if is_store: @@ -3035,7 +3069,10 @@ def gen_transposed_tile_load_store(self, name, var, index, is_store): tile_var = self.cse.cache[load_or_store] if need_define: - define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{self.outer_num_elems}*{self.inner_num_elems}];" + define_line = ( + f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}" + f"[{cexpr_index(self.outer_num_elems * self.inner_num_elems)}];" + ) self.preloads.writeline(define_line) load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) @@ -3085,7 +3122,7 @@ def store(self, name, index, value, mode=None): torch.uint8, torch.int8, ]: - line = f"{value}.store({storebuf}, {self.num_elems});" + line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});" else: line = f"{value}.store({storebuf});" self.stores.writeline(DeferredLine(name, line)) @@ -3097,11 +3134,11 @@ def codegen_inner_loops(self, code): inner = self.inner_itervar() if self.inner_is_tiling_idx: code.writeline( - f"for (long {inner} = 0; {inner} < {self.outer_num_elems}; {inner}++)" + f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)" ) else: code.writeline( - f"for (long {inner} = 0; {inner} < {self.inner_num_elems}; {inner}++)" + f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)" ) def set_ranges(self, group, reduction_group): @@ -3643,9 +3680,6 @@ def run(kernel): if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): # can be removed after masked vectorizable dtype are same with vectorizable dtype could_masked_vec = False - if has_free_symbols(self.ranges): - # can be removed after masked vec support dynamic ranges - could_masked_vec = False if len(tiling_indices) == 1: vec_kernel = codegen_kernel( @@ -3657,11 +3691,7 @@ def run(kernel): ) main_loop.set_kernel(vec_kernel) main_loop.simd_vec = True - if ( - config.cpp.enable_loop_tail_vec - and could_masked_vec - and (tail_loop.size - tail_loop.offset) >= 4 - ): + if config.cpp.enable_loop_tail_vec and could_masked_vec: tail_loop.steps = tail_loop.size - tail_loop.offset masked_vec_kernel = codegen_kernel( CppVecKernel, @@ -4779,7 +4809,15 @@ def lines(self): line1 = "" offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" size_str = f"{self.var}<{size_expr}" - steps_str = f"{self.var}+={cexpr_index(self.steps)}" + if self.steps.is_number: + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + else: + # If the step size is 0, change it to 1 because a step size of 0 + # will cause floating point exception (core dump) during parallelization. + steps_str = ( + f"{self.var}+=({cexpr_index(self.steps)} == 0 ? " + f"1 : {cexpr_index(self.steps)})" + ) line2 = f"for({offset_str}; {size_str}; {steps_str})" if self.collapsed or not line1: return [line2] From 7836998ebfe72e7c214ffd301c48af207dcc9c55 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Wed, 4 Sep 2024 10:06:43 -0700 Subject: [PATCH 0451/1018] Fix decomp behaviour in export training IR (#134801) Subset of changes in https://github.com/pytorch/pytorch/pull/132901, can't land the previous one because it is too complicated. Rest of the change will be implemented as follow up after export design meeting. This part just makes the training IR -> inference IR decomp to have the same path as normal export. Differential Revision: [D62000525](https://our.internmc.facebook.com/intern/diff/D62000525) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134801 Approved by: https://github.com/avikchaudhuri, https://github.com/angelayi --- test/export/test_experimental.py | 147 +++++++++++------- test/export/test_export.py | 42 +++--- torch/_export/utils.py | 246 +++++++++++++++++++++++++++++-- torch/_subclasses/fake_impls.py | 1 + torch/export/_trace.py | 92 ++++-------- torch/export/exported_program.py | 219 +++++++++------------------ 6 files changed, 445 insertions(+), 302 deletions(-) diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 468ee296e36edb..8c416b17da2a7b 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -209,58 +209,103 @@ def forward(self, x): m(*example_inputs) ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True) joint_ep = _export_forward_backward(ep) - print(joint_ep) - - """ - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"): - # No stacktrace found for following nodes - view: "f32[1, 3]" = torch.ops.aten.view.default(arg3_1, [1, 3]); arg3_1 = None - t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None - addmm: "f32[1, 3]" = torch.ops.aten.addmm.default(arg1_1, view, t); arg1_1 = t = None - view_1: "f32[3]" = torch.ops.aten.view.default(addmm, [3]); addmm = None - _softmax: "f32[3]" = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None - detach_1: "f32[3]" = torch.ops.aten.detach.default(_softmax) - clone: "f32[3]" = torch.ops.aten.clone.default(arg2_1); arg2_1 = None - detach_5: "f32[3]" = torch.ops.aten.detach.default(clone); clone = None - _log_softmax: "f32[3]" = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None - detach_12: "f32[3]" = torch.ops.aten.detach.default(_log_softmax) - mul: "f32[3]" = torch.ops.aten.mul.Tensor(_log_softmax, detach_5); _log_softmax = None - sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None - neg: "f32[]" = torch.ops.aten.neg.default(sum_1); sum_1 = None - div: "f32[]" = torch.ops.aten.div.Scalar(neg, 1); neg = None - ones_like: "f32[]" = torch.ops.aten.ones_like.default(div, pin_memory = False, memory_format = torch.preserve_format) - div_1: "f32[]" = torch.ops.aten.div.Scalar(ones_like, 1); ones_like = None - neg_1: "f32[]" = torch.ops.aten.neg.default(div_1); div_1 = None - expand: "f32[3]" = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None - mul_1: "f32[3]" = torch.ops.aten.mul.Tensor(expand, detach_5); expand = detach_5 = None - _log_softmax_backward_data: "f32[3]" = torch.ops.aten._log_softmax_backward_data.default(mul_1, detach_12, 0, torch.float32); mul_1 = detach_12 = None - _softmax_backward_data: "f32[3]" = torch.ops.aten._softmax_backward_data.default(_log_softmax_backward_data, detach_1, 0, torch.float32); _log_softmax_backward_data = detach_1 = None - view_2: "f32[1, 3]" = torch.ops.aten.view.default(_softmax_backward_data, [1, 3]); _softmax_backward_data = None - t_1: "f32[3, 1]" = torch.ops.aten.t.default(view_2) - mm: "f32[3, 3]" = torch.ops.aten.mm.default(t_1, view); t_1 = view = None - t_2: "f32[3, 3]" = torch.ops.aten.t.default(mm); mm = None - sum_2: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None - view_3: "f32[3]" = torch.ops.aten.view.default(sum_2, [3]); sum_2 = None - t_3: "f32[3, 3]" = torch.ops.aten.t.default(t_2); t_2 = None - return (div, t_3, view_3) - - Graph signature: ExportGraphSignature( - input_specs=[ - InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='linear.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='linear.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target='lifted_tensor_0', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None, persistent=None) - ], - output_specs=[ - OutputSpec(kind=, arg=TensorArgument(name='div'), target=None), - OutputSpec(kind=, arg=TensorArgument(name='t_3'), target='linear.weight'), - OutputSpec(kind=, arg=TensorArgument(name='view_3'), target='linear.bias') - ] + self.assertExpectedInline( + str(joint_ep.graph_module.code).strip(), + """\ +def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): + view = torch.ops.aten.view.default(x, [1, 3]); x = None + permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None + addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None + view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None + _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None + alias = torch.ops.aten.alias.default(_softmax) + alias_1 = torch.ops.aten.alias.default(alias); alias = None + clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None + alias_2 = torch.ops.aten.alias.default(clone); clone = None + alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None + alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None + _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None + alias_5 = torch.ops.aten.alias.default(_log_softmax) + alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None + mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None + sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None + neg = torch.ops.aten.neg.default(sum_1); sum_1 = None + div = torch.ops.aten.div.Scalar(neg, 1); neg = None + full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) + div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None + neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None + expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None + mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None + alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None + alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None + exp = torch.ops.aten.exp.default(alias_8); alias_8 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) + mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None + sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None + alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None + mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None + sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) + mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None + sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None + view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None + permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) + mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None + permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None + sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None + view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None + permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + return (div, permute_3, view_3)""", + ) + ep = joint_ep.run_decompositions() + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): + view = torch.ops.aten.view.default(x, [1, 3]); x = None + permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None + addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None + view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None + _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None + alias = torch.ops.aten.alias.default(_softmax) + alias_1 = torch.ops.aten.alias.default(alias); alias = None + clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None + alias_2 = torch.ops.aten.alias.default(clone); clone = None + alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None + alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None + _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None + alias_5 = torch.ops.aten.alias.default(_log_softmax) + alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None + mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None + sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None + neg = torch.ops.aten.neg.default(sum_1); sum_1 = None + div = torch.ops.aten.div.Scalar(neg, 1); neg = None + full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) + div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None + neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None + expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None + mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None + alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None + alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None + exp = torch.ops.aten.exp.default(alias_8); alias_8 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) + mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None + sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None + alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None + mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None + sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) + mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None + sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None + view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None + permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) + mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None + permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None + sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None + view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None + permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + return (div, permute_3, view_3)""", ) - Range constraints: {} - """ def test_joint_dynamic(self) -> None: from torch.export import Dim diff --git a/test/export/test_export.py b/test/export/test_export.py index 3a5f10c88ea944..cd4a9aeb9e590f 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1475,18 +1475,16 @@ def forward(self, x, y): decomp_table={}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, ) + self.assertExpectedInline( str(ep_has_linear_convd.graph_module.code).strip(), """\ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): - convolution = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None - convolution_1 = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None - view = torch.ops.aten.view.default(convolution, [31680, 98]); convolution = None - t = torch.ops.aten.t.default(b_linear_weight); b_linear_weight = None - addmm = torch.ops.aten.addmm.default(b_linear_bias, view, t); b_linear_bias = view = t = None - view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None - cos = torch.ops.aten.cos.default(view_1); view_1 = None - sum_1 = torch.ops.aten.sum.default(convolution_1); convolution_1 = None + conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None + linear = torch.ops.aten.linear.default(conv2d, b_linear_weight, b_linear_bias); conv2d = b_linear_weight = b_linear_bias = None + cos = torch.ops.aten.cos.default(linear); linear = None + sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) @@ -1498,18 +1496,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_ torch.ops.aten.conv1d.default, ], ) + self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): - convolution = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None - convolution_1 = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None - view = torch.ops.aten.view.default(convolution, [31680, 98]); convolution = None - t = torch.ops.aten.t.default(b_linear_weight); b_linear_weight = None - addmm = torch.ops.aten.addmm.default(b_linear_bias, view, t); b_linear_bias = view = t = None + conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None + view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None + permute = torch.ops.aten.permute.default(b_linear_weight, [1, 0]); b_linear_weight = None + addmm = torch.ops.aten.addmm.default(b_linear_bias, view, permute); b_linear_bias = view = permute = None view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None cos = torch.ops.aten.cos.default(view_1); view_1 = None - sum_1 = torch.ops.aten.sum.default(convolution_1); convolution_1 = None + sum_1 = torch.ops.aten.sum.dim_IntList(conv1d, []); conv1d = None add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) @@ -1517,18 +1516,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_ ep_has_convd = ep_has_convd.run_decompositions( decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] ) + self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): - convolution = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None - convolution_1 = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None - view = torch.ops.aten.view.default(convolution, [31680, 98]); convolution = None + conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + convolution = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None + view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None permute = torch.ops.aten.permute.default(b_linear_weight, [1, 0]); b_linear_weight = None addmm = torch.ops.aten.addmm.default(b_linear_bias, view, permute); b_linear_bias = view = permute = None view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None cos = torch.ops.aten.cos.default(view_1); view_1 = None - sum_1 = torch.ops.aten.sum.dim_IntList(convolution_1, []); convolution_1 = None + sum_1 = torch.ops.aten.sum.dim_IntList(convolution, []); convolution = None add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) @@ -1758,9 +1758,9 @@ def forward(self, x): """\ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): add = torch.ops.aten.add.Tensor(b_buffer, 5); b_buffer = None - t = torch.ops.aten.t.default(p_linear_weight); p_linear_weight = None - addmm = torch.ops.aten.addmm.default(p_linear_bias, x, t); p_linear_bias = x = t = None - sum_1 = torch.ops.aten.sum.default(add) + permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None + addmm = torch.ops.aten.addmm.default(p_linear_bias, x, permute); p_linear_bias = x = permute = None + sum_1 = torch.ops.aten.sum.dim_IntList(add, []) add_1 = torch.ops.aten.add.Tensor(addmm, sum_1); addmm = sum_1 = None return (add, add_1)""", ) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 00d09f94d301e8..e085e18b68a207 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -6,15 +6,18 @@ import operator import re from inspect import Parameter -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING import torch +from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import FakeTensor -from torch.export import ExportedProgram -from torch.export.exported_program import ( - _name_hoo_subgraph_placeholders, - _rename_without_collisions, -) + + +if TYPE_CHECKING: + from torch._export.passes.lift_constants_pass import ConstantAttrMap + from torch.export import ExportedProgram + from torch.export.graph_signature import ExportGraphSignature + from torch.export.graph_signature import InputKind, OutputKind from torch.utils._pytree import ( _register_pytree_node, @@ -42,6 +45,178 @@ } +def _collect_and_set_constant_attrs( + graph_signature, constants, mod +) -> "ConstantAttrMap": + # the exported module will store constants & non-persistent buffers such that + # retracing treats them as persistent buffers, so we inform the constants lifting pass + # and overwrite the new graph signature using the previous program. This is intended to only be used + # in run_decompositions where we still have access to original EP. + from torch._export.passes.lift_constants_pass import ConstantAttrMap + + constant_attrs = ConstantAttrMap() + non_persistent_buffers = { + spec.target + for spec in graph_signature.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + for name, value in constants.items(): + if name in non_persistent_buffers: + continue + # recursive getattr + _mod = mod + *atoms, attr = name.split(".") + for atom in atoms: + _mod = getattr(_mod, atom) + # remove as buffer, reassign as constant/non-persistent buffer + _mod._buffers.pop(attr, None) + setattr(_mod, attr, value) + constant_attrs.add(value, name) + return constant_attrs + + +def _overwrite_signature_for_non_persistent_buffers( + old_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature" +): + # overwrite signature for non-persistent buffers + non_persistent_buffers = { + spec.target + for spec in old_sig.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + + for spec in new_sig.input_specs: + if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: + spec.persistent = False + return new_sig + + +def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]: + """ + Param/buffer metadata needs to be saved before lowering to aten IR + because aten IR lifts them, as a result, automatic preservation doesn't work. + This is intended to be called on the strict mode tracing right before lowering to + aten IR OR run_decomposition pass. + """ + params_buffers_to_node_meta = {} + + def _getattr(model: torch.fx.GraphModule, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + for node in mod.graph.nodes: + target = node.target + meta = node.meta + if node.op == "call_module": + submodule = _getattr(mod, target) + if isinstance(submodule, torch.nn.Module): + for name, _ in submodule.named_parameters( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + for name, _ in submodule.named_buffers( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + if node.op == "get_attr": + submodule = _getattr(mod, target) + if not isinstance(submodule, torch.fx.GraphModule): + params_buffers_to_node_meta[target] = meta + + # If the call_function uses param as input, we also need to update params' meta + # with this call_function node's meta. + # This is basically the same flow as torch.fx.traceback.preserve_meta() + if node.op == "call_function" and not isinstance( + node.target, torch._ops.HigherOrderOperator + ): + for arg in node._input_nodes: + if arg.op == "get_attr": + for entry in torch.fx.proxy._COPY_META_FIELDS: + if entry in meta: + params_buffers_to_node_meta[arg.target][entry] = meta[entry] + + return params_buffers_to_node_meta + + +def _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta: Dict[str, Any], + gm: torch.fx.GraphModule, + new_sig: "ExportGraphSignature", +) -> None: + """ + Given that we collected param'buffer metadata before, we put them back in + newly traced graph module + """ + # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes + for metadata in params_buffers_to_node_meta.values(): + metadata.pop("nn_module_stack", None) + metadata.pop("stack_trace", None) + + for node in gm.graph.nodes: + if node.op == "placeholder": + if node.target in new_sig.inputs_to_parameters: + param_name = new_sig.inputs_to_parameters[node.target] + if param_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[param_name].items(): + node.meta[k] = v + if node.target in new_sig.inputs_to_buffers: + buffer_name = new_sig.inputs_to_buffers[node.target] + if buffer_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[buffer_name].items(): + node.meta[k] = v + + +def _get_shape_env_from_gm(gm: torch.fx.GraphModule): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + + fake_mode = _detect_fake_mode_from_gm(gm) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + +def _rename_without_collisions( + name_map: Dict[str, str], + orig_name: str, + name: str, + is_placeholder: bool = False, +): + """ + Renames nodes to avoid name collisions, with suffixing. + name_map: map from original name to new name + orig_name: mapping key + name: candidate name (potentially suffixed, e.g. mul_2) + is_placeholder: if the node is a placeholder, avoid detecting suffix + """ + if name in name_map.values(): + # non-placeholder nodes may be suffixed with the count + # instead of adding another suffix, we will try to increment it + match = re.match(r"(.*)_(\d+)", name) + if match and not is_placeholder: + name, n = match.group(1), int(match.group(2)) + else: + n = 0 + while (dup_name := f"{name}_{n + 1}") in name_map.values(): + n += 1 + name_map[orig_name] = dup_name + else: + name_map[orig_name] = name + return name_map[orig_name] + + def _check_input_constraints_for_graph( input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints ): @@ -227,7 +402,7 @@ def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]: ) -def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool: +def is_param(program: "ExportedProgram", node: torch.fx.Node) -> bool: """ Checks if the given node is a parameter within the exported program """ @@ -236,7 +411,7 @@ def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool: def get_param( - program: ExportedProgram, + program: "ExportedProgram", node: torch.fx.Node, ) -> Optional[torch.nn.Parameter]: """ @@ -251,7 +426,7 @@ def get_param( return None -def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool: +def is_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: """ Checks if the given node is a buffer within the exported program """ @@ -260,7 +435,7 @@ def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool: def get_buffer( - program: ExportedProgram, + program: "ExportedProgram", node: torch.fx.Node, ) -> Optional[torch.Tensor]: """ @@ -279,7 +454,7 @@ def get_buffer( def is_lifted_tensor_constant( - program: ExportedProgram, + program: "ExportedProgram", node: torch.fx.Node, ) -> bool: """ @@ -290,7 +465,7 @@ def is_lifted_tensor_constant( def get_lifted_tensor_constant( - program: ExportedProgram, + program: "ExportedProgram", node: torch.fx.Node, ) -> Optional[torch.Tensor]: """ @@ -484,9 +659,51 @@ def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): return sig.bind(*fake_args, **fake_kwargs).arguments +def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: + """ + Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, + and handle collisions with non-placeholders by count suffixing. + Different HOO subgraph types have different input schemas, so we first enumerate them + and gather the top-level named placeholder nodes. + """ + # gather all HOO subgraphs and their top-level named placeholder nodes + subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] + for node in gm.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.HigherOrderOperator + ): + # HOO subgraphs have varying input schemas, so we enumerate them there + if node.target._name == "cond": + _, true_graph, false_graph, cond_args = node._args + subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) + subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) + elif node.target._name == "wrap_with_set_grad_enabled": + subgraph, phs = node._args[1], node._args[2:] + subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) + elif node.target._name == "map_impl": + body_graph, array, args = node._args + subgraph_ph_tuples.append( + (getattr(gm, body_graph.target), array + args) + ) + + # propagate names + for subgraph, hoo_phs in subgraph_ph_tuples: + name_map: Dict[str, str] = {} + for i, node in enumerate(subgraph.graph.nodes): + if i < len(hoo_phs): # placeholder, retain name + name_map[node.name] = hoo_phs[i].name + node.name = node.target = hoo_phs[i].name + else: # non-placeholder, check for collisions + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # recurse and recompile + _name_hoo_subgraph_placeholders(subgraph) + subgraph.recompile() + + def placeholder_naming_pass( gm: torch.fx.GraphModule, - export_graph_signature: torch.export.ExportGraphSignature, + export_graph_signature: "ExportGraphSignature", mod: torch.nn.Module, fake_args, fake_kwargs, @@ -514,6 +731,8 @@ def placeholder_naming_pass( def _strip_name(x): if x.startswith("L__self___"): x = x[len("L__self___") :] + elif x.startswith("self_"): + x = x[len("self_") :] x = re.sub(r"[^a-zA-Z0-9]", "_", x) return x @@ -652,7 +871,6 @@ def _detect_fake_mode_from_gm( Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. If no fake mode is found, we return None for fake_mode. """ - from torch._guards import detect_fake_mode fake_inps: List[torch.Tensor] = [] fake_vals: List[torch.Tensor] = [] diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index b0c9ee1a716aa3..c8acc90567c5eb 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -372,6 +372,7 @@ def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): return repeats.new_empty(output_size) +@register_op_impl(torch.ops.aten.item.default) @register_op_impl(torch.ops.aten._local_scalar_dense.default) def local_scalar_dense(fake_mode, func, arg): if (r := arg.item_memo) is not None: diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 8b9b732b0119d8..abd2b5405fb937 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -38,7 +38,13 @@ lift_constants_pass, rewrite_script_object_meta, ) -from torch._export.utils import placeholder_naming_pass, placeholder_prefixes +from torch._export.utils import ( + _collect_param_buffer_metadata, + _get_shape_env_from_gm, + _populate_param_buffer_metadata_to_new_gm, + placeholder_naming_pass, + placeholder_prefixes, +) from torch._export.verifier import SpecViolationError from torch._export.wrappers import _wrap_submodules from torch._functorch._aot_autograd.input_output_analysis import ( @@ -365,7 +371,11 @@ def _preserve_requires_grad_pass( assert spec.target is not None constant = constants[spec.target] if isinstance(constant, torch.Tensor): - node.meta["val"].requires_grad = constant.requires_grad + # If the tensor is not leaf, it should already have a correct requires grad field + if node.meta["val"].is_leaf: + node.meta["val"].requires_grad = constant.requires_grad + else: + assert node.meta["val"].requires_grad == constant.requires_grad elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN): continue else: @@ -589,6 +599,7 @@ def _export_to_aten_ir( *, transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. pre_dispatch=False, + decomp_table=None, _check_autograd_state=True, _is_torch_jit_trace=False, ) -> ATenExportArtifact: @@ -628,6 +639,7 @@ def _compiling_state_context(): fake_args, trace_joint=False, pre_dispatch=pre_dispatch, + decompositions=decomp_table, kwargs=fake_kwargs, ) @@ -662,9 +674,6 @@ def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. # Overwrite output specs afterwards. flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) - fake_mode = torch._export.utils._detect_fake_mode_from_gm(gm) - assert fake_mode is not None, "Cannot detect fake mode from graph" - if not torch._dynamo.config.do_not_emit_runtime_asserts: stack_trace = ( 'File "torch/fx/passes/runtime_assert.py", line 24, ' @@ -673,10 +682,11 @@ def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): with _set_node_metadata_hook( gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) ): - if fake_mode: + shape_env = _get_shape_env_from_gm(gm) + if shape_env: insert_deferred_runtime_asserts( gm, - fake_mode.shape_env, # type: ignore[arg-type] + shape_env, f"exported program: {first_call_function_nn_module_stack(gm.graph)}", export=True, ) @@ -1275,43 +1285,6 @@ def _strict_export_lower_to_aten_ir( attr, static_shapes=True ) - # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace) - # from the param nodes as they are treated as fresh inputs - # Therefore, we manually extract them before calling into aot_export - params_buffers_to_node_meta = {} - for node in gm_torch_level.graph.nodes: - target = node.target - meta = node.meta - if node.op == "call_module": - submodule = getattr(gm_torch_level, target) - if isinstance(submodule, torch.nn.Module): - for name, _ in submodule.named_parameters( - recurse=True, remove_duplicate=False - ): - params_buffers_to_node_meta[target + "." + name] = meta - - for name, _ in submodule.named_buffers( - recurse=True, remove_duplicate=False - ): - params_buffers_to_node_meta[target + "." + name] = meta - - if node.op == "get_attr": - submodule = getattr(gm_torch_level, target) - if not isinstance(submodule, torch.fx.GraphModule): - params_buffers_to_node_meta[target] = meta - - # If the call_function uses param as input, we also need to update params' meta - # with this call_function node's meta. - # This is basically the same flow as torch.fx.traceback.preserve_meta() - if node.op == "call_function" and not isinstance( - node.target, torch._ops.HigherOrderOperator - ): - for arg in node._input_nodes: - if arg.op == "get_attr": - for entry in torch.fx.proxy._COPY_META_FIELDS: - if entry in meta: - params_buffers_to_node_meta[arg.target][entry] = meta[entry] - # Fix the graph output signature to be tuple if scalar out_spec = orig_out_spec = gm_torch_level._out_spec @@ -1336,6 +1309,13 @@ def _strict_export_lower_to_aten_ir( _normalize_nn_module_stack(gm_torch_level, type(mod)) + params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + + # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace) + # from the param nodes as they are treated as fresh inputs + # Therefore, we manually extract them before calling into aot_export + # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + constant_attrs = _gather_constant_attrs(mod) param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) @@ -1364,26 +1344,9 @@ def _strict_export_lower_to_aten_ir( export_graph_signature = aten_export_artifact.sig constants = aten_export_artifact.constants - # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes - for metadata in params_buffers_to_node_meta.values(): - metadata.pop("nn_module_stack", None) - metadata.pop("stack_trace", None) - - # After aot_export, set the param/buffer metadata back into placeholders - # Technically, users can still construct this data from param names - # without relying on this metadata - for node in gm.graph.nodes: - if node.op == "placeholder": - if node.target in export_graph_signature.inputs_to_parameters: - param_name = export_graph_signature.inputs_to_parameters[node.target] - if param_name in params_buffers_to_node_meta: - for k, v in params_buffers_to_node_meta[param_name].items(): - node.meta[k] = v - if node.target in export_graph_signature.inputs_to_buffers: - buffer_name = export_graph_signature.inputs_to_buffers[node.target] - if buffer_name in params_buffers_to_node_meta: - for k, v in params_buffers_to_node_meta[buffer_name].items(): - node.meta[k] = v + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, export_graph_signature + ) # Do some cleanups on the graph module to restore the state dict to the # expected form. Each of these steps should probably get fixed upstream. @@ -1737,6 +1700,7 @@ def _produce_guards_callback(gm): ) assert out_spec is not None + return ExportArtifact( aten=aten_export_artifact, out_spec=out_spec, diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index cf547eb3f7e6e0..b9c71bdc91deca 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -5,7 +5,6 @@ import dataclasses import functools import operator -import re import types import warnings from collections import namedtuple @@ -26,8 +25,10 @@ from torch._higher_order_ops.utils import autograd_not_implemented from torch._library.fake_class_registry import FakeScriptObject +from torch.fx._utils import first_call_function_nn_module_stack from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts if TYPE_CHECKING: @@ -41,15 +42,23 @@ import torch import torch.utils._pytree as pytree +from torch._export.utils import ( + _collect_and_set_constant_attrs, + _collect_param_buffer_metadata, + _detect_fake_mode_from_gm, + _name_hoo_subgraph_placeholders, + _overwrite_signature_for_non_persistent_buffers, + _populate_param_buffer_metadata_to_new_gm, + _rename_without_collisions, +) from torch._export.verifier import Verifier +from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import unset_fake_temporarily from torch._subclasses.functional_tensor import FunctionalTensor from torch.export._tree_utils import is_equivalent, reorder_kwargs from torch.fx._compatibility import compatibility -from torch.fx._utils import first_call_function_nn_module_stack from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.infra.pass_manager import PassManager -from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from .graph_signature import ( # noqa: F401 ArgumentSpec, @@ -279,77 +288,6 @@ def _override_decomp_aten_to_variants(): yield -def _rename_without_collisions( - name_map: Dict[str, str], - orig_name: str, - name: str, - is_placeholder: bool = False, -): - """ - Renames nodes to avoid name collisions, with suffixing. - name_map: map from original name to new name - orig_name: mapping key - name: candidate name (potentially suffixed, e.g. mul_2) - is_placeholder: if the node is a placeholder, avoid detecting suffix - """ - if name in name_map.values(): - # non-placeholder nodes may be suffixed with the count - # instead of adding another suffix, we will try to increment it - match = re.match(r"(.*)_(\d+)", name) - if match and not is_placeholder: - name, n = match.group(1), int(match.group(2)) - else: - n = 0 - while (dup_name := f"{name}_{n + 1}") in name_map.values(): - n += 1 - name_map[orig_name] = dup_name - else: - name_map[orig_name] = name - return name_map[orig_name] - - -def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: - """ - Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, - and handle collisions with non-placeholders by count suffixing. - Different HOO subgraph types have different input schemas, so we first enumerate them - and gather the top-level named placeholder nodes. - """ - # gather all HOO subgraphs and their top-level named placeholder nodes - subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] - for node in gm.graph.nodes: - if node.op == "call_function" and isinstance( - node.target, torch._ops.HigherOrderOperator - ): - # HOO subgraphs have varying input schemas, so we enumerate them there - if node.target._name == "cond": - _, true_graph, false_graph, cond_args = node._args - subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) - subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) - elif node.target._name == "wrap_with_set_grad_enabled": - subgraph, phs = node._args[1], node._args[2:] - subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) - elif node.target._name == "map_impl": - body_graph, array, args = node._args - subgraph_ph_tuples.append( - (getattr(gm, body_graph.target), array + args) - ) - - # propagate names - for subgraph, hoo_phs in subgraph_ph_tuples: - name_map: Dict[str, str] = {} - for i, node in enumerate(subgraph.graph.nodes): - if i < len(hoo_phs): # placeholder, retain name - name_map[node.name] = hoo_phs[i].name - node.name = node.target = hoo_phs[i].name - else: # non-placeholder, check for collisions - node.name = _rename_without_collisions(name_map, node.name, node.name) - - # recurse and recompile - _name_hoo_subgraph_placeholders(subgraph) - subgraph.recompile() - - def _decompose_and_get_gm_with_new_signature_constants( ep, *, @@ -357,9 +295,7 @@ def _decompose_and_get_gm_with_new_signature_constants( _preserve_ops: Tuple[torch._ops.OpOverload], joint_loss_index: Optional[int], ): - from torch._export.passes.lift_constants_pass import ConstantAttrMap from torch._functorch.aot_autograd import aot_export_module - from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode from torch.export._trace import ( _export_to_aten_ir, @@ -371,12 +307,18 @@ def _decompose_and_get_gm_with_new_signature_constants( ) from torch.fx.experimental.symbolic_shapes import ShapeEnv + # TODO Merge this path with inference IR decomp, but it will require some additional work + # so I will leave it for now. T200307782 if ep.verifier.dialect == "TRAINING": mod = ep.module() - fake_args = [ - node.meta["val"] for node in mod.graph.nodes if node.op == "placeholder" - ] - fake_mode = torch._export.utils._detect_fake_mode_from_gm(mod) + + fake_args = [] + for node in mod.graph.nodes: + if node.op == "placeholder": + fake_args.append(node.meta["val"]) + + fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec) + fake_mode = _detect_fake_mode_from_gm(mod) if fake_mode is None: fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) @@ -402,51 +344,21 @@ def _decompose_and_get_gm_with_new_signature_constants( # the exported module will store constants & non-persistent buffers such that # retracing treats them as persistent buffers, so we inform the constants lifting pass # and overwrite the new graph signature using the previous program. - constant_attrs = ConstantAttrMap() - non_persistent_buffers = { - spec.target - for spec in ep.graph_signature.input_specs - if spec.kind == InputKind.BUFFER and not spec.persistent - } - for name, value in ep.constants.items(): - if name in non_persistent_buffers: - continue - # recursive getattr - _mod = mod - *atoms, attr = name.split(".") - for atom in atoms: - _mod = getattr(_mod, atom) - # remove as buffer, reassign as constant/non-persistent buffer - _mod._buffers.pop(attr, None) - setattr(_mod, attr, value) - constant_attrs.add(value, name) + constant_attrs = _collect_and_set_constant_attrs( + ep.graph_signature, ep.constants, mod + ) # get params & buffers after excluding constants fake_params_buffers = _fakify_params_buffers(fake_mode, mod) - params_buffers_to_node_meta = {} - for node in mod.graph.nodes: - target = node.target - meta = node.meta - if node.op == "get_attr": - params_buffers_to_node_meta[target] = meta - - # If the call_function uses param as input, we also need to update params' meta - # with this call_function node's meta. - # This is basically the same flow as torch.fx.traceback.preserve_meta() - if node.op == "call_function" and not isinstance( - node.target, torch._ops.HigherOrderOperator - ): - for arg in node._input_nodes: - if arg.op == "get_attr": - for entry in torch.fx.proxy._COPY_META_FIELDS: - if entry in meta: - params_buffers_to_node_meta[arg.target][entry] = meta[ - entry - ] - - with fake_mode, _override_decomp_aten_to_variants(): - fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec) + params_buffers_to_node_meta = _collect_param_buffer_metadata(mod) + + with _ignore_backend_decomps(), ( + fake_mode + ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( + _preserve_ops, + decomp_table, + ): aten_export_artifact = _export_to_aten_ir( mod, # this requires empty kwargs, but not in pytree.flattened format @@ -457,49 +369,27 @@ def _decompose_and_get_gm_with_new_signature_constants( {}, fake_params_buffers, constant_attrs, + decomp_table=decomp_table, _check_autograd_state=False, ) gm = aten_export_artifact.gm new_graph_signature = aten_export_artifact.sig - for node in gm.graph.nodes: - # nn_module_stack - if node.op not in ["placeholder", "output"]: - for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): - if isinstance(mod_cls, type): - node.meta["nn_module_stack"][key] = ( - fqn, - mod_cls.__module__ + "." + mod_cls.__qualname__, - ) - - # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes - for metadata in params_buffers_to_node_meta.values(): - metadata.pop("nn_module_stack", None) - metadata.pop("stack_trace", None) - - for node in gm.graph.nodes: - if node.op == "placeholder": - if node.target in new_graph_signature.inputs_to_parameters: - param_name = new_graph_signature.inputs_to_parameters[node.target] - if param_name in params_buffers_to_node_meta: - for k, v in params_buffers_to_node_meta[param_name].items(): - node.meta[k] = v - if node.target in new_graph_signature.inputs_to_buffers: - buffer_name = new_graph_signature.inputs_to_buffers[node.target] - if buffer_name in params_buffers_to_node_meta: - for k, v in params_buffers_to_node_meta[buffer_name].items(): - node.meta[k] = v + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, new_graph_signature + ) # overwrite signature for non-persistent buffers - for spec in new_graph_signature.input_specs: - if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: - spec.persistent = False + new_graph_signature = _overwrite_signature_for_non_persistent_buffers( + ep.graph_signature, new_graph_signature + ) _verify_nn_module_stack(gm) _verify_stack_trace(gm) _verify_placeholder_names(gm, new_graph_signature) - return gm, new_graph_signature + + return _remove_unneccessary_copy_op_pass(gm, new_graph_signature) old_placeholders = [ node for node in ep.graph_module.graph.nodes if node.op == "placeholder" @@ -663,6 +553,31 @@ def update_arg(old_arg, new_ph): return gm, new_graph_signature +def _remove_unneccessary_copy_op_pass( + gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature +) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]: + """ + Removes redundant copy_ node that was introduced due to mutated buffer. + """ + with gm._set_replace_hook(new_graph_signature.get_replace_hook()): + for node in gm.graph.nodes: + if node.op == "output": + args, _ = pytree.tree_flatten(node.args) + for out in args: + if ( + isinstance(out, torch.fx.Node) + and out.name in new_graph_signature.buffers_to_mutate + ): + if ( + out.op == "call_function" + and out.target == torch.ops.aten.copy.default + ): + out.replace_all_uses_with(out.args[1]) # type: ignore[arg-type] + gm.graph.erase_node(out) + gm.recompile() + return gm, new_graph_signature + + def _common_getitem_elimination_pass( gm: torch.fx.GraphModule, graph_signature, module_call_graph ): From 793f9c5684348d400c98b17eebdb1cc496d2e666 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Wed, 4 Sep 2024 17:12:14 +0000 Subject: [PATCH 0452/1018] [inductor] [cpp] generate reindexer for each epilogue_node (#134984) Fixes the FP32 accuracy failure of `levit_128` in timm. Previously, we used `Y` which is the output of the final epilogue node to calculate the reindexer. We actually need to use each epilogue node to calculate the reindexer from the GEMM output to the epilogue node. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134984 Approved by: https://github.com/jgong5 --- test/inductor/test_cpu_select_algorithm.py | 92 ++++++++++++++++++++ torch/_inductor/codegen/cpp_gemm_template.py | 67 +++++++------- 2 files changed, 129 insertions(+), 30 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 0797e63d9aa974..6f3edb8d799bc2 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -487,6 +487,98 @@ def forward(self, x): vec_amx = VecAMX() self._check_amx_counter(vec_amx) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (8,)) + @parametrize("in_features", (128,)) + @parametrize("in_features_2", (196,)) + @parametrize("out_features", (256,)) + @parametrize( + "bias", + (True,), + ) + @dtypes(torch.float32) + def test_linear_with_multiple_reindexers( + self, + batch_size, + in_features, + in_features_2, + out_features, + bias, + dtype, + ): + flatten_BS = int(batch_size * in_features_2) + + # Reproducer from the levit_128 model in timm + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.conv = torch.nn.Conv2d( + 64, + 128, + kernel_size=3, + padding=1, + stride=2, + dilation=1, + groups=1, + ) + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self._frozen_param221 = torch.randn(out_features) + self._frozen_param389 = torch.randn(out_features) + self._frozen_param20 = torch.randn(out_features) + self._frozen_param21 = torch.randn(out_features) + + def forward(self, view_368): + _mkl_linear_57 = self.linear(view_368) + view_369 = torch.ops.aten.reshape.default( + _mkl_linear_57, [batch_size, in_features_2, out_features] + ) + _mkl_linear_57 = None + + view_370 = torch.ops.aten.reshape.default( + view_369, [flatten_BS, out_features] + ) + view_369 = None + sub_85 = torch.ops.aten.sub.Tensor(view_370, self._frozen_param221) + view_370 = _frozen_param221 = None + mul_261 = torch.ops.aten.mul.Tensor(sub_85, self._frozen_param389) + sub_85 = _frozen_param389 = None + mul_262 = torch.ops.aten.mul.Tensor(mul_261, self._frozen_param20) + mul_261 = _frozen_param20 = None + add_219 = torch.ops.aten.add.Tensor(mul_262, self._frozen_param21) + mul_262 = _frozen_param21 = None + view_371 = torch.ops.aten.reshape.default( + add_219, [batch_size, in_features_2, out_features] + ) + add_219 = None + + add_220 = torch.ops.aten.add.Tensor(view_371, 3) + clamp_min_35 = torch.ops.aten.clamp_min.default(add_220, 0) + add_220 = None + clamp_max_35 = torch.ops.aten.clamp_max.default(clamp_min_35, 6) + clamp_min_35 = None + mul_263 = torch.ops.aten.mul.Tensor(view_371, clamp_max_35) + view_371 = clamp_max_35 = None + div_51 = torch.ops.aten.div.Tensor(mul_263, 6) + mul_263 = None + + return div_51 + + view_368 = torch.randn(flatten_BS, in_features) + + mod = M(bias=bias).eval() + with verify(dtype) as (atol, rtol): + self.common( + mod, + (view_368,), + atol=atol, + rtol=rtol, + ) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) + @inductor_config.patch({"freezing": True}) @patches @torch.no_grad diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 6558e5ab5c8793..9942185d9e2d5c 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -874,38 +874,45 @@ def copy_inner(index): reindexers.extend([None] * len(epilogue_nodes)) Y_2d = Y else: - # From template_buffer to Y_ordered (ordered by stride decreasingly, in dense format), for example: - # template_buffer: - # size (324, 512), stride (512, 1) - # Y_ordered (ordered by stride decreasingly, in dense format): - # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) - stride_order = list( - ir.get_stride_order(V.graph.sizevars.size_hints(Y.get_stride())) - ) - fill_order = ir.stride_order2fill_order(stride_order) - reversed_fill_order = list(reversed(fill_order)) - size_with_stride_ordered_decreasingly = [ - Y.get_size()[i] for i in reversed_fill_order - ] - reshape_reindex = ir.View.dynamic_reshape_indexer( - size_with_stride_ordered_decreasingly, template_buffer.get_size() - ) - # From Y_ordered (ordered by stride decreasingly, in dense format) to Y, for example: - # Y_ordered (ordered by stride decreasingly, in dense format): - # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) - # Y: - # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) - from_stride_ordered_decreasingly_to_Y_order = [ - (len(stride_order) - 1) - stride_order[i] - for i in range(len(stride_order)) - ] - stride_reindex = ir.same_reorder( - from_stride_ordered_decreasingly_to_Y_order - ) + def get_reindexer(epilogue_node): + # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example: + # template_buffer: + # size (324, 512), stride (512, 1) + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + stride_order = list( + ir.get_stride_order( + V.graph.sizevars.size_hints(epilogue_node.get_stride()) + ) + ) + fill_order = ir.stride_order2fill_order(stride_order) + reversed_fill_order = list(reversed(fill_order)) + size_with_stride_ordered_decreasingly = [ + epilogue_node.get_size()[i] for i in reversed_fill_order + ] + reshape_reindex = ir.View.dynamic_reshape_indexer( + size_with_stride_ordered_decreasingly, + template_buffer.get_size(), + ) + + # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example: + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + # epilogue_node: + # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) + from_stride_ordered_decreasingly_to_epilogue_node_order = [ + (len(stride_order) - 1) - stride_order[i] + for i in range(len(stride_order)) + ] + stride_reindex = ir.same_reorder( + from_stride_ordered_decreasingly_to_epilogue_node_order + ) + + reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) + return reindexer - reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) - reindexers.extend([reindexer] * len(epilogue_nodes)) # type: ignore[list-item] + reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item] if isinstance(Y, ir.BaseView): storage = ir.StorageBox(Y.unwrap_view()) else: From cfaabf178c339ebf0eb98b841703ebb109387021 Mon Sep 17 00:00:00 2001 From: Yan Zhiwei Date: Tue, 3 Sep 2024 05:38:52 +0000 Subject: [PATCH 0453/1018] [Intel GPU] Customized XPU behaviour in indexing, group norm (#134453) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134453 Approved by: https://github.com/EikanWang, https://github.com/albanD ghstack dependencies: #133980 --- aten/src/ATen/native/TensorAdvancedIndexing.cpp | 8 ++++---- aten/src/ATen/native/group_norm.cpp | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 699f9393be2af5..33d5d19d8b888a 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -590,9 +590,9 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) } } - // For CUDA/MPS tensors, force all index tensors to have the same striding to - // simplify the CUDA/MPS kernel. - if (indices.size() >= 2 && (this->src.device().type() == kCUDA || this->src.device().type() == kMPS)) { + // For CUDA/MPS/XPU tensors, force all index tensors to have the same striding to + // simplify the CUDA/MPS/XPU kernel. + if (indices.size() >= 2 && (this->src.device().type() == kCUDA || this->src.device().type() == kMPS || this->src.device().type() == kXPU)) { if (!all_strides_match(indices)) { for (auto & indice : indices) { indice = indice.contiguous(); @@ -1808,7 +1808,7 @@ void scatter_impl( if (index.numel() == 0) return; auto op = ReductionType::SUM; - bool deterministic = globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA; + bool deterministic = globalContext().deterministicAlgorithms() && (self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU); if (reduce.has_value()) { op = get_operator_enum(reduce.value(), use_new_options); diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 22410e8690ce65..655b2c1e72173f 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -197,8 +197,7 @@ Tensor group_norm( const Tensor kEmpty; auto memory_format = input.suggest_memory_format(); - const auto& X = input.device().is_cpu() || input.device().is_xpu() ? - input.contiguous(memory_format) : input.contiguous(); + const auto& X = input.device().is_cpu() ? input.contiguous(memory_format) : input.contiguous(); const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty; const auto& beta = bias.defined() ? bias.contiguous() : kEmpty; TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C); From d2bd02f195b430c6553a26d99a6daa7dca51b72d Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 5 Sep 2024 07:50:01 +0000 Subject: [PATCH 0454/1018] restore CSE'd node metadata in runtime asserts pass (#134516) Adds val, and optionally stack_trace & nn_module_stack metadata back to SymInt compute nodes that we CSE, with a hook on `graph.create_node()`. Not sure if there's other metadata we want to populate here? Pull Request resolved: https://github.com/pytorch/pytorch/pull/134516 Approved by: https://github.com/ezyang --- test/distributed/test_dynamo_distributed.py | 1 - torch/fx/passes/runtime_assert.py | 117 ++++++++++++++------ 2 files changed, 84 insertions(+), 34 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 71f9065e7d4638..5ba20a7c948d41 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -423,7 +423,6 @@ def forward(self, x, y): opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512), torch.tensor([12, 13])) - @unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534" @config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True) def test_unbacked_symbol_splitting_no_binding(self): class Model(nn.Module): diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index e2a7c1bee3db1d..01803600d021ab 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import functools import logging import operator import sys @@ -92,6 +93,7 @@ def insert_deferred_runtime_asserts( # Import sympy locally import sympy + from torch._export.passes._node_metadata_hook import _set_node_metadata_hook from torch.fx.experimental.symbolic_shapes import ( _has_uninterpretable_sympy_function, CallMethodKey, @@ -145,6 +147,36 @@ def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: ) ) + # Figure out what key to use, val or example_value + val_key = "val" + for node in graph.nodes: + if "example_value" in node.meta: + val_key = "example_value" + break + elif "val" in node.meta: + break + + def _node_metadata_hook( + node: torch.fx.Node, + stack_trace: Optional[str] = None, + nn_module_stack: Optional[Dict[str, Any]] = None, + ) -> None: + fake_args = [ + _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ] + try: + node.meta[val_key] = node.target(*fake_args) # type: ignore[operator] + except NotImplementedError: + # This can happen when attempting to reify a symbol with an unsupported call_function node, + # e.g. with NestedTensors + sym_size.int via match_symbol(). + # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input. + pass + if stack_trace is not None: + node.meta["stack_trace"] = stack_trace + if nn_module_stack is not None: + node.meta["nn_module_stack"] = nn_module_stack + # Track asserts/checks we've added added_asserts: Set[sympy.Expr] = set() constrained_unbacked_symbols: Set[sympy.Symbol] = set() @@ -211,16 +243,17 @@ def add_runtime_asserts(ras): else: # Convert the sympy expression into a sequence of FX # nodes - res = _sympy_interp(expr_to_proxy, ra.expr).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - # TODO: use ra.msg here, but it's pretty - # useless right now - ( - res, - f"Runtime assertion failed for expression {ra.expr} on node '{res}'", - ), - ) + with _set_node_metadata_hook(gm, _node_metadata_hook): + res = _sympy_interp(expr_to_proxy, ra.expr).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + # TODO: use ra.msg here, but it's pretty + # useless right now + ( + res, + f"Runtime assertion failed for expression {ra.expr} on node '{res}'", + ), + ) added_asserts.add(ra.expr) nodes = list(graph.nodes) @@ -247,7 +280,8 @@ def match_symbol(symint, cb): and isinstance(s := symint.node.expr, sympy.Symbol) and s not in expr_to_proxy ): - expr_to_proxy[s] = fx.Proxy(cb()) + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(cb()) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) match_symbol(example_value, lambda: node) @@ -331,7 +365,15 @@ def match_symbol(symint, cb): if _is_intermediate_tensor_sym_call( node ): # reify from input shapes - expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] # won't try DCE-ing tensor compute here hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] node.replace_all_uses_with(hash_node) @@ -436,7 +478,8 @@ def go(node, keypath): raise AssertionError(f"unrecognized keypath {keypath}") if s not in expr_to_proxy: - expr_to_proxy[s] = fx.Proxy(go(node, keypath)) + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(go(node, keypath)) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) for i0 in defs: @@ -519,26 +562,34 @@ def convert(s): # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts # raises AOTAutograd errors on cast_symbool_to_symint_guardless - if (min_val := convert(vr.lower)) is not None: - ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - ge, - f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", - ), - ) - added_asserts.add(i0 >= min_val) - if (max_val := convert(vr.upper)) is not None: - le = _sympy_interp(expr_to_proxy, i0 <= max_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - le, - f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", - ), - ) - added_asserts.add(i0 <= max_val) + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + if (min_val := convert(vr.lower)) is not None: + ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + ge, + f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", + ), + ) + added_asserts.add(i0 >= min_val) + if (max_val := convert(vr.upper)) is not None: + le = _sympy_interp(expr_to_proxy, i0 <= max_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + le, + f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", + ), + ) + added_asserts.add(i0 <= max_val) constrained_unbacked_symbols.add(i0) add_runtime_asserts(ras) From d9690ef3e6268dff5da7e08bf6a35bdcc2545642 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Wed, 4 Sep 2024 22:26:01 -0700 Subject: [PATCH 0455/1018] [c10d] Change collective to take in a list of tensors so it work fully for all collectives (#135049) We found that currently, we only pass one input and output tensor to the function `collective`, and this causes NaNCheck, work numel stats and FR input/output sizes not accurate for all-to-all, scatter and reduce. So we want to let the collective take in a list of tensors to ensure it works for all collectives inside PGNCCL. This partially revert what we did in https://github.com/pytorch/pytorch/pull/119421, and down the road we will have another round of cleanup on the collective to make it cleaner. For now, at least for the sake of correctness, we changed it back. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135049 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_nccl.py | 2 +- torch/csrc/distributed/c10d/NCCLUtils.hpp | 8 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 120 ++++++++++++------ .../distributed/c10d/ProcessGroupNCCL.hpp | 12 ++ 4 files changed, 101 insertions(+), 41 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 36c3e28e16065a..c12a2e733c1339 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3598,7 +3598,7 @@ def started_or_scheduled(self, timing_enabled): class NCCLTraceTest(NCCLTraceTestBase): def _verify_trace(self, t, include_collectives, timing_enabled, is_json): ver = t["version"] - self.assertEqual(ver, "2.3") + self.assertEqual(ver, "2.4") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 4b0a197dbd536a..070cbd34b37970 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -176,14 +176,14 @@ namespace c10d { #define DEFINE_CONSTANT(name, value) \ static c10::IValue name = value; \ static std::string name##_str = value; -DEFINE_CONSTANT(entries_key, "entries"); -DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state"); -DEFINE_CONSTANT(version_key, "version"); // Update whenever changing contents or formatting of the dump // (minor when adding fields, major when changing existing fields) // Also update both JSON and Pickle dumps to make use of the newly defined // field(s). -DEFINE_CONSTANT(version_val, "2.3"); +DEFINE_CONSTANT(version_val, "2.4"); +DEFINE_CONSTANT(entries_key, "entries"); +DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state"); +DEFINE_CONSTANT(version_key, "version"); DEFINE_CONSTANT(pg_config_key, "pg_config"); DEFINE_CONSTANT(pg_status_key, "pg_status"); DEFINE_CONSTANT(record_id_key, "record_id"); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index ebf92808e801e4..3c0f237c5774d5 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2631,8 +2631,8 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { template c10::intrusive_ptr ProcessGroupNCCL::collective( - at::Tensor& input, - at::Tensor& output, + std::vector& inputs, + std::vector& outputs, Fn fn, PreProcess pre, PostProcess post, @@ -2652,7 +2652,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( seqCollective_++; op_id_++; - auto device = getDevice(input); + auto device = getDevice(inputs[0]); const auto key = getKeyFromDevice(device); auto ncclComm = getNCCLComm(key, device, opType); @@ -2677,28 +2677,25 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // First let NCCL streams wait for input tensors allocation streams syncStream(device, ncclEvents_[key], ncclStream); - std::vector inputs{input}; - std::vector outputs{output}; - bool enqueue = !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; auto work = initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue); // Store references to outputs to be used by WorkNCCL::result and operator<<. - work->outputs_ = - std::make_shared>(std::move(outputs)); + work->outputs_ = std::make_shared>(outputs); if (avoidRecordStreams) { work->stashed_for_allocator_safety_ = - std::make_shared>(); - work->stashed_for_allocator_safety_->push_back(input); + std::make_shared>(inputs); } at::cuda::OptionalCUDAGuard gpuGuard(device); if (nanCheck) { - checkForNan(input, ncclStream); + for (const auto& input : inputs) { + checkForNan(input, ncclStream); + } } // Start event should only be recorded before the ncclGroupStart() @@ -2719,25 +2716,35 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // // See [Sync Streams]. if (!avoidRecordStreams) { - if (!input.is_sparse()) { - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), ncclStream); - } else { - // for sparse input case record streams on both index and value - // tensors - c10::cuda::CUDACachingAllocator::recordStream( - input.values().storage().data_ptr(), ncclStream); - c10::cuda::CUDACachingAllocator::recordStream( - input.indices().storage().data_ptr(), ncclStream); + for (const auto& input : inputs) { + if (!input.is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + input.values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + input.indices().storage().data_ptr(), ncclStream); + } } } + +// Not all collectives have the same signature, e.g, all-reduce take in a Tensor +// as the input and output while all-to-all take in a vector of Tensors as input +// and output. Because we define the signature of the fn to take only single +// tensor as input and output, we need to do a hack to get the first element in +// the vector and pass it to fn. +// TODO: we should clean up this in future (by either entirely removing lambda's +// or removing input and output from lambda's signature). #ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( - fn(input, output, comm, ncclStream), + fn(inputs[0], outputs[0], comm, ncclStream), ncclComm->getNcclCommFailureReason()); #else C10D_NCCL_CHECK_TIMEOUT( - fn(input, output, comm, ncclStream), + fn(inputs[0], outputs[0], comm, ncclStream), comm, ncclComm->getNcclCommFailureReason()); #endif @@ -2779,8 +2786,14 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( assignTimeoutToWork(work, options_); // Record size info for debug. We only record the size on the first device as // multi-device per process is deprecated - work->numelIn_ = input.numel(); - work->numelOut_ = output.numel(); + work->numelIn_ = 0; + work->numelOut_ = 0; + for (const auto& input : inputs) { + work->numelIn_ += input.numel(); + } + for (const auto& output : outputs) { + work->numelOut_ += output.numel(); + } // Notify graphs before we check the capture status preemptively at::cuda::CUDAGraph::inc_pending_event_queries(); @@ -3222,6 +3235,31 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } } +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + pre, + post, + opType, + profilingTitle, + avoidRecordStreams, + nanCheck); +} + template c10::intrusive_ptr ProcessGroupNCCL::collective( at::Tensor& input, @@ -3231,9 +3269,11 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( const char* profilingTitle, bool avoidRecordStreams, bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; return collective( - input, - output, + inputs, + outputs, fn, [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, @@ -3532,10 +3572,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( ValueError, "Tensor input and output of _broadcast_oop must have the same number of elements "); } - const auto root = opts.rootRank + opts.rootTensor; bool nanCheck = (root == rank_); - return collective( inputTensor, outputTensor, @@ -3632,7 +3670,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( ValueError, "Tensor input and output of _reduce_oop must have the same number of elements "); } - return collective( inputTensor, outputTensor, @@ -4225,8 +4262,8 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( this->getSize()); // worldSize return collective( - inputTensors[0], - outputTensors[0], + inputTensors, + outputTensors, [&](at::Tensor& /* unused */, at::Tensor& /* unused */, ncclComm_t comm, @@ -4419,9 +4456,11 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( // avoidRecordStreams_ note: collective() will stash inputTensors and // outputs, which == outputTensors[0] on the root rank where it matters. + + auto inputs = std::vector{inputTensor}; return collective( - inputTensor, - outputs[0], // just to fit the collective interface + inputs, + outputs, // just to fit the collective interface [&](at::Tensor& /* unused */, at::Tensor& /* unused */, ncclComm_t comm, @@ -4438,6 +4477,10 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( torch::cuda::nccl::gather(inputTensor, outputs, comm, stream, root); return ncclSuccess; }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, OpType::GATHER, "nccl:gather"); } @@ -4511,9 +4554,10 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( const auto root = opts.rootRank; bool nanCheck = (rank_ == root); + auto outputs = std::vector{outputTensor}; return collective( - outputTensor, - inputs[0], // just to fit the collective interface + outputs, + inputs, // just to fit the collective interface [&](at::Tensor& /* unused */, at::Tensor& /* unused */, ncclComm_t comm, @@ -4529,6 +4573,10 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( torch::cuda::nccl::scatter(inputs, outputTensor, comm, stream, root); return ncclSuccess; }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, OpType::SCATTER, "nccl:scatter", avoidRecordStreams, diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index e099a75de58879..cc3455e7237904 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -777,6 +777,18 @@ class TORCH_API ProcessGroupNCCL : public Backend { bool avoidRecordStreams = false, bool nanCheck = true); + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, + bool nanCheck = true); + template c10::intrusive_ptr collectiveCoalesced( std::vector& input, From a322e0dd233cabbf46493b1fe2a7a15d0b0c6fa5 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Wed, 4 Sep 2024 17:12:15 +0000 Subject: [PATCH 0456/1018] [inductor] [cpp] use_local_acc if template_buffer_has_other_users (#135081) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix the compilation error of `coat_lite_mini` in timm and `YituTechConvBert` in HF: ``` /tmp/tmpuu94adg_/nf/cnf3zm677wbfjzzll522zvjp57g44udzfnj66ac2t5b2odvfqpts.cpp:239:33: error: invalid conversion from ‘const float*’ to ‘float*’ [-fpermissive] 239 | &(in_ptr2[static_cast(n_start + (192L*m_start) + (Nr*nci) + ((-1L)*Nr*nc))]), | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | | | const float* ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135081 Approved by: https://github.com/jgong5 ghstack dependencies: #134984 --- test/inductor/test_cpu_select_algorithm.py | 290 +++++++++++++++++++ torch/_inductor/codegen/cpp_gemm_template.py | 1 + 2 files changed, 291 insertions(+) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 6f3edb8d799bc2..6cebd99704da2f 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -666,6 +666,296 @@ def forward(self, view_12, input_ids, view_9): self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @parametrize("batch_size", (8,)) + @parametrize("in_features", (3,)) + @parametrize("in_features2", (192,)) + @parametrize("image_size", (224,)) + @parametrize("out_features", (64,)) + @parametrize( + "bias", + (True,), + ) + @dtypes(torch.float32) + def test_linear_with_in_out_buffer( + self, + batch_size, + in_features, + in_features2, + image_size, + out_features, + bias, + dtype, + ): + # Reproducer from the coat_lite_mini model in timm + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self._frozen_param398 = torch.randn(batch_size, out_features, 1, 1) + self.conv = torch.nn.Conv2d( + in_features, + out_features, + kernel_size=4, + padding=0, + stride=4, + dilation=1, + groups=1, + ) + self.conv2 = torch.nn.Conv2d( + out_features, + out_features, + kernel_size=3, + padding=1, + stride=1, + dilation=1, + groups=out_features, + ) + + self.conv3 = torch.nn.Conv2d( + 16, + 16, + kernel_size=3, + padding=1, + stride=1, + dilation=1, + groups=16, + ) + + self.conv4 = torch.nn.Conv2d( + 24, + 24, + kernel_size=5, + padding=2, + stride=1, + dilation=1, + groups=24, + ) + + self.conv5 = torch.nn.Conv2d( + 24, + 24, + kernel_size=7, + padding=3, + stride=1, + dilation=1, + groups=24, + ) + + self.linear = torch.nn.Linear(out_features, in_features2, bias) + + self.linear2 = torch.nn.Linear(out_features, out_features, bias) + self._frozen_param2 = torch.randn(out_features) + self._frozen_param3 = torch.randn(out_features) + self._frozen_param7 = torch.randn(out_features) + self._frozen_param8 = torch.randn(out_features) + self._frozen_param153 = torch.randn(batch_size, 1, out_features) + + def forward(self, arg152_1): + _convolution_pointwise_default_35 = self.conv(arg152_1) + arg152_1 = None + + view_168 = torch.ops.aten.reshape.default( + _convolution_pointwise_default_35, [8, 64, 3136] + ) + _convolution_pointwise_default_35 = None + permute_97 = torch.ops.aten.permute.default(view_168, [0, 2, 1]) + view_168 = None + clone_65 = torch.ops.aten.clone.default( + permute_97, memory_format=torch.contiguous_format + ) + permute_97 = None + var_mean_21 = torch.ops.aten.var_mean.correction( + clone_65, [2], correction=0, keepdim=True + ) + getitem_90 = var_mean_21[0] + getitem_91 = var_mean_21[1] + var_mean_21 = None + add_82 = torch.ops.aten.add.Tensor(getitem_90, 1e-05) + getitem_90 = None + rsqrt_21 = torch.ops.aten.rsqrt.default(add_82) + add_82 = None + sub_29 = torch.ops.aten.sub.Tensor(clone_65, getitem_91) + clone_65 = getitem_91 = None + mul_82 = torch.ops.aten.mul.Tensor(sub_29, rsqrt_21) + sub_29 = rsqrt_21 = None + mul_83 = torch.ops.aten.mul.Tensor(mul_82, self._frozen_param2) + mul_82 = None + add_83 = torch.ops.aten.add.Tensor(mul_83, self._frozen_param3) + mul_83 = None + _frozen_param153 = self._frozen_param153 + cat_20 = torch.ops.aten.cat.default([_frozen_param153, add_83], 1) + _frozen_param153 = add_83 = None + slice_111 = torch.ops.aten.slice.Tensor(cat_20, 1, 0, 1) + slice_113 = torch.ops.aten.slice.Tensor( + cat_20, 1, 1, 9223372036854775807 + ) + cat_20 = None + permute_98 = torch.ops.aten.permute.default(slice_113, [0, 2, 1]) + slice_113 = None + view_169 = torch.ops.aten.reshape.default(permute_98, [8, 64, 56, 56]) + permute_98 = None + _convolution_pointwise_default_34 = self.conv2(view_169) + + add_84 = torch.ops.aten.add.Tensor( + _convolution_pointwise_default_34, view_169 + ) + _convolution_pointwise_default_34 = view_169 = None + view_170 = torch.ops.aten.reshape.default(add_84, [8, 64, 3136]) + add_84 = None + permute_99 = torch.ops.aten.permute.default(view_170, [0, 2, 1]) + view_170 = None + cat_21 = torch.ops.aten.cat.default([slice_111, permute_99], 1) + slice_111 = permute_99 = None + var_mean_22 = torch.ops.aten.var_mean.correction( + cat_21, [2], correction=0, keepdim=True + ) + getitem_92 = var_mean_22[0] + getitem_93 = var_mean_22[1] + var_mean_22 = None + add_85 = torch.ops.aten.add.Tensor(getitem_92, 1e-06) + getitem_92 = None + rsqrt_22 = torch.ops.aten.rsqrt.default(add_85) + add_85 = None + sub_30 = torch.ops.aten.sub.Tensor(cat_21, getitem_93) + getitem_93 = None + mul_84 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_22) + sub_30 = rsqrt_22 = None + mul_85 = torch.ops.aten.mul.Tensor(mul_84, self._frozen_param7) + mul_84 = None + add_86 = torch.ops.aten.add.Tensor(mul_85, self._frozen_param8) + mul_85 = None + view_171 = torch.ops.aten.reshape.default(add_86, [25096, 64]) + add_86 = None + + _mkl_linear_32 = self.linear(view_171) + view_171 = None + + view_172 = torch.ops.aten.reshape.default( + _mkl_linear_32, [8, 3137, 192] + ) + _mkl_linear_32 = None + view_173 = torch.ops.aten.reshape.default(view_172, [8, 3137, 3, 8, 8]) + view_172 = None + permute_101 = torch.ops.aten.permute.default(view_173, [2, 0, 3, 1, 4]) + view_173 = None + unbind_8 = torch.ops.aten.unbind.int(permute_101) + permute_101 = None + getitem_94 = unbind_8[0] + getitem_95 = unbind_8[1] + getitem_96 = unbind_8[2] + unbind_8 = None + clone_66 = torch.ops.aten.clone.default( + getitem_95, memory_format=torch.contiguous_format + ) + getitem_95 = None + amax_8 = torch.ops.aten.amax.default(clone_66, [2], True) + sub_31 = torch.ops.aten.sub.Tensor(clone_66, amax_8) + clone_66 = amax_8 = None + exp_8 = torch.ops.aten.exp.default(sub_31) + sub_31 = None + sum_9 = torch.ops.aten.sum.dim_IntList(exp_8, [2], True) + div_8 = torch.ops.aten.div.Tensor(exp_8, sum_9) + exp_8 = sum_9 = None + permute_102 = torch.ops.aten.permute.default(div_8, [0, 1, 3, 2]) + div_8 = None + expand_37 = torch.ops.aten.expand.default(permute_102, [8, 8, 8, 3137]) + permute_102 = None + view_174 = torch.ops.aten.reshape.default(expand_37, [64, 8, 3137]) + expand_37 = None + expand_38 = torch.ops.aten.expand.default(getitem_96, [8, 8, 3137, 8]) + clone_67 = torch.ops.aten.clone.default( + expand_38, memory_format=torch.contiguous_format + ) + expand_38 = None + view_175 = torch.ops.aten.reshape.default(clone_67, [64, 3137, 8]) + clone_67 = None + bmm_16 = torch.ops.aten.bmm.default(view_174, view_175) + view_174 = view_175 = None + view_176 = torch.ops.aten.reshape.default(bmm_16, [8, 8, 8, 8]) + bmm_16 = None + expand_39 = torch.ops.aten.expand.default(getitem_94, [8, 8, 3137, 8]) + clone_68 = torch.ops.aten.clone.default( + expand_39, memory_format=torch.contiguous_format + ) + expand_39 = None + view_177 = torch.ops.aten.reshape.default(clone_68, [64, 3137, 8]) + clone_68 = None + expand_40 = torch.ops.aten.expand.default(view_176, [8, 8, 8, 8]) + view_176 = None + view_178 = torch.ops.aten.reshape.default(expand_40, [64, 8, 8]) + expand_40 = None + bmm_17 = torch.ops.aten.bmm.default(view_177, view_178) + view_177 = view_178 = None + view_179 = torch.ops.aten.reshape.default(bmm_17, [8, 8, 3137, 8]) + bmm_17 = None + slice_116 = torch.ops.aten.slice.Tensor( + getitem_94, 2, 1, 9223372036854775807 + ) + getitem_94 = None + slice_120 = torch.ops.aten.slice.Tensor( + getitem_96, 2, 1, 9223372036854775807 + ) + getitem_96 = None + permute_103 = torch.ops.aten.permute.default(slice_120, [0, 1, 3, 2]) + slice_120 = None + view_180 = torch.ops.aten.reshape.default(permute_103, [8, 64, 56, 56]) + permute_103 = None + split_with_sizes_8 = torch.ops.aten.split_with_sizes.default( + view_180, [16, 24, 24], 1 + ) + view_180 = None + getitem_97 = split_with_sizes_8[0] + getitem_98 = split_with_sizes_8[1] + getitem_99 = split_with_sizes_8[2] + split_with_sizes_8 = None + + _convolution_pointwise_default_33 = self.conv3(getitem_97) + _convolution_pointwise_default_32 = self.conv4(getitem_98) + _convolution_pointwise_default_31 = self.conv5(getitem_99) + + cat_22 = torch.ops.aten.cat.default( + [ + _convolution_pointwise_default_33, + _convolution_pointwise_default_32, + _convolution_pointwise_default_31, + ], + 1, + ) + _convolution_pointwise_default_33 = ( + _convolution_pointwise_default_32 + ) = _convolution_pointwise_default_31 = None + view_181 = torch.ops.aten.reshape.default(cat_22, [8, 8, 8, 3136]) + cat_22 = None + permute_104 = torch.ops.aten.permute.default(view_181, [0, 1, 3, 2]) + view_181 = None + + mul_86 = torch.ops.aten.mul.Tensor(slice_116, permute_104) + slice_116 = permute_104 = None + constant_pad_nd_8 = torch.ops.aten.constant_pad_nd.default( + mul_86, [0, 0, 1, 0, 0, 0], 0.0 + ) + mul_86 = None + mul_87 = torch.ops.aten.mul.Tensor(view_179, 0.3535533905932738) + view_179 = None + add_87 = torch.ops.aten.add.Tensor(mul_87, constant_pad_nd_8) + mul_87 = constant_pad_nd_8 = None + return add_87 + + view_12 = torch.randn(batch_size, in_features, image_size, image_size) + + mod = M(bias=bias).eval() + with verify(dtype) as (atol, rtol): + self.common( + mod, + (view_12,), + atol=atol, + rtol=rtol, + ) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + @inductor_config.patch({"freezing": True}) @patches @torch.no_grad diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 9942185d9e2d5c..a0856a15973349 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -759,6 +759,7 @@ def render( # type: ignore[override,return] use_local_acc = ( self.layout.dtype != torch.float + or template_buffer_has_other_users or int8_gemm or self.padded_n != self.n or self.maybe_k_slicing() From 641703099b49a4647fee2a4443bbebea748f2b90 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 5 Sep 2024 08:36:35 +0000 Subject: [PATCH 0457/1018] Enhance the stability of the complex divide code (#134647) In C++, when a floating-point literal (e.g., 3.14) is compared with a variable of type float, the literal is by default interpreted as a double. ```c++ float f = 3.14f; if (f == 3.14) { // Do something } ``` If a device does not support double, an error will occur. This PR addresses the issue of complex64 errors on machines that do not support double operations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134647 Approved by: https://github.com/EikanWang, https://github.com/albanD --- c10/util/complex.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/c10/util/complex.h b/c10/util/complex.h index af810a780dd547..c08e10aa0f2933 100644 --- a/c10/util/complex.h +++ b/c10/util/complex.h @@ -261,19 +261,19 @@ struct alignas(sizeof(T) * 2) complex { #endif if (abs_c >= abs_d) { - if (abs_c == 0 && abs_d == 0) { + if (abs_c == U(0) && abs_d == U(0)) { /* divide by zeros should yield a complex inf or nan */ real_ = a / abs_c; imag_ = b / abs_d; } else { auto rat = d / c; - auto scl = 1.0 / (c + d * rat); + auto scl = U(1.0) / (c + d * rat); real_ = (a + b * rat) * scl; imag_ = (b - a * rat) * scl; } } else { auto rat = c / d; - auto scl = 1.0 / (d + c * rat); + auto scl = U(1.0) / (d + c * rat); real_ = (a * rat + b) * scl; imag_ = (b * rat - a) * scl; } From 13adbe3ff5fa0c685c51c1072e44fe9c5d7d5706 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Thu, 5 Sep 2024 09:01:06 +0000 Subject: [PATCH 0458/1018] Update unbacked symints in masked_select more precisely (#134899) ## Summary At the moment, the fake impl for `masked_select` simply sets the upper range while updating its size-like SymInt to `sys.maxsize`(9223372036854775807, max value for an unsigned int64) if the there are any SymInts in the original input tensor shape. This PR constrains the range more intelligently by using the upper ranges of each SymInt in the input tensor shape. This solves an issue where an model being lowered to Executorch errors during memory planning because the memory allocated for `masked_select` ended up exceeded the 64-bit address space (`INT_MAX * size(dtype)`). ## Test plan - Passes existing unit tests (tests case where upper bound is inf) - Added unit test to verify upper bound reduction calculation - Tested end-to-end by exporting with TORCH_LOGS="export" and ensuring that the range for `masked_select`'s SymInt size has the correct upper bound Pull Request resolved: https://github.com/pytorch/pytorch/pull/134899 Approved by: https://github.com/ezyang --- test/export/test_export.py | 32 ++++++++++++++++++++++++++++++++ torch/_subclasses/fake_impls.py | 18 ++++++++++++++++-- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index cd4a9aeb9e590f..fca1605716f24c 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -520,6 +520,38 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) self.assertEqual(gm(*args), m(*args)) + def test_masked_select_dynamic(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mask = x.ge(0.5) + return torch.masked_select(x, mask) + + example_args = (torch.randn(3, 4, 5),) + dim0_x_max, dim1_x_max = 100, 7 + dynamic_shapes = { + "x": { + 0: Dim("dim0_x", max=dim0_x_max), + 1: Dim("dim1_x_max", max=dim1_x_max), + } + } + m = M() + exported_program: torch.export.ExportedProgram = export( + m, args=example_args, dynamic_shapes=dynamic_shapes + ) + + # Test that the expected upper bound is among the range constraints. + expected_upper_bound = dim0_x_max * dim1_x_max * 5 + vr_upper_bounds = [ + vr.upper for vr in exported_program.range_constraints.values() + ] + self.assertTrue(expected_upper_bound in set(vr_upper_bounds)) + # Test that none of the upper bounds are larger. + for vr_upper in vr_upper_bounds: + self.assertTrue(vr_upper <= expected_upper_bound) + def test_setgrad_lifted_tensor(self): class M(torch.nn.Module): def forward(self, x, y): diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index c8acc90567c5eb..1ba37b6f06bddf 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -2,6 +2,7 @@ import functools import itertools +import math import sys from typing import Callable, Union @@ -455,10 +456,23 @@ def masked_select(fake_mode, func, self, mask): _constrain_range_for_size, has_free_symbols, ) + from torch.utils._sympy.numbers import IntInfinity + from torch.utils._sympy.value_ranges import bound_sympy + # If num elements is expressed symbolically, calculate + # the concrete value based on upper bounds. Otherwise, + # we can set max val directly. if not has_free_symbols(self.numel()): - if self.numel() > 2: - maxval = int(self.numel()) + num_elements = int(self.numel()) + else: + prod_node = math.prod(self.shape).node + prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range) + if isinstance(prod_range.upper, IntInfinity): + num_elements = sys.maxsize - 1 + else: + num_elements = prod_range.upper + if num_elements > 2: + maxval = num_elements _constrain_range_for_size(nnz, max=maxval) From b4a41dca9bd0a2472ba76280cfcca921a1944922 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Thu, 5 Sep 2024 10:05:23 +0000 Subject: [PATCH 0459/1018] Update torch-xpu-ops pin (ATen XPU implementation) (#135185) Release cycle for PyTorch 2.5 1. Update specific AOT targets for Windows. On Windows, AOT target list prefers Intel client GPUs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135185 Approved by: https://github.com/EikanWang --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index e71ba25bc5f08f..42f3d8ca513447 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -78e33d0afc3376a45e6184132f1e9fc154db3cae +17f4bd5b7ad63de408cdacee961d03c51cc75ab4 From 45c83e142ddb8b8e96192b35997303c692aadc3b Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 4 Sep 2024 14:24:12 -0700 Subject: [PATCH 0460/1018] [AOTI] Fix a unbacked symint retrieve bug (#134670) Summary: Fix https://github.com/pytorch/pytorch/issues/134081. When a unbacked symint is computed as the shape of a tensor from a tuple, generated C++ code needs to use std::get<> to extract the tensor. Differential Revision: [D62142113](https://our.internmc.facebook.com/intern/diff/D62142113) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134670 Approved by: https://github.com/angelayi, https://github.com/22quinn, https://github.com/chenyang78 --- test/inductor/test_aot_inductor.py | 15 +++++++++++++++ torch/_inductor/codegen/cpp_wrapper_cuda.py | 2 +- torch/_inductor/ir.py | 6 +++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 784c52a50dd990..a39ac7347424ff 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3421,6 +3421,20 @@ def forward(self, x, y): count, ).run(code) + def test_size_from_multi_output(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + _x, _i = torch.unique(x, sorted=True, return_inverse=True) + _x = _x.clone().detach() + return self.relu(_x), _i + + example_inputs = (torch.randn(8, device=self.device),) + self.check_model(Model(), example_inputs) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) @@ -3609,6 +3623,7 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_custom_op_missing_arg_with_default_value": fail_minimal_arrayref_interface( is_skip=True ), + "test_size_from_multi_output": fail_stack_allocation(is_skip=True), } # test_failures, xfail by default, set is_skip=True to skip diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index cc406e4dac4ec9..b82a707961df9e 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -63,7 +63,7 @@ def _new_line(self, line): class DeferredCudaDefaultGrid: """ - A marker to + A container for the default grid, which may be used by DeferredCudaGridLine """ def __init__( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b9f08fedfe847c..3c2e36286a7abd 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5864,7 +5864,11 @@ def go(expr, keypath): elif isinstance(keypath[0], CallMethodKey): return go(f"{expr}.{keypath[0].name}()", keypath[1:]) elif isinstance(keypath[0], pytree.SequenceKey): - return go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + return ( + go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:]) + if V.graph.cpp_wrapper + else go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + ) elif isinstance(keypath[0], DivideByKey): # TODO: need to assert divisibility # TODO: this is invalid C++ codegen From 11c2f957687cfa8575850a842f0fe57fbe646264 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Thu, 5 Sep 2024 12:14:53 +0000 Subject: [PATCH 0461/1018] Revise CPU vectorization ISA support API (#135075) Revising (mostly renaming) CPU vectorization ISA support API (non-frontend-user-facing). Also added AVX512_BF16 ISA detection API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135075 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/ezyang --- aten/src/ATen/cpu/Utils.cpp | 18 +++++++++++++----- aten/src/ATen/cpu/Utils.h | 11 +++++++---- torch/_C/_cpu.pyi | 9 +++++---- torch/_dynamo/trace_rules.py | 18 ++++++++++-------- torch/_inductor/cpu_vec_isa.py | 6 +++--- torch/ao/quantization/qconfig.py | 2 +- torch/cpu/__init__.py | 21 +++++++++++++-------- torch/csrc/cpu/Module.cpp | 9 +++++---- 8 files changed, 57 insertions(+), 37 deletions(-) diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index a50e52155b54ae..4455d4c1177312 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -9,7 +9,7 @@ #endif namespace at::cpu { -bool is_cpu_support_avx2() { +bool is_avx2_supported() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_avx2(); #else @@ -17,7 +17,7 @@ bool is_cpu_support_avx2() { #endif } -bool is_cpu_support_avx512() { +bool is_avx512_supported() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq(); #else @@ -25,7 +25,7 @@ bool is_cpu_support_avx512() { #endif } -bool is_cpu_support_avx512_vnni() { +bool is_avx512_vnni_supported() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_avx512vnni(); #else @@ -33,7 +33,15 @@ bool is_cpu_support_avx512_vnni() { #endif } -bool is_cpu_support_amx_tile() { +bool is_avx512_bf16_supported() { +#if !defined(__s390x__) && !defined(__powerpc__) + return cpuinfo_initialize() && cpuinfo_has_x86_avx512bf16(); +#else + return false; +#endif +} + +bool is_amx_tile_supported() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_amx_tile(); #else @@ -42,7 +50,7 @@ bool is_cpu_support_amx_tile() { } bool init_amx() { - if (!is_cpu_support_amx_tile()) { + if (!is_amx_tile_supported()) { return false; } diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index 7498367fee4ea7..ad918dde7e0599 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -6,14 +6,17 @@ namespace at::cpu { -TORCH_API bool is_cpu_support_avx2(); -TORCH_API bool is_cpu_support_avx512(); +TORCH_API bool is_avx2_supported(); +TORCH_API bool is_avx512_supported(); // Detect if CPU support Vector Neural Network Instruction. -TORCH_API bool is_cpu_support_avx512_vnni(); +TORCH_API bool is_avx512_vnni_supported(); + +// Detect if CPU supports AVX512_BF16 ISA +TORCH_API bool is_avx512_bf16_supported(); // Detect if CPU support Advanced Matrix Extension. -TORCH_API bool is_cpu_support_amx_tile(); +TORCH_API bool is_amx_tile_supported(); // Enable the system to use AMX instructions. TORCH_API bool init_amx(); diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index 353ad12f5b7d3e..6593222a119f4d 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -2,10 +2,11 @@ from torch.types import _bool, _int # Defined in torch/csrc/cpu/Module.cpp -def _is_cpu_support_avx2() -> _bool: ... -def _is_cpu_support_avx512() -> _bool: ... -def _is_cpu_support_avx512_vnni() -> _bool: ... -def _is_cpu_support_amx_tile() -> _bool: ... +def _is_avx2_supported() -> _bool: ... +def _is_avx512_supported() -> _bool: ... +def _is_avx512_vnni_supported() -> _bool: ... +def _is_avx512_bf16_supported() -> _bool: ... +def _is_amx_tile_supported() -> _bool: ... def _init_amx() -> _bool: ... def _L1d_cache_size() -> _int: ... def _L2_cache_size() -> _int: ... diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 5eee636438fed2..7c70ed0761f839 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -413,10 +413,11 @@ "torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata", "torch._C._construct_storage_from_data_pointer", "torch._C._conv_determine_backend_memory_format", - "torch._C._cpu._is_cpu_support_avx2", - "torch._C._cpu._is_cpu_support_avx512", - "torch._C._cpu._is_cpu_support_avx512_vnni", - "torch._C._cpu._is_cpu_support_amx_tile", + "torch._C._cpu._is_avx2_supported", + "torch._C._cpu._is_avx512_supported", + "torch._C._cpu._is_avx512_vnni_supported", + "torch._C._cpu._is_avx512_bf16_supported", + "torch._C._cpu._is_amx_tile_supported", "torch._C._cpu._init_amx", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", @@ -2431,10 +2432,11 @@ "torch.chain_matmul", "torch.compile", "torch.compiled_with_cxx11_abi", - "torch.cpu._is_cpu_support_avx2", - "torch.cpu._is_cpu_support_avx512", - "torch.cpu._is_cpu_support_avx512_vnni", - "torch.cpu._is_cpu_support_amx_tile", + "torch._C._cpu._is_avx2_supported", + "torch._C._cpu._is_avx512_supported", + "torch._C._cpu._is_avx512_vnni_supported", + "torch._C._cpu._is_avx512_bf16_supported", + "torch._C._cpu._is_amx_tile_supported", "torch.cpu._init_amx", "torch.cpu.current_device", "torch.cpu.current_stream", diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 7cadacfdf1c47b..bc4838e5f16855 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -296,9 +296,9 @@ def _check_and_append_supported_isa( if Arch != "x86_64" and Arch != "AMD64": return supported_isa - avx2 = torch.cpu._is_cpu_support_avx2() - avx512 = torch.cpu._is_cpu_support_avx512() - amx_tile = torch.cpu._is_cpu_support_amx_tile() + avx2 = torch.cpu._is_avx2_supported() + avx512 = torch.cpu._is_avx512_supported() + amx_tile = torch.cpu._is_amx_tile_supported() _check_and_append_supported_isa(supported_isa, avx2, "avx2") _check_and_append_supported_isa(supported_isa, avx512, "avx512") diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 763c181be0caa6..a867acbeb1cb86 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -281,7 +281,7 @@ def get_default_qconfig(backend="x86", version=0): weight=default_weight_observer, ) elif backend == "onednn": - if not torch.cpu._is_cpu_support_vnni(): + if not torch.cpu._is_vnni_supported(): warnings.warn( "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " "on CPU without Vector Neural Network Instruction support." diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 978405c0410c69..8443e0447aa25a 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -29,25 +29,30 @@ _device_t = Union[_device, str, int, None] -def _is_cpu_support_avx2() -> bool: +def _is_avx2_supported() -> bool: r"""Returns a bool indicating if CPU supports AVX2.""" - return torch._C._cpu._is_cpu_support_avx2() + return torch._C._cpu._is_avx2_supported() -def _is_cpu_support_avx512() -> bool: +def _is_avx512_supported() -> bool: r"""Returns a bool indicating if CPU supports AVX512.""" - return torch._C._cpu._is_cpu_support_avx512() + return torch._C._cpu._is_avx512_supported() -def _is_cpu_support_vnni() -> bool: +def _is_avx512_bf16_supported() -> bool: + r"""Returns a bool indicating if CPU supports AVX512_BF16.""" + return torch._C._cpu._is_avx512_bf16_supported() + + +def _is_vnni_supported() -> bool: r"""Returns a bool indicating if CPU supports VNNI.""" # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later. - return torch._C._cpu._is_cpu_support_avx512_vnni() + return torch._C._cpu._is_avx512_vnni_supported() -def _is_cpu_support_amx_tile() -> bool: +def _is_amx_tile_supported() -> bool: r"""Returns a bool indicating if CPU supports AMX_TILE.""" - return torch._C._cpu._is_cpu_support_amx_tile() + return torch._C._cpu._is_amx_tile_supported() def _init_amx() -> bool: diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index 700623782fe47d..84eb864d2ceca9 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -8,10 +8,11 @@ void initModule(PyObject* module) { auto m = py::handle(module).cast(); auto cpu = m.def_submodule("_cpu", "cpu related pybind."); - cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2); - cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512); - cpu.def("_is_cpu_support_avx512_vnni", at::cpu::is_cpu_support_avx512_vnni); - cpu.def("_is_cpu_support_amx_tile", at::cpu::is_cpu_support_amx_tile); + cpu.def("_is_avx2_supported", at::cpu::is_avx2_supported); + cpu.def("_is_avx512_supported", at::cpu::is_avx512_supported); + cpu.def("_is_avx512_vnni_supported", at::cpu::is_avx512_vnni_supported); + cpu.def("_is_avx512_bf16_supported", at::cpu::is_avx512_bf16_supported); + cpu.def("_is_amx_tile_supported", at::cpu::is_amx_tile_supported); cpu.def("_init_amx", at::cpu::init_amx); cpu.def("_L1d_cache_size", at::cpu::L1d_cache_size); cpu.def("_L2_cache_size", at::cpu::L2_cache_size); From b07b0f4dd2f07434ecdb23088f3d6ac711937b2f Mon Sep 17 00:00:00 2001 From: Xinyu Date: Thu, 5 Sep 2024 12:15:08 +0000 Subject: [PATCH 0462/1018] [Dynamo] Support builtin function frozenset (#134563) Support builtin function frozenset in dynamo Pull Request resolved: https://github.com/pytorch/pytorch/pull/134563 Approved by: https://github.com/anijain2305, https://github.com/EikanWang, https://github.com/jansel --- test/dynamo/test_functions.py | 65 +++++++++++++++++++++++++++++ torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/builtin.py | 15 +++++++ torch/_dynamo/variables/dicts.py | 47 +++++++++++++++++++++ 4 files changed, 128 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 3359b6128f4133..a0e3369ed73c7e 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -3394,6 +3394,71 @@ def fn(x): ref = opt_fn(x) self.assertEqual(ref, res) + def test_frozenset_construction(self): + def fn(x): + s = frozenset({x}) + t = frozenset(s) + return len(t) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + res = fn(x) + ref = opt_fn(x) + self.assertEqual(ref, res) + + def test_frozenset_reconstruction(self): + d = {} + f = frozenset() + d[f] = torch.randn(4) + + def fn(x): + k = frozenset() + torch._dynamo.graph_break() + return d[k] * x + + opt_fn = torch.compile(fn, backend="eager") + x = torch.randn(4) + res = fn(x) + ref = opt_fn(x) + self.assertEqual(ref, res) + + def test_frozenset_illegal_call_method(self): + def fn_add(): + s = frozenset((1, 2, 3)) + s.add({2}) + return len(s) + + def fn_pop(): + s = frozenset((1, 2, 3)) + s.pop() + return len(s) + + def fn_update(): + s = frozenset((1, 2, 3)) + s.update({4, 5, 6}) + return len(s) + + def fn_remove(): + s = frozenset((1, 2, 3)) + s.remove(2) + return len(s) + + def fn_discard(): + s = frozenset((1, 2, 3)) + s.discard(2) + return len(s) + + def fn_clear(): + s = frozenset((1, 2, 3)) + s.clear() + return len(s) + + for fn in [fn_add, fn_pop, fn_update, fn_remove, fn_discard, fn_clear]: + torch._dynamo.reset() + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError): + opt_fn() + def test_is_tensor_tensor(self): def fn(x, y): if x is y: diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index a6a46bb3932004..021233cd6e235f 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -24,6 +24,7 @@ ConstDictVariable, CustomizedDictVariable, DefaultDictVariable, + FrozensetVariable, SetVariable, ) from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index b19bee7e1b2f81..d330a4944045d0 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -49,6 +49,7 @@ ConstDictVariable, DefaultDictVariable, DictView, + FrozensetVariable, is_hashable, SetVariable, ) @@ -1428,6 +1429,20 @@ def call_set(self, tx: "InstructionTranslator", *args, **kwargs): else: unimplemented(f"set(): {args} {kwargs}") + def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs): + assert not kwargs + if not args: + return FrozensetVariable([]) + assert len(args) == 1 + arg = args[0] + if isinstance(arg, variables.FrozensetVariable): + return FrozensetVariable([x.vt for x in arg.set_items]) + elif arg.has_unpack_var_sequence(tx): + items = arg.unpack_var_sequence(tx) + return FrozensetVariable(items) + else: + unimplemented(f"frozenset(): {args} {kwargs}") + def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): if kwargs: assert len(kwargs) == 1 and "strict" in kwargs diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index f6eed6f7bc3c62..6c69364ca57ee9 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -540,6 +540,53 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): raise RuntimeError("Illegal to getitem on a set") +class FrozensetVariable(SetVariable): + def __init__( + self, + items: List[VariableTracker], + **kwargs, + ) -> None: + super().__init__(items, **kwargs) + + def debug_repr(self): + if not self.items: + return "frozenset()" + else: + return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" + + @property + def set_items(self): + return self.items.keys() + + def python_type(self): + return frozenset + + def as_python_constant(self): + return {k.vt.as_python_constant() for k in self.set_items} + + def reconstruct(self, codegen): + codegen.foreach([x.vt for x in self.set_items]) + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_global("frozenset"), + ] + ) + ) + codegen.extend_output(create_call_function(0, False)) + + def call_method( + self, + tx, + name, + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ) -> "VariableTracker": + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a frozenset") + return super().call_method(tx, name, args, kwargs) + + class DictView(VariableTracker): """ Models _PyDictViewObject From b7d6649e2fc85fa9734de42863cb55cc530b4e47 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 5 Sep 2024 12:52:43 +0000 Subject: [PATCH 0463/1018] [Intel GPU][Windows] Fix overriding default CMAKE_CXX_FLAGS (#135093) The root cause is that `/EHsc` is part of the default `CMAKE_CXX_FLAGS` in CMake. Fix to not override the default `CMAKE_CXX_FLAGS`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135093 Approved by: https://github.com/EikanWang, https://github.com/atalman --- cmake/Modules/FindMKLDNN.cmake | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index b69378b6b9fb96..a8d42b6ee8da6a 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -43,7 +43,9 @@ IF(NOT MKLDNN_FOUND) endif() endif() if(LINUX) - set(ABI_NEUTRAL_FLAGS -fpreview-breaking-changes) + set(DNNL_CXX_FLAGS "-DCMAKE_CXX_FLAGS=-fpreview-breaking-changes") + else() + set(DNNL_CXX_FLAGS "") endif() ExternalProject_Add(xpu_mkldnn_proj SOURCE_DIR ${MKLDNN_ROOT} @@ -51,7 +53,7 @@ IF(NOT MKLDNN_FOUND) BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=${SYCL_CXX_DRIVER} - -DCMAKE_CXX_FLAGS=${ABI_NEUTRAL_FLAGS} + ${DNNL_CXX_FLAGS} -DDNNL_GPU_RUNTIME=SYCL -DDNNL_CPU_RUNTIME=THREADPOOL -DDNNL_BUILD_TESTS=OFF From e8e2d64d0646d5a9ffaf3c426ea691d7ae65164e Mon Sep 17 00:00:00 2001 From: Stonepia Date: Thu, 5 Sep 2024 14:36:23 +0000 Subject: [PATCH 0464/1018] [CI] Use larger instance for building triton whl (#135201) When running CI jobs of "Build Triton Wheels", it failed due to the lack of resources. This PR uses a larger runner to avoid these issues. The failure message is like: ``` Process completed with exit code 137. ``` Related running actions: Failed actions: https://github.com/pytorch/pytorch/actions/runs/10714445036 Success actions: https://github.com/pytorch/pytorch/actions/runs/10716710830 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135201 Approved by: https://github.com/chuanqi129, https://github.com/atalman --- .github/workflows/build-triton-wheel.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index d6a83eaf3654cc..a8240df9788dfb 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -31,7 +31,7 @@ concurrency: jobs: build-wheel: name: "Build Triton Wheel" - runs-on: [self-hosted, linux.2xlarge] + runs-on: [self-hosted, linux.4xlarge] strategy: fail-fast: false matrix: From 386e18808c69100f2e43f30a6f768a64b1eb3d15 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 5 Sep 2024 09:33:03 -0400 Subject: [PATCH 0465/1018] Upgrade expecttest to 0.2.1 (#135136) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135136 Approved by: https://github.com/albanD, https://github.com/atalman, https://github.com/Skylion007 --- .ci/docker/requirements-ci.txt | 4 ++-- .github/requirements/pip-requirements-macOS.txt | 2 +- .github/workflows/lint.yml | 2 +- .lintrunner.toml | 2 +- requirements.txt | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 9a9b568dd7e2b9..07626204e6683c 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -30,10 +30,10 @@ dill==0.3.7 #Pinned versions: 0.3.7 #test that import: dynamo/test_replay_record.py test_dataloader.py test_datapipe.py test_serialization.py -expecttest==0.1.6 +expecttest==0.2.1 #Description: method for writing tests where test framework auto populates # the expected output based on previous runs -#Pinned versions: 0.1.6 +#Pinned versions: 0.2.1 #test that import: flatbuffers==2.0 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index c3221044889c74..64396398274ac4 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -1,6 +1,6 @@ boto3==1.19.12 hypothesis==6.56.4 -expecttest==0.1.6 +expecttest==0.2.1 librosa>=0.6.2 mpmath==1.3.0 networkx==2.8.7 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f86d7270612017..84a2a370297d8c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -223,7 +223,7 @@ jobs: cache: pip - name: Install dependencies run: | - pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.1.* numpy==1.24.* + pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.2.* numpy==1.24.* pip install torch --pre --index-url https://download.pytorch.org/whl/nightly/cpu/ - name: Run run_test.py (nonretryable) run: | diff --git a/.lintrunner.toml b/.lintrunner.toml index 55849ff4695a03..4789991736be17 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -138,7 +138,7 @@ init_command = [ '--dry-run={{DRYRUN}}', 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', - 'expecttest==0.1.6', + 'expecttest==0.2.1', 'mypy==1.10.0', 'sympy==1.12.1 ; python_version == "3.8"', 'sympy==1.13.0 ; python_version >= "3.9"', diff --git a/requirements.txt b/requirements.txt index 8ebae971e488f8..477332375872f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Python dependencies required for development astunparse -expecttest!=0.2.0 +expecttest>=0.2.1 hypothesis numpy psutil From 5ea5290656af697db1e8f6788bdf89a080966c0a Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Thu, 5 Sep 2024 16:11:18 +0000 Subject: [PATCH 0466/1018] Switch torch pt2e xnnpack tests to use export_for_training (#134788) Migrate all the callsites inside the pt2e XNNPACK tests to use export_for_training. Differential Revision: D61994553 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134788 Approved by: https://github.com/mergennachin --- .../pt2e/test_xnnpack_quantizer.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index 5e850a68482886..0e88a7793ce475 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -4,7 +4,6 @@ import torch import torch._dynamo as torchdynamo -from torch._export import capture_pre_autograd_graph from torch.ao.ns.fx.utils import compute_sqnr from torch.ao.quantization import ( default_dynamic_fake_quant, @@ -30,6 +29,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, PT2EQuantizationTestCase, @@ -361,7 +361,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = capture_pre_autograd_graph(m, example_inputs) + m = export_for_training(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. @@ -497,10 +497,10 @@ def test_propagate_annotation(self): example_inputs = (torch.randn(1, 3, 5, 5),) # program capture - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -754,10 +754,10 @@ def forward(self, input_tensor, hidden_tensor): model_fx = _convert_to_reference_decomposed_fx(model_fx) with torchdynamo.config.patch(allow_rnn=True): - model_graph = capture_pre_autograd_graph( + model_graph = export_for_training( model_graph, example_inputs, - ) + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=False, is_dynamic=False @@ -818,10 +818,10 @@ def forward(self, input_tensor, hidden_tensor): model_fx = _convert_to_reference_decomposed_fx(model_fx) with torchdynamo.config.patch(allow_rnn=True): - model_graph = capture_pre_autograd_graph( + model_graph = export_for_training( model_graph, example_inputs, - ) + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=False, is_dynamic=False @@ -995,10 +995,10 @@ def test_resnet18(self): m = torchvision.models.resnet18().eval() m_copy = copy.deepcopy(m) # program capture - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) From 764416520a8197c4ea160a65c5723b66aa0f44e4 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Wed, 4 Sep 2024 13:44:38 +0000 Subject: [PATCH 0467/1018] Implement VariableTracker.python_type() (#134215) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134215 Approved by: https://github.com/amjames, https://github.com/jansel --- torch/_dynamo/variables/base.py | 5 ++++- torch/_dynamo/variables/builtin.py | 3 --- torch/_dynamo/variables/constant.py | 6 ------ torch/_dynamo/variables/functions.py | 3 --- torch/_dynamo/variables/iter.py | 3 --- torch/_dynamo/variables/misc.py | 9 --------- torch/_dynamo/variables/tensor.py | 3 --- torch/_dynamo/variables/torch.py | 3 --- torch/_dynamo/variables/user_defined.py | 3 --- 9 files changed, 4 insertions(+), 34 deletions(-) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index e7a3b7320e5c62..25b4ffad99e7b3 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -207,7 +207,10 @@ def python_type(self): Raises: NotImplementedError: If the method is not implemented in a subclass. """ - raise NotImplementedError(f"{self} has no type") + try: + return type(self.as_python_constant()) + except NotImplementedError: + raise NotImplementedError(f"{self} has no type") from None def as_python_constant(self): """For constants""" diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index d330a4944045d0..6ffd820c3f1320 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -637,9 +637,6 @@ def __str__(self) -> str: return f"{self.__class__.__name__}({name})" - def python_type(self): - return type(self.fn) - def as_python_constant(self): return self.fn diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 8f62a3585bcecf..0028ecd81decda 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -85,9 +85,6 @@ def as_proxy(self): def __str__(self) -> str: return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -222,9 +219,6 @@ def as_proxy(self): def __str__(self) -> str: return f"EnumVariable({type(self.value)})" - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 3f881693c8be62..90c2d0665aebc2 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -626,9 +626,6 @@ def __init__(self, value, reason=None, **kwargs) -> None: self.value = value self.reason = reason - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 34c2354026a7bd..192f25a0c6b677 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -30,9 +30,6 @@ def __init__(self, value, **kwargs) -> None: def __repr__(self) -> str: return f"ItertoolsVariable({self.value})" - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 527474d95969c0..2b62c0c1acc116 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1192,9 +1192,6 @@ def call_method( ) unimplemented("typing") - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -1312,9 +1309,6 @@ def call_method( ) -> "VariableTracker": unimplemented("numpy") - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -1489,9 +1483,6 @@ def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 5d1a8a6bda799b..088313bdbf551d 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1328,9 +1328,6 @@ def call_function( def as_python_constant(self): return self.value - def python_type(self): - return type(self.value) - class UntypedStorageVariable(VariableTracker): _nonvar_fields = { diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 30dcdb84afffc0..f4d4eb900c1735 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -173,9 +173,6 @@ def reconstruct(self, codegen): def as_proxy(self): return self.value - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 8806627e71eb8d..54dc9119d267b4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -115,9 +115,6 @@ def __init__(self, value, **kwargs) -> None: def as_python_constant(self): return self.value - def python_type(self): - return type(self.value) - def as_proxy(self): return self.value From 7e214b5ed6715a4a2f9264da2efaaa1cb36ab10d Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 5 Sep 2024 16:40:12 +0000 Subject: [PATCH 0468/1018] [inductor][test] in test_unbacked_symints, replace inductor's skipCUDAIf with common device type's skipcudaif (#133936) Differential Revision: D61506212 Use `skipCUDAIf` from `torch.testing._internal.common_device_type` if we create the test class with `instantiate_device_type_tests`. `instantiate_device_type_tests` would make sure the class has attr device_type, which works with`skipCUDAIf` from `torch.testing._internal.common_device_type`. Also skipping test_vertical_pointwise_reduction_fusion for cpu test class, since the test expects cuda. FAILED [0.0026s] test/inductor/test_unbacked_symints.py::TestUnbackedSymintsCPU::test_vertical_pointwise_reduction_fusion_cpu - AttributeError: 'TestUnbackedSymintsCPU' object has no attribute 'device' repro: ``` CUDA_VISIBLE_DEVICES="" pytest test/inductor/test_unbacked_symints.py -k cpu -v ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133936 Approved by: https://github.com/ColinPeppler, https://github.com/desertfire --- test/inductor/test_torchinductor_dynamic_shapes.py | 3 ++- test/inductor/test_unbacked_symints.py | 10 ++++++++-- torch/testing/_internal/inductor_utils.py | 8 ++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 476419779fbb02..34c71f44332572 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -27,6 +27,7 @@ ) from torch.testing._internal.common_utils import ( IS_ARM64, + IS_FBCODE, parametrize, TEST_CUDA_MEM_LEAK_CHECK, TEST_WITH_ASAN, @@ -362,7 +363,7 @@ def f(x, r): f(torch.tensor([3], device=device), torch.randn(10, device=device)) - @unittest.expectedFailure + @unittest.skipUnless(IS_FBCODE, "") @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 14668d51b1137e..0ef2e6131166c8 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -9,9 +9,12 @@ from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import is_big_gpu from torch.testing import make_tensor -from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + skipCUDAIf, +) from torch.testing._internal.common_utils import IS_LINUX, parametrize -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, skipCUDAIf +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA class TestUnbackedSymints(InductorTestCase): @@ -194,6 +197,9 @@ def test_vertical_pointwise_reduction_fusion(self, device): # reset in case we run both cpu and cuda tests torch._inductor.metrics.reset() + if device == "cpu": + raise unittest.SkipTest("This test requires cuda") + # Tests fusing a pointwise & reduction op with unbacked numel/rnumel. def fn(x, y, repeats): u0 = repeats.item() diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 7e952e929feefa..00f9ba0dd2ee29 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -1,5 +1,6 @@ # mypy: ignore-errors +import logging import torch import re import unittest @@ -22,6 +23,8 @@ IS_WINDOWS, ) +log: logging.Logger = logging.getLogger(__name__) + def test_cpu(): try: CppCodeCache.load("") @@ -72,6 +75,11 @@ def skipDeviceIf(cond, msg, *, device): if cond: def decorate_fn(fn): def inner(self, *args, **kwargs): + if not hasattr(self, "device"): + warn_msg = "Expect the test class to have attribute device but not found. " + if hasattr(self, "device_type"): + warn_msg += "Consider using the skip device decorators in common_device_type.py" + log.warning(warn_msg) if self.device == device: raise unittest.SkipTest(msg) return fn(self, *args, **kwargs) From 39d8a20cb4b3e2a17996f3183b0053187eaa0174 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 5 Sep 2024 16:44:30 +0000 Subject: [PATCH 0469/1018] Gradient scaler for DTensor (#132816) Solve the request [here](https://github.com/pytorch/pytorch/issues/120003#issuecomment-2248805798). Enable DTensor input in gradient scaler's APIs, especially on `.unscale_()` Related dispatch strategy is added to accept DTensor input. To enable found_inf to conduct reduce action across devices, we add allreduce at dispatch with args after dispatch strategy and kernel. Since `aten._amp_foreach_non_finite_check_and_unscale_.default` is an inplace_op, grad_scale as the arg[0] with be inplaced, so that redesign a strategy or refactoring the kernel would not help Test files are testing 2 parts under 1-d(dp) and 2-d(dp,tp) cases: 1. whether the non-inf values unscaled 2. whether all DTensors at each device could found inf even not at their device. 3. If inf not found, will new parameters generates 4. if inf found, will scale be updated Pull Request resolved: https://github.com/pytorch/pytorch/pull/132816 Approved by: https://github.com/XilunWu, https://github.com/weifengpy, https://github.com/wanchaol --- .../fsdp/test_fully_shard_grad_scaler.py | 111 ++++++++++++++++++ torch/distributed/_tensor/_dispatch.py | 49 +++++++- .../distributed/_tensor/ops/_pointwise_ops.py | 1 + 3 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py diff --git a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py new file mode 100644 index 00000000000000..1d87e1b267d0d5 --- /dev/null +++ b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py @@ -0,0 +1,111 @@ +# Owner(s): ["oncall: distributed"] +import copy + +import torch +import torch.nn as nn +from torch.amp.grad_scaler import GradScaler, OptState +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor import init_device_mesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest, MLP +from torch.testing._internal.common_utils import run_tests, skipIfRocm + + +class TestFullyShardGradientScaler(FSDPTest): + @skip_if_lt_x_gpu(4) + @skipIfRocm + def test_gradient_scaler(self): + self.run_subtests( + {"has_inf": [True, False], "test_2d": [True, False]}, + self._test_gradient_scaler, + ) + + def _test_gradient_scaler(self, has_inf: bool, test_2d: bool): + torch.manual_seed(0) + model = nn.Sequential( + *[nn.Linear(4, 4, device="cuda", bias=False) for _ in range(2)] + ) + for layer in model: + fully_shard(layer) + fully_shard(model) + input = torch.randn([4, 4], device="cuda") + + if test_2d: + mesh_2d = init_device_mesh( + "cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp") + ) + dp_mesh, tp_mesh = mesh_2d["dp"], mesh_2d["tp"] + model = nn.Sequential(MLP(2), MLP(2), MLP(2)) + tp_parallelize_plan = { + "0.in_proj": ColwiseParallel(), + "0.out_proj": RowwiseParallel(), + "1.in_proj": ColwiseParallel(), + "1.out_proj": RowwiseParallel(), + "2.in_proj": ColwiseParallel(), + "2.out_proj": RowwiseParallel(), + } + model = parallelize_module( + model, + device_mesh=tp_mesh, + parallelize_plan=tp_parallelize_plan, + ) + for module in model: + fully_shard(module, mesh=dp_mesh) + fully_shard(model, mesh=dp_mesh) + input = torch.randn((2,), device="cuda") + + loss = model(input).sum() + scaler = GradScaler(init_scale=2.0, enabled=True) + opt = torch.optim.Adam(model.parameters(), lr=1e-2) + scaler.scale(loss).backward() + inv_scale = scaler._scale.double().reciprocal().float() + if ( + has_inf is True + and opt.param_groups[0]["params"][0].grad._local_tensor.device.index == 1 + ): + opt.param_groups[0]["params"][0].grad._local_tensor[0, 0].fill_( + float("inf") + ) + inital_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() + + scaler.unscale_(opt) + for found_inf in scaler._per_optimizer_states[id(opt)][ + "found_inf_per_device" + ].values(): + self.assertEqual(found_inf, has_inf) + self.assertEqual( + scaler._per_optimizer_states[id(opt)]["stage"].value, + OptState.UNSCALED.value, + ) + unscaled_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() + self.assertEqual(unscaled_grad, inital_grad * inv_scale) + initial_scale = scaler.get_scale() + initial_state = copy.copy(opt.state) + + scaler.step(opt) + steped_state = copy.copy(opt.state) + if has_inf: + # assert parameters are the same before/after + self.assertEqual(steped_state, initial_state) + else: + # new parameters here if no inf found during .unscale_() + self.assertNotEqual(steped_state.items(), initial_state.items()) + + scaler.update() + updated_scale = scaler.get_scale() + if has_inf: + # assert scale is updated + backoff_factor = scaler.get_backoff_factor() + self.assertEqual(updated_scale, initial_scale * backoff_factor) + else: + # scale is not updated + self.assertEqual(updated_scale, initial_scale) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py index c811fc457d6f76..9faae22cd37353 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/_tensor/_dispatch.py @@ -24,7 +24,13 @@ convolution_handler, ) from torch.distributed._tensor._utils import try_find_mesh_from_args -from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Placement, + Replicate, + TensorMeta, +) from torch.distributed._tensor.random import is_rng_supported_mesh @@ -66,6 +72,46 @@ def is_same_size_handler( return lhs.shape == rhs.shape +def found_inf_reduce_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> None: + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + local_tensor_args = pytree.tree_unflatten( + cast(List[object], op_info.local_args), op_info.args_tree_spec + ) + local_tensor_args = cast(Tuple[object, ...], local_tensor_args) + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + + grad_dtensor = cast(list[dtensor.DTensor], args[0])[0] + grad_placements = grad_dtensor.placements + mesh = grad_dtensor.device_mesh + + found_inf_placements: list[Placement] = [] + for placement in grad_placements: + if isinstance(placement, Replicate): + found_inf_placements.append(placement) + else: + found_inf_placements.append(Partial("max")) + + target_tensor = cast(torch.Tensor, args[1]) + spec = DTensorSpec( + mesh=mesh, + placements=tuple(found_inf_placements), + tensor_meta=TensorMeta( + shape=target_tensor.size(), + stride=target_tensor.stride(), + dtype=target_tensor.dtype, + ), + ) + found_inf_dtensor = dtensor.DTensor( + local_tensor=target_tensor, spec=spec, requires_grad=False + ) + found_inf = found_inf_dtensor.full_tensor() + target_tensor.copy_(found_inf) + + class OpDispatcher: """ Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding @@ -99,6 +145,7 @@ def __init__(self) -> None: aten.is_same_size.default: is_same_size_handler, aten.convolution.default: convolution_handler, aten.convolution_backward.default: convolution_backward_handler, + aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler, } # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) diff --git a/torch/distributed/_tensor/ops/_pointwise_ops.py b/torch/distributed/_tensor/ops/_pointwise_ops.py index 73062441025fae..229029e863b16a 100644 --- a/torch/distributed/_tensor/ops/_pointwise_ops.py +++ b/torch/distributed/_tensor/ops/_pointwise_ops.py @@ -585,6 +585,7 @@ def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> Strategy aten._foreach_cos_.default, aten._foreach_log.default, aten._foreach_log_.default, + aten._amp_foreach_non_finite_check_and_unscale_.default, ] for_each_linearity_ops = [ From d4770700bd5ad852ee9c14cb1ce4ffcfb56528ab Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 4 Sep 2024 18:05:23 -0700 Subject: [PATCH 0470/1018] Render log filepaths that are not anchored in torch's directory in a reasonable way (#135165) For example, if I do TORCH_LOGS=fbscribelogger I'll get: ``` I0904 17:59:07.567000 3672513 fbscribelogger/__init__.py:161] stop ``` instead of ``` I0904 12:46:15.332000 2930287 ../../../../../home/ezyang/local/a/pytorch-env/lib/python3.10/site-packages/fbscribelogger/__init__.py:161] stop ``` Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135165 Approved by: https://github.com/Skylion007 --- torch/_logging/_internal.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index f079eb86804282..f11af08e3b768d 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -6,7 +6,9 @@ import logging import os import os.path +import pathlib import re +import sys import tempfile from dataclasses import dataclass, field from importlib import __import__ @@ -747,6 +749,26 @@ def _has_registered_parent(log_qname): return False +def make_module_path_relative(abs_path): + """ + Given an absolute filepath corresponding to a Python module which was + loaded via normal import mechanisms using sys.path, convert it into + a relative path relative to one of the Python search paths. + """ + + abs_path = pathlib.Path(abs_path).resolve() + + for path in sys.path: + try: + rel_path = abs_path.relative_to(path) + except ValueError: + continue + else: + return str(rel_path) + + return str(abs_path) + + # apply custom formats to artifacts when necessary class TorchLogsFormatter(logging.Formatter): def __init__(self, *, trace: bool = False): @@ -807,9 +829,11 @@ def format(self, record): if artifact_name is not None: record.artifactprefix = f" [__{artifact_name}]" + filepath = make_module_path_relative(record.pathname) + prefix = ( f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.process} " - f"{os.path.relpath(record.pathname, os.path.dirname(os.path.dirname(torch.__file__)))}:" + f"{filepath}:" f"{record.lineno}]{record.traceid}{record.artifactprefix}" ) if self._is_trace: From cc0cd82b4a90a815e2668526e2fbd6af46508636 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 5 Sep 2024 06:50:12 -0700 Subject: [PATCH 0471/1018] Add torch.serialization.skip_data context manager (#134504) ## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504 Approved by: https://github.com/albanD --- docs/source/notes/serialization.rst | 1 + ...cpp_extensions_open_device_registration.py | 12 +- test/test_serialization.py | 88 ++++++++++++++ torch/_tensor.py | 75 ++++++++++-- torch/_utils.py | 6 +- torch/serialization.py | 110 ++++++++++++++++-- torch/storage.py | 4 + 7 files changed, 270 insertions(+), 26 deletions(-) diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 5541d28bdcafa0..c05dc028a471c9 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -398,3 +398,4 @@ The following utility functions are related to serialization: .. autofunction:: clear_safe_globals .. autofunction:: get_safe_globals .. autoclass:: safe_globals +.. autoclass:: skip_data diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 616a6e0f4b5518..4e86ed458b0788 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - def test_open_device_numpy_serialization_map_location(self): + def test_open_device_numpy_serialization(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL @@ -553,6 +553,7 @@ def test_open_device_numpy_serialization_map_location(self): self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) + # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( @@ -569,6 +570,15 @@ def test_open_device_numpy_serialization_map_location(self): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) + # Test metadata_only + with TemporaryFileName() as f: + with self.assertRaisesRegex( + RuntimeError, + "Cannot serialize tensors on backends with no storage under skip_data context manager", + ): + with torch.serialization.skip_data(): + torch.save(sd, f) + if __name__ == "__main__": common.run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index a041473d195b19..3ba96b80541d88 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,5 +1,6 @@ # Owner(s): ["module: serialization"] +import contextlib import copy import gc import gzip @@ -19,6 +20,7 @@ from pathlib import Path import torch +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter from torch._utils import _rebuild_tensor from torch._utils_internal import get_file_path_2 from torch.serialization import ( @@ -27,6 +29,7 @@ LoadEndianness, safe_globals, set_default_load_endianness, + skip_data, SourceChangeWarning, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -4212,6 +4215,91 @@ def test_filewriter_metadata_writing(self, filename): sd_loaded_ref = torch.load(f) self.assertEqual(sd_loaded, sd_loaded_ref) + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization(self, materialize_fake): + # Create one tensor that uses each of the paths in __reduce_ex__ that should work + t_device = "cuda" if torch.cuda.is_available() else "cpu" + t_v2 = torch.randn(2, 3, device=t_device) + t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device) + i = torch.tensor([[0, 1, 1], + [2, 0, 2]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + if not materialize_fake: + # FakeTensorConverter messes up sizes of i and v for the sparse tensor + st = torch.sparse_coo_tensor(i, v, (2, 4)) + tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device)) + + mode, converter = FakeTensorMode(), FakeTensorConverter() + + def fn(t): + return converter.from_real_tensor(mode, t) if materialize_fake else t + + sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)} + sd_expected = { + 't_v2': torch.zeros(2, 3, device=t_device), + 't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device), + 'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)), + } + + if not materialize_fake: + sd['st'] = st + sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4)) + + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + with safe_globals([TwoTensor]): + sd_loaded = torch.load(f, weights_only=True) + self.assertEqual(sd_loaded, sd_expected, exact_device=True) + self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False)) + self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False)) + + # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx + if not materialize_fake: + ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) + with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__..remove'"): + with skip_data(), BytesIOContext() as f: + torch.save(ft, f) + + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization_preserves_views(self, materialize_fake): + ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext + with ctx(): + t = torch.randn(2, 3) + t_view = t.view(-1) + t_slice = t[1] + sd = {'t': t, 't_view': t_view, 't_slice': t_slice} + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + sd_loaded = torch.load(f, weights_only=True) + self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + + def test_skip_data_serialization_error_cases(self): + def _save_load(t): + with BytesIOContext() as f: + with skip_data(): + torch.save(t, f) + f.seek(0) + torch.load(f, weights_only=True) + + nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)]) + t = torch.randn(2, 3, device="meta") + with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"): + _save_load(nt) + + with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"): + _save_load(t) + + with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"): + with skip_data(), BytesIOContext() as f: + torch.save(torch.randn(2, 3), f) + f.seek(0) + torch.load(f, weights_only=True) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 98563aebae9aa7..7d8010081abc77 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -209,8 +209,19 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) state = torch._utils._get_obj_state(self) - if type(self) is Tensor and not state: + # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has + # some state that cannot be pickled + if ( + # TODO: remove hasattr, it's a hack to support versions of torch that + # don't have _subclasses + hasattr(torch, "_subclasses") + and type(self) is torch._subclasses.fake_tensor.FakeTensor + and materialize_fake_tensors + ) or (type(self) is Tensor and not state): # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): @@ -251,6 +262,12 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() + + skip_data = torch.serialization._serialization_tls.skip_data + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -268,6 +285,10 @@ def _reduce_ex_internal(self, proto): # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) numpy_tensor = ( self.cpu().numpy() if self.dtype != torch.bfloat16 @@ -280,6 +301,10 @@ def _reduce_ex_internal(self, proto): if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. + if skip_data: + warnings.warn( + "Serializing tensors on the meta device under skip_data context manager is a no-op" + ) arg_meta = ( self.dtype, tuple(self.size()), @@ -288,6 +313,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: + if skip_data: + raise RuntimeError( + "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" + ) # quantizer_params can be different type based on torch attribute quantizer_params: Union[ Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] @@ -369,6 +398,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif self.is_nested: + if skip_data: + raise RuntimeError( + "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" + ) args_nested = ( # NB: values() currently returns the storage as a buffer in an unsafe way. # Ideally, we'd use a private API for this instead. TODO: Switch to this if @@ -383,14 +416,30 @@ def _reduce_ex_internal(self, proto): type(self) is not torch.Tensor and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ and ( - isinstance( - self, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), + isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) + or ( + not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and self.data_ptr() == 0 ) - or self.data_ptr() == 0 + ) + ): + arg_wrapper_subclass = ( + type(self), + self.dtype, + tuple(self.size()), + self.stride(), + self.storage_offset(), + self.layout, + self.device, + self.requires_grad, + ) + return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) + elif ( + type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + and ( + isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and not (skip_data and materialize_fake_tensors) ) ): arg_wrapper_subclass = ( @@ -418,6 +467,16 @@ def _reduce_ex_internal(self, proto): dtype=self.dtype, _internal=True, ) # type: ignore[assignment] + + # TODO: remove hasattr, it's a hack to support versions of torch that + # don't have _subclasses + if ( + hasattr(torch, "_subclasses") + and isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and skip_data + ): + storage._fake_device = self.device + args = ( storage, self.storage_offset(), diff --git a/torch/_utils.py b/torch/_utils.py index 938392fa971590..f0d38daa811490 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -3,7 +3,6 @@ import functools import logging import sys -import threading import traceback import warnings from collections import defaultdict @@ -109,16 +108,13 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] -_thread_local_state = threading.local() - - def _get_restore_location(device): """Return the map_location location. Used for rebuild functions where the tensor device is distinct from the storage """ - map_location = getattr(_thread_local_state, "map_location", None) + map_location = torch.serialization._serialization_tls.map_location if map_location is None: return device else: diff --git a/torch/serialization.py b/torch/serialization.py index 2ac36ec371fe57..d936d31d6f5209 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,6 +11,7 @@ import sys import tarfile import tempfile +import threading import warnings from contextlib import closing, contextmanager from enum import Enum @@ -60,6 +61,7 @@ "get_safe_globals", "add_safe_globals", "safe_globals", + "skip_data", ] @@ -87,6 +89,22 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] +# _serialization_tls is used to store thread local state specific to serialization +# that needs to be propagated to other files, in particular we use this for +# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) +# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +class _SerializationLocal(threading.local): + def __init__(self): + super().__init__() + self.map_location: Optional[MAP_LOCATION] = None + self.skip_data: bool = False + self.materialize_fake_tensors: bool = False + + +_serialization_tls = _SerializationLocal() + + class SourceChangeWarning(Warning): pass @@ -268,6 +286,47 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ +class skip_data: + """ + Context-manager that skips writing storage bytes for ``torch.save`` calls. + + Storages will still be saved, but the space that their bytes would usually be written to + will be empty space. The storage bytes can then be populated in a separate pass. + + .. warning:: + The ``skip_data`` context manager is an early prototype and is subject to change. + + Args: + materialize_fake_tensors: Whether to materialize FakeTensors. + + Example: + >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows") + >>> import tempfile + >>> t = torch.randn(2, 3) + >>> with tempfile.NamedTemporaryFile() as f: + ... with torch.serialization.skip_data(): + ... torch.save(t, f.name) + ... torch.load(f.name, weights_only=True) + tensor([[0., 0., 0.], + [0., 0., 0.]]) + """ + + def __init__(self, materialize_fake_tensors: bool = False): + self.materialize_fake_tensors = materialize_fake_tensors + + def __enter__(self): + global _serialization_tls + self._old_skip_data = _serialization_tls.skip_data + self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors + _serialization_tls.skip_data = True + _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors + + def __exit__(self, type, value, tb): + global _serialization_tls + _serialization_tls.skip_data = self._old_skip_data + _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors + + def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -797,6 +856,11 @@ def save( ) return else: + global _serialization_tls + if _serialization_tls.skip_data: + raise RuntimeError( + "Cannot use skip_data=True with _use_new_zipfile_serialization=False" + ) with _open_file_like(f, "wb") as opened_file: _legacy_save(obj, opened_file, pickle_module, pickle_protocol) @@ -955,7 +1019,13 @@ def persistent_id(obj: Any) -> Optional[Tuple]: ) -def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): serialized_storages = {} id_map: Dict[int, str] = {} @@ -990,7 +1060,7 @@ def persistent_id(obj): # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check - if storage.data_ptr() != 0: + if str(storage.device) != "meta" and storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( @@ -1001,7 +1071,10 @@ def persistent_id(obj): storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) - location = location_tag(storage) + if hasattr(obj, "_fake_device") and obj._fake_device is not None: + location = str(obj._fake_device) + else: + location = location_tag(storage) serialized_storages[storage_key] = storage return ("storage", storage_type, storage_key, location, storage_numel) @@ -1027,14 +1100,18 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f"data/{key}" storage = serialized_storages[key] - # given that we copy things around anyway, we might use storage.cpu() - # this means to that to get tensors serialized, you need to implement - # .cpu() on the underlying Storage - if storage.device.type != "cpu": - storage = storage.cpu() - # Now that it is on the CPU we can directly copy it into the zip file num_bytes = storage.nbytes() - zip_file.write_record(name, storage, num_bytes) + global _serialization_tls + if _serialization_tls.skip_data: + zip_file.write_record_metadata(name, num_bytes) + else: + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + zip_file.write_record(name, storage, num_bytes) def load( @@ -1184,6 +1261,14 @@ def _get_wo_message(message: str) -> str: updated_message += message return updated_message + DOCS_MESSAGE + global _serialization_tls + skip_data = _serialization_tls.skip_data + if skip_data: + raise RuntimeError( + "`torch.load` called within a torch.serialization.skip_data context manager " + "is not supported yet. Please call torch.load outside the skip_data context manager." + ) + if weights_only is None: weights_only, warn_weights_only = False, True else: @@ -1758,9 +1843,10 @@ def find_class(self, mod_name, name): unpickler.persistent_load = persistent_load # Needed for tensors where storage device and rebuild tensor device are # not connected (wrapper subclasses and tensors rebuilt using numpy) - torch._utils._thread_local_state.map_location = map_location + global _serialization_tls + _serialization_tls.map_location = map_location result = unpickler.load() - del torch._utils._thread_local_state.map_location + _serialization_tls.map_location = None torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata( diff --git a/torch/storage.py b/torch/storage.py index b6ba608c16e5ca..8848649905f936 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,6 +39,8 @@ class _StorageBase: is_sparse: _bool = False is_sparse_csr: _bool = False device: torch.device + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None def __init__(self, *args, **kwargs): pass @@ -649,6 +651,8 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None dtype: torch.dtype From b4cc9b18e4b9afd0746213b9fa6102c3a3f05c56 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 5 Sep 2024 06:45:54 -0700 Subject: [PATCH 0472/1018] [fake_tensor] Move unrecognized_type NotImplemented before ConstProp (#135033) We should not try to do ConstProp on the unrecognized types (e.g. Subclasses). In case of those types throwing NotImplemented will jump to the next torch_dispatch. Test: ``` python test/functorch/test_aotdispatch.py -k test_aot_test_subclasses_with_tensor_factories ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135033 Approved by: https://github.com/zou3519, https://github.com/bdhirsh --- torch/_subclasses/fake_tensor.py | 35 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 51f726f53bacc3..27348dca80e483 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1715,11 +1715,28 @@ def _dispatch_impl( ) -> Optional[FakeTensor]: flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # DO NOT PUT LOGIC BEFORE UNRECOGNIZED TYPE CHECKING + # We must throw NotImplemented in case of unrecognized types to handle subclasses. + # Throwing the exception will pass the control to the next __torch_dispatch__. + # See [subclass inputs] below + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)] has_symbolic_sizes = any( i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors ) or any(isinstance(a, SymInt) for a in flat_args) - has_subclasses = any(is_traceable_wrapper_subclass(a) for a in flat_args) converter = self.fake_tensor_converter @@ -1737,7 +1754,6 @@ def _dispatch_impl( should_allow_numbers_as_tensors(func) and not has_symbolic_sizes and not flat_arg_fake_tensors - and not has_subclasses ): assert all( t.constant is not None for t in flat_arg_fake_tensors @@ -1757,21 +1773,6 @@ def _dispatch_impl( out = out.clone() return converter.from_real_tensor(self, out, make_constant=True) - # See [subclass inputs] below - # NB: If you're seeing a mysterious infinite loop involving fake - # tensor, it might be related to this line. Though I'm not sure - # how you'll know to read this comment, as this line won't show up - # in the stack trace. - has_unrecognized_types = _check_for_subclass(flat_args) - if has_unrecognized_types: - unrecognized_types = [ - type(x) for x in flat_args if _check_for_subclass_arg(x) - ] - not_implemented_log.debug( - "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types - ) - return NotImplemented - # if we are in the dispatch mode, we will enter this function even if the inputs # are not FakeTensors. For now, throw if any non-Fake Tensor inputs # and just support constructors. From 100f23d11ec6f9bbb55887af2dfcdefdb82f8626 Mon Sep 17 00:00:00 2001 From: CaoE Date: Thu, 5 Sep 2024 06:54:17 -0700 Subject: [PATCH 0473/1018] [Inductor][CPP] Leverage full bits for BF16/FP16 vectorization (#126502) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126502 Approved by: https://github.com/jgong5, https://github.com/jansel --- aten/src/ATen/cpu/vec/vec256/vec256_float.h | 15 ++++++ aten/src/ATen/cpu/vec/vec512/vec512_float.h | 36 +++++++++++-- test/inductor/test_cpu_repro.py | 56 +++++++++++++++++++ torch/_inductor/codegen/cpp.py | 60 ++++++++++++++------- 4 files changed, 144 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h index 0d0fe99252a7d0..dab1790b26ab01 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h @@ -636,6 +636,21 @@ inline void transpose_mxn( _mm256_storeu_ps(&dst[7 * ld_dst], th); } +template<> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + transpose_mxn( + src , ld_src, dst, ld_dst); + transpose_mxn( + src + 8, ld_src, dst + 8 * ld_dst, ld_dst); + transpose_mxn( + src + 8 * ld_src, ld_src, dst + 8, ld_dst); + transpose_mxn( + src + 8 * ld_src + 8, ld_src, dst + 8 * ld_dst + 8, ld_dst); +} #endif }} // namespace at::vec::CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h index 804e9f404297ab..4e21eae91cb240 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h @@ -582,8 +582,7 @@ Vectorized inline fmsub(const Vectorized& a, const Vectorized -inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) { +inline void transpose_mxn_16x16(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) { TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn expects M, N <= 16."); // load from src to registers __m512 input[16]; @@ -667,8 +666,39 @@ inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, i } } +template<> +inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) { + int64_t i = 0; + for (; i < M / 16 * 16; i += 16) { + int64_t j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, 16); + } + // handle remainder j + int nrem = N - j; + if (nrem > 0) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem); + } + } + // handle remainder i + int mrem = M - i; + if (mrem > 0) { + int j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16); + } + // handle remainder j + int nrem = N - j; + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); + } +} + template ::value && M <= 16 && N <= 16, int> = 0> + typename std::enable_if_t::value, int> = 0> inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { transpose_mxn(src, ld_src, dst, ld_dst, M, N); } diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index d1f5a82cdf6c74..100b5fa35361e1 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4056,6 +4056,62 @@ def fn(x1, x2, x3): n_veckernel = 6 if op is torch.masked.mean else 3 check_metrics_vec_kernel_count(n_veckernel) + @requires_vectorization + def test_full_bits_lowp(self): + def check_use_full_bits(func, shapes, dtype, mixed, check_vecn): + example_inputs = [torch.randn(shape, dtype=dtype) for shape in shapes] + if mixed: + example_inputs[0] = example_inputs[0].to( + dtype=torch.half if dtype == torch.bfloat16 else torch.bfloat16 + ) + f_opt = torch.compile()(func) + _, code = run_and_get_cpp_code(f_opt, *example_inputs) + if check_vecn: + self.assertTrue( + "at::vec::VectorizedN" in code or "at::vec::convert" ) - code.writeline(f"at::vec::Vectorized {exponent};") + code.writeline( + f"at::vec::Vectorized {exponent};" + if n_vec == 1 + else f"at::vec::VectorizedN {exponent};" + ) code.writeline(f"{mantissa_t} {mantissa};") code.writeline("[&]()") with code.indent(): @@ -1603,6 +1608,8 @@ def frexp(x): ) code.writeline( f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + if n_vec == 1 + else f"{exponent} = at::vec::VectorizedN::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" ) code.writeline( f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" @@ -2667,9 +2674,9 @@ def reduction(self, dtype, src_dtype, reduction_type, value): ) is_bool = dtype == torch.bool # we are using at::vec::VecMask for bool - vec_dtype = "float" if is_bool else DTYPE_TO_CPP[dtype] - vec = f"at::vec::Vectorized<{vec_dtype}>" - vec_reduce_all_func = f"at::vec::vec_reduce_all<{vec_dtype}>" + vec_dtype = torch.float if is_bool else dtype + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" self.reduction_suffix.writeline( @@ -2697,6 +2704,8 @@ def store_reduction(self, name, index, value): if out_dtype.is_floating_point else torch.int64 ) + out_num_vectors = V.kernel._get_num_vectors(out_dtype) + src_num_vectors = V.kernel._get_num_vectors(dtype) code = IndentedBuffer() if self.tiling_idx >= self.reduction_depth: # Horizontal reduction @@ -2710,7 +2719,15 @@ def store_reduction(self, name, index, value): if out_dtype == torch.bool: convert = f"{value}.template cast()" else: - convert = f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + if src_num_vectors == out_num_vectors == 1: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + ) + else: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}," + f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})" + ) code.writeline(f"auto {converted_value} = {convert};") value = converted_value code.splice(self._get_store_line(value, var, index, out_dtype)) @@ -3167,15 +3184,16 @@ def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: ) -def get_loop_body_lowp_fp(_body: ir.LoopBody) -> Optional[torch.dtype]: +def get_loop_body_lowp_fp(_body: ir.LoopBody) -> Tuple[Optional[torch.dtype], bool]: """ - Returns the low precision float data type (torch.float16/torch.bfloat16) if all the - nodes can codegen with this data type. Otherwise returns None. + Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes + and if all the nodes can codegen with this data type without converting to float. + Otherwise returns None and True. """ sub_blocks = [_body.root_block] + list(_body.subblocks.values()) _lowp_fp_type: Optional[torch.dtype] = None - + _use_fp32 = False for sub_block in sub_blocks: for _node in sub_block.graph.nodes: if _node.op == "placeholder" or _node.target in ( @@ -3192,23 +3210,22 @@ def get_loop_body_lowp_fp(_body: ir.LoopBody) -> Optional[torch.dtype]: "neg", "output", ]: - return None + _use_fp32 = True if hasattr(_node, "meta") and _node.meta: assert OptimizationContext.key in _node.meta opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: - return None - if _lowp_fp_type: - assert ( - _lowp_fp_type == opt_ctx.dtype - ), "do not support bf16/fp16 mix" + _use_fp32 = True + elif _lowp_fp_type is not None: + if _lowp_fp_type != opt_ctx.dtype: + warnings.warn("bf16 and fp16 are mixed in the scheduler node.") else: _lowp_fp_type = opt_ctx.dtype else: - return None + _use_fp32 = True - return _lowp_fp_type + return _lowp_fp_type, _use_fp32 class TilingSelect: @@ -3232,9 +3249,9 @@ def select_tiling( if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): return [], [] dtype = torch.float - _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0]) + _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0] if _lowp_fp_dtype and all( - (get_loop_body_lowp_fp(loop_body) == _lowp_fp_dtype) + (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype) for loop_body in loop_bodies[1:] ): dtype = _lowp_fp_dtype @@ -3433,7 +3450,10 @@ def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): return True # Propagate the dtype to check if all the fx node is bf16/fp16 DataTypePropagation.propagate_scheduler_node(scheduler_node) - return get_loop_body_lowp_fp(scheduler_node._body) is not None + return ( + get_loop_body_lowp_fp(scheduler_node._body)[0] is not None + and not get_loop_body_lowp_fp(scheduler_node._body)[1] + ) def legalize_lowp_fp_dtype_loopbody(self, loop_body: ir.LoopBody): def add_to_dtype(sub_graph: torch.fx.Graph): From 6e787dfeb9532963001b3171d8a2aeb782f045ca Mon Sep 17 00:00:00 2001 From: Xintong Hu Date: Thu, 5 Sep 2024 18:17:06 +0000 Subject: [PATCH 0474/1018] [PT2] Directly set meta.val in group_batch_fusion_aten (#135078) Summary: instead of using FakeTensorProp after the pass Differential Revision: D62162640 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135078 Approved by: https://github.com/frank-wei --- torch/_inductor/fx_passes/pre_grad.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index e9cc56b2b0f6b4..eef058ab0b5aa2 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -12,7 +12,6 @@ matches_module_pattern, replace_node_module, ) -from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F @@ -169,10 +168,6 @@ def shape_prop(mod) -> None: example_inputs, "[Pre grad(predispatch IR)] Apply group_batch_fusion", ) - # update node.meta after group batch fusion - FakeTensorProp(module=gm, mode=detect_fake_mode(example_inputs)).propagate( - *example_inputs - ) pass_execution_and_save( normalize_node_kwargs_pass, gm, From f7c7342015c9d358358b3c0362e0c3c4fd10acf3 Mon Sep 17 00:00:00 2001 From: Zhonglin Han Date: Thu, 5 Sep 2024 18:53:13 +0000 Subject: [PATCH 0475/1018] Update Resize.cpp with new device type (#135117) Update Resize.cpp with new device type Pull Request resolved: https://github.com/pytorch/pytorch/pull/135117 Approved by: https://github.com/egienvalue --- aten/src/ATen/native/Resize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index e4d63de7f4d6d8..f64b9cc3cefa34 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -284,7 +284,7 @@ void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& newsize) { } else if (device_type == at::kPrivateUse1) { at::detail::getPrivateUse1Hooks().resizePrivateUse1Bytes( storage, newsize.expect_int()); - } else if (device_type == at::kXPU || device_type == at::kHPU) { + } else if (device_type == at::kXPU || device_type == at::kHPU || device_type == at::kMTIA) { ptrdiff_t size_bytes_i = newsize.expect_int(); TORCH_CHECK( !c10::overflows(size_bytes_i), From e8e8fe358a50d5ec0990e6e68fb2bbf0c03bb6ee Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 4 Sep 2024 14:44:53 -0700 Subject: [PATCH 0476/1018] Allow cross-device copies for cpu scalars in refs (#135140) This copies our eager-mode behavior where someone can do torch.add(a, b, out=c) where a and b are CPU scalar tensors and c is a CUDA tensor. Fixes https://github.com/pytorch/pytorch/issues/121619 by side effect (we get into a situation where we're writing a CPU scalar into a FakeTensor that is actually a meta tensor) Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/135140 Approved by: https://github.com/williamwen42, https://github.com/yanboliang --- test/dynamo/test_misc.py | 13 +++++++++++++ torch/_prims_common/wrappers.py | 6 +++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e0581ffc34b45b..ab1b343044a26d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -308,6 +308,19 @@ def fn(x): "Graph break for an optree C/C++ function optree._C.PyCapsule.flatten. Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py", ) + def test_scalar_device_movement(self): + if not torch._dynamo.config.assume_static_by_default: + self.skipTest("Doesn't work with symints") + + def add_fn(a, b, out): + res = torch.add(a, b, out=out) + return res + + res = add_fn(2, 3, torch.tensor(0.0)) + add_fn = torch.compile(add_fn, backend="eager", fullgraph=True) + res_compiled = add_fn(2, 3, torch.tensor(0.0)) + self.assertEqual(res, res_compiled) + @skipIfNNModuleInlined("fails internal CI") @unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode") def test_cpp_extension_recommends_custom_ops(self): diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 043322dc9d7448..a89ea7cb9997e6 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -185,11 +185,15 @@ def _maybe_resize_out( return out +def is_cpu_scalar(x: TensorLikeType) -> bool: + return x.dim() == 0 and x.device.type == "cpu" + + def _safe_copy_out( *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False ): # Checks same device - if copy_from.device != copy_to.device: + if not is_cpu_scalar(copy_from) and copy_from.device != copy_to.device: msg = ( f"Attempting to copy from device {copy_from.device} " f"to device {copy_to.device}, but cross-device copies are not allowed!" From 18476914dbf0f1cb8fabf912107ead4591437365 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 5 Sep 2024 20:03:39 +0000 Subject: [PATCH 0477/1018] Use actions/upload-artifact@v4.4.0 for triton builds (#135263) Same as: https://github.com/pytorch/pytorch/pull/135139 Fixes upload failure: https://github.com/pytorch/pytorch/actions/runs/10722567217/job/29748125015 fix regression introduced by https://github.com/pytorch/pytorch/pull/135068 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135263 Approved by: https://github.com/kit1980, https://github.com/huydhn --- .github/workflows/build-triton-wheel.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index a8240df9788dfb..f036252140fd00 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -120,7 +120,7 @@ jobs: fi docker exec -t "${container_name}" chown -R 1000.1000 /artifacts - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 with: name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }} if-no-files-found: error @@ -253,7 +253,7 @@ jobs: docker exec -t "${container_name}" python /pytorch/.github/scripts/build_triton_wheel.py --build-conda --py-version="${PY_VERS}" $RELEASE docker exec -t "${container_name}" chown -R 1000.1000 /artifacts - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 with: name: pytorch-triton-conda-${{ matrix.py_vers }} if-no-files-found: error From c63bb3919b526aaddddd9de3a0efeebfa6f38c85 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Thu, 5 Sep 2024 20:28:42 +0000 Subject: [PATCH 0478/1018] [export] Add ability to run eagerly on UnflattenedModule (#133996) Summary: Added the contextmanager, `_disable_interpreter`, which is meant to put around a call to `unflatten`. This will generate an UnflattendModule and sub-InterpreterModules which will not use torch.fx.Interpreter to run eagerly. We want to have this as a state of the module instead of a contextmanager around running the module because it's not clear where we are calling the unflattened module. This seems to improve the performance: https://fb.workplace.com/groups/1075192433118967/posts/1473590629945810/?comment_id=1473621763276030 Test Plan: CI Differential Revision: D60939034 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133996 Approved by: https://github.com/pianpwk --- test/export/test_unflatten.py | 72 +++++++++++++++++++++++++++++++++++ torch/export/unflatten.py | 34 ++++++++++++++--- 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index e4ef525e7cdbe0..d179371389c3f7 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -22,6 +22,7 @@ from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch.export import Constraint, Dim, export, FlatArgsAdapter, unflatten from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG +from torch.export.unflatten import _disable_interpreter from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck from torch.testing._internal.common_utils import ( @@ -896,6 +897,77 @@ def forward(self, x, y): self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1) self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0) + def test_unflatten_eager(self): + class NestedChild(torch.nn.Module): + def forward(self, x): + return x / x + + class Child1(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.nested = NestedChild() + self.register_parameter( + "child1param", torch.nn.Parameter(torch.ones(2, 3)) + ) + + def forward(self, x): + x = self.nested(x) + return x + self.child1param + + class Child2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) + + def forward(self, x): + return x - self.child2buffer + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.foo = Child1() + self.bar = Child2() + self.register_parameter( + "rootparam", torch.nn.Parameter(torch.ones(2, 3)) + ) + + def forward(self, x): + x = x * self.rootparam + x = self.foo(x) + x = self.bar(x) + return x + + orig_eager = MyModule() + export_module = export(orig_eager, (torch.rand(2, 3),), {}) + with _disable_interpreter(): + unflattened = unflatten(export_module) + + self.assertEqual(unflattened._run_with_interpeter, False) + self.assertEqual(unflattened.foo._run_with_interpeter, False) + + inputs = (torch.rand(2, 3),) + + # Compare the root modules and all submodules + self.compare_outputs(orig_eager, unflattened, inputs) + self.compare_outputs(orig_eager.foo, unflattened.foo, inputs) + self.compare_outputs(orig_eager.bar, unflattened.bar, inputs) + self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs) + + # Check state dicts are equal + orig_state_dict = orig_eager.state_dict() + exported_state_dict = unflattened.state_dict() + for name, value in orig_state_dict.items(): + self.assertTrue(torch.allclose(value, exported_state_dict[name])) + + # Check composability with symbolic trace, as torchrec ddp uses symbolic + # tracer + symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs) + self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs))) + + # torch.compile submodule + unflattened.foo = torch.compile(unflattened.foo, fullgraph=True) + self.compare_outputs(orig_eager, unflattened, inputs) + if __name__ == "__main__": run_tests() diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 6c0e75e4e0076a..2962cdec8fd16c 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -3,6 +3,7 @@ import copy import operator from collections import defaultdict +from contextlib import contextmanager from copy import deepcopy from enum import Enum from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union @@ -36,6 +37,20 @@ class _AttrKind(Enum): CONSTANT = "constant" +RUN_WITH_INTERPRETER = True + + +@contextmanager +def _disable_interpreter(): + global RUN_WITH_INTERPRETER + old_flag = RUN_WITH_INTERPRETER + RUN_WITH_INTERPRETER = False + try: + yield + finally: + RUN_WITH_INTERPRETER = old_flag + + # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module # This installs empty Modules where none exist yet if they are subpaths of target def _assign_attr( @@ -87,13 +102,18 @@ def __init__( super().__init__() self.graph = graph self.graph.owning_module = self + self._run_with_interpeter = RUN_WITH_INTERPRETER def forward(self, *args, **kwargs): assert self.graph_module is not None, "Didn't finalize this InterpreterModule" - if torch.compiler.is_dynamo_compiling(): + if not is_fx_tracing() and ( + torch.compiler.is_dynamo_compiling() or not self._run_with_interpeter + ): # Dynamo cannot trace through torch.fx.Interpreter, so fall back to # GraphModule codegen in this instance. - return self.graph_module(*args, **kwargs) + # Patch the codegened forward to run with this InterpreterModule, + # so attribute accesses, etc. are on this module instead. + return type(self.graph_module).forward(self, *args, **kwargs) else: if kwargs: # Handle **kwargs. FX only natively supports positional @@ -186,6 +206,7 @@ def __init__( self.flat_args_adapter = flat_args_adapter # Flag to indicate whether args have been adapted. self.adapted = False + self._run_with_interpeter = RUN_WITH_INTERPRETER _inplace_buffer_mutations(export_graph, self.graph_signature) _outline_submodules(export_graph, self) @@ -480,9 +501,12 @@ def forward(self, *args, **kwargs): _check_input_constraints_for_graph( self.input_placeholders, new_flat_args_with_path, self.range_constraints ) - tree_out = torch.fx.Interpreter(self, graph=self.graph).run( - *flat_args, enable_io_processing=False - ) + if torch.compiler.is_dynamo_compiling() and not self._run_with_interpreter: + tree_out = torch.fx.GraphModule(self, self.graph)(*flat_args) + else: + tree_out = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) return pytree.tree_unflatten(tree_out, signature.out_spec) def print_readable( From f1b68caead81976de5c3d5fa5824f754595ad7b1 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Thu, 5 Sep 2024 20:36:45 +0000 Subject: [PATCH 0479/1018] [ROCm] remove triton-rocm commit pin and merge pins with triton.txt (#133438) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133438 Approved by: https://github.com/jithunnair-amd, https://github.com/malfet Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> --- .ci/docker/centos-rocm/Dockerfile | 4 ++-- .ci/docker/ci_commit_pins/triton-rocm.txt | 1 - .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/common/install_triton.sh | 5 +---- .ci/docker/ubuntu-rocm/Dockerfile | 4 ++-- .circleci/scripts/binary_populate_env.sh | 2 +- .github/scripts/build_triton_wheel.py | 4 +--- .github/workflows/build-triton-wheel.yml | 2 -- CODEOWNERS | 1 - 9 files changed, 8 insertions(+), 17 deletions(-) delete mode 100644 .ci/docker/ci_commit_pins/triton-rocm.txt diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index bfac9ddd859084..30ce1406e3f81c 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -108,10 +108,10 @@ ENV CMAKE_C_COMPILER cc ENV CMAKE_CXX_COMPILER c++ COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt +COPY ci_commit_pins/triton.txt triton.txt COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt # Install AOTriton (Early fail) COPY ./aotriton_version.txt aotriton_version.txt diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt deleted file mode 100644 index 0cb336acccb5b1..00000000000000 --- a/.ci/docker/ci_commit_pins/triton-rocm.txt +++ /dev/null @@ -1 +0,0 @@ -21eae954efa5bf584da70324b640288c3ee7aede diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 41c8d1602b6bef..378d625784fbc5 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -dedb7bdf339a3546896d4820366ca562c586bfa0 +757b6a61e7df814ba806f498f8bb3160f84b120c diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index d4a0dc80ee177e..6e5fb8839c686e 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -12,10 +12,7 @@ conda_reinstall() { as_jenkins conda install -q -n py_$ANACONDA_PYTHON_VERSION -y --force-reinstall $* } -if [ -n "${ROCM_VERSION}" ]; then - TRITON_REPO="https://github.com/openai/triton" - TRITON_TEXT_FILE="triton-rocm" -elif [ -n "${XPU_VERSION}" ]; then +if [ -n "${XPU_VERSION}" ]; then TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton" TRITON_TEXT_FILE="triton-xpu" else diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index ee9ede8ba611b6..07e25f533a71f9 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -100,10 +100,10 @@ ARG TRITON # try to reach out to S3, which docker build runners don't have access COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt +COPY ci_commit_pins/triton.txt triton.txt COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt # Install AOTriton COPY ./aotriton_version.txt aotriton_version.txt diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index e918635922a45e..106d0917ca68c5 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -90,7 +90,7 @@ fi if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then - TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt) + TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt) TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}" fi if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 7ee2cb4b8e61b5..096b20fe0909c0 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -15,9 +15,7 @@ def read_triton_pin(device: str = "cuda") -> str: triton_file = "triton.txt" - if device == "rocm": - triton_file = "triton-rocm.txt" - elif device == "xpu": + if device == "xpu": triton_file = "triton-xpu.txt" with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / triton_file) as f: return f.read().strip() diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index f036252140fd00..73abadd4f12e26 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -13,7 +13,6 @@ on: - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton.txt - - .ci/docker/ci_commit_pins/triton-rocm.txt - .ci/docker/ci_commit_pins/triton-xpu.txt pull_request: paths: @@ -21,7 +20,6 @@ on: - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton.txt - - .ci/docker/ci_commit_pins/triton-rocm.txt - .ci/docker/ci_commit_pins/triton-xpu.txt concurrency: diff --git a/CODEOWNERS b/CODEOWNERS index 7b9db26104a9d5..bafce8f6f53521 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -57,7 +57,6 @@ nn/qat/ @jerryzh168 # Docker /.ci/docker/ @jeffdaily /.ci/docker/ci_commit_pins/triton.txt @desertfire @Chillee @eellison @shunting314 @bertmaher @jeffdaily @jataylo @jithunnair-amd @pruthvistony -/.ci/docker/ci_commit_pins/triton-rocm.txt @jeffdaily @jataylo @jithunnair-amd @pruthvistony /.ci/docker/ci_commit_pins/triton-xpu.txt @EikanWang @gujinghui # Github Actions From d6e3887c12afc0e5d876dba5d8b9393b818c3c3d Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 5 Sep 2024 21:05:06 +0000 Subject: [PATCH 0480/1018] Use actions/upload-artifact@v4.4.0 for rest of workflows (#135264) To be consistent with https://github.com/pytorch/pytorch/pull/135263 and rest of workflows. Use v4.4.0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135264 Approved by: https://github.com/kit1980, https://github.com/malfet --- .github/actions/upload-test-artifacts/action.yml | 6 +++--- .github/workflows/_ios-build-test.yml | 6 +++--- .github/workflows/_mac-build.yml | 4 ++-- .github/workflows/_rocm-test.yml | 2 +- .github/workflows/_xpu-test.yml | 2 +- .github/workflows/scorecards.yml | 2 +- .github/workflows/target_determination.yml | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/actions/upload-test-artifacts/action.yml b/.github/actions/upload-test-artifacts/action.yml index 04cb43b20c389b..cbb52a74afb79a 100644 --- a/.github/actions/upload-test-artifacts/action.yml +++ b/.github/actions/upload-test-artifacts/action.yml @@ -147,7 +147,7 @@ runs: # GHA upload - name: Store Test Downloaded JSONs on Github - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 if: inputs.use-gha continue-on-error: true with: @@ -158,7 +158,7 @@ runs: path: test/**/*.json - name: Store Test Reports on Github - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 if: inputs.use-gha continue-on-error: true with: @@ -172,7 +172,7 @@ runs: test/**/*.csv - name: Store Usage Logs on Github - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 if: inputs.use-gha continue-on-error: true with: diff --git a/.github/workflows/_ios-build-test.yml b/.github/workflows/_ios-build-test.yml index 95fe6bd1a3b50a..58f55b68540ad0 100644 --- a/.github/workflows/_ios-build-test.yml +++ b/.github/workflows/_ios-build-test.yml @@ -263,7 +263,7 @@ jobs: zip -r "${IOS_ARCH}.zip" install src version.txt LICENSE ./*.podspec popd - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: matrix.ios_platform == 'OS' with: name: pytorch-ios-build-artifacts-${{ matrix.ios_arch }} @@ -401,13 +401,13 @@ jobs: echo "SPEC_NAME=${SPEC_NAME}" } >> "$GITHUB_ENV" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 with: name: pytorch-ios-artifacts if-no-files-found: error path: ${{ env.DEST_DIR }}/${{ env.ARTIFACT_NAME }} - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 with: name: pytorch-ios-podspec if-no-files-found: error diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 5bc6f06b0f6bcd..8b42037c104104 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -186,7 +186,7 @@ jobs: zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .additional_ci_files - name: Store PyTorch Build Artifacts on GHA - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -195,7 +195,7 @@ jobs: path: artifacts.zip - name: Upload sccache stats to GHA - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 # Only if sccache is installed, see above if: ${{ (github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository) && steps.build.outcome != 'skipped' }} with: diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 790cd9d403dc25..b97715b1c050a8 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -269,7 +269,7 @@ jobs: find . -iname "core.[1-9]*" -exec docker exec "${CONTAINER_NAME}" sh -c "gdb python {} -ex 'bt' -ex 'q'" \; - name: Store Core dumps on GitHub - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 if: failure() with: name: coredumps-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }} diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index 036a2c8eeca85f..56c2f9fb041d4d 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -261,7 +261,7 @@ jobs: docker stop "${{ env.CONTAINER_NAME }}" - name: Store Core dumps on GitHub - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 if: failure() with: name: coredumps-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }} diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index cd4258c478763f..29fa6718802aa8 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -42,7 +42,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 with: name: SARIF file path: results.sarif diff --git a/.github/workflows/target_determination.yml b/.github/workflows/target_determination.yml index f7b2f383f314ea..dabd8f6b7e170b 100644 --- a/.github/workflows/target_determination.yml +++ b/.github/workflows/target_determination.yml @@ -83,7 +83,7 @@ jobs: path: td_results.json - name: Store TD results on GHA - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4.0 if: steps.td.outcome == 'success' with: name: td_results.json From 42ace0a0d1a23f06267854bff607596f83ca3834 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 5 Sep 2024 21:12:54 +0000 Subject: [PATCH 0481/1018] Revert "Remove Caffe2 code from tool scripts (#134941)" This reverts commit c818ecd1698a28d9fadf4a81453a89914b18374a. Reverted https://github.com/pytorch/pytorch/pull/134941 on behalf of https://github.com/kit1980 due to breaking internal builds - The path `caffe2/operators/hip/gather_op.cuh` does not exist ([comment](https://github.com/pytorch/pytorch/pull/134941#issuecomment-2332636624)) --- tools/amd_build/build_amd.py | 13 +++++++++++++ tools/build_pytorch_libs.py | 6 ++++++ 2 files changed, 19 insertions(+) diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index b06799702e6258..967f63ae2c8f15 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -64,10 +64,21 @@ out_dir = args.output_directory includes = [ + "caffe2/operators/*", + "caffe2/sgd/*", + "caffe2/image/*", + "caffe2/transforms/*", + "caffe2/video/*", + "caffe2/distributed/*", + "caffe2/queue/*", + "caffe2/contrib/aten/*", "binaries/*", "caffe2/**/*_test*", "caffe2/core/*", + "caffe2/db/*", "caffe2/utils/*", + "caffe2/contrib/gloo/*", + "caffe2/contrib/nccl/*", "c10/cuda/*", "c10/cuda/test/CMakeLists.txt", "modules/*", @@ -107,6 +118,8 @@ includes.append(abs_new_dir) ignores = [ + "caffe2/operators/depthwise_3x3_conv_op_cudnn.cu", + "caffe2/operators/pool_op_cudnn.cu", "*/hip/*", # These files are compatible with both cuda and hip "aten/src/ATen/core/*", diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 3c5893be97e7e9..64e7132a691017 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -2,6 +2,7 @@ import os import platform +import shutil from glob import glob from setuptools import distutils # type: ignore[import] @@ -86,3 +87,8 @@ def build_caffe2( if cmake_only: return cmake.build(my_env) + if build_python: + caffe2_proto_dir = os.path.join(cmake.build_dir, "caffe2", "proto") + for proto_file in glob(os.path.join(caffe2_proto_dir, "*.py")): + if proto_file != os.path.join(caffe2_proto_dir, "__init__.py"): + shutil.copy(proto_file, os.path.join("caffe2", "proto")) From e87cc6ca62764670d6923c70733534434a88d201 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 5 Sep 2024 21:16:14 +0000 Subject: [PATCH 0482/1018] Revert "[Reland] Refactor caching device allocator utils (#130923)" This reverts commit 9809080b9ed657a8c0ea0383be7cbdce3a26e05e. Reverted https://github.com/pytorch/pytorch/pull/130923 on behalf of https://github.com/kit1980 due to breaking internal builds - Error: Relocation overflow has occured ([comment](https://github.com/pytorch/pytorch/pull/130923#issuecomment-2332640961)) --- c10/core/CachingDeviceAllocator.h | 131 ----------------- c10/cuda/CUDACachingAllocator.cpp | 163 ++++++++++++++------- c10/cuda/CUDACachingAllocator.h | 82 +++++++++-- c10/cuda/CUDAMallocAsyncAllocator.cpp | 2 - torch/csrc/cuda/CUDAPluggableAllocator.cpp | 4 +- torch/csrc/cuda/CUDAPluggableAllocator.h | 2 +- torch/csrc/cuda/Module.cpp | 8 +- 7 files changed, 190 insertions(+), 202 deletions(-) delete mode 100644 c10/core/CachingDeviceAllocator.h diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h deleted file mode 100644 index 8724ecf88ae0f9..00000000000000 --- a/c10/core/CachingDeviceAllocator.h +++ /dev/null @@ -1,131 +0,0 @@ -#pragma once - -#include -#include - -#include - -namespace c10::CachingDeviceAllocator { - -struct Stat { - void increase(size_t amount) { - current += static_cast(amount); - peak = std::max(current, peak); - allocated += static_cast(amount); - } - - void decrease(size_t amount) { - current -= static_cast(amount); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - current >= 0, - "Negative tracked stat in device allocator (likely logic error)."); - freed += static_cast(amount); - } - - void reset_accumulated() { - allocated = 0; - freed = 0; - } - - void reset_peak() { - peak = current; - } - - int64_t current = 0; - int64_t peak = 0; - int64_t allocated = 0; - int64_t freed = 0; -}; - -enum struct StatType : uint64_t { - AGGREGATE = 0, - SMALL_POOL = 1, - LARGE_POOL = 2, - NUM_TYPES = 3 // remember to update this whenever a new stat type is added -}; - -using StatArray = std::array(StatType::NUM_TYPES)>; -using StatTypes = std::array(StatType::NUM_TYPES)>; - -template -void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { - for (const auto stat_type : c10::irange(stat_types.size())) { - if (stat_types[stat_type]) { - f(stat_type); - } - } -} - -// Struct containing memory allocator summary statistics for a device. -struct DeviceStats { - // COUNT: allocations requested by client code - StatArray allocation; - // COUNT: number of allocated segments from device memory allocation. - StatArray segment; - // COUNT: number of active memory blocks (allocated or used by stream) - StatArray active; - // COUNT: number of inactive, split memory blocks (unallocated but can't be - // released via device memory deallocation) - StatArray inactive_split; - - // SUM: bytes allocated by this memory alocator - StatArray allocated_bytes; - // SUM: bytes reserved by this memory allocator (both free and used) - StatArray reserved_bytes; - // SUM: bytes within active memory blocks - StatArray active_bytes; - // SUM: bytes within inactive, split memory blocks - StatArray inactive_split_bytes; - // SUM: bytes requested by client code - StatArray requested_bytes; - - // COUNT: total number of failed calls to device malloc necessitating cache - // flushes. - int64_t num_alloc_retries = 0; - - // COUNT: total number of OOMs (i.e. failed calls to device memory allocation - // after cache flush) - int64_t num_ooms = 0; - - // COUNT: total number of oversize blocks allocated from pool - Stat oversize_allocations; - - // COUNT: total number of oversize blocks requiring malloc - Stat oversize_segments; - - // COUNT: total number of synchronize_and_free_events() calls - int64_t num_sync_all_streams = 0; - - // COUNT: total number of device memory allocation calls. This includes both - // mapped and malloced memory. - int64_t num_device_alloc = 0; - - // COUNT: total number of device memory deallocation calls. This includes both - // un-mapped and free memory. - int64_t num_device_free = 0; - - // SIZE: maximum block size that is allowed to be split. - int64_t max_split_size = 0; -}; - -// Size pretty-printer -inline std::string format_size(uint64_t size) { - std::ostringstream os; - os.precision(2); - os << std::fixed; - if (size <= 1024) { - os << size << " bytes"; - } else if (size <= 1048576) { - os << (static_cast(size) / 1024.0); - os << " KiB"; - } else if (size <= 1073741824ULL) { - os << static_cast(size) / 1048576.0; - os << " MiB"; - } else { - os << static_cast(size) / 1073741824.0; - os << " GiB"; - } - return os.str(); -} - -} // namespace c10::CachingDeviceAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 470264321fffd6..fa26a9f3e0cca7 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -26,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -42,8 +44,6 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace cuda::CUDACachingAllocator { -using namespace c10::CachingDeviceAllocator; - // Included here as this is externally used in CUDAAllocatorConfig const size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks @@ -134,13 +134,47 @@ namespace { using stream_set = ska::flat_hash_set; +using StatTypes = std::array(StatType::NUM_TYPES)>; + +void increase_stat(Stat& stat, size_t amount) { + stat.current += static_cast(amount); + stat.peak = std::max(stat.current, stat.peak); + stat.allocated += static_cast(amount); +} + +void decrease_stat(Stat& stat, size_t amount) { + stat.current -= static_cast(amount); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stat.current >= 0, + "Negative tracked stat in CUDA allocator (likely logic error)."); + stat.freed += static_cast(amount); +} + +void reset_accumulated_stat(Stat& stat) { + stat.allocated = 0; + stat.freed = 0; +} + +void reset_peak_stat(Stat& stat) { + stat.peak = stat.current; +} + +template +void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { + for (const auto stat_type : c10::irange(stat_types.size())) { + if (stat_types[stat_type]) { + f(stat_type); + } + } +} + void decrease_stat_array( StatArray& stat_array, size_t amount, const StatTypes& stat_types) { for_each_selected_stat_type( stat_types, [&stat_array, amount](size_t stat_type) { - stat_array[stat_type].decrease(amount); + decrease_stat(stat_array[stat_type], amount); }); } @@ -1390,16 +1424,16 @@ class DeviceCachingAllocator { // A new split inactive block is being created from a previously unsplit // block, size remaining->size bytes. for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - stats.inactive_split_bytes[stat_type].increase(remaining->size); - stats.inactive_split[stat_type].increase(1); + increase_stat(stats.inactive_split_bytes[stat_type], remaining->size); + increase_stat(stats.inactive_split[stat_type], 1); }); } } else if (already_split && !block->expandable_segment_) { // An already-split block is becoming active for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - stats.inactive_split_bytes[stat_type].decrease(block->size); - stats.inactive_split[stat_type].decrease(1); + decrease_stat(stats.inactive_split_bytes[stat_type], block->size); + decrease_stat(stats.inactive_split[stat_type], 1); }); } @@ -1420,14 +1454,14 @@ class DeviceCachingAllocator { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - stats.allocation[stat_type].increase(1); - stats.allocated_bytes[stat_type].increase(block->size); - stats.active[stat_type].increase(1); - stats.active_bytes[stat_type].increase(block->size); - stats.requested_bytes[stat_type].increase(block->requested_size); + increase_stat(stats.allocation[stat_type], 1); + increase_stat(stats.allocated_bytes[stat_type], block->size); + increase_stat(stats.active[stat_type], 1); + increase_stat(stats.active_bytes[stat_type], block->size); + increase_stat(stats.requested_bytes[stat_type], block->requested_size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - stats.oversize_allocations.increase(1); + increase_stat(stats.oversize_allocations, 1); c10::reportMemoryUsageToProfiler( block->ptr, @@ -1453,8 +1487,8 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.allocation[stat_type].decrease(1); - stats.allocated_bytes[stat_type].decrease(block->size); + decrease_stat(stats.allocation[stat_type], 1); + decrease_stat(stats.allocated_bytes[stat_type], block->size); }); record_trace( @@ -1466,7 +1500,7 @@ class DeviceCachingAllocator { context ? context : block->context_when_allocated); if (block->size >= CUDAAllocatorConfig::max_split_size()) - stats.oversize_allocations.decrease(1); + decrease_stat(stats.oversize_allocations, 1); if (!block->stream_uses.empty()) { if (C10_UNLIKELY(!captures_underway.empty())) { @@ -1594,15 +1628,15 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - stats.allocation[statType].reset_accumulated(); - stats.segment[statType].reset_accumulated(); - stats.active[statType].reset_accumulated(); - stats.inactive_split[statType].reset_accumulated(); - stats.allocated_bytes[statType].reset_accumulated(); - stats.reserved_bytes[statType].reset_accumulated(); - stats.active_bytes[statType].reset_accumulated(); - stats.inactive_split_bytes[statType].reset_accumulated(); - stats.requested_bytes[statType].reset_accumulated(); + reset_accumulated_stat(stats.allocation[statType]); + reset_accumulated_stat(stats.segment[statType]); + reset_accumulated_stat(stats.active[statType]); + reset_accumulated_stat(stats.inactive_split[statType]); + reset_accumulated_stat(stats.allocated_bytes[statType]); + reset_accumulated_stat(stats.reserved_bytes[statType]); + reset_accumulated_stat(stats.active_bytes[statType]); + reset_accumulated_stat(stats.inactive_split_bytes[statType]); + reset_accumulated_stat(stats.requested_bytes[statType]); } stats.num_alloc_retries = 0; @@ -1610,8 +1644,8 @@ class DeviceCachingAllocator { stats.num_sync_all_streams = 0; stats.num_device_alloc = 0; stats.num_device_free = 0; - stats.oversize_allocations.reset_accumulated(); - stats.oversize_segments.reset_accumulated(); + reset_accumulated_stat(stats.oversize_allocations); + reset_accumulated_stat(stats.oversize_segments); } /** Resets the historical peak stats for the device **/ @@ -1620,18 +1654,18 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - stats.allocation[statType].reset_peak(); - stats.segment[statType].reset_peak(); - stats.active[statType].reset_peak(); - stats.inactive_split[statType].reset_peak(); - stats.allocated_bytes[statType].reset_peak(); - stats.reserved_bytes[statType].reset_peak(); - stats.active_bytes[statType].reset_peak(); - stats.inactive_split_bytes[statType].reset_peak(); - stats.requested_bytes[statType].reset_peak(); + reset_peak_stat(stats.allocation[statType]); + reset_peak_stat(stats.segment[statType]); + reset_peak_stat(stats.active[statType]); + reset_peak_stat(stats.inactive_split[statType]); + reset_peak_stat(stats.allocated_bytes[statType]); + reset_peak_stat(stats.reserved_bytes[statType]); + reset_peak_stat(stats.active_bytes[statType]); + reset_peak_stat(stats.inactive_split_bytes[statType]); + reset_peak_stat(stats.requested_bytes[statType]); } - stats.oversize_allocations.reset_peak(); - stats.oversize_segments.reset_peak(); + reset_peak_stat(stats.oversize_allocations); + reset_peak_stat(stats.oversize_segments); } /* Checkpoint the state of a private pool necessary to return it to its @@ -2243,7 +2277,7 @@ class DeviceCachingAllocator { total_allocated_memory += mapped_range.size; StatTypes stat_types = get_stat_types_for_pool(*to_map->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.reserved_bytes[stat_type].increase(mapped_range.size); + increase_stat(stats.reserved_bytes[stat_type], mapped_range.size); }); stats.num_device_alloc++; @@ -2350,23 +2384,27 @@ class DeviceCachingAllocator { // inactive_split if (!block->expandable_segment_) { if (net_change_inactive_split_blocks > 0) { - stats.inactive_split[stat_type].increase( + increase_stat( + stats.inactive_split[stat_type], static_cast(net_change_inactive_split_blocks)); } else if (net_change_inactive_split_blocks < 0) { - stats.inactive_split[stat_type].decrease( + decrease_stat( + stats.inactive_split[stat_type], static_cast(-net_change_inactive_split_blocks)); } if (net_change_inactive_split_size > 0) { - stats.inactive_split_bytes[stat_type].increase( + increase_stat( + stats.inactive_split_bytes[stat_type], static_cast(net_change_inactive_split_size)); } else if (net_change_inactive_split_size < 0) { - stats.inactive_split_bytes[stat_type].decrease( + decrease_stat( + stats.inactive_split_bytes[stat_type], static_cast(-net_change_inactive_split_size)); } } - stats.active[stat_type].decrease(1); - stats.active_bytes[stat_type].decrease(original_block_size); - stats.requested_bytes[stat_type].decrease(requested_size); + decrease_stat(stats.active[stat_type], 1); + decrease_stat(stats.active_bytes[stat_type], original_block_size); + decrease_stat(stats.requested_bytes[stat_type], requested_size); }); } @@ -2678,11 +2716,11 @@ class DeviceCachingAllocator { total_allocated_memory += size; p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { - stats.segment[stat_type].increase(1); - stats.reserved_bytes[stat_type].increase(size); + increase_stat(stats.segment[stat_type], 1); + increase_stat(stats.reserved_bytes[stat_type], size); }); if (size >= CUDAAllocatorConfig::max_split_size()) - stats.oversize_segments.increase(1); + increase_stat(stats.oversize_segments, 1); // p.block came from new, not cudaMalloc. It should not be nullptr here. TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); @@ -2817,12 +2855,12 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.segment[stat_type].decrease(1); - stats.reserved_bytes[stat_type].decrease(block->size); + decrease_stat(stats.segment[stat_type], 1); + decrease_stat(stats.reserved_bytes[stat_type], block->size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - stats.oversize_segments.decrease(1); + decrease_stat(stats.oversize_segments, 1); pool->blocks.erase(block); delete block; } @@ -2874,7 +2912,7 @@ class DeviceCachingAllocator { total_allocated_memory -= unmapped.size; StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.reserved_bytes[stat_type].decrease(unmapped.size); + decrease_stat(stats.reserved_bytes[stat_type], unmapped.size); }); if (block->pool->owner_PrivatePool) { @@ -3703,6 +3741,25 @@ void local_raw_delete(void* ptr) { } } // namespace Native +// Size pretty-printer +std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (static_cast(size) / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << static_cast(size) / 1048576.0; + os << " MiB"; + } else { + os << static_cast(size) / 1073741824.0; + os << " GiB"; + } + return os.str(); +} namespace CudaMallocAsync { // If this is put in its own header file, it gets incorrectly renamed in HIPify. diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 70385654201f50..72617bcaf3a944 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -48,12 +48,75 @@ C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace c10::cuda::CUDACachingAllocator { -// Preserved only for BC reasons -// NOLINTNEXTLINE(misc-unused-using-decls) -using c10::CachingDeviceAllocator::DeviceStats; - extern const size_t kLargeBuffer; +struct Stat { + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +typedef std::array(StatType::NUM_TYPES)> StatArray; + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from cudaMalloc(). + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via cudaFree) + StatArray inactive_split; + + // SUM: bytes allocated by this memory alocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to CUDA malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of CUDA allocation calls. This includes both cuMemMap + // and cudaMalloc. + int64_t num_device_alloc = 0; + + // COUNT: total number of CUDA free calls. This includes both cuMemUnmap + // and cudaFree. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + typedef std::shared_ptr (*CreateContextFn)(); // Struct containing info of an allocation block (i.e. a fractional part of a @@ -184,6 +247,9 @@ enum struct RecordContext { ALL = 3, // additionally record stacks for when something is freed }; +// Size pretty-printer +std::string format_size(uint64_t size); + using OutOfMemoryObserver = std::functionrecordStream(dataPtr, stream); } -inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device) { +inline DeviceStats getDeviceStats(c10::DeviceIndex device) { return get()->getDeviceStats(device); } diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 3a7b485ebce223..3fe414b55e30a2 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -11,8 +11,6 @@ namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync { -using namespace c10::CachingDeviceAllocator; - #if CUDA_VERSION >= 11040 // CUDA device allocator that uses cudaMallocAsync to implement // the same interface as CUDACachingAllocator.cpp. diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index 5220e86233bd67..c6af163481481e 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -210,8 +210,8 @@ void CUDAPluggableAllocator::recordStream( } } -c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats( - c10::DeviceIndex device) { +c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: + getDeviceStats(c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getDeviceStats. " diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 8652ef0f2bfde8..70014b2fa0b37b 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -115,7 +115,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator void recordStream(const c10::DataPtr&, streamType stream) override; - c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 22cbadc42fc09f..134758b2c06814 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -565,10 +565,10 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const auto device_index = THPUtils_unpackDeviceIndex(arg); - using c10::CachingDeviceAllocator::DeviceStats; - using c10::CachingDeviceAllocator::Stat; - using c10::CachingDeviceAllocator::StatArray; - using c10::CachingDeviceAllocator::StatType; + using c10::cuda::CUDACachingAllocator::DeviceStats; + using c10::cuda::CUDACachingAllocator::Stat; + using c10::cuda::CUDACachingAllocator::StatArray; + using c10::cuda::CUDACachingAllocator::StatType; const auto statToDict = [](const Stat& stat) { py::dict dict; From 8625b6257c90a3fba026a79c992a35e9ab2ff9db Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 5 Sep 2024 21:19:28 +0000 Subject: [PATCH 0483/1018] [training ir migration] Fix quantization tests (#135184) Summary: Fixed some quantization tests for new training ir: Fix batch norm node pattern matcher. In training ir, we have `aten.batch_norm` node instead of `aten._native_batch_norm_legit` and `aten._native_batch_norm_legit_no_training`. Test Plan: ``` buck run fbcode//mode/dev-nosan fbcode//caffe2/test:quantization_pt2e ``` Reviewed By: tugsbayasgalan Differential Revision: D62209819 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135184 Approved by: https://github.com/tugsbayasgalan --- test/quantization/pt2e/test_quantize_pt2e.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 767eb00b358659..eef06f19cc8ed2 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -4,6 +4,7 @@ import torch from torch import Tensor from torch._export import capture_pre_autograd_graph +from torch._utils_internal import capture_pre_autograd_graph_using_training_ir from torch.ao.quantization import observer, ObserverOrFakeQuantize, QConfigMapping from torch.ao.quantization.qconfig import ( default_per_channel_symmetric_qnnpack_qconfig, @@ -1309,7 +1310,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: m = M().eval() example_inputs = torch.randn(1, 2, 3, 3) - m = capture_pre_autograd_graph(m, example_inputs) + m = capture_pre_autograd_graph(m, (example_inputs,)) with self.assertRaises(Exception): m = prepare_pt2e(m, BackendAQuantizer()) @@ -1863,6 +1864,14 @@ def test_move_exported_model_dropout_inplace(self): self._test_move_exported_model_dropout(inplace=True) def _get_bn_train_eval_ops(self): + if capture_pre_autograd_graph_using_training_ir(): + return ( + torch.ops.aten.batch_norm.default, + torch.ops.aten.batch_norm.default, + ) + # TODO: This branch is going through a deprecated branch and should be deleted soon, + # after capture_pre_autograd_graph fully migrate to training IR + # T199018392 if TEST_WITH_ROCM: return ( torch.ops.aten.miopen_batch_norm.default, From 429de44f4b84e1bc2f89dd7a8661a5f13917d5e9 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 4 Sep 2024 22:10:48 -0700 Subject: [PATCH 0484/1018] [Traceable FSDP2][Dynamo] allow tracing through auto_functionalized HOP (#135169) If an `auto_functionalized` HOP is included in backward graph due to activation checkpointing, we will run into a scenario where Compiled Autograd Dynamo tracing will need to trace through the `auto_functionalized` HOP. This PR adds support for it. Test commands: - `pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_auto_functionalized` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135169 Approved by: https://github.com/zou3519 --- test/inductor/test_compiled_autograd.py | 134 +++++++++++++++++++- torch/_dynamo/variables/higher_order_ops.py | 22 ++++ 2 files changed, 154 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 239698d8e29f29..ad83833ea3386e 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch import _inductor as inductor from torch._dynamo import compiled_autograd, config +from torch._dynamo.backends.debugging import aot_eager from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase @@ -30,13 +31,18 @@ # note: these tests are not run on windows due to inductor_utils.HAS_CPU -def make_compiler_fn(fullgraph=True, dynamic=True): +def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"): + assert backend in ["inductor", "aot_eager"] + def _compiler_fn(gm): """Same as torch.compile() but counts number of compiles""" def _inner_compiler(gm_, example_inputs_): counters["compiled_autograd"]["compiles"] += 1 - return inductor.compile(gm_, example_inputs_) + if backend == "inductor": + return inductor.compile(gm_, example_inputs_) + elif backend == "aot_eager": + return aot_eager(gm_, example_inputs_) return torch.compile( gm, backend=_inner_compiler, fullgraph=fullgraph, dynamic=dynamic @@ -1431,6 +1437,130 @@ def _compiler_fn(gm): f, compiler_fn=compiler_fn_with_op_check, compile_fn=False ) + def test_trace_auto_functionalized(self): + torch.library.define( + "testlib::foo", + "(Tensor(a!) x) -> (Tensor)", + tags=torch.Tag.pt2_compliant_tag, + ) + torch.library.define( + "testlib::foo_mutated", + "(Tensor(a!) x) -> (Tensor)", + tags=torch.Tag.pt2_compliant_tag, + ) + + @torch.library.impl("testlib::foo", "cpu") + def foo(x): + x.add_(5) + return x + + @torch.library.impl("testlib::foo", "Meta") + def foo_meta(x): + return x + + @torch.library.impl("testlib::foo_mutated", "CompositeImplicitAutograd") + def foo_mutated(x): + return torch.ops.testlib.foo(x) + + def _get_custom_policy(must_recompute_list=None): + def _custom_policy(ctx, func, *args, **kwargs): + if must_recompute_list is not None and func in must_recompute_list: + return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE + else: + return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE + + return _custom_policy + + def context_fn(): + must_recompute_list = [ + torch.ops.higher_order.auto_functionalized, + ] + return torch.utils.checkpoint.create_selective_checkpoint_contexts( + _get_custom_policy( + must_recompute_list=must_recompute_list, + ), + ) + + def g(x): + x = torch.matmul(x, x) + torch.ops.testlib.foo_mutated(x) + return torch.matmul(x, x) + + def g_cp(x): + return torch.utils.checkpoint.checkpoint( + g, x, use_reentrant=False, context_fn=context_fn + ) + + def f(): + inps = (torch.randn(4, 4, requires_grad=True),) + output = torch.compile(g_cp, backend="aot_eager", fullgraph=True)(*inps) + output.sum().backward() + return output, inps[0].grad + + """ + Walkthrough of what happens with `auto_functionalized`: + 1. `auto_functionalized` op is inserted into the graph during AOTAutograd functionalization. + We force the op to be recomputed (by using SAC), so it appears in the backward graph. + 2. The AOT backward graph looks like: + ``` + ===== Backward graph 0 ===== + def forward(self, primals_1: "f32[4, 4][4, 1]cpu", tangents_1: "f32[4, 4][4, 1]cpu"): + ... + X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = mm) + ... + return (add_1,) + ``` + 3. The Compiled Autograd graph looks like: + ``` + ===== Compiled autograd graph ===== + def forward(self, inputs, sizes, scalars, hooks): + ... + X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = aot0_mm) + ... + return [] + ``` + 4. The Dynamo graph captured by Compiled Autograd looks like: + ``` + ===== __compiled_fn_3 ===== + def forward(self, L_inputs_ : list): + ... + X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = aot0_mm) + ... + return (new_grad,) + ``` + 5. The Compiled Autograd's AOT "forward-only" graph looks like: + ``` + ===== Forward graph 1 ===== + def forward(self, arg0_1: "f32[][]cpu", arg1_1: "f32[4, 4][4, 1]cpu"): + ... + X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = mm) + ... + return (clone_1,) + ``` + 6. The `auto_functionalized` op should then be lowered using the normal lowering path in Inductor. + """ + + compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager") + + def make_compiler_fn_with_op_check(): + def _compiler_fn(gm): + # Checks that `auto_functionalized` op exists in Compiled Autograd's Dynamo graph. + self.assertTrue( + any( + node.target is torch.ops.higher_order.auto_functionalized + for node in gm.graph.nodes + ), + f"`torch.ops.higher_order.auto_functionalized` op not found in {gm.graph}", + ) + return compiler_fn(gm) + + return _compiler_fn + + compiler_fn_with_op_check = make_compiler_fn_with_op_check() + self.check_output_and_recompiles( + f, compiler_fn=compiler_fn_with_op_check, compile_fn=False + ) + def test_non_traceable_autograd_cpp_node(self): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 35d78fd304dae8..fa0c606296e8d3 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -606,6 +606,8 @@ def make(value, source=None, **kwargs): return CallTorchbindHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "wrap_with_set_grad_enabled": return WrapWithSetGradEnabledHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "auto_functionalized": + return AutoFunctionalizeHigherOrderVariable(value, source, **kwargs) else: unimplemented(f"HigherOrderOperator {value.__name__}") @@ -1738,6 +1740,26 @@ def call_function( ) +class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace From 256131b516f4f9fa50439e4e03d857807f769f96 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 5 Sep 2024 21:31:36 +0000 Subject: [PATCH 0485/1018] Run inductor micro benchmark on x86 metal runner (#135042) This enables inductor micro benchmark on CPU (x86): * Running on AWS metal runner for more accurate benchmark * I add a new `arch` column, which will be either x86_64 or arm64 for CPU or GPU name for GPU. We can use this later to differentiate between different setup, i.e. cuda (a100) vs cuda (a10g) or cpu (x86_64) vs cpu (arm64) The next step would be to run this one cpu arm64, and cuda (a10g). ### Testing Here is the CSV results from my test run https://github.com/pytorch/pytorch/actions/runs/10709344180 ``` name,metric,target,actual,dtype,device,arch,is_model mlp_layer_norm_gelu,flops_utilization,0.8,17.36,bfloat16,cpu,x86_64,False gather_gemv,memory_bandwidth(GB/s),990,170.80,int8,cpu,x86_64,False gather_gemv,memory_bandwidth(GB/s),1060,204.78,bfloat16,cpu,x86_64,False Mixtral-8x7B-v0.1,token_per_sec,175,26.68,int8,cpu,x86_64,True Mixtral-8x7B-v0.1,memory_bandwidth(GB/s),1130,171.91,int8,cpu,x86_64,True Mixtral-8x7B-v0.1,compilation_time(s),162,47.36,int8,cpu,x86_64,True gemv,memory_bandwidth(GB/s),870,236.36,int8,cpu,x86_64,False gemv,memory_bandwidth(GB/s),990,305.71,bfloat16,cpu,x86_64,False Llama-2-7b-chat-hf,token_per_sec,94,14.01,bfloat16,cpu,x86_64,True Llama-2-7b-chat-hf,memory_bandwidth(GB/s),1253,185.18,bfloat16,cpu,x86_64,True Llama-2-7b-chat-hf,compilation_time(s),162,74.99,bfloat16,cpu,x86_64,True Llama-2-7b-chat-hf,token_per_sec,144,25.09,int8,cpu,x86_64,True Llama-2-7b-chat-hf,memory_bandwidth(GB/s),957,165.83,int8,cpu,x86_64,True Llama-2-7b-chat-hf,compilation_time(s),172,70.69,int8,cpu,x86_64,True layer_norm,memory_bandwidth(GB/s),950,172.03,bfloat16,cpu,x86_64,False ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135042 Approved by: https://github.com/yanboliang --- .ci/pytorch/test.sh | 3 ++ .github/pytorch-probot.yml | 1 + .../inductor-micro-benchmark-x86.yml | 40 ++++++++++++++ .github/workflows/upload-test-stats.yml | 2 +- benchmarks/gpt_fast/benchmark.py | 52 +++++++++++++++---- benchmarks/gpt_fast/generate.py | 32 ++++++++++-- 6 files changed, 116 insertions(+), 14 deletions(-) create mode 100644 .github/workflows/inductor-micro-benchmark-x86.yml diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 53970cf346658c..9d2f2246ad7509 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -596,6 +596,9 @@ test_single_dynamo_benchmark() { test_inductor_micro_benchmark() { TEST_REPORTS_DIR=$(pwd)/test/test-reports + if [[ "${TEST_CONFIG}" == *cpu* ]]; then + test_inductor_set_cpu_affinity + fi python benchmarks/gpt_fast/benchmark.py --output "${TEST_REPORTS_DIR}/gpt_fast_benchmark.csv" } diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 991c139b89001f..e43c39396c5f0e 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -9,6 +9,7 @@ ciflow_push_tags: - ciflow/inductor-rocm - ciflow/inductor-perf-compare - ciflow/inductor-micro-benchmark +- ciflow/inductor-micro-benchmark-cpu-x86 - ciflow/inductor-cu124 - ciflow/linux-aarch64 - ciflow/mps diff --git a/.github/workflows/inductor-micro-benchmark-x86.yml b/.github/workflows/inductor-micro-benchmark-x86.yml new file mode 100644 index 00000000000000..d31dbc5951ea12 --- /dev/null +++ b/.github/workflows/inductor-micro-benchmark-x86.yml @@ -0,0 +1,40 @@ +name: inductor-micro-benchmark-x86 + +on: + schedule: + - cron: 0 7 * * * + push: + tags: + - ciflow/inductor-micro-benchmark-cpu-x86/* + workflow_dispatch: + + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: read-all + +jobs: + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks + # Use metal host for benchmark jobs + test-matrix: | + { include: [ + { config: "inductor-micro-benchmark-cpu-x86", shard: 1, num_shards: 1, runner: "linux.24xl.spr-metal" }, + ]} + + linux-jammy-cpu-py3_9-gcc11-inductor-micro-benchmark-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build + with: + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} + use-gha: anything-non-empty-to-use-gha + timeout-minutes: 720 diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index e82dc61370d3c6..c09d76559a9355 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -2,7 +2,7 @@ name: Upload test stats on: workflow_run: - workflows: [pull, trunk, periodic, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, inductor-micro-benchmark, inductor-cu124, inductor-rocm] + workflows: [pull, trunk, periodic, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, inductor-micro-benchmark, inductor-micro-benchmark-x86, inductor-cu124, inductor-rocm] types: - completed diff --git a/benchmarks/gpt_fast/benchmark.py b/benchmarks/gpt_fast/benchmark.py index 95e46cb87ca9bb..6270f9744f6a73 100644 --- a/benchmarks/gpt_fast/benchmark.py +++ b/benchmarks/gpt_fast/benchmark.py @@ -3,7 +3,12 @@ import dataclasses import os -from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8 +from generate import ( + get_arch_name, + run_llama2_7b_bf16, + run_llama2_7b_int8, + run_mixtral_8x7b_int8, +) import torch import torch.nn as nn @@ -24,6 +29,7 @@ class Experiment: actual: float dtype: str device: str + arch: str # GPU name for CUDA or CPU arch for CPU is_model: bool = False @@ -71,7 +77,12 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"): for _ in range(WARMUP_ITER): compiled_mod(x) - us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000 + benchmark_fn = ( + benchmarker.benchmark_gpu + if device == "cuda" + else benchmarker.benchmark_cpu + ) + us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000 flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS flops_utilization = flops_utilization / len(input_shapes) @@ -84,6 +95,7 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"): f"{flops_utilization:.02f}", dtype_str, device, + get_arch_name(), ) ) return results @@ -108,7 +120,12 @@ def run_layer_norm(device: str = "cuda"): for _ in range(WARMUP_ITER): compiled_mod(x) - us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000 + benchmark_fn = ( + benchmarker.benchmark_gpu + if device == "cuda" + else benchmarker.benchmark_cpu + ) + us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000 memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9 memory_bandwidth = memory_bandwidth / len(input_shapes) @@ -121,6 +138,7 @@ def run_layer_norm(device: str = "cuda"): f"{memory_bandwidth:.02f}", dtype_str, device, + get_arch_name(), ) ) return results @@ -151,9 +169,12 @@ def gather_gemv(W, score_idxs, x): for _ in range(WARMUP_ITER): compiled_fn(W, score_idxs, x) - us_per_iter = ( - benchmarker.benchmark_gpu(lambda: compiled_fn(W, score_idxs, x)) * 1000 + benchmark_fn = ( + benchmarker.benchmark_gpu + if device == "cuda" + else benchmarker.benchmark_cpu ) + us_per_iter = benchmark_fn(lambda: compiled_fn(W, score_idxs, x)) * 1000 memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9 memory_bandwidth = memory_bandwidth / len(input_shapes) @@ -166,6 +187,7 @@ def gather_gemv(W, score_idxs, x): f"{memory_bandwidth:.02f}", dtype_str, device, + get_arch_name(), ) ) return results @@ -186,15 +208,20 @@ def run_gemv(device: str = "cuda"): def gemv(W, x): return W.to(x.dtype) @ x - W = torch.randn(D, D, device="cuda").to(dtype=dtype) - x = torch.randn(D, device="cuda", dtype=torch.bfloat16) + W = torch.randn(D, D, device=device).to(dtype=dtype) + x = torch.randn(D, device=device, dtype=torch.bfloat16) compiled_fn = torch.compile(gemv, dynamic=False) for _ in range(WARMUP_ITER): compiled_fn(W, x) - us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_fn(W, x)) * 1000 + benchmark_fn = ( + benchmarker.benchmark_gpu + if device == "cuda" + else benchmarker.benchmark_cpu + ) + us_per_iter = benchmark_fn(lambda: compiled_fn(W, x)) * 1000 memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9 memory_bandwidth = memory_bandwidth / len(input_shapes) @@ -207,6 +234,7 @@ def gemv(W, x): f"{memory_bandwidth:.02f}", dtype_str, device, + get_arch_name(), ) ) return results @@ -252,7 +280,13 @@ def main(output_file=DEFAULT_OUTPUT_FILE): results = [] for func in all_experiments: - lst = func() + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + except AssertionError: + # This happens when torch is compiled with CUDA turning off completely + device = "cpu" + + lst = func(device) for x in lst: results.append(dataclasses.astuple(x)) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index c2aa6e6d17fec4..bbfffa75ac9fdc 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -1,5 +1,6 @@ import dataclasses import itertools +import platform import time from typing import Optional, Tuple @@ -41,6 +42,14 @@ def device_sync(device): print(f"device={device} is not yet suppported") +def get_arch_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_name() + else: + # This returns x86_64 or arm64 (for aarch64) + return platform.machine() + + def multinomial_sample_one_no_sync( probs_sort, ): # Does multinomial sampling without a cuda synchronization @@ -198,7 +207,7 @@ def run_experiment( ) -> None: print(f"Loading model {x.name}") t0 = time.time() - model = _load_model(x) + model = _load_model(x, device=device) device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") @@ -257,7 +266,9 @@ def run_llama2_7b_bf16(device: str = "cuda"): 1253, 162, ) - token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment( + model, device=device + ) return [ Experiment( model.name, @@ -266,6 +277,7 @@ def run_llama2_7b_bf16(device: str = "cuda"): f"{token_per_sec:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -275,6 +287,7 @@ def run_llama2_7b_bf16(device: str = "cuda"): f"{memory_bandwidth:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -284,6 +297,7 @@ def run_llama2_7b_bf16(device: str = "cuda"): f"{compilation_time:.02f}", model.mode, device, + get_arch_name(), True, ), ] @@ -302,7 +316,9 @@ def run_llama2_7b_int8(device: str = "cuda"): 957, 172, ) - token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment( + model, device=device + ) return [ Experiment( model.name, @@ -311,6 +327,7 @@ def run_llama2_7b_int8(device: str = "cuda"): f"{token_per_sec:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -320,6 +337,7 @@ def run_llama2_7b_int8(device: str = "cuda"): f"{memory_bandwidth:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -329,6 +347,7 @@ def run_llama2_7b_int8(device: str = "cuda"): f"{compilation_time:.02f}", model.mode, device, + get_arch_name(), True, ), ] @@ -348,7 +367,9 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): 1130, 162, ) - token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment( + model, device=device + ) return [ Experiment( model.name, @@ -357,6 +378,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): f"{token_per_sec:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -366,6 +388,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): f"{memory_bandwidth:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -375,6 +398,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): f"{compilation_time:.02f}", model.mode, device, + get_arch_name(), True, ), ] From 6c4ed573ac0b8f8c6482cb147d162641ae38f670 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Wed, 4 Sep 2024 16:21:39 -0700 Subject: [PATCH 0486/1018] [hop] preserve metadata in re-tracing hop subgraph by running with interpreter (#135159) In this way, the interpreter.run can preserve the current metadata of subgraphs correctly when tracing the subgraphs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135159 Approved by: https://github.com/tugsbayasgalan --- test/export/test_export.py | 10 ---------- torch/_higher_order_ops/associative_scan.py | 5 ++++- torch/_higher_order_ops/cond.py | 5 +++-- torch/_higher_order_ops/map.py | 3 ++- torch/_higher_order_ops/while_loop.py | 5 +++-- 5 files changed, 12 insertions(+), 16 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index fca1605716f24c..1d65bf7683537f 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -999,16 +999,6 @@ def forward(self, p_linear_weight, p_linear_bias, x): return (getitem, getitem_1, getitem_2)""", ) - # TODO(yidi) - # Expected failure for test cases that calls run_decomposition(). - # The top-level cond node has pre-existing metadata, - # which overrides the metadata for operators in subgraph due to interpreter.run(), - # where cond is a single node in the interpreter.run(). And we preserve metadata - # by copying current node's metadata for all nodes created during interpreting. - @testing.expectedFailurePreDispatchRunDecomp - @testing.expectedFailureRetraceability - @testing.expectedFailureTrainingIRToRunDecomp # T193700910 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_export_cond_preserve_torch_fn_for_subgraphs(self): class MySubModule(torch.nn.Module): def foo(self, x): diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index 3616574efe77d6..11dcca5eb4959f 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -9,6 +9,7 @@ import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._higher_order_ops.utils import ( + _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, reenter_make_fx, @@ -357,6 +358,8 @@ def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim): def associative_scan_functionalize(ctx, combine_fn, input, dim): unwrapped_input = ctx.unwrap_tensors(input) with ctx.redispatch_to_next() as m: - functional_combine_fn = ctx.functionalize(combine_fn) + functional_combine_fn = ctx.functionalize( + _maybe_run_with_interpreter(combine_fn) + ) ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim) return ctx.wrap_tensors(ret) diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 49a1a4cd45fa39..dee400d76f5964 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -18,6 +18,7 @@ from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, _set_compilation_env, reenter_make_fx, unique_graph_id, @@ -449,8 +450,8 @@ def cond_func(ctx, pred, true_fn, false_fn, inputs): unwrapped_inputs = ctx.unwrap_tensors(inputs) unwrapped_pred = ctx.unwrap_tensors(pred) with ctx.redispatch_to_next() as m: - functional_true = ctx.functionalize(true_fn) - functional_false = ctx.functionalize(false_fn) + functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn)) + functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn)) pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch for branch in [functional_true, functional_false]: if _has_potential_branch_input_mutation( diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index a8dc9d94e45f0e..d57d68d5e473f7 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -7,6 +7,7 @@ from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, reenter_make_fx, UnsupportedAliasMutationException, ) @@ -243,7 +244,7 @@ def map_fake_tensor_mode(mode, f, xs, args): def map_functionalize(ctx, f, xs, pos_args): unwrapped_xs = ctx.unwrap_tensors(xs) unwrapped_args = ctx.unwrap_tensors(pos_args) - wrapped_fn = ctx.functionalize(f) + wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f)) with ctx.redispatch_to_next(): with disable_proxy_modes_tracing(): diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 87764a826d9276..d64f512399d9d5 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -7,6 +7,7 @@ from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, reenter_make_fx, @@ -238,8 +239,8 @@ def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs): unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs with ctx.redispatch_to_next() as m: - functional_cond_fn = ctx.functionalize(cond_fn) - functional_body_fn = ctx.functionalize(body_fn) + functional_cond_fn = ctx.functionalize(_maybe_run_with_interpreter(cond_fn)) + functional_body_fn = ctx.functionalize(_maybe_run_with_interpreter(body_fn)) pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch for fn, fn_name in [ (functional_cond_fn, "cond_fn"), From 8587f2e7dacdb197e079408cb3ced1c1b853f261 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 5 Sep 2024 21:43:05 +0000 Subject: [PATCH 0487/1018] Revert "Use actions/upload-artifact@v4.4.0 for rest of workflows (#135264)" This reverts commit 9c0b03020b7204ca5d5dbe18174bab005f79c47b. Reverted https://github.com/pytorch/pytorch/pull/135264 on behalf of https://github.com/atalman due to broke CI ([comment](https://github.com/pytorch/pytorch/pull/135264#issuecomment-2332674607)) --- .github/actions/upload-test-artifacts/action.yml | 6 +++--- .github/workflows/_ios-build-test.yml | 6 +++--- .github/workflows/_mac-build.yml | 4 ++-- .github/workflows/_rocm-test.yml | 2 +- .github/workflows/_xpu-test.yml | 2 +- .github/workflows/scorecards.yml | 2 +- .github/workflows/target_determination.yml | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/actions/upload-test-artifacts/action.yml b/.github/actions/upload-test-artifacts/action.yml index cbb52a74afb79a..04cb43b20c389b 100644 --- a/.github/actions/upload-test-artifacts/action.yml +++ b/.github/actions/upload-test-artifacts/action.yml @@ -147,7 +147,7 @@ runs: # GHA upload - name: Store Test Downloaded JSONs on Github - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 if: inputs.use-gha continue-on-error: true with: @@ -158,7 +158,7 @@ runs: path: test/**/*.json - name: Store Test Reports on Github - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 if: inputs.use-gha continue-on-error: true with: @@ -172,7 +172,7 @@ runs: test/**/*.csv - name: Store Usage Logs on Github - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 if: inputs.use-gha continue-on-error: true with: diff --git a/.github/workflows/_ios-build-test.yml b/.github/workflows/_ios-build-test.yml index 58f55b68540ad0..95fe6bd1a3b50a 100644 --- a/.github/workflows/_ios-build-test.yml +++ b/.github/workflows/_ios-build-test.yml @@ -263,7 +263,7 @@ jobs: zip -r "${IOS_ARCH}.zip" install src version.txt LICENSE ./*.podspec popd - - uses: actions/upload-artifact@v4.4.0 + - uses: actions/upload-artifact@v3 if: matrix.ios_platform == 'OS' with: name: pytorch-ios-build-artifacts-${{ matrix.ios_arch }} @@ -401,13 +401,13 @@ jobs: echo "SPEC_NAME=${SPEC_NAME}" } >> "$GITHUB_ENV" - - uses: actions/upload-artifact@v4.4.0 + - uses: actions/upload-artifact@v3 with: name: pytorch-ios-artifacts if-no-files-found: error path: ${{ env.DEST_DIR }}/${{ env.ARTIFACT_NAME }} - - uses: actions/upload-artifact@v4.4.0 + - uses: actions/upload-artifact@v3 with: name: pytorch-ios-podspec if-no-files-found: error diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 8b42037c104104..5bc6f06b0f6bcd 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -186,7 +186,7 @@ jobs: zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .additional_ci_files - name: Store PyTorch Build Artifacts on GHA - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -195,7 +195,7 @@ jobs: path: artifacts.zip - name: Upload sccache stats to GHA - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 # Only if sccache is installed, see above if: ${{ (github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository) && steps.build.outcome != 'skipped' }} with: diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index b97715b1c050a8..790cd9d403dc25 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -269,7 +269,7 @@ jobs: find . -iname "core.[1-9]*" -exec docker exec "${CONTAINER_NAME}" sh -c "gdb python {} -ex 'bt' -ex 'q'" \; - name: Store Core dumps on GitHub - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 if: failure() with: name: coredumps-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }} diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index 56c2f9fb041d4d..036a2c8eeca85f 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -261,7 +261,7 @@ jobs: docker stop "${{ env.CONTAINER_NAME }}" - name: Store Core dumps on GitHub - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 if: failure() with: name: coredumps-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }} diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 29fa6718802aa8..cd4258c478763f 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -42,7 +42,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 with: name: SARIF file path: results.sarif diff --git a/.github/workflows/target_determination.yml b/.github/workflows/target_determination.yml index dabd8f6b7e170b..f7b2f383f314ea 100644 --- a/.github/workflows/target_determination.yml +++ b/.github/workflows/target_determination.yml @@ -83,7 +83,7 @@ jobs: path: td_results.json - name: Store TD results on GHA - uses: actions/upload-artifact@v4.4.0 + uses: actions/upload-artifact@v3 if: steps.td.outcome == 'success' with: name: td_results.json From 3b8fb66981a17bfdc2db9a14c95d9f7cce726279 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 4 Sep 2024 14:16:32 -0700 Subject: [PATCH 0488/1018] [PyTorch] Fix -Wshadow -Werror build in BFloat16-inl.h (#135031) `float_t` is required to exists in C99 math.h, which causes -Wshadow to fire. We don't need the alias, fortunately. Differential Revision: [D62135908](https://our.internmc.facebook.com/intern/diff/D62135908/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135031 Approved by: https://github.com/albanD --- c10/util/BFloat16-math.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h index 88a6b849d37bf9..5ae993e7418954 100644 --- a/c10/util/BFloat16-math.h +++ b/c10/util/BFloat16-math.h @@ -237,10 +237,9 @@ C10_HOST_DEVICE inline T nextafter(T from, T to) { // Reference: // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c using int_repr_t = uint16_t; - using float_t = T; constexpr uint8_t bits = 16; union { - float_t f; + T f; int_repr_t i; } ufrom = {from}, uto = {to}; From b9f9086ee3a486e76e70d9f01731d42ed5aabda1 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 4 Sep 2024 14:16:33 -0700 Subject: [PATCH 0489/1018] [PyTorch] Add isfinite to BFloat16-math.h (#135052) Missing function from . Differential Revision: [D62148884](https://our.internmc.facebook.com/intern/diff/D62148884/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135052 Approved by: https://github.com/PaliC, https://github.com/albanD ghstack dependencies: #135031 --- c10/util/BFloat16-math.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h index 5ae993e7418954..bad374cbd4353e 100644 --- a/c10/util/BFloat16-math.h +++ b/c10/util/BFloat16-math.h @@ -68,6 +68,12 @@ template < inline T expm1(T a) { return std::expm1(float(a)); } +template < + typename T, + typename std::enable_if_t, int> = 0> +inline bool isfinite(T a) { + return std::isfinite(float(a)); +} template < typename T, typename std::enable_if_t, int> = 0> From eecd103b78167c60b3f55602127eb9619cc6f4a2 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Thu, 5 Sep 2024 21:58:20 +0000 Subject: [PATCH 0490/1018] [export] Expand coverage to more copied sym ops for unflattener. (#135119) Test Plan: buck2 test 'fbcode//mode/opt' fbcode//torchrec/ir/tests:test_serializer -- --run-disabled ``` File changed: fbcode//caffe2/torch/export/unflatten.py Buck UI: https://www.internalfb.com/buck2/2e0377e7-e2b6-4bd0-8133-a787245165a0 Test UI: https://www.internalfb.com/intern/testinfra/testrun/5066549824883887 Network: Up: 0B Down: 0B Jobs completed: 16. Time elapsed: 10.2s. Tests finished: Pass 6. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D62190172 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135119 Approved by: https://github.com/yushangdi --- torch/export/unflatten.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 2962cdec8fd16c..992818f1758048 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -890,7 +890,15 @@ def remap_input(self, x): self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) return self.node_to_placeholder[x] elif x.op == "call_function" and ( - x.target == torch.ops.aten.sym_size.int + x.target + in ( + torch.ops.aten.sym_size.int, + torch.ops.aten.item.default, + torch.ops.aten.unbind.int, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.view.default, + torch.ops.aten.diff.default, + ) or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator") ): # export deduplicates sym_size nodes, and may need to re-copy them From 7e3fbbdb9c62530af2731c4c31b7878f6cfc58b3 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Thu, 5 Sep 2024 22:01:11 +0000 Subject: [PATCH 0491/1018] Tune int8 AMX WoQ micro-kernel for CPU (#134832) This patch prevents performance regression against the default ATen implementation for LLaMA 3.1 int8 GPTQ WoQ workload. Uses AMX micro-kernel only if `M` >= `block_m` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134832 Approved by: https://github.com/jgong5 --- torch/_inductor/codegen/cpp_micro_gemm.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index fb6d86d1a720f5..cb26c48fce53c6 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -603,15 +603,11 @@ class CppMicroGemmAMX(CppMicroGemm): {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %} // create a buffer for tiles of B. - // TODO: loop-unrolling of the "compute" lambda may result in incorrect output - // as this buffer would be used, so maybe 4 of these should be used? - // Since UT output is correct, looks like loop unrolling isn't actually happening. alignas(64) {{input_t}} bf16_weights_buf[512]; int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4; int b_tile_ptr_stride = ldb * {{vnni_size}}; - // TODO: verify whether or not these lambdas inline auto load_B_row = [&]({{input2_t}}* src, {{input_t}}* dst) { {{kernel.unroll_pragma(2)}} for (int i = 0; i < 2; i++) { @@ -801,6 +797,14 @@ def create_from_config(cls, config: CppMicroGemmConfig): ): continue block_m, block_n, block_k = config.register_blocking + if ( + config.vec_isa_cls == VecAMX + and m < block_m + and input_dtype == torch.bfloat16 + and input2_dtype == torch.int8 + ): + # For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m + continue # Criteria on the ranking of configurations # 1. ISA: AMX > VEC # 2. Dividable by block sizes (block_m, block_n, block_k) From 34bebfa403e523128760d9f2ee9103e9a6db9fc8 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 5 Sep 2024 22:05:54 +0000 Subject: [PATCH 0492/1018] [fbcode][dynamo] Turn on guard_nn_modules using justknobs_check (#134928) As Title Pull Request resolved: https://github.com/pytorch/pytorch/pull/134928 Approved by: https://github.com/ezyang --- torch/_dynamo/config.py | 2 +- torch/_dynamo/guards.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 8885ca33cc0765..7cec9d153d69a5 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -98,7 +98,7 @@ def is_fbcode(): allow_ignore_mark_dynamic = False # Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing) -guard_nn_modules = False if is_fbcode() else True +guard_nn_modules = True # Uses CPython internal dictionary tags to detect mutation. There is some # overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag. diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index fccea9cbd14c27..ed0c9893d808b7 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -65,6 +65,7 @@ Source, ) from torch._logging import structured +from torch._utils_internal import justknobs_check from torch.fx.experimental.symbolic_shapes import ( EqualityConstraint, is_symbolic, @@ -2238,9 +2239,16 @@ def cleanup_builder(weak_b): # Break retain cycle. See test_release_input_memory w_builder = weakref.ref(builder, cleanup_builder) + guard_on_nn_modules = config.guard_nn_modules and justknobs_check( + "pytorch/compiler:guard_nn_modules" + ) + + if not justknobs_check("pytorch/compiler:guard_nn_modules"): + log.warning("guard_nn_modules is turned off using justknobs killswitch") + for guard in sorted(guards or [], key=Guard.sort_key): if ( - not config.guard_nn_modules + not guard_on_nn_modules and guard.is_specialized_nn_module() # Default func args must be guarded on. # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API From 4835438d5229a718ddf6cfa7799362b43553024b Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Thu, 5 Sep 2024 16:57:38 -0500 Subject: [PATCH 0493/1018] Support rolling over a percentage of workflows (#134816) In order to support adding a rollover percentage, this ended up being a complete rewrite of runner_determinator.py. Details of the new format are in the comments up top. On the plus side, this now includes some unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134816 Approved by: https://github.com/PaliC, https://github.com/zxiiro --- .github/scripts/runner_determinator.py | 337 ++++++++++++++------ .github/scripts/test_runner_determinator.py | 237 ++++++++++++++ .github/workflows/_runner-determinator.yml | 337 ++++++++++++++------ 3 files changed, 699 insertions(+), 212 deletions(-) create mode 100644 .github/scripts/test_runner_determinator.py diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index e52b19c64b4e3c..c1cc82dbadbe34 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -3,49 +3,94 @@ """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default -https://github.com/pytorch/test-infra/issues/5132) as a user list to determine -which users will get their jobs to run on experimental runners. This user list -is also a comma separated list of additional features or experiments which the -user could be opted in to. +https://github.com/pytorch/test-infra/issues/5132) to define the configuration +of which runners should be used to run which job. + +The configuration has two parts, the settings and a list of opted-in users, +separated by a line containing "---". If the line is not present, the +settings are considered to be empty with only the second part, the user +list, defined. + +The first part is a YAML block that defines the rollout settings. This can be +used to define any settings that are needed to determine which runners to use. +It's fields are defined by the RolloutSettings class below. + +The second part is a list of users who are explicitly opted in to the LF fleet. +The user list is also a comma separated list of additional features or +experiments which the user could be opted in to. The user list has the following rules: -- Users are GitHub usernames with the @ prefix -- If the first line is a "*" then all users will use the new runners -- If the first line is a "!" then all users will use the old runners +- Users are GitHub usernames, which must start with the @ prefix - Each user is also a comma-separated list of features/experiments to enable -- A "#" prefix indicates the user is opted out of the new runners but is opting - into features/experiments. +- A "#" prefix opts the user out of all experiments + +Example config: + # A list of experiments that can be opted into. + # This defines the behavior they'll induce when opted into. + # Expected syntax is: + # [experiment_name]: # Name of the experiment. Also used for the label prefix. + # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. + + experiments: + lf: + rollout_percent: 25 -Example user list: + --- - @User1 - @User2,amz2023 - #@UserOptOutOfNewRunner,amz2023 + # Opt-ins: + # Users can opt into the LF fleet by adding their GitHub username to this list + # and specifying experiments to enable in a comma-separated list. + # Experiments should be from the above list. + + @User1,lf,split_build + @User2,lf + @User3,split_build """ import logging import os +import random from argparse import ArgumentParser from logging import LogRecord -from typing import Any, Iterable +from typing import Any, Dict, Iterable, List, NamedTuple, Tuple +import yaml from github import Auth, Github from github.Issue import Issue -WORKFLOW_LABEL_META = "" # use meta runners +DEFAULT_LABEL_PREFIX = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation -RUNNER_AMI_LEGACY = "" -RUNNER_AMI_AMZ2023 = "amz2023" - GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" +SETTING_EXPERIMENTS = "experiments" + +LF_FLEET_EXPERIMENT = "lf" +CANARY_FLEET_SUFFIX = ".c" + + +class Experiment(NamedTuple): + rollout_perc: int = ( + 0 # Percentage of workflows to experiment on when user is not opted-in. + ) + + # Add more fields as needed + + +class Settings(NamedTuple): + """ + Settings for the experiments that can be opted into. + """ + + experiments: Dict[str, Experiment] = {} + + class ColorFormatter(logging.Formatter): """Color codes the log messages based on the log level""" @@ -172,85 +217,181 @@ def is_exception_branch(branch: str) -> bool: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} -def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str: +def load_yaml(yaml_text: str) -> Any: + try: + data = yaml.safe_load(yaml_text) + return data + except yaml.YAMLError as exc: + log.exception("Error loading YAML") + raise + + +def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: + """ + Extracts the text with settings, if any, and the opted in users from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. + + If it doesn't contain "---" then the settings are empty and the rest is the users. """ - Determines if the job should run on the LF fleet or the Meta fleet + rollout_state_parts = rollout_state.split("---") + if len(rollout_state_parts) >= 2: + return rollout_state_parts[0], rollout_state_parts[1] + else: + return "", rollout_state + - Returns: - The appropriate label prefix for the runner, corresponding to the fleet to use. - This gets prefixed to the very start of the runner label. +class UserOptins(Dict[str, List[str]]): + """ + Dictionary of users with a list of features they have opted into """ + +def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: + """ + Parse the user opt-in text into a key value pair of username and the list of features they have opted into + + Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. + - Example line: "@User1,lf,split_build" + - A "#" prefix indicates the user is opted out of all experiments + + + """ + optins = UserOptins() + for user in user_optin_text.split("\n"): + user = user.strip("\r\n\t -") + if not user or not user.startswith("@"): + # Not a valid user. Skip + continue + + if user: + usr_name = user.split(",")[0].strip("@") + optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] + + return optins + + +def parse_settings_from_text(settings_text: str) -> Settings: + """ + Parse the experiments from the issue body into a list of ExperimentSettings + """ try: - if rollout_state[0] == "!": - log.info("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META - elif rollout_state[0] == "*": - log.info("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF - else: - all_opted_in_users = { - usr_raw.strip("\n\t@ ").split(",")[0] - for usr_raw in rollout_state.split() - } - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - if opted_in_requestors: - log.info( - f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners." - ) - return WORKFLOW_LABEL_LF - else: - log.info( - f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners." - ) - return WORKFLOW_LABEL_META + if settings_text: + # Escape the backtick as well so that we can have the settings in a code block on the GH issue + # for easy reading + # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on + # the backtick character in shell commands. + backtick = chr(96) # backtick character + settings_text = settings_text.strip(f"\r\n\t{backtick} ") + settings = load_yaml(settings_text) + + # For now we just load experiments. We can expand this if/when we add more settings + experiments = {} + + for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): + valid_settings = {} + for setting in exp_settings: + if setting not in Experiment._fields: + log.warning( + f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" + ) + else: + valid_settings[setting] = exp_settings[setting] + + experiments[exp_name] = Experiment(**valid_settings) + return Settings(experiments) except Exception as e: - log.error( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META + log.exception("Failed to parse settings") + return Settings() -def get_optin_feature( - rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str -) -> str: + +def parse_settings(rollout_state: str) -> Settings: """ - Used to dynamically opt in jobs to specific runner-type variants. + Parse settings, if any, from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. - Returns: - The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string. - This variant name is prefixed to the runner-type in the label. + If it doesn't contain "---" then the settings are empty and the default values are used. """ - try: - userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()} - all_opted_in_users = set() - for user in userlist: - for i in user.split(","): - if i == feature: - all_opted_in_users.add(user.split(",")[0]) - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - - if opted_in_requestors: + settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) + return parse_settings_from_text(settings_text) + + +def parse_users(rollout_state: str) -> UserOptins: + """ + Parse users from the rollout state. + + """ + _, users_text = extract_settings_user_opt_in_from_text(rollout_state) + return parse_user_opt_in_from_text(users_text) + + +def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: + """ + Check if a user is opted into an experiment + """ + return experiment_name in user_optins.get(user, []) + + +def get_runner_prefix( + rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False +) -> str: + settings = parse_settings(rollout_state) + user_optins = parse_users(rollout_state) + + fleet_prefix = "" + prefixes = [] + for experiment_name, experiment_settings in settings.experiments.items(): + enabled = False + + # Is any workflow_requestor opted in to this experiment? + opted_in_users = [ + requestor + for requestor in workflow_requestors + if is_user_opted_in(requestor, user_optins, experiment_name) + ] + + if opted_in_users: log.info( - f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}." + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." ) - return feature + enabled = True else: - log.info( - f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"." - ) - return fallback + # If no user is opted in, then we randomly enable the experiment based on the rollout percentage + r = random.randint(1, 100) + if r <= experiment_settings.rollout_perc: + log.info( + f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." + ) + enabled = True + + if enabled: + label = experiment_name + if experiment_name == LF_FLEET_EXPERIMENT: + # We give some special treatment to the "lf" experiment since determines the fleet we use + # - If it's enabled, then we always list it's prefix first + # - If we're in the canary branch, then we append ".c" to the lf prefix + if is_canary: + label += CANARY_FLEET_SUFFIX + fleet_prefix = label + else: + prefixes.append(label) - except Exception as e: + if len(prefixes) > 1: log.error( - f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}' + f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" ) - return fallback + prefixes = prefixes[:1] + + # Fleet always comes first + if fleet_prefix: + prefixes.insert(0, fleet_prefix) + + return ".".join(prefixes) + "." if prefixes else "" def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -268,9 +409,10 @@ def main() -> None: args = parse_args() if args.github_ref_type == "branch" and is_exception_branch(args.github_branch): - log.info(f"Exception branch: '{args.github_branch}', using meta runners") - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY + log.info( + f"Exception branch: '{args.github_branch}', using Meta runners and no experiments." + ) + runner_label_prefix = DEFAULT_LABEL_PREFIX else: try: rollout_state = get_rollout_state_from_issue( @@ -285,35 +427,18 @@ def main() -> None: args.github_branch, ) - label_type = get_fleet( - rollout_state, - ( - args.github_issue_owner, - username, - ), - ) - runner_ami = get_optin_feature( - rollout_state=rollout_state, - workflow_requestors=( - args.github_issue_owner, - username, - ), - feature=RUNNER_AMI_AMZ2023, - fallback=RUNNER_AMI_LEGACY, + is_canary = args.github_repo == "pytorch/pytorch-canary" + + runner_label_prefix = get_runner_prefix( + rollout_state, (args.github_issue_owner, username), is_canary ) + except Exception as e: log.error( - f"Failed to get issue. Falling back to meta runners. Exception: {e}" + f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" ) - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY - - # For Canary builds use canary runners - if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF: - label_type = WORKFLOW_LABEL_LF_CANARY - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type) - set_github_output(GH_OUTPUT_KEY_AMI, runner_ami) + set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) if __name__ == "__main__": diff --git a/.github/scripts/test_runner_determinator.py b/.github/scripts/test_runner_determinator.py new file mode 100644 index 00000000000000..98063b190b0547 --- /dev/null +++ b/.github/scripts/test_runner_determinator.py @@ -0,0 +1,237 @@ +from unittest import main, TestCase +from unittest.mock import Mock, patch + +import runner_determinator as rd + + +class TestRunnerDeterminatorIssueParser(TestCase): + def test_parse_settings(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + settings = rd.parse_settings(settings_text) + + self.assertTupleEqual( + rd.Experiment(rollout_perc=25), + settings.experiments["lf"], + "lf settings not parsed correctly", + ) + self.assertTupleEqual( + rd.Experiment(rollout_perc=0), + settings.experiments["otherExp"], + "otherExp settings not parsed correctly", + ) + + def test_parse_settings_in_code_block(self) -> None: + settings_text = """ + + ``` + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 0 + + ``` + + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + settings = rd.parse_settings(settings_text) + + self.assertTupleEqual( + rd.Experiment(rollout_perc=25), + settings.experiments["lf"], + "lf settings not parsed correctly", + ) + self.assertTupleEqual( + rd.Experiment(rollout_perc=0), + settings.experiments["otherExp"], + "otherExp settings not parsed correctly", + ) + + def test_parse_users(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + users = rd.parse_users(settings_text) + self.assertDictEqual( + {"User1": ["lf"], "User2": ["lf", "otherExp"]}, + users, + "Users not parsed correctly", + ) + + def test_parse_users_without_settings(self) -> None: + settings_text = """ + + @User1,lf + @User2,lf,otherExp + + """ + + users = rd.parse_users(settings_text) + self.assertDictEqual( + {"User1": ["lf"], "User2": ["lf", "otherExp"]}, + users, + "Users not parsed correctly", + ) + + +class TestRunnerDeterminatorGetRunnerPrefix(TestCase): + def test_opted_in_user(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("lf.", prefix, "Runner prefix not correct for User1") + + def test_opted_in_user_two_experiments(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User2"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User1") + + @patch("random.randint", return_value=50) + def test_opted_out_user(self, mock_randint: Mock) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User3"]) + self.assertEqual("", prefix, "Runner prefix not correct for user") + + @patch("random.randint", return_value=10) + def test_opted_out_user_was_pulled_in_by_rollout(self, mock_randint: Mock) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + # User3 is opted out, but is pulled into both experiments by the 10% rollout + prefix = rd.get_runner_prefix(settings_text, ["User3"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + def test_lf_prefix_always_comes_first(self) -> None: + settings_text = """ + experiments: + otherExp: + rollout_perc: 0 + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,otherExp,lf + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User2"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + def test_ignores_commented_users(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + #@User1,lf + @User2,lf,otherExp + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("", prefix, "Runner prefix not correct for user") + + def test_ignores_extra_experiments(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + foo: + rollout_perc: 0 + --- + + Users: + @User1,lf,otherExp,foo + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 8ba709358c5ea1..5edb79da64463d 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -62,49 +62,94 @@ jobs: """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default - https://github.com/pytorch/test-infra/issues/5132) as a user list to determine - which users will get their jobs to run on experimental runners. This user list - is also a comma separated list of additional features or experiments which the - user could be opted in to. + https://github.com/pytorch/test-infra/issues/5132) to define the configuration + of which runners should be used to run which job. + + The configuration has two parts, the settings and a list of opted-in users, + separated by a line containing "---". If the line is not present, the + settings are considered to be empty with only the second part, the user + list, defined. + + The first part is a YAML block that defines the rollout settings. This can be + used to define any settings that are needed to determine which runners to use. + It's fields are defined by the RolloutSettings class below. + + The second part is a list of users who are explicitly opted in to the LF fleet. + The user list is also a comma separated list of additional features or + experiments which the user could be opted in to. The user list has the following rules: - - Users are GitHub usernames with the @ prefix - - If the first line is a "*" then all users will use the new runners - - If the first line is a "!" then all users will use the old runners + - Users are GitHub usernames, which must start with the @ prefix - Each user is also a comma-separated list of features/experiments to enable - - A "#" prefix indicates the user is opted out of the new runners but is opting - into features/experiments. + - A "#" prefix opts the user out of all experiments + + Example config: + # A list of experiments that can be opted into. + # This defines the behavior they'll induce when opted into. + # Expected syntax is: + # [experiment_name]: # Name of the experiment. Also used for the label prefix. + # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. + + experiments: + lf: + rollout_percent: 25 - Example user list: + --- - @User1 - @User2,amz2023 - #@UserOptOutOfNewRunner,amz2023 + # Opt-ins: + # Users can opt into the LF fleet by adding their GitHub username to this list + # and specifying experiments to enable in a comma-separated list. + # Experiments should be from the above list. + + @User1,lf,split_build + @User2,lf + @User3,split_build """ import logging import os + import random from argparse import ArgumentParser from logging import LogRecord - from typing import Any, Iterable + from typing import Any, Dict, Iterable, List, NamedTuple, Tuple + import yaml from github import Auth, Github from github.Issue import Issue - WORKFLOW_LABEL_META = "" # use meta runners + DEFAULT_LABEL_PREFIX = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation - RUNNER_AMI_LEGACY = "" - RUNNER_AMI_AMZ2023 = "amz2023" - GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" + SETTING_EXPERIMENTS = "experiments" + + LF_FLEET_EXPERIMENT = "lf" + CANARY_FLEET_SUFFIX = ".c" + + + class Experiment(NamedTuple): + rollout_perc: int = ( + 0 # Percentage of workflows to experiment on when user is not opted-in. + ) + + # Add more fields as needed + + + class Settings(NamedTuple): + """ + Settings for the experiments that can be opted into. + """ + + experiments: Dict[str, Experiment] = {} + + class ColorFormatter(logging.Formatter): """Color codes the log messages based on the log level""" @@ -231,85 +276,181 @@ jobs: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} - def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str: + def load_yaml(yaml_text: str) -> Any: + try: + data = yaml.safe_load(yaml_text) + return data + except yaml.YAMLError as exc: + log.exception("Error loading YAML") + raise + + + def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: + """ + Extracts the text with settings, if any, and the opted in users from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. + + If it doesn't contain "---" then the settings are empty and the rest is the users. """ - Determines if the job should run on the LF fleet or the Meta fleet + rollout_state_parts = rollout_state.split("---") + if len(rollout_state_parts) >= 2: + return rollout_state_parts[0], rollout_state_parts[1] + else: + return "", rollout_state + - Returns: - The appropriate label prefix for the runner, corresponding to the fleet to use. - This gets prefixed to the very start of the runner label. + class UserOptins(Dict[str, List[str]]): + """ + Dictionary of users with a list of features they have opted into """ + + def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: + """ + Parse the user opt-in text into a key value pair of username and the list of features they have opted into + + Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. + - Example line: "@User1,lf,split_build" + - A "#" prefix indicates the user is opted out of all experiments + + + """ + optins = UserOptins() + for user in user_optin_text.split("\n"): + user = user.strip("\r\n\t -") + if not user or not user.startswith("@"): + # Not a valid user. Skip + continue + + if user: + usr_name = user.split(",")[0].strip("@") + optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] + + return optins + + + def parse_settings_from_text(settings_text: str) -> Settings: + """ + Parse the experiments from the issue body into a list of ExperimentSettings + """ try: - if rollout_state[0] == "!": - log.info("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META - elif rollout_state[0] == "*": - log.info("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF - else: - all_opted_in_users = { - usr_raw.strip("\n\t@ ").split(",")[0] - for usr_raw in rollout_state.split() - } - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - if opted_in_requestors: - log.info( - f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners." - ) - return WORKFLOW_LABEL_LF - else: - log.info( - f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners." - ) - return WORKFLOW_LABEL_META + if settings_text: + # Escape the backtick as well so that we can have the settings in a code block on the GH issue + # for easy reading + # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on + # the backtick character in shell commands. + backtick = chr(96) # backtick character + settings_text = settings_text.strip(f"\r\n\t{backtick} ") + settings = load_yaml(settings_text) + + # For now we just load experiments. We can expand this if/when we add more settings + experiments = {} + + for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): + valid_settings = {} + for setting in exp_settings: + if setting not in Experiment._fields: + log.warning( + f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" + ) + else: + valid_settings[setting] = exp_settings[setting] + + experiments[exp_name] = Experiment(**valid_settings) + return Settings(experiments) except Exception as e: - log.error( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META + log.exception("Failed to parse settings") + return Settings() - def get_optin_feature( - rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str - ) -> str: + + def parse_settings(rollout_state: str) -> Settings: """ - Used to dynamically opt in jobs to specific runner-type variants. + Parse settings, if any, from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. - Returns: - The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string. - This variant name is prefixed to the runner-type in the label. + If it doesn't contain "---" then the settings are empty and the default values are used. """ - try: - userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()} - all_opted_in_users = set() - for user in userlist: - for i in user.split(","): - if i == feature: - all_opted_in_users.add(user.split(",")[0]) - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - - if opted_in_requestors: + settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) + return parse_settings_from_text(settings_text) + + + def parse_users(rollout_state: str) -> UserOptins: + """ + Parse users from the rollout state. + + """ + _, users_text = extract_settings_user_opt_in_from_text(rollout_state) + return parse_user_opt_in_from_text(users_text) + + + def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: + """ + Check if a user is opted into an experiment + """ + return experiment_name in user_optins.get(user, []) + + + def get_runner_prefix( + rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False + ) -> str: + settings = parse_settings(rollout_state) + user_optins = parse_users(rollout_state) + + fleet_prefix = "" + prefixes = [] + for experiment_name, experiment_settings in settings.experiments.items(): + enabled = False + + # Is any workflow_requestor opted in to this experiment? + opted_in_users = [ + requestor + for requestor in workflow_requestors + if is_user_opted_in(requestor, user_optins, experiment_name) + ] + + if opted_in_users: log.info( - f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}." + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." ) - return feature + enabled = True else: - log.info( - f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"." - ) - return fallback + # If no user is opted in, then we randomly enable the experiment based on the rollout percentage + r = random.randint(1, 100) + if r <= experiment_settings.rollout_perc: + log.info( + f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." + ) + enabled = True + + if enabled: + label = experiment_name + if experiment_name == LF_FLEET_EXPERIMENT: + # We give some special treatment to the "lf" experiment since determines the fleet we use + # - If it's enabled, then we always list it's prefix first + # - If we're in the canary branch, then we append ".c" to the lf prefix + if is_canary: + label += CANARY_FLEET_SUFFIX + fleet_prefix = label + else: + prefixes.append(label) - except Exception as e: + if len(prefixes) > 1: log.error( - f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}' + f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" ) - return fallback + prefixes = prefixes[:1] + + # Fleet always comes first + if fleet_prefix: + prefixes.insert(0, fleet_prefix) + + return ".".join(prefixes) + "." if prefixes else "" def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -327,9 +468,10 @@ jobs: args = parse_args() if args.github_ref_type == "branch" and is_exception_branch(args.github_branch): - log.info(f"Exception branch: '{args.github_branch}', using meta runners") - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY + log.info( + f"Exception branch: '{args.github_branch}', using Meta runners and no experiments." + ) + runner_label_prefix = DEFAULT_LABEL_PREFIX else: try: rollout_state = get_rollout_state_from_issue( @@ -344,35 +486,18 @@ jobs: args.github_branch, ) - label_type = get_fleet( - rollout_state, - ( - args.github_issue_owner, - username, - ), - ) - runner_ami = get_optin_feature( - rollout_state=rollout_state, - workflow_requestors=( - args.github_issue_owner, - username, - ), - feature=RUNNER_AMI_AMZ2023, - fallback=RUNNER_AMI_LEGACY, + is_canary = args.github_repo == "pytorch/pytorch-canary" + + runner_label_prefix = get_runner_prefix( + rollout_state, (args.github_issue_owner, username), is_canary ) + except Exception as e: log.error( - f"Failed to get issue. Falling back to meta runners. Exception: {e}" + f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" ) - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY - - # For Canary builds use canary runners - if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF: - label_type = WORKFLOW_LABEL_LF_CANARY - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type) - set_github_output(GH_OUTPUT_KEY_AMI, runner_ami) + set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) if __name__ == "__main__": From ca5b9c019a291fd13dba0d455340a7e3f28ab63a Mon Sep 17 00:00:00 2001 From: eqy Date: Thu, 5 Sep 2024 22:22:43 +0000 Subject: [PATCH 0494/1018] [cuDNN][64-bit indexing] cuDNN v9.3+ supports non-batch-splittable convolutions with > 2**31 elements (#134890) For longstanding issues such as #95024 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134890 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/Convolution.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 97a51b19377003..4a72c102777d15 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -422,7 +422,13 @@ struct ConvParams { // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) #if !defined(C10_MOBILE) if (needs_64bit_indexing_no_split(input, weight)) { - return false; + static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); + if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { + TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" + " if the V8 API is not enabled or before cuDNN version 9.3+." + " Consider upgrading cuDNN and/or enabling the V8 API for better efficiency."); + return false; + } } if (!detail::getCUDAHooks().compiledWithCuDNN()) { return false; From ac4a2f53e50a776066a9711afe3d4501f5216ec9 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 5 Sep 2024 15:27:03 -0400 Subject: [PATCH 0495/1018] Add torch._logging.scribe (#135224) See https://github.com/pytorch/pytorch/pull/135138 for a usage example. Meta only, see https://docs.google.com/document/d/1JpbAQvRhTmuxjnKKjT7qq57dsnV84nxSLpWJo1abJuE/edit#heading=h.9wi46k7np6xw for context fbscribelogger is a library that allows us to write to scribe, which is Meta's logging infrastructure, when you have appropriate access token (this token is available for jobs running on main, as well as authorized jobs with the ci-scribe label). The resulting data is accessible via Scuba (a real time in-memory database) and Hive (a more traditional SQL persisted database). Here's the motivating use case. Suppose there is somewhere in PyTorch's codebase where you'd like to log an event, and then you'd like to find all the situations where this log is called. If PyTorch is rolled out to our internal users, we have some FB-oriented APIs (like torch._utils_internal.signpost_event) with which you can do this. But you have to actually land your PR to main, wait for it to be ingested to fbcode, and then wait for us to actually roll out this version, before you get any data. But what if you want the results within the next few hours? Instead, you can use torch._logging.scribe to directly write to our logging infrastructure *from inside CI jobs.* The most convenient approach is to log unstructured JSON blobs to `open_source_signpost` (added in this PR; you can also add your own dedicated table as described in the GDoc above). After adding logging code to your code, you can push your PR to CI, add 'ci-scribe' label, and in a few hours view the results in Scuba, e.g., (Meta-only) https://fburl.com/scuba/torch_open_source_signpost/z2mq8o4l If you want continuous logging on all commits on master, you can land your PR and it will be continuously get logging for all CI runs that happen on main. Eventually, if your dataset is important enough, you can consider collaborating with PyTorch Dev Infra to get the data collected in our public AWS cloud so that OSS users can view it without access to Meta's internal users. But this facility is really good for prototyping / one-off experiments. It's entirely self serve: just add your logging, run your PR CI with ci-scribe, get results, do analysis in Scuba. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135224 Approved by: https://github.com/Skylion007 --- .ci/docker/requirements-ci.txt | 5 ++ .../requirements/pip-requirements-macOS.txt | 1 + .github/workflows/lint.yml | 2 +- torch/_logging/scribe.py | 61 +++++++++++++++++++ 4 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 torch/_logging/scribe.py diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 07626204e6683c..33166eb1518723 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -36,6 +36,11 @@ expecttest==0.2.1 #Pinned versions: 0.2.1 #test that import: +fbscribelogger==0.1.6 +#Description: write to scribe from authenticated jobs on CI +#Pinned versions: 0.1.6 +#test that import: + flatbuffers==2.0 #Description: cross platform serialization library #Pinned versions: 2.0 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 64396398274ac4..caa831b882312c 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -1,6 +1,7 @@ boto3==1.19.12 hypothesis==6.56.4 expecttest==0.2.1 +fbscribelogger==0.1.6 librosa>=0.6.2 mpmath==1.3.0 networkx==2.8.7 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 84a2a370297d8c..b0427b87bb16ab 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -223,7 +223,7 @@ jobs: cache: pip - name: Install dependencies run: | - pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.2.* numpy==1.24.* + pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.2.* fbscribelogger==0.1.* numpy==1.24.* pip install torch --pre --index-url https://download.pytorch.org/whl/nightly/cpu/ - name: Run run_test.py (nonretryable) run: | diff --git a/torch/_logging/scribe.py b/torch/_logging/scribe.py new file mode 100644 index 00000000000000..18745e468b5e6a --- /dev/null +++ b/torch/_logging/scribe.py @@ -0,0 +1,61 @@ +from typing import Callable, List, Union +from typing_extensions import TypeAlias + + +try: + from fbscribelogger import make_scribe_logger # type: ignore[import-untyped] +except ImportError: + TAtom: TypeAlias = Union[int, float, bool, str] + TField: TypeAlias = Union[TAtom, List[TAtom]] + TLazyField: TypeAlias = Union[TField, Callable[[], TField]] + + def make_scribe_logger(name: str, thrift_src: str) -> Callable[..., None]: + def inner(**kwargs: TLazyField) -> None: + pass + + return inner + + +open_source_signpost = make_scribe_logger( + "TorchOpenSourceSignpost", + """ +struct TorchOpenSourceSignpostLogEntry { + + # The commit SHA that triggered the workflow, e.g., 02a6b1d30f338206a71d0b75bfa09d85fac0028a. Derived from GITHUB_SHA. + 4: optional string commit_sha; + + # Commit date (not author date) of the commit in commit_sha as timestamp, e.g., 1724208105. Increasing if merge bot is used, though not monotonic; duplicates occur when stack is landed. + 5: optional i64 commit_date; + + # The fully-formed ref of the branch or tag that triggered the workflow run, e.g., refs/pull/133891/merge or refs/heads/main. Derived from GITHUB_REF. + 6: optional string github_ref; + + # Indicates if branch protections or rulesets are configured for the ref that triggered the workflow run. Derived from GITHUB_REF_PROTECTED. + 7: optional bool github_ref_protected; + + # A unique number for each attempt of a particular workflow run in a repository, e.g., 1. Derived from GITHUB_RUN_ATTEMPT. + 8: optional string github_run_attempt; + + # A unique number for each workflow run within a repository, e.g., 19471190684. Derived from GITHUB_RUN_ID. + 9: optional string github_run_id; + + # A unique number for each run of a particular workflow in a repository, e.g., 238742. Derived from GITHUB_RUN_NUMBER. + 10: optional string github_run_number_str; + + # The name of the current job. Derived from JOB_NAME, e.g., linux-jammy-py3.8-gcc11 / test (default, 3, 4, amz2023.linux.2xlarge). + 11: optional string job_name; + + # The GitHub user who triggered the job. Derived from GITHUB_TRIGGERING_ACTOR. + 12: optional string github_triggering_actor; + 13: optional string name; # Event name + 14: optional string parameters; # Parameters (JSON data) + 16: optional string subsystem; # Subsystem the event is associated with + + # The unit timestamp in second for the Scuba Time Column override + 17: optional i64 time; + + # The weight of the record according to current sampling rate + 18: optional i64 weight; +} +""", # noqa: B950 +) From f4757a296682207585fee9d1c76121841cf3fd51 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Wed, 4 Sep 2024 16:25:03 -0700 Subject: [PATCH 0496/1018] [cond] fix typo in cond codegen (#134708) As titled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134708 Approved by: https://github.com/jansel --- torch/_inductor/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 3c2e36286a7abd..229751410345be 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6524,7 +6524,7 @@ def create( subgraph.graph.run(*fake_operands) true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] - false_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + false_outputs = false_fn.graph.graph_outputs # type: ignore[union-attr] for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): if _has_aliased_buffers(true_outputs): From 6640e0d06efbda139fd900aac435f4754f2aebcc Mon Sep 17 00:00:00 2001 From: Stonepia Date: Thu, 5 Sep 2024 22:46:31 +0000 Subject: [PATCH 0497/1018] [Intel Triton] Update Intel Triton to release/2.5.0 (#134074) This PR relands https://github.com/pytorch/pytorch/pull/134053 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134074 Approved by: https://github.com/EikanWang --- .ci/docker/ci_commit_pins/triton-xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index 3e47e877deb5fd..4b77bbfc0f8fa6 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -1b2f15840e0d70eec50d84c7a0575cb835524def +cc981feba10a3f4c2e46f3fe368e8fcf5f5643df From 7f868f48c4c15fe2e3677c1a76e6c72a6135f4fa Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Thu, 5 Sep 2024 23:18:10 +0000 Subject: [PATCH 0498/1018] [rfc] scuba for flight recorder (#134794) Summary: Record flight recorder status in a scuba table. Test Plan: Testing with timing out a job. Will post results soon. Differential Revision: D61729221 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134794 Approved by: https://github.com/fduwjj --- .../distributed/c10d/ProcessGroupNCCL.cpp | 43 ++++++++++++++++--- .../distributed/c10d/ProcessGroupNCCL.hpp | 3 +- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3c0f237c5774d5..d8bd66db1c0d47 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1093,8 +1093,18 @@ void ProcessGroupNCCL::waitForFutureOrTimeout( std::future& fut, const std::chrono::milliseconds& timeOutMilSec, const std::string& futDescription, - bool throwException) { + bool throwException, + bool log) { std::string errorMsg; + + ::c10d::C10dLoggingData data; + if (log) { + data.integers["pg_id"] = local_id_; + data.integers["rank"] = rank_; + data.integers["global_rank"] = globalRank(); + data.strings["flight_recorder_version"] = c10d::version_val_str; + } + TORCH_CHECK(fut.valid(), "Expected a valid future"); std::future_status status = fut.wait_for(timeOutMilSec); if (status == std::future_status::ready) { @@ -1105,20 +1115,31 @@ void ProcessGroupNCCL::waitForFutureOrTimeout( if (result) { LOG(INFO) << logPrefix() << "future is successfully executed for: " << futDescription; + if (log) { + data.strings["status"] = "SUCCESS"; + } } } catch (const std::exception& e) { errorMsg = c10::str( logPrefix(), - "Exception thrown when waitng for future ", + "Exception thrown when waiting for future ", futDescription, ": ", e.what()); + if (log) { + data.strings["status"] = "EXCEPTION"; + data.strings["exception"] = e.what(); + } LOG(ERROR) << errorMsg; } catch (...) { errorMsg = c10::str( logPrefix(), - "Unknown exception thrown when waitng for future ", + "Unknown exception thrown when waiting for future ", futDescription); + if (log) { + data.strings["status"] = "EXCEPTION"; + data.strings["exception"] = "Unknown exception"; + } LOG(ERROR) << errorMsg; } } else { @@ -1129,8 +1150,15 @@ void ProcessGroupNCCL::waitForFutureOrTimeout( " timed out after ", timeOutMilSec.count(), " ms"); + data.strings["status"] = "TIMEOUT"; LOG(ERROR) << errorMsg; } + if (log) { + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(data); + } + } if (throwException && !errorMsg.empty()) { C10_THROW_ERROR(DistBackendError, errorMsg); } @@ -1204,7 +1232,8 @@ void ProcessGroupNCCL::shutdown(std::optional reason) { std::future fut = std::async( std::launch::async, [this, &reason]() { return this->abort(reason); }); - waitForFutureOrTimeout(fut, options_->timeout, "ProcessGroup abort", true); + waitForFutureOrTimeout( + fut, options_->timeout, "ProcessGroup abort", true, false); LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully."; // We need to wait for abort to finish before we can safely shut down @@ -1466,11 +1495,13 @@ void ProcessGroupNCCL::heartbeatMonitor() { std::future asyncDebugDump = std::async( std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); - // wait for the dump until timeout + // wait for the dump until timeout - log data waitForFutureOrTimeout( asyncDebugDump, std::chrono::milliseconds(waitTimeoutDumpInMilSec_), - "Flight recorder dump in heartbeatMonitor"); + "Flight recorder dump in heartbeatMonitor", + false, + true); } if (get_gil_checker() != nullptr) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index cc3455e7237904..10a7d20947433a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -902,7 +902,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::future& fut, const std::chrono::milliseconds& timeOutMilSec, const std::string& futDescription, - bool throwException = false); + bool throwException = false, + bool log = false); // When watchdog timeout, this function will be called and return debug info // for users. For now we only get information from retrieveDesyncReport. From d84a62b4ddaec5c90fd58693c4722795309d8e8d Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Thu, 5 Sep 2024 23:20:31 +0000 Subject: [PATCH 0499/1018] fix fake tensor tolist implementation (#135131) Summary: When exporting for training with `tolist`, we do not hit `FunctionalTensor.tolist` since we do not functionalize. Unfortunately, this means we hit `FakeTensor.tolist`, which creates unbacked symints that are not backed by proxies. Rather than trying to patch up this low-level implementation, we replace it with essentially what `FunctionalTensor.tolist` does, which is higher-level: we essentially desugar to `item()` calls and let it take care of unbacked symints. Test Plan: Some expected failures are gone now. Also found a test for `tolist` that was written when `FunctionalTensor.tolist` was implemented but not really doing much; repurposed it now to exercise more modes. Differential Revision: D62197742 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135131 Approved by: https://github.com/ezyang --- test/dynamo/test_subclasses.py | 1 + test/export/test_export.py | 17 +++++----- test/test_nestedtensor.py | 21 +++++++++---- torch/_subclasses/fake_tensor.py | 31 ++++++------------- .../_internal/common_methods_invocations.py | 2 -- 5 files changed, 33 insertions(+), 39 deletions(-) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 7137fd419b32be..fdc74592ec4d23 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -2239,6 +2239,7 @@ def fn(x, y): for ref_v, res_v in zip(values_copy, values): self.assertEqual(ref_v.grad, res_v.grad) + @torch._dynamo.config.patch({"capture_scalar_outputs": True}) def test_unbind(self): # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). # This causes a recompile later on when it realizes the batch and last dim diff --git a/test/export/test_export.py b/test/export/test_export.py index 1d65bf7683537f..4220f21ecb14d0 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2286,8 +2286,6 @@ def forward(self, t): M = M_v3 export(N(), (t,), strict=strict) - @testing.expectedFailureTrainingIRToRunDecomp - @testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked? @testing.expectedFailureSerDer # T195866111 def test_suggested_fixes_for_data_dependent_errors_puzzlers(self): # suggested fixes for data-dependent errors only work in non-strict mode @@ -2418,6 +2416,14 @@ def forward(self, xs, y): strict=strict, ) + def test_tolist(self): + class M(torch.nn.Module): + def forward(self, x): + return x.tolist() + + ep = export(M(), (torch.ones(3, dtype=torch.int),)) + self.assertEqual(ep.module()(torch.tensor([1, 2, 3])), [1, 2, 3]) + def test_if_functional(self): class Module(torch.nn.Module): def forward(self, x): @@ -7920,13 +7926,6 @@ def forward(self, x): arg = node.args[0] self.assertTrue(arg.op == "placeholder") - def test_tolist_nonstrict_output(self): - class M(torch.nn.Module): - def forward(self, x): - x.tolist() - - ep = torch.export.export(M(), (torch.ones(3),), strict=False) - def test_preserve_non_cia_op(self): class M(torch.nn.Module): def forward(self, x): diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 1c8f23e8c41fba..d9fcd7caaf0c77 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -1025,9 +1025,14 @@ def test_embedding(self, device, layout): ) emb = torch.nn.Embedding(100, 8, device=device) y = emb(x) - ys = y.unbind() - for i, inp in enumerate(inputs): - self.assertEqual(emb(inp), ys[i]) + + @torch._dynamo.disable + def check(inputs, y): + ys = y.unbind() + for i, inp in enumerate(inputs): + self.assertEqual(emb(inp), ys[i]) + + check(inputs, y) @skipMeta @torch.inference_mode() @@ -7120,9 +7125,13 @@ def test_unbind_backward(self, device, dtype): a, b, c = nt.unbind() b.sum().backward() - expected_grad = torch.zeros_like(nt) - expected_grad.unbind()[1].add_(1.0) - torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad) + @torch._dynamo.disable + def check(nt): + expected_grad = torch.zeros_like(nt) + expected_grad.unbind()[1].add_(1.0) + self.assertEqual(nt.grad, expected_grad) + + check(nt) FORWARD_FAILURES = { diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 27348dca80e483..4867b43e84c2ca 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -14,6 +14,7 @@ from collections import defaultdict from dataclasses import dataclass from typing import ( + Any, Callable, cast, Dict, @@ -893,28 +894,14 @@ def get_nested_int( ) return self.nested_int_memo * coeff - # We must handle tolist in a special way for FakeTensors here in the case - # where tolist is called from torch dispatch for tensor subclasses. - # Ordinarily, if a program calls .tolist compiling still works because there is - # special handling in dynamo, but for tensor subclasses if .tolist is called - # inside torch dispatch, the .tolist call may be directly on a FakeTensor. - # This would result in an error since wrapper subclasses don't have storage. - # To avoid this, we handle the FakeTensor case by (1) specializing on the size - # of the tensor to create the output Python list, and (2) creating unbacked - # symints for each element of the list. - def tolist(self) -> List[SymInt]: - assert self.dim() == 1, "NYI for higher dims" - shape_env = self.fake_mode.shape_env - assert shape_env is not None - out = [] - # Specialize on the length of the list - for _ in range(self.shape[0]): - s = shape_env.create_unbacked_symint() - # max value? - torch._check_is_size(s) - torch._check(s >= 2) - out.append(s) - return out + # Similar to FunctionalTensor.tolist + def tolist(self) -> Any: + if self.dim() == 0: + return self.item() + elif self.dim() == 1: + return [elem.item() for elem in self] + else: + return [elem.tolist() for elem in self] _MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 869270a600bc90..e1dc92edb6302a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -23797,8 +23797,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.tensor_split", torch_opinfo_name="tensor_split", skips=( - # TensorMeta doesn't support tolist - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), # RuntimeError: no _refs support for torch.Tensor.tolist DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), ), From fd35f1ca5af15c7cc46d0b569aac5ef3af58f1f4 Mon Sep 17 00:00:00 2001 From: Kulin Seth Date: Thu, 5 Sep 2024 23:23:15 +0000 Subject: [PATCH 0500/1018] [MPS] Add support for autocast in MPS (#99272) Fixes https://github.com/pytorch/pytorch/issues/88415 Need to run inductor/test_cpu_select_algorithm Pull Request resolved: https://github.com/pytorch/pytorch/pull/99272 Approved by: https://github.com/malfet Co-authored-by: Siddharth Kotapati Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Co-authored-by: Roy Hvaara --- aten/src/ATen/autocast_mode.cpp | 114 +++++++++++++++++++++++++- aten/src/ATen/autocast_mode.h | 27 +++++- aten/src/ATen/core/interned_strings.h | 1 + c10/core/DispatchKey.cpp | 3 + c10/core/DispatchKey.h | 1 + c10/core/DispatchKeySet.h | 5 ++ test/test_autocast.py | 49 +++++++++++ test/test_mps.py | 23 ++++++ torch/amp/autocast_mode.py | 9 ++ torch/csrc/jit/passes/autocast.cpp | 2 +- torch/csrc/utils/python_dispatch.cpp | 1 + 11 files changed, 232 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 10fb72796fc647..4766476d23377e 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -69,7 +69,7 @@ thread_local std::array at::ScalarType::Undefined, // Vulkan at::ScalarType::Undefined, // Metal at::kHalf, // XPU - at::ScalarType::Undefined, // MPS + at::kHalf, // MPS at::ScalarType::Undefined, // Meta (tensors with no data) at::kBFloat16, // HPU / HABANA at::ScalarType::Undefined, // SX-Aurora / NEC @@ -206,6 +206,118 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } +TORCH_LIBRARY_IMPL(_, AutocastMPS, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + +TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { + // lower_precision_fp + KERNEL_MPS2(_convolution, deprecated, lower_precision_fp) + KERNEL_MPS(_convolution, lower_precision_fp) + KERNEL_MPS(conv1d, lower_precision_fp) + KERNEL_MPS(conv2d, lower_precision_fp) + KERNEL_MPS(conv_tbc, lower_precision_fp) + KERNEL_MPS(conv_transpose1d, lower_precision_fp) + KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp) + KERNEL_MPS(convolution, lower_precision_fp) + KERNEL_MPS(_mps_convolution, lower_precision_fp) + KERNEL_MPS(prelu, lower_precision_fp) + KERNEL_MPS(addmm, lower_precision_fp) + KERNEL_MPS(addmv, lower_precision_fp) + KERNEL_MPS(addr, lower_precision_fp) + KERNEL_MPS(matmul, lower_precision_fp) + KERNEL_MPS(einsum, lower_precision_fp) + KERNEL_MPS(mm, lower_precision_fp) + KERNEL_MPS(mv, lower_precision_fp) + KERNEL_MPS(linear, lower_precision_fp) + KERNEL_MPS(addbmm, lower_precision_fp) + KERNEL_MPS(baddbmm, lower_precision_fp) + KERNEL_MPS(bmm, lower_precision_fp) + KERNEL_MPS(chain_matmul, lower_precision_fp) + KERNEL_MPS(linalg_multi_dot, lower_precision_fp) + KERNEL_MPS(lstm_cell, lower_precision_fp) + + // fp32 + KERNEL_MPS(acos, fp32) + KERNEL_MPS(asin, fp32) + KERNEL_MPS(cosh, fp32) + KERNEL_MPS(erfinv, fp32) + KERNEL_MPS(exp, fp32) + KERNEL_MPS(expm1, fp32) + KERNEL_MPS(log, fp32) + KERNEL_MPS(log10, fp32) + KERNEL_MPS(log2, fp32) + KERNEL_MPS(log1p, fp32) + KERNEL_MPS(reciprocal, fp32) + KERNEL_MPS(rsqrt, fp32) + KERNEL_MPS(sinh, fp32) + KERNEL_MPS(tan, fp32) + KERNEL_MPS2(pow, Tensor_Scalar, fp32) + KERNEL_MPS2(pow, Tensor_Tensor, fp32) + KERNEL_MPS2(pow, Scalar, fp32) + KERNEL_MPS(softplus, fp32) + KERNEL_MPS(layer_norm, fp32) + KERNEL_MPS(native_layer_norm, fp32) + KERNEL_MPS(group_norm, fp32) + KERNEL_MPS2(frobenius_norm, dim, fp32) + KERNEL_MPS(nuclear_norm, fp32) + KERNEL_MPS2(nuclear_norm, dim, fp32) + KERNEL_MPS(batch_norm, fp32) + KERNEL_MPS(cosine_similarity, fp32) + KERNEL_MPS(poisson_nll_loss, fp32) + KERNEL_MPS(cosine_embedding_loss, fp32) + KERNEL_MPS(nll_loss, fp32) + KERNEL_MPS(nll_loss2d, fp32) + KERNEL_MPS(hinge_embedding_loss, fp32) + KERNEL_MPS(kl_div, fp32) + KERNEL_MPS(l1_loss, fp32) + KERNEL_MPS(smooth_l1_loss, fp32) + KERNEL_MPS(huber_loss, fp32) + KERNEL_MPS(mse_loss, fp32) + KERNEL_MPS(margin_ranking_loss, fp32) + KERNEL_MPS(multilabel_margin_loss, fp32) + KERNEL_MPS(soft_margin_loss, fp32) + KERNEL_MPS(triplet_margin_loss, fp32) + KERNEL_MPS(multi_margin_loss, fp32) + KERNEL_MPS(binary_cross_entropy_with_logits, fp32) + KERNEL_MPS(dist, fp32) + KERNEL_MPS(pdist, fp32) + KERNEL_MPS(cdist, fp32) + KERNEL_MPS(renorm, fp32) + KERNEL_MPS(logsumexp, fp32) + + // fp32_set_opt_dtype + KERNEL_MPS(prod, fp32) + KERNEL_MPS2(prod, dim_int, fp32) + KERNEL_MPS2(prod, dim_Dimname, fp32) + KERNEL_MPS2(softmax, int, fp32) + KERNEL_MPS2(softmax, Dimname, fp32) + KERNEL_MPS2(log_softmax, int, fp32) + KERNEL_MPS2(log_softmax, Dimname, fp32) + KERNEL_MPS(cumprod, fp32) + KERNEL_MPS2(cumprod, dimname, fp32) + KERNEL_MPS(cumsum, fp32) + KERNEL_MPS2(cumsum, dimname, fp32) + KERNEL_MPS(linalg_vector_norm, fp32) + KERNEL_MPS(linalg_matrix_norm, fp32) + KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32) + KERNEL_MPS(sum, fp32) + KERNEL_MPS2(sum, dim_IntList, fp32) + KERNEL_MPS2(sum, dim_DimnameList, fp32) + // + // promote + KERNEL_MPS(addcdiv, promote) + KERNEL_MPS(addcmul, promote) + KERNEL_MPS(atan2, promote) + KERNEL_MPS(bilinear, promote) + KERNEL_MPS(cross, promote) + KERNEL_MPS(dot, promote) + KERNEL_MPS(grid_sampler, promote) + KERNEL_MPS(index_put, promote) + KERNEL_MPS(tensordot, promote) + KERNEL_MPS(scatter_add, promote) +} + TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); } diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 6c9f8c556aef8b..95f1dd2ca0c009 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -145,6 +145,8 @@ inline bool is_autocast_eligible( return tensor.is_xla() && tensor.is_floating_point(); case c10::DeviceType::PrivateUse1: return tensor.is_privateuseone() && tensor.is_floating_point(); + case c10::DeviceType::MPS: + return tensor.is_mps() && tensor.is_floating_point(); default: return false; } @@ -168,6 +170,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type( return DispatchKey::AutocastXLA; case c10::DeviceType::PrivateUse1: return DispatchKey::AutocastPrivateUse1; + case c10::DeviceType::MPS: + return DispatchKey::AutocastMPS; default: throw std::runtime_error( "unknown device type for autocast in get_autocast_dispatch_key_from_device_type"); @@ -178,7 +182,7 @@ inline bool is_autocast_available(c10::DeviceType device_type) { if (device_type == at::kCPU || device_type == at::kCUDA || device_type == at::kXPU || device_type == at::kIPU || device_type == at::kHPU || device_type == at::kXLA || - device_type == at::kPrivateUse1) { + device_type == at::kPrivateUse1 || device_type == at::kMPS) { return true; } else { return false; @@ -745,6 +749,27 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. REDISPATCH_SIGNATURE, \ POLICY) +// KERNEL_MPS registration for AutocastMPS +#define KERNEL_MPS(OP, POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" #OP), \ + &WrapFunction< \ + CastPolicy::POLICY, \ + DeviceType::MPS, \ + decltype(ATEN_FN(OP)), \ + decltype(ATEN_FN(OP)), \ + &ATEN_FN(OP)>::type::call); + +#define KERNEL_MPS2(OP, OVERLOAD, POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ + &WrapFunction< \ + CastPolicy::POLICY, \ + DeviceType::MPS, \ + decltype(ATEN_FN2(OP, OVERLOAD)), \ + decltype(ATEN_FN2(OP, OVERLOAD)), \ + &ATEN_FN2(OP, OVERLOAD)>::type::call); + // Op lists for different policies. // To make sure other backends can reuse the policy op list. #define AT_FORALL_LOWER_PRECISION_FP(_) \ diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 4f6abd66cb8878..38942031befcd6 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -228,6 +228,7 @@ namespace c10 { _(aten, is_autocast_cpu_enabled) \ _(aten, is_autocast_xla_enabled) \ _(aten, get_autocast_dtype) \ + _(aten, is_autocast_mps_enabled) \ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 0388234efd5b34..526e7f079ee5ad 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -149,6 +149,8 @@ const char* toString(DispatchKey t) { return "AutocastXLA"; case DispatchKey::AutocastPrivateUse1: return "AutocastPrivateUse1"; + case DispatchKey::AutocastMPS: + return "AutocastMPS"; case DispatchKey::FuncTorchBatched: return "FuncTorchBatched"; @@ -297,6 +299,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"AutocastCUDA", c10::DispatchKey::AutocastCUDA}, {"AutocastXLA", c10::DispatchKey::AutocastXLA}, {"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1}, + {"AutocastMPS", c10::DispatchKey::AutocastMPS}, {"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched}, {"BatchedNestedTensor", c10::DispatchKey::BatchedNestedTensor}, {"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode}, diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 5e417ae4a3ddb1..fc5bdabd18fdd4 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -359,6 +359,7 @@ enum class DispatchKey : uint16_t { AutocastXLA, // AutocastXLA is only being used for TPUs. XLA GPUs continue to use // AutocastCUDA. + AutocastMPS, AutocastCUDA, AutocastPrivateUse1, diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index ef020071fbc2c9..ca54e1966c5e65 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -655,6 +655,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ DispatchKey::AutocastCPU, + DispatchKey::AutocastMPS, DispatchKey::AutocastCUDA, DispatchKey::AutocastXPU, DispatchKey::AutocastIPU, @@ -671,6 +672,7 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ DispatchKey::AutocastCPU, + DispatchKey::AutocastMPS, DispatchKey::AutocastCUDA, DispatchKey::AutocastXPU, DispatchKey::AutocastIPU, @@ -863,6 +865,7 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA); constexpr auto autocast_privateuse1_ks = DispatchKeySet(DispatchKey::AutocastPrivateUse1); + constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS); switch (t) { case BackendComponent::CPUBit: return autocast_cpu_ks; @@ -878,6 +881,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { return autocast_xla_ks; case BackendComponent::PrivateUse1Bit: return autocast_privateuse1_ks; + case BackendComponent::MPSBit: + return autocast_mps_ks; default: return DispatchKeySet(); } diff --git a/test/test_autocast.py b/test/test_autocast.py index 0866c5c865bd91..1b25750b404c00 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -300,6 +300,55 @@ def test_autocast_prioritize(self): loss.backward() +@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") +class TestAutocastMPS(TestCase): + def test_cast_cache_is_global(self): + class CustomLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w_t): + ctx.save_for_backward(x, w_t) + return torch.nn.functional.linear(x, w_t) + + @staticmethod + def backward(ctx, grad_output): + x, w_t = ctx.saved_tensors + with torch.autocast(device_type="mps"): + dL_dX = torch.matmul(grad_output, w_t) + dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) + return dL_dX, dL_dW + + data = torch.randn(2, 3).to("mps") + weight = torch.nn.Parameter(torch.randn(4, 3).to("mps")) + weight_dtype_cast_counter = 0 + + class WeightDTypeCastCounterMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if ( + func is torch.ops.aten._to_copy.default + and args[0] is weight + and kwargs["dtype"] is torch.float16 + ): + nonlocal weight_dtype_cast_counter + weight_dtype_cast_counter += 1 + return func(*args, **kwargs) + + def __enter__(self): + # self.old_clear_cache = torch.clear_autocast_cache + # torch.clear_autocast_cache = lambda: None + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + # torch.clear_autocast_cache = self.old_clear_cache + return super().__exit__(exc_type, exc_val, exc_tb) + + with WeightDTypeCastCounterMode(): + with torch.autocast(device_type="mps"): + output = CustomLinear.apply(data, weight) + s = output.sum() + s.backward() + self.assertEqual(weight_dtype_cast_counter, 2) + + class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self): gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda") diff --git a/test/test_mps.py b/test/test_mps.py index f3dd38dda7cfae..55d05e502d3d24 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1203,6 +1203,29 @@ def __exit__(self, exec_type, exec_value, traceback): raise RuntimeError(msg) +class TestAutocastMPS(TestCase): + + def test_matmul_autocast(self): + autocast_tensor_A = torch.rand((8, 8), device="mps") + autocast_tensor_B = torch.rand((8, 8), device="mps") + tensor_A = autocast_tensor_A.clone().detach() + tensor_B = autocast_tensor_B.clone().detach() + autocast_output_tensor = torch.empty(8, 8) + output_tensor = autocast_output_tensor.clone().detach() + + with torch.autocast(device_type="mps"): + autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B) + autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor) + + output_tensor = torch.mm(tensor_A, tensor_B) + output_tensor = torch.mm(tensor_A, output_tensor) + + self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16") + self.assertEqual(autocast_output_tensor, + output_tensor.to(torch.float16), + f"Autocast & non-autocast tensors did not match, \ + got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}") + # Expand TestCase class with Memory Leak Detection on MPS device class TestCaseMPS(TestCase): _do_mps_memory_leak_check = True diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index f5a50bbe2b3e2c..6aba6bbad42efc 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -322,6 +322,15 @@ def __init__( raise RuntimeError( "Current CUDA Device does not support bfloat16. Please switch dtype to float16." ) + elif self.device == "mps": + supported_dtype = [torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += ( + "MPS Autocast only supports dtype of torch.bfloat16 currently." + ) + warnings.warn(error_message) + enabled = False elif self.device == "xla": supported_dtype = [torch.float16, torch.bfloat16] if self.fast_dtype not in supported_dtype: diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index 1aff3eac1e641b..766c0843026456 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -108,7 +108,7 @@ std::optional parseAutocast( TORCH_CHECK( dtype != c10::ScalarType::Undefined, "Autocast has invalid fast_dtype attribute"); - if (device == "cuda") { + if (device == "cuda" || device == "mps") { scope.context.gpu_enabled = enabled.value(); scope.context.gpu_scalar_type = dtype; } else if (device == "cpu") { diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 455d1c95ac571c..aa875680788671 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -726,6 +726,7 @@ void initDispatchBindings(PyObject* module) { DEF_ONE(PreDispatch) DEF_ONE(Functionalize) DEF_ONE(AutocastCPU) + DEF_ONE(AutocastMPS) DEF_ONE(AutocastXPU) DEF_ONE(AutocastHPU) DEF_ONE(AutocastIPU) From 10aeae2eb97363a3a46f0b26807992e21f5aea72 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 5 Sep 2024 23:39:41 +0000 Subject: [PATCH 0501/1018] Revert "Support rolling over a percentage of workflows (#134816)" This reverts commit fc890b55b51098437b6149abf1026a8b2aaee389. Reverted https://github.com/pytorch/pytorch/pull/134816 on behalf of https://github.com/malfet due to Causes lint to intermittently fail ([comment](https://github.com/pytorch/pytorch/pull/134816#issuecomment-2332902609)) --- .github/scripts/runner_determinator.py | 337 ++++++-------------- .github/scripts/test_runner_determinator.py | 237 -------------- .github/workflows/_runner-determinator.yml | 337 ++++++-------------- 3 files changed, 212 insertions(+), 699 deletions(-) delete mode 100644 .github/scripts/test_runner_determinator.py diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index c1cc82dbadbe34..e52b19c64b4e3c 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -3,94 +3,49 @@ """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default -https://github.com/pytorch/test-infra/issues/5132) to define the configuration -of which runners should be used to run which job. - -The configuration has two parts, the settings and a list of opted-in users, -separated by a line containing "---". If the line is not present, the -settings are considered to be empty with only the second part, the user -list, defined. - -The first part is a YAML block that defines the rollout settings. This can be -used to define any settings that are needed to determine which runners to use. -It's fields are defined by the RolloutSettings class below. - -The second part is a list of users who are explicitly opted in to the LF fleet. -The user list is also a comma separated list of additional features or -experiments which the user could be opted in to. +https://github.com/pytorch/test-infra/issues/5132) as a user list to determine +which users will get their jobs to run on experimental runners. This user list +is also a comma separated list of additional features or experiments which the +user could be opted in to. The user list has the following rules: -- Users are GitHub usernames, which must start with the @ prefix +- Users are GitHub usernames with the @ prefix +- If the first line is a "*" then all users will use the new runners +- If the first line is a "!" then all users will use the old runners - Each user is also a comma-separated list of features/experiments to enable -- A "#" prefix opts the user out of all experiments - -Example config: - # A list of experiments that can be opted into. - # This defines the behavior they'll induce when opted into. - # Expected syntax is: - # [experiment_name]: # Name of the experiment. Also used for the label prefix. - # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. - - experiments: - lf: - rollout_percent: 25 +- A "#" prefix indicates the user is opted out of the new runners but is opting + into features/experiments. - --- +Example user list: - # Opt-ins: - # Users can opt into the LF fleet by adding their GitHub username to this list - # and specifying experiments to enable in a comma-separated list. - # Experiments should be from the above list. - - @User1,lf,split_build - @User2,lf - @User3,split_build + @User1 + @User2,amz2023 + #@UserOptOutOfNewRunner,amz2023 """ import logging import os -import random from argparse import ArgumentParser from logging import LogRecord -from typing import Any, Dict, Iterable, List, NamedTuple, Tuple +from typing import Any, Iterable -import yaml from github import Auth, Github from github.Issue import Issue -DEFAULT_LABEL_PREFIX = "" # use meta runners +WORKFLOW_LABEL_META = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation +RUNNER_AMI_LEGACY = "" +RUNNER_AMI_AMZ2023 = "amz2023" + GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" -SETTING_EXPERIMENTS = "experiments" - -LF_FLEET_EXPERIMENT = "lf" -CANARY_FLEET_SUFFIX = ".c" - - -class Experiment(NamedTuple): - rollout_perc: int = ( - 0 # Percentage of workflows to experiment on when user is not opted-in. - ) - - # Add more fields as needed - - -class Settings(NamedTuple): - """ - Settings for the experiments that can be opted into. - """ - - experiments: Dict[str, Experiment] = {} - - class ColorFormatter(logging.Formatter): """Color codes the log messages based on the log level""" @@ -217,181 +172,85 @@ def is_exception_branch(branch: str) -> bool: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} -def load_yaml(yaml_text: str) -> Any: - try: - data = yaml.safe_load(yaml_text) - return data - except yaml.YAMLError as exc: - log.exception("Error loading YAML") - raise - - -def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: - """ - Extracts the text with settings, if any, and the opted in users from the rollout state. - - If the issue body contains "---" then the text above that is the settings - and the text below is the list of opted in users. - - If it doesn't contain "---" then the settings are empty and the rest is the users. +def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str: """ - rollout_state_parts = rollout_state.split("---") - if len(rollout_state_parts) >= 2: - return rollout_state_parts[0], rollout_state_parts[1] - else: - return "", rollout_state - + Determines if the job should run on the LF fleet or the Meta fleet -class UserOptins(Dict[str, List[str]]): - """ - Dictionary of users with a list of features they have opted into + Returns: + The appropriate label prefix for the runner, corresponding to the fleet to use. + This gets prefixed to the very start of the runner label. """ - -def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: - """ - Parse the user opt-in text into a key value pair of username and the list of features they have opted into - - Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. - - Example line: "@User1,lf,split_build" - - A "#" prefix indicates the user is opted out of all experiments - - - """ - optins = UserOptins() - for user in user_optin_text.split("\n"): - user = user.strip("\r\n\t -") - if not user or not user.startswith("@"): - # Not a valid user. Skip - continue - - if user: - usr_name = user.split(",")[0].strip("@") - optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] - - return optins - - -def parse_settings_from_text(settings_text: str) -> Settings: - """ - Parse the experiments from the issue body into a list of ExperimentSettings - """ try: - if settings_text: - # Escape the backtick as well so that we can have the settings in a code block on the GH issue - # for easy reading - # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on - # the backtick character in shell commands. - backtick = chr(96) # backtick character - settings_text = settings_text.strip(f"\r\n\t{backtick} ") - settings = load_yaml(settings_text) - - # For now we just load experiments. We can expand this if/when we add more settings - experiments = {} - - for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): - valid_settings = {} - for setting in exp_settings: - if setting not in Experiment._fields: - log.warning( - f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" - ) - else: - valid_settings[setting] = exp_settings[setting] - - experiments[exp_name] = Experiment(**valid_settings) - return Settings(experiments) + if rollout_state[0] == "!": + log.info("LF Workflows are disabled for everyone. Using meta runners.") + return WORKFLOW_LABEL_META + elif rollout_state[0] == "*": + log.info("LF Workflows are enabled for everyone. Using LF runners.") + return WORKFLOW_LABEL_LF + else: + all_opted_in_users = { + usr_raw.strip("\n\t@ ").split(",")[0] + for usr_raw in rollout_state.split() + } + opted_in_requestors = { + usr for usr in workflow_requestors if usr in all_opted_in_users + } + if opted_in_requestors: + log.info( + f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners." + ) + return WORKFLOW_LABEL_LF + else: + log.info( + f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners." + ) + return WORKFLOW_LABEL_META except Exception as e: - log.exception("Failed to parse settings") - - return Settings() - - -def parse_settings(rollout_state: str) -> Settings: - """ - Parse settings, if any, from the rollout state. - - If the issue body contains "---" then the text above that is the settings - and the text below is the list of opted in users. - - If it doesn't contain "---" then the settings are empty and the default values are used. - """ - settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) - return parse_settings_from_text(settings_text) - + log.error( + f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" + ) + return WORKFLOW_LABEL_META -def parse_users(rollout_state: str) -> UserOptins: - """ - Parse users from the rollout state. +def get_optin_feature( + rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str +) -> str: """ - _, users_text = extract_settings_user_opt_in_from_text(rollout_state) - return parse_user_opt_in_from_text(users_text) + Used to dynamically opt in jobs to specific runner-type variants. - -def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: - """ - Check if a user is opted into an experiment + Returns: + The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string. + This variant name is prefixed to the runner-type in the label. """ - return experiment_name in user_optins.get(user, []) - - -def get_runner_prefix( - rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False -) -> str: - settings = parse_settings(rollout_state) - user_optins = parse_users(rollout_state) - - fleet_prefix = "" - prefixes = [] - for experiment_name, experiment_settings in settings.experiments.items(): - enabled = False - - # Is any workflow_requestor opted in to this experiment? - opted_in_users = [ - requestor - for requestor in workflow_requestors - if is_user_opted_in(requestor, user_optins, experiment_name) - ] - - if opted_in_users: + try: + userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()} + all_opted_in_users = set() + for user in userlist: + for i in user.split(","): + if i == feature: + all_opted_in_users.add(user.split(",")[0]) + opted_in_requestors = { + usr for usr in workflow_requestors if usr in all_opted_in_users + } + + if opted_in_requestors: log.info( - f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." + f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}." ) - enabled = True + return feature else: - # If no user is opted in, then we randomly enable the experiment based on the rollout percentage - r = random.randint(1, 100) - if r <= experiment_settings.rollout_perc: - log.info( - f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." - ) - enabled = True - - if enabled: - label = experiment_name - if experiment_name == LF_FLEET_EXPERIMENT: - # We give some special treatment to the "lf" experiment since determines the fleet we use - # - If it's enabled, then we always list it's prefix first - # - If we're in the canary branch, then we append ".c" to the lf prefix - if is_canary: - label += CANARY_FLEET_SUFFIX - fleet_prefix = label - else: - prefixes.append(label) + log.info( + f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"." + ) + return fallback - if len(prefixes) > 1: + except Exception as e: log.error( - f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" + f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}' ) - prefixes = prefixes[:1] - - # Fleet always comes first - if fleet_prefix: - prefixes.insert(0, fleet_prefix) - - return ".".join(prefixes) + "." if prefixes else "" + return fallback def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -409,10 +268,9 @@ def main() -> None: args = parse_args() if args.github_ref_type == "branch" and is_exception_branch(args.github_branch): - log.info( - f"Exception branch: '{args.github_branch}', using Meta runners and no experiments." - ) - runner_label_prefix = DEFAULT_LABEL_PREFIX + log.info(f"Exception branch: '{args.github_branch}', using meta runners") + label_type = WORKFLOW_LABEL_META + runner_ami = RUNNER_AMI_LEGACY else: try: rollout_state = get_rollout_state_from_issue( @@ -427,18 +285,35 @@ def main() -> None: args.github_branch, ) - is_canary = args.github_repo == "pytorch/pytorch-canary" - - runner_label_prefix = get_runner_prefix( - rollout_state, (args.github_issue_owner, username), is_canary + label_type = get_fleet( + rollout_state, + ( + args.github_issue_owner, + username, + ), + ) + runner_ami = get_optin_feature( + rollout_state=rollout_state, + workflow_requestors=( + args.github_issue_owner, + username, + ), + feature=RUNNER_AMI_AMZ2023, + fallback=RUNNER_AMI_LEGACY, ) - except Exception as e: log.error( - f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" + f"Failed to get issue. Falling back to meta runners. Exception: {e}" ) + label_type = WORKFLOW_LABEL_META + runner_ami = RUNNER_AMI_LEGACY + + # For Canary builds use canary runners + if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF: + label_type = WORKFLOW_LABEL_LF_CANARY - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) + set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type) + set_github_output(GH_OUTPUT_KEY_AMI, runner_ami) if __name__ == "__main__": diff --git a/.github/scripts/test_runner_determinator.py b/.github/scripts/test_runner_determinator.py deleted file mode 100644 index 98063b190b0547..00000000000000 --- a/.github/scripts/test_runner_determinator.py +++ /dev/null @@ -1,237 +0,0 @@ -from unittest import main, TestCase -from unittest.mock import Mock, patch - -import runner_determinator as rd - - -class TestRunnerDeterminatorIssueParser(TestCase): - def test_parse_settings(self) -> None: - settings_text = """ - experiments: - lf: - rollout_perc: 25 - otherExp: - rollout_perc: 0 - --- - - Users: - @User1,lf - @User2,lf,otherExp - - """ - - settings = rd.parse_settings(settings_text) - - self.assertTupleEqual( - rd.Experiment(rollout_perc=25), - settings.experiments["lf"], - "lf settings not parsed correctly", - ) - self.assertTupleEqual( - rd.Experiment(rollout_perc=0), - settings.experiments["otherExp"], - "otherExp settings not parsed correctly", - ) - - def test_parse_settings_in_code_block(self) -> None: - settings_text = """ - - ``` - experiments: - lf: - rollout_perc: 25 - otherExp: - rollout_perc: 0 - - ``` - - --- - - Users: - @User1,lf - @User2,lf,otherExp - - """ - - settings = rd.parse_settings(settings_text) - - self.assertTupleEqual( - rd.Experiment(rollout_perc=25), - settings.experiments["lf"], - "lf settings not parsed correctly", - ) - self.assertTupleEqual( - rd.Experiment(rollout_perc=0), - settings.experiments["otherExp"], - "otherExp settings not parsed correctly", - ) - - def test_parse_users(self) -> None: - settings_text = """ - experiments: - lf: - rollout_perc: 25 - otherExp: - rollout_perc: 0 - --- - - Users: - @User1,lf - @User2,lf,otherExp - - """ - - users = rd.parse_users(settings_text) - self.assertDictEqual( - {"User1": ["lf"], "User2": ["lf", "otherExp"]}, - users, - "Users not parsed correctly", - ) - - def test_parse_users_without_settings(self) -> None: - settings_text = """ - - @User1,lf - @User2,lf,otherExp - - """ - - users = rd.parse_users(settings_text) - self.assertDictEqual( - {"User1": ["lf"], "User2": ["lf", "otherExp"]}, - users, - "Users not parsed correctly", - ) - - -class TestRunnerDeterminatorGetRunnerPrefix(TestCase): - def test_opted_in_user(self) -> None: - settings_text = """ - experiments: - lf: - rollout_perc: 25 - otherExp: - rollout_perc: 25 - --- - - Users: - @User1,lf - @User2,lf,otherExp - - """ - prefix = rd.get_runner_prefix(settings_text, ["User1"]) - self.assertEqual("lf.", prefix, "Runner prefix not correct for User1") - - def test_opted_in_user_two_experiments(self) -> None: - settings_text = """ - experiments: - lf: - rollout_perc: 25 - otherExp: - rollout_perc: 25 - --- - - Users: - @User1,lf - @User2,lf,otherExp - - """ - prefix = rd.get_runner_prefix(settings_text, ["User2"]) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User1") - - @patch("random.randint", return_value=50) - def test_opted_out_user(self, mock_randint: Mock) -> None: - settings_text = """ - experiments: - lf: - rollout_perc: 25 - otherExp: - rollout_perc: 25 - --- - - Users: - @User1,lf - @User2,lf,otherExp - - """ - prefix = rd.get_runner_prefix(settings_text, ["User3"]) - self.assertEqual("", prefix, "Runner prefix not correct for user") - - @patch("random.randint", return_value=10) - def test_opted_out_user_was_pulled_in_by_rollout(self, mock_randint: Mock) -> None: - settings_text = """ - experiments: - lf: - rollout_perc: 25 - otherExp: - rollout_perc: 25 - --- - - Users: - @User1,lf - @User2,lf,otherExp - - """ - - # User3 is opted out, but is pulled into both experiments by the 10% rollout - prefix = rd.get_runner_prefix(settings_text, ["User3"]) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") - - def test_lf_prefix_always_comes_first(self) -> None: - settings_text = """ - experiments: - otherExp: - rollout_perc: 0 - lf: - rollout_perc: 0 - --- - - Users: - @User1,lf - @User2,otherExp,lf - - """ - - prefix = rd.get_runner_prefix(settings_text, ["User2"]) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") - - def test_ignores_commented_users(self) -> None: - settings_text = """ - experiments: - lf: - rollout_perc: 0 - otherExp: - rollout_perc: 0 - --- - - Users: - #@User1,lf - @User2,lf,otherExp - - """ - - prefix = rd.get_runner_prefix(settings_text, ["User1"]) - self.assertEqual("", prefix, "Runner prefix not correct for user") - - def test_ignores_extra_experiments(self) -> None: - settings_text = """ - experiments: - lf: - rollout_perc: 0 - otherExp: - rollout_perc: 0 - foo: - rollout_perc: 0 - --- - - Users: - @User1,lf,otherExp,foo - - """ - - prefix = rd.get_runner_prefix(settings_text, ["User1"]) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") - - -if __name__ == "__main__": - main() diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 5edb79da64463d..8ba709358c5ea1 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -62,94 +62,49 @@ jobs: """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default - https://github.com/pytorch/test-infra/issues/5132) to define the configuration - of which runners should be used to run which job. - - The configuration has two parts, the settings and a list of opted-in users, - separated by a line containing "---". If the line is not present, the - settings are considered to be empty with only the second part, the user - list, defined. - - The first part is a YAML block that defines the rollout settings. This can be - used to define any settings that are needed to determine which runners to use. - It's fields are defined by the RolloutSettings class below. - - The second part is a list of users who are explicitly opted in to the LF fleet. - The user list is also a comma separated list of additional features or - experiments which the user could be opted in to. + https://github.com/pytorch/test-infra/issues/5132) as a user list to determine + which users will get their jobs to run on experimental runners. This user list + is also a comma separated list of additional features or experiments which the + user could be opted in to. The user list has the following rules: - - Users are GitHub usernames, which must start with the @ prefix + - Users are GitHub usernames with the @ prefix + - If the first line is a "*" then all users will use the new runners + - If the first line is a "!" then all users will use the old runners - Each user is also a comma-separated list of features/experiments to enable - - A "#" prefix opts the user out of all experiments - - Example config: - # A list of experiments that can be opted into. - # This defines the behavior they'll induce when opted into. - # Expected syntax is: - # [experiment_name]: # Name of the experiment. Also used for the label prefix. - # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. - - experiments: - lf: - rollout_percent: 25 + - A "#" prefix indicates the user is opted out of the new runners but is opting + into features/experiments. - --- + Example user list: - # Opt-ins: - # Users can opt into the LF fleet by adding their GitHub username to this list - # and specifying experiments to enable in a comma-separated list. - # Experiments should be from the above list. - - @User1,lf,split_build - @User2,lf - @User3,split_build + @User1 + @User2,amz2023 + #@UserOptOutOfNewRunner,amz2023 """ import logging import os - import random from argparse import ArgumentParser from logging import LogRecord - from typing import Any, Dict, Iterable, List, NamedTuple, Tuple + from typing import Any, Iterable - import yaml from github import Auth, Github from github.Issue import Issue - DEFAULT_LABEL_PREFIX = "" # use meta runners + WORKFLOW_LABEL_META = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation + RUNNER_AMI_LEGACY = "" + RUNNER_AMI_AMZ2023 = "amz2023" + GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" - SETTING_EXPERIMENTS = "experiments" - - LF_FLEET_EXPERIMENT = "lf" - CANARY_FLEET_SUFFIX = ".c" - - - class Experiment(NamedTuple): - rollout_perc: int = ( - 0 # Percentage of workflows to experiment on when user is not opted-in. - ) - - # Add more fields as needed - - - class Settings(NamedTuple): - """ - Settings for the experiments that can be opted into. - """ - - experiments: Dict[str, Experiment] = {} - - class ColorFormatter(logging.Formatter): """Color codes the log messages based on the log level""" @@ -276,181 +231,85 @@ jobs: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} - def load_yaml(yaml_text: str) -> Any: - try: - data = yaml.safe_load(yaml_text) - return data - except yaml.YAMLError as exc: - log.exception("Error loading YAML") - raise - - - def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: - """ - Extracts the text with settings, if any, and the opted in users from the rollout state. - - If the issue body contains "---" then the text above that is the settings - and the text below is the list of opted in users. - - If it doesn't contain "---" then the settings are empty and the rest is the users. + def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str: """ - rollout_state_parts = rollout_state.split("---") - if len(rollout_state_parts) >= 2: - return rollout_state_parts[0], rollout_state_parts[1] - else: - return "", rollout_state - + Determines if the job should run on the LF fleet or the Meta fleet - class UserOptins(Dict[str, List[str]]): - """ - Dictionary of users with a list of features they have opted into + Returns: + The appropriate label prefix for the runner, corresponding to the fleet to use. + This gets prefixed to the very start of the runner label. """ - - def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: - """ - Parse the user opt-in text into a key value pair of username and the list of features they have opted into - - Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. - - Example line: "@User1,lf,split_build" - - A "#" prefix indicates the user is opted out of all experiments - - - """ - optins = UserOptins() - for user in user_optin_text.split("\n"): - user = user.strip("\r\n\t -") - if not user or not user.startswith("@"): - # Not a valid user. Skip - continue - - if user: - usr_name = user.split(",")[0].strip("@") - optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] - - return optins - - - def parse_settings_from_text(settings_text: str) -> Settings: - """ - Parse the experiments from the issue body into a list of ExperimentSettings - """ try: - if settings_text: - # Escape the backtick as well so that we can have the settings in a code block on the GH issue - # for easy reading - # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on - # the backtick character in shell commands. - backtick = chr(96) # backtick character - settings_text = settings_text.strip(f"\r\n\t{backtick} ") - settings = load_yaml(settings_text) - - # For now we just load experiments. We can expand this if/when we add more settings - experiments = {} - - for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): - valid_settings = {} - for setting in exp_settings: - if setting not in Experiment._fields: - log.warning( - f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" - ) - else: - valid_settings[setting] = exp_settings[setting] - - experiments[exp_name] = Experiment(**valid_settings) - return Settings(experiments) + if rollout_state[0] == "!": + log.info("LF Workflows are disabled for everyone. Using meta runners.") + return WORKFLOW_LABEL_META + elif rollout_state[0] == "*": + log.info("LF Workflows are enabled for everyone. Using LF runners.") + return WORKFLOW_LABEL_LF + else: + all_opted_in_users = { + usr_raw.strip("\n\t@ ").split(",")[0] + for usr_raw in rollout_state.split() + } + opted_in_requestors = { + usr for usr in workflow_requestors if usr in all_opted_in_users + } + if opted_in_requestors: + log.info( + f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners." + ) + return WORKFLOW_LABEL_LF + else: + log.info( + f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners." + ) + return WORKFLOW_LABEL_META except Exception as e: - log.exception("Failed to parse settings") - - return Settings() - - - def parse_settings(rollout_state: str) -> Settings: - """ - Parse settings, if any, from the rollout state. - - If the issue body contains "---" then the text above that is the settings - and the text below is the list of opted in users. - - If it doesn't contain "---" then the settings are empty and the default values are used. - """ - settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) - return parse_settings_from_text(settings_text) - + log.error( + f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" + ) + return WORKFLOW_LABEL_META - def parse_users(rollout_state: str) -> UserOptins: - """ - Parse users from the rollout state. + def get_optin_feature( + rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str + ) -> str: """ - _, users_text = extract_settings_user_opt_in_from_text(rollout_state) - return parse_user_opt_in_from_text(users_text) + Used to dynamically opt in jobs to specific runner-type variants. - - def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: - """ - Check if a user is opted into an experiment + Returns: + The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string. + This variant name is prefixed to the runner-type in the label. """ - return experiment_name in user_optins.get(user, []) - - - def get_runner_prefix( - rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False - ) -> str: - settings = parse_settings(rollout_state) - user_optins = parse_users(rollout_state) - - fleet_prefix = "" - prefixes = [] - for experiment_name, experiment_settings in settings.experiments.items(): - enabled = False - - # Is any workflow_requestor opted in to this experiment? - opted_in_users = [ - requestor - for requestor in workflow_requestors - if is_user_opted_in(requestor, user_optins, experiment_name) - ] - - if opted_in_users: + try: + userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()} + all_opted_in_users = set() + for user in userlist: + for i in user.split(","): + if i == feature: + all_opted_in_users.add(user.split(",")[0]) + opted_in_requestors = { + usr for usr in workflow_requestors if usr in all_opted_in_users + } + + if opted_in_requestors: log.info( - f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." + f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}." ) - enabled = True + return feature else: - # If no user is opted in, then we randomly enable the experiment based on the rollout percentage - r = random.randint(1, 100) - if r <= experiment_settings.rollout_perc: - log.info( - f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." - ) - enabled = True - - if enabled: - label = experiment_name - if experiment_name == LF_FLEET_EXPERIMENT: - # We give some special treatment to the "lf" experiment since determines the fleet we use - # - If it's enabled, then we always list it's prefix first - # - If we're in the canary branch, then we append ".c" to the lf prefix - if is_canary: - label += CANARY_FLEET_SUFFIX - fleet_prefix = label - else: - prefixes.append(label) + log.info( + f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"." + ) + return fallback - if len(prefixes) > 1: + except Exception as e: log.error( - f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" + f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}' ) - prefixes = prefixes[:1] - - # Fleet always comes first - if fleet_prefix: - prefixes.insert(0, fleet_prefix) - - return ".".join(prefixes) + "." if prefixes else "" + return fallback def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -468,10 +327,9 @@ jobs: args = parse_args() if args.github_ref_type == "branch" and is_exception_branch(args.github_branch): - log.info( - f"Exception branch: '{args.github_branch}', using Meta runners and no experiments." - ) - runner_label_prefix = DEFAULT_LABEL_PREFIX + log.info(f"Exception branch: '{args.github_branch}', using meta runners") + label_type = WORKFLOW_LABEL_META + runner_ami = RUNNER_AMI_LEGACY else: try: rollout_state = get_rollout_state_from_issue( @@ -486,18 +344,35 @@ jobs: args.github_branch, ) - is_canary = args.github_repo == "pytorch/pytorch-canary" - - runner_label_prefix = get_runner_prefix( - rollout_state, (args.github_issue_owner, username), is_canary + label_type = get_fleet( + rollout_state, + ( + args.github_issue_owner, + username, + ), + ) + runner_ami = get_optin_feature( + rollout_state=rollout_state, + workflow_requestors=( + args.github_issue_owner, + username, + ), + feature=RUNNER_AMI_AMZ2023, + fallback=RUNNER_AMI_LEGACY, ) - except Exception as e: log.error( - f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" + f"Failed to get issue. Falling back to meta runners. Exception: {e}" ) + label_type = WORKFLOW_LABEL_META + runner_ami = RUNNER_AMI_LEGACY + + # For Canary builds use canary runners + if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF: + label_type = WORKFLOW_LABEL_LF_CANARY - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) + set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type) + set_github_output(GH_OUTPUT_KEY_AMI, runner_ami) if __name__ == "__main__": From bf378284a494ababb6408417c57e0af39575a9ac Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 5 Sep 2024 12:04:46 -0700 Subject: [PATCH 0502/1018] [inductor] Improve compile time regression from MemoryDep.normalize (#135070) Possible fix for #135056 Before ![image](https://github.com/user-attachments/assets/3962cb85-e808-4fd4-991f-471ff5ef7eae) After ![image](https://github.com/user-attachments/assets/2322d48d-6518-4518-baca-336027b5cda8) Measured based on: ``` python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --training --only hf_Bert_large --stats -n1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135070 Approved by: https://github.com/Chillee --- torch/_inductor/scheduler.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 179dde1386ebba..5af3978c41b934 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3079,9 +3079,14 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: if isinstance(read, MemoryDep): if read.mode == write.mode and write.mode is not None: return True - read_name = read.name - if read_name in self.mutation_renames: - read_name = self.mutation_renames[read_name] + read_name = self.mutation_renames.get(read.name, read.name) + + if ( + read_name != write.name + or free_symbol_is_type(read.index, SymT.TMP) + or free_symbol_is_type(write.index, SymT.TMP) + ): + return False if config.loop_ordering_after_fusion and read.num_vars != write.num_vars: # Need merge loops if we do loop ordering after fusion since @@ -3091,10 +3096,7 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: write = write.normalize() return ( - read_name == write.name - and not free_symbol_is_type(read.index, SymT.TMP) - and not free_symbol_is_type(write.index, SymT.TMP) - and read.index == write.index + read.index == write.index and len(read.size) >= len(write.size) and read.size[: len(write.size)] == write.size ) From 3bd0b105f814fe4ac9d0deaf6522a61449af7657 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 5 Sep 2024 12:04:46 -0700 Subject: [PATCH 0503/1018] [fx] Compile time optimization in Node.__update_args_kwargs (#135076) Before this we took two passes over all of the args. Before: ![image](https://github.com/user-attachments/assets/24ce5628-03f4-4983-9f2d-5ddf0ca5816e) After: ![image](https://github.com/user-attachments/assets/c9681aa2-32f0-4f6b-a598-fc6f90ffafb5) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135076 Approved by: https://github.com/Chillee ghstack dependencies: #135070 --- torch/fx/node.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/torch/fx/node.py b/torch/fx/node.py index 218fcebe332107..456ab371c6bad0 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -33,6 +33,8 @@ BaseArgumentTypes ]] +_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']) + _side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { torch._C._set_grad_enabled, torch.amp._enter_autocast, @@ -164,6 +166,8 @@ class Node(_NodeBase): - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement in the Graph printout. """ + _args: Tuple['Argument', ...] + _kwargs: Dict[str, 'Argument'] @compatibility(is_backward_compatible=True) def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', @@ -198,7 +202,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', super().__init__() self.graph = graph self.name = name # unique name of value being created - assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] + assert op in _legal_ops self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr if op == 'call_function': if not callable(target): @@ -215,7 +219,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # The public API for this is `all_input_nodes`, this private attribute # should not be accessed directly. self._input_nodes : Dict[Node, None] = {} - self.__update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore[arg-type] + self.__update_args_kwargs(args, kwargs) # All of the nodes that use the value produced by this Node # Note one user may correspond to several uses, e.g. the node fo ``x + x`` @@ -364,7 +368,7 @@ def args(self, a : Tuple[Argument, ...]) -> None: """ # DO NOT CALL `__update_args_kwargs` directly. The correct way to # set `args` is via direct assignment, i.e. `node.args = new_args` - self.__update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore[arg-type] + self.__update_args_kwargs(a, self._kwargs) @property def kwargs(self) -> Dict[str, Argument]: @@ -387,7 +391,7 @@ def kwargs(self, k : Dict[str, Argument]) -> None: """ # DO NOT CALL `__update_args_kwargs` directly. The correct way to # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` - self.__update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore[arg-type] + self.__update_args_kwargs(self._args, k) @property def all_input_nodes(self) -> List['Node']: @@ -453,9 +457,7 @@ def update_kwarg(self, key : str, arg : Argument) -> None: key (str): The key in ``self.kwargs`` of the element to update arg (Argument): The new argument value to write into ``kwargs`` """ - kwargs = dict(self.kwargs) - kwargs[key] = arg - self.kwargs = kwargs + self.kwargs = {**self.kwargs, key: arg} @property def stack_trace(self) -> Optional[str]: @@ -479,18 +481,23 @@ def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : D """ This API is internal. Do *not* call it directly. """ - self._args = new_args - self._kwargs = new_kwargs + def update_users_and_input_nodes(n: Any) -> Any: + if isinstance(n, Node): + self._input_nodes.setdefault(n) + n.users.setdefault(self) + return n + # Clear prior users and input_nodes for old_use in self._input_nodes.keys(): old_use.users.pop(self) - self._input_nodes = {} - map_arg(self._args, self._input_nodes.setdefault) - map_arg(self._kwargs, self._input_nodes.setdefault) - for new_use in self._input_nodes.keys(): - new_use.users.setdefault(self) + # We do three things in a single pass of the args + # - Normalize list->immutable_list, dict->immutable_dict, etc + # - Populate self._input_nodes + # - Populate arg.users[self] for each arg + self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment] + self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment] def __repr__(self) -> str: if self._repr_fn: From c1054a5d098398680226a9a653bae2dc95a2fb66 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 5 Sep 2024 12:04:47 -0700 Subject: [PATCH 0504/1018] [fx] Don't use generators in map_aggregate (#135082) While the generators avoid a copy, they are slow. Before: ![image](https://github.com/user-attachments/assets/70a55a9a-0595-4105-b0ab-22cf77c7409c) After: ![image](https://github.com/user-attachments/assets/cecb9c59-ae36-47de-8b08-cab2c7cb3d57) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135082 Approved by: https://github.com/oulgen ghstack dependencies: #135070, #135076 --- torch/fx/node.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch/fx/node.py b/torch/fx/node.py index 456ab371c6bad0..f84b23e29ddf85 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -772,13 +772,16 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ if isinstance(a, tuple): - t = tuple(map_aggregate(elem, fn) for elem in a) + t = tuple([map_aggregate(elem, fn) for elem in a]) # Support NamedTuple (if it has `_fields`) by repacking into original type. return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type] elif isinstance(a, list): - return immutable_list(map_aggregate(elem, fn) for elem in a) + return immutable_list([map_aggregate(elem, fn) for elem in a]) elif isinstance(a, dict): - return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) + rv = immutable_dict() + for k, v in a.items(): + dict.__setitem__(rv, k, map_aggregate(v, fn)) + return rv elif isinstance(a, slice): return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) else: From 9110e310e99b0c0f9b4d62dd0d8dbc4df35217e0 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 5 Sep 2024 12:04:47 -0700 Subject: [PATCH 0505/1018] [debug] Add helper to run cProfile on a function (#135084) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135084 Approved by: https://github.com/oulgen ghstack dependencies: #135070, #135076, #135082 --- torch/_dynamo/debug_utils.py | 42 +++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 30fe8346079d47..94687ff2747bf3 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" - +import atexit import copy +import cProfile import functools import getpass import inspect @@ -10,6 +11,7 @@ import os import re import subprocess +import sys import tempfile import textwrap from collections import Counter @@ -782,3 +784,41 @@ def gen_tensor(shape, dtype) -> Tensor: setattr(container, attr_name, gen_tensor(shape, dtype)) return kwargs + + +def profile_to_file(filename: str) -> Callable[[T], T]: + """ + Decorator to cProfile a given function and save the result to disk on process exit. + + Args: + filename: filename to save profile to + """ + prof = cProfile.Profile() + filename = os.path.abspath(os.path.expanduser(filename)) + + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + prof.enable() + try: + return fn(*args, **kwargs) + finally: + prof.disable() + + return wrapper + + def save_it(): + prof.dump_stats(filename) + sys.stderr.write( + textwrap.dedent( + f"""\ + Wrote profile to {filename}, view with: + + snakeviz {filename} + + """ + ) + ) + + atexit.register(save_it) + return decorator From 31bff703b71a99fda83d5021f373ffa81547908e Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 5 Sep 2024 23:52:49 +0000 Subject: [PATCH 0506/1018] [ONNX] Delete ONNXProgramSerializer (#135261) Fixes #135182 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135261 Approved by: https://github.com/justinchuby --- docs/source/onnx_dynamo.rst | 4 - test/onnx/dynamo/test_exporter_api.py | 39 +------ torch/onnx/__init__.py | 3 - torch/onnx/_internal/_exporter_legacy.py | 130 ++++------------------- 4 files changed, 19 insertions(+), 157 deletions(-) diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst index be6b9d48a8d399..5694c40d83c4f5 100644 --- a/docs/source/onnx_dynamo.rst +++ b/docs/source/onnx_dynamo.rst @@ -28,7 +28,6 @@ The exporter is designed to be modular and extensible. It is composed of the fol - **FX Graph Extractor**: :class:`FXGraphExtractor` extracts the FX graph from the PyTorch model. - **Fake Mode**: :class:`ONNXFakeContext` is a context manager that enables fake mode for large scale models. - **ONNX Program**: :class:`ONNXProgram` is the output of the exporter that contains the exported ONNX graph and diagnostics. - - **ONNX Program Serializer**: :class:`ONNXProgramSerializer` serializes the exported model to a file. - **ONNX Diagnostic Options**: :class:`DiagnosticOptions` has a set of options that control the diagnostics emitted by the exporter. Dependencies @@ -144,9 +143,6 @@ API Reference .. autoclass:: torch.onnx.ONNXProgram :members: -.. autoclass:: torch.onnx.ONNXProgramSerializer - :members: - .. autoclass:: torch.onnx.ONNXRuntimeOptions :members: diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index c6e8cd47cfe24d..17d55f0374c59a 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -7,10 +7,7 @@ import torch from torch.onnx import dynamo_export, ExportOptions, ONNXProgram from torch.onnx._internal import _exporter_legacy -from torch.onnx._internal._exporter_legacy import ( - ONNXProgramSerializer, - ResolvedExportOptions, -) +from torch.onnx._internal._exporter_legacy import ResolvedExportOptions from torch.testing._internal import common_utils @@ -75,40 +72,6 @@ def test_save_to_existing_buffer_default_serializer(self): dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer) onnx.load(buffer) - def test_save_to_file_using_specified_serializer(self): - expected_buffer = "I am not actually ONNX" - - class CustomSerializer(ONNXProgramSerializer): - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - destination.write(expected_buffer.encode()) - - with common_utils.TemporaryFileName() as path: - dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save( - path, serializer=CustomSerializer() - ) - with open(path) as fp: - self.assertEqual(fp.read(), expected_buffer) - - def test_save_to_file_using_specified_serializer_without_inheritance(self): - expected_buffer = "I am not actually ONNX" - - # NOTE: Inheritance from `ONNXProgramSerializer` is not required. - # Because `ONNXProgramSerializer` is a Protocol class. - class CustomSerializer: - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - destination.write(expected_buffer.encode()) - - with common_utils.TemporaryFileName() as path: - dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save( - path, serializer=CustomSerializer() - ) - with open(path) as fp: - self.assertEqual(fp.read(), expected_buffer) - def test_save_sarif_log_to_file_with_successful_export(self): with common_utils.TemporaryFileName(suffix=".sarif") as path: dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 0dbc42aea33395..9d4084da1ac4b6 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -44,7 +44,6 @@ "DiagnosticOptions", "ExportOptions", "ONNXProgram", - "ONNXProgramSerializer", "ONNXRuntimeOptions", "InvalidExportOptionsError", "OnnxExporterError", @@ -109,7 +108,6 @@ DiagnosticOptions, ExportOptions, ONNXProgram, - ONNXProgramSerializer, ONNXRuntimeOptions, InvalidExportOptionsError, OnnxExporterError, @@ -127,7 +125,6 @@ JitScalarType.__module__ = "torch.onnx" ExportOptions.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" -ONNXProgramSerializer.__module__ = "torch.onnx" ONNXRuntimeOptions.__module__ = "torch.onnx" dynamo_export.__module__ = "torch.onnx" InvalidExportOptionsError.__module__ = "torch.onnx" diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index c66c3492e98ba4..4d1830b3e9ce12 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -11,17 +11,7 @@ import tempfile import warnings from collections import defaultdict -from typing import ( - Any, - Callable, - Final, - Mapping, - Protocol, - runtime_checkable, - Sequence, - TYPE_CHECKING, - TypeVar, -) +from typing import Any, Callable, Final, Mapping, Sequence, TYPE_CHECKING, TypeVar from typing_extensions import Self import torch @@ -462,97 +452,6 @@ def enable_fake_mode(): ) # type: ignore[assignment] -@runtime_checkable -class ONNXProgramSerializer(Protocol): - """Protocol for serializing an ONNX graph into a specific format (e.g. Protobuf). - Note that this is an advanced usage scenario.""" - - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - """Protocol method that must be implemented for serialization. - - Args: - onnx_program: Represents the in-memory exported ONNX model - destination: A binary IO stream or pre-allocated buffer into which - the serialized model should be written. - - Example: - - A simple serializer that writes the exported :py:obj:`onnx.ModelProto` in Protobuf - format to ``destination``: - - :: - - # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) - >>> import io - >>> import torch - >>> import torch.onnx - >>> class MyModel(torch.nn.Module): # Dummy model - ... def __init__(self) -> None: - ... super().__init__() - ... self.linear = torch.nn.Linear(2, 2) - ... def forward(self, x): - ... out = self.linear(x) - ... return out - >>> class ProtobufONNXProgramSerializer: - ... def serialize( - ... self, onnx_program: torch.onnx.ONNXProgram, destination: io.BufferedIOBase - ... ) -> None: - ... destination.write(onnx_program.model_proto.SerializeToString()) - >>> model = MyModel() - >>> arg1 = torch.randn(2, 2, 2) # positional input 1 - >>> torch.onnx.dynamo_export(model, arg1).save( - ... destination="exported_model.onnx", - ... serializer=ProtobufONNXProgramSerializer(), - ... ) - """ - ... - - -class ProtobufONNXProgramSerializer: - """Serializes ONNX graph as Protobuf.""" - - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - import onnx - - if not isinstance(onnx_program.model_proto, onnx.ModelProto): # type: ignore[attr-defined] - raise ValueError("onnx_program.ModelProto is not an onnx.ModelProto") - destination.write(onnx_program.model_proto.SerializeToString()) - - -class LargeProtobufONNXProgramSerializer: - """Serializes ONNX graph as Protobuf. - - Fallback to serializing as Protobuf with external data for models larger than 2GB. - """ - - _destination_path: Final[str] # type: ignore[misc] - - def __init__(self, destination_path: str): - self._destination_path = destination_path - - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - """`destination` is ignored. The model is saved to `self._destination_path` instead.""" - import onnx - - if onnx_program.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT: - onnx.save_model(onnx_program.model_proto, self._destination_path) # type: ignore[attr-defined] - else: - # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB - # Fallback to serializing the model with external data. - onnx.save_model( # type: ignore[attr-defined] - onnx_program.model_proto, - self._destination_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - ) - - class ONNXRuntimeOptions: """Options to influence the execution of the ONNX model through ONNX Runtime. @@ -959,7 +858,6 @@ def save( *, include_initializers: bool = True, model_state: dict[str, Any] | str | None = None, - serializer: ONNXProgramSerializer | None = None, ) -> None: """Saves the in-memory ONNX model to ``destination`` using specified ``serializer``. @@ -975,17 +873,12 @@ def save( It can be either a string with the path to a checkpoint or a dictionary with the actual model state. The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`. Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph. - serializer: The serializer to use. If not specified, the model will be serialized as Protobuf. """ + import onnx assert ( include_initializers is True or model_state is None ), "Cannot specify both `include_initializers=False` and `model_state`." - if serializer is None: - if isinstance(destination, str): - serializer = LargeProtobufONNXProgramSerializer(destination) - else: - serializer = ProtobufONNXProgramSerializer() # Add initializers when symbolic tracing is enabled _model_state_files: list[str | io.BytesIO | dict[str, Any]] = [] @@ -1031,10 +924,24 @@ def save( else: if isinstance(destination, str): with open(destination, "wb") as f: - serializer.serialize(self, f) + if self.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT: + onnx.save_model(self.model_proto, destination) # type: ignore[attr-defined] + else: + # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB + # Fallback to serializing the model with external data. + onnx.save_model( # type: ignore[attr-defined] + self.model_proto, + destination, + save_as_external_data=True, + all_tensors_to_one_file=True, + ) else: try: - serializer.serialize(self, destination) + if not isinstance(self.model_proto, onnx.ModelProto): # type: ignore[attr-defined] + raise ValueError( + "onnx_program.ModelProto is not an onnx.ModelProto" + ) + destination.write(self.model_proto.SerializeToString()) except ValueError as exc: raise ValueError( "'destination' should be provided as a path-like string when saving a model larger than 2GB. " @@ -1544,7 +1451,6 @@ def common_pre_export_passes( "DiagnosticOptions", "ExportOptions", "ONNXProgram", - "ONNXProgramSerializer", "ONNXRuntimeOptions", "InvalidExportOptionsError", "OnnxExporterError", From b0c5bd61b007e27a85aebb52dff949a1e0046973 Mon Sep 17 00:00:00 2001 From: wz337 Date: Thu, 5 Sep 2024 12:42:13 -0700 Subject: [PATCH 0507/1018] [DTensor] Fix view op replicating on tensor dim when the size of the tensor dim = 1 (#135054) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We found a corner case that when a tensor dimension is 1, calling `view(1)` would result in an unexpected replication (see case 1 below). When the tensor dimension to shard is not 1, no matter whether the tensor dimension is evenly-shardable across the mesh dimension, it won't cause an implicit replication behind the scenes if view doesn't change the size of the given tensor dimension (see case 2 and 3). When the tensor dimension to shard is of size 1, it is not being added to shardable_dims here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/_view_ops.py#L518 ``` # uneven case where the size of the tensor dimension to shard is 1 p = torch.randn(1,2) mesh = init_device_mesh(“cuda”, (2,)) dtensor = distribute_tensor(p, mesh, [Shard(0)]) t = dtensor.view(1, 2) # this would result in replication, meaning t is now replicated across all ranks. # uneven case where the size of the tensor dimension to shard is not 1 p = torch.randn(3, 2) mesh = init_device_mesh(“cuda”, (2,)) dtensor = distribute_tensor(p, mesh, [Shard(0)]) t = dtensor.view(3, 2) # this would not result in replication. # this would not result in replication, meaning t stays as sharded. # even case p = torch.randn(2,2) dtensor = distribute_tensor(p, mesh, [Shard(0)]) t = dtensor.view(2, 2) # this would not result in replication, meaning t stays as sharded. ``` Differential Revision: [D62155606](https://our.internmc.facebook.com/intern/diff/D62155606) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135054 Approved by: https://github.com/tianyu-l, https://github.com/wanchaol --- test/distributed/_tensor/test_view_ops.py | 64 +++++++++++++++++++--- torch/distributed/_tensor/ops/_view_ops.py | 2 +- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index 8ace53d97131b2..3af39815997fc4 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -7,7 +7,13 @@ import torch import torch.distributed as dist from torch import rand, randn, Tensor -from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard +from torch.distributed._tensor import ( + DeviceMesh, + distribute_tensor, + init_device_mesh, + Replicate, + Shard, +) from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.ops._view_ops import ( Broadcast, @@ -29,6 +35,10 @@ class TestViewOps(DTensorTestBase): + @property + def world_size(self) -> int: + return 6 + def test_view_groups(self): self.assertEqual( view_groups([2, 3], [3, 2]), @@ -106,8 +116,8 @@ def test_view_groups(self): view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]), ( Flatten((InputDim(2), InputDim(3))), - Singleton(), - Singleton(), + InputDim(4), + InputDim(5), Singleton(), ), ) @@ -116,7 +126,7 @@ def test_view_groups(self): ( Split(InputDim(2), (3, 4), 0), Split(InputDim(2), (3, 4), 1), - Singleton(), + InputDim(3), Flatten((InputDim(6), InputDim(7))), ), ) @@ -125,10 +135,6 @@ def test_view_groups(self): (InputDim(0), InputDim(1), InputDim(2)), ) - @property - def world_size(self) -> int: - return 6 - def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): dim_map = dim_maps[op] rules = dim_map(*args, **kwargs) @@ -429,7 +435,7 @@ def test_view_ops(self): self.dimmap_test( Tensor.view, (randn(1, 1, 42, 1, 24, 1), -1), - (Flatten((InputDim(2), InputDim(4))),), + (Flatten((InputDim(2), InputDim(input_dim=3), InputDim(4))),), ) self.dimmap_test( @@ -525,6 +531,46 @@ def test_complex_view_ops(self): ) self.assertEqual(out, out_dt.full_tensor()) + @with_comms + def test_dtensor_view_op_uneven(self): + """ + Test two uneven cases for view op: + 1) the sharded tensor dim is 1 so that only the first rank has an non-empty shard. + 2) the sharded tensor dim is uneven such that some ranks have full shards, + smaller non-empty shards, and empty shards. + """ + dim0_sizes = [1, self.world_size + 1] + for dim0_size in dim0_sizes: + p = torch.randn(dim0_size, 2, 2, 2) + mesh = init_device_mesh(self.device_type, (self.world_size,)) + dtensor = distribute_tensor(p, mesh, [Shard(0)]) + + with CommDebugMode() as comm_mode: + view = dtensor.view(dim0_size, 2, 4) + self.assertEqual(len(comm_mode.get_comm_counts()), 0) + # when no communication happens, the data pointer should be the same. + self.assertEqual( + view.to_local().data_ptr(), dtensor.to_local().data_ptr() + ) + + view = dtensor.view(dim0_size, 4, 2) + self.assertEqual( + view.to_local().data_ptr(), dtensor.to_local().data_ptr() + ) + self.assertEqual(len(comm_mode.get_comm_counts()), 0) + + view = dtensor.view(dim0_size, 8) + self.assertEqual( + view.to_local().data_ptr(), dtensor.to_local().data_ptr() + ) + self.assertEqual(len(comm_mode.get_comm_counts()), 0) + + view = dtensor.view(dtensor.shape) + self.assertEqual( + view.to_local().data_ptr(), dtensor.to_local().data_ptr() + ) + self.assertEqual(len(comm_mode.get_comm_counts()), 0) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_tensor/ops/_view_ops.py b/torch/distributed/_tensor/ops/_view_ops.py index 0d7299544d4e4e..c61e21e98b89ba 100644 --- a/torch/distributed/_tensor/ops/_view_ops.py +++ b/torch/distributed/_tensor/ops/_view_ops.py @@ -377,7 +377,7 @@ def view_groups(from_size: Shape, to_size: Shape) -> DimMap: if len(to_group_shape) > 0: flattened = Flatten.new( - tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] > 1) + tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] >= 1) ) result_pp += [ Split.new(flattened, tuple(to_group_shape), i) From ff57f1e14b5b2c3b9aad53a96fba9c1723965c9f Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 6 Sep 2024 00:08:57 +0000 Subject: [PATCH 0508/1018] [Submodule] Bump pybind11 to v2.13.5 (#135202) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/135202 Approved by: https://github.com/Skylion007 --- third_party/pybind11 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/pybind11 b/third_party/pybind11 index 941f45bcb51457..7c33cdc2d39c7b 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 941f45bcb51457884fa1afd6e24a67377d70f75c +Subproject commit 7c33cdc2d39c7b99a122579f53bc94c8eb3332ff From 71632b7dae10a9e00bf01b7cbbc60802a8caf949 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 5 Sep 2024 13:34:28 -0700 Subject: [PATCH 0509/1018] Track base of FunctionalTensor in inference mode. (#135141) The idea behind the tracking is the following, whenever we see a tensor if the tensors is a root tensors (does not have any view metas ) when we consider is as the base of the all the tensors that shares its storage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135141 Approved by: https://github.com/zou3519 --- aten/src/ATen/FunctionalTensorWrapper.cpp | 7 +- aten/src/ATen/FunctionalTensorWrapper.h | 8 ++ test/inductor/test_auto_functionalize.py | 85 +++---------------- tools/pyi/gen_pyi.py | 3 + torch/_higher_order_ops/auto_functionalize.py | 11 ++- torch/_subclasses/functional_tensor.py | 43 ++++++++-- .../python_torch_functions_manual.cpp | 4 + 7 files changed, 74 insertions(+), 87 deletions(-) diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index dfd4928808602d..6f66e8065731a4 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -707,7 +707,12 @@ bool are_all_mutations_under_no_grad_or_inference_mode(const Tensor& functional_ } bool isFunctionalTensor(const at::Tensor& tensor) { - return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); + return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); +} + +bool isBaseTensor(const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor)); + return unsafeGetFunctionalWrapper(tensor)->isBaseTensor(); } bool isFunctionalTensor(const std::optional& t) { diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index afb3af5fc54f29..ed5daf90e5f44b 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -165,6 +165,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { was_storage_changed_ = true; } + // A FunctionalTensor is considered a base if its not a view of another + // tensor. + bool isBaseTensor() const { + return view_metas_.empty(); + } + c10::SymInt get_storage_size(bool before) { return functional_storage_impl()->get_storage_size(before); } @@ -290,6 +296,8 @@ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( return functional_impl; } +TORCH_API bool isBaseTensor(const at::Tensor& tensor); + TORCH_API bool isFunctionalTensor(const at::Tensor& tensor); TORCH_API bool isFunctionalTensor(const std::optional& t); TORCH_API bool isFunctionalTensor( diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index b1e3ab1f8ad777..019e88cf1a188e 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -911,85 +911,20 @@ def test_inference_mode1_v2(self): with torch.inference_mode(): self.test_auto_functionalize_extra1() - # In inference mode we do not support inplacing views yet. @torch._inductor.config.patch(enable_auto_functionalized_v2=True) def test_inference_mode2_v2(self): - with torch.inference_mode(), torch.library._scoped_library( - "mylib", "FRAGMENT" - ) as lib: - torch.library.define( - "mylib::foo", - "(Tensor(a!) x, Tensor(b!) y) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x, y): - x.sin_() - y.sin_() - - def f(x): - a = x[0] - b = x[1] - torch.ops.mylib.foo(a, b) - return - - orig_args = [torch.randn(2)] - - [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) - [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - result3 = f(*eager_args) - - self.assertEqual(inductor_args, eager_args) - self.assertEqual(inductor_args, aot_eager_args) - - self.assertEqual(result3, result1) - self.assertEqual(result3, result2) - - if torch._dynamo.config.assume_static_by_default: - self.assertExpectedInline( - graph_aot, - """\ -def forward(self, arg0_1: "f32[2][1]cpu"): - select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) - select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1) - auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]); select = select_1 = None - getitem_1: "f32[][]cpu" = auto_functionalized_v2[1] - getitem_2: "f32[][]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None - select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0); getitem_1 = None - select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1); select_scatter = getitem_2 = None - copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1); arg0_1 = select_scatter_1 = copy_ = None - return ()""", # noqa: B950 - ignore_comments=True, - ignore_empty_lines=True, - ) + with torch.inference_mode(): + self.test_auto_functionalize_extra2() - # 2. Run with inductor backend + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_inference_mode3_v2(self): + with torch.inference_mode(): + self.test_auto_functionalize_extra3() - if torch._dynamo.config.assume_static_by_default: - self.assertExpectedInline( - graph_inductor, - """\ -def forward(self, arg0_1: "f32[2][1]cpu"): - select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) - select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1) - as_strided_default: "f32[1][1]cpu" = torch.ops.aten.as_strided.default(select, [1], [1], 0); select = None - clone_default: "f32[1][1]cpu" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None - as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default, [], [], 0); clone_default = None - as_strided_default_2: "f32[2][1]cpu" = torch.ops.aten.as_strided.default(select_1, [2], [1], 0); select_1 = None - clone_default_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_default_2); as_strided_default_2 = None - as_strided_default_3: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default_1, [], [], 1); clone_default_1 = None - foo_default = torch.ops.mylib.foo.default(as_strided_default_1, as_strided_default_3); foo_default = None - select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, as_strided_default_1, 0, 0); as_strided_default_1 = None - select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, as_strided_default_3, 0, 1); select_scatter_default = as_strided_default_3 = None - copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1); arg0_1 = select_scatter_default_1 = copy_ = None - return ()""", # noqa: B950 - ignore_comments=True, - ignore_empty_lines=True, - ) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_inference_mode4_v2(self): + with torch.inference_mode(): + self.test_auto_functionalize_extra4() @torch._inductor.config.patch(enable_auto_functionalized_v2=True) def test_dynamic_v2(self): diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index e4e15bcbc4f6da..bd2fcee5e51c4a 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -781,6 +781,9 @@ def gen_pyi( "_is_functional_tensor": [ "def _is_functional_tensor(t: Tensor) -> _bool: ..." ], + "_is_functional_tensor_base": [ + "def _is_functional_tensor_base(t: Tensor) -> _bool: ..." + ], "_from_functional_tensor": [ "def _from_functional_tensor(t: Tensor) -> Tensor: ..." ], diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 8466d42ff6f435..232981f1f0192e 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -18,6 +18,13 @@ ) +def get_base(tensor): + if torch.is_inference_mode_enabled(): + return tensor._inference_mode_base + else: + return tensor._base + + @dataclass class ViewInfo: base_index: int @@ -68,7 +75,7 @@ def write_single_view(prefix: str, tensor: Tensor, base_index: int): if tensor is None: kwargs[f"{prefix}_base_index"] = None - elif tensor._base is None: + elif get_base(tensor) is None: # if the tensor is the base (not view), for simplicity we do not serialize view meta. kwargs[f"{prefix}_base_index"] = base_index else: @@ -437,7 +444,7 @@ def do_auto_functionalize_v2( arg_to_base_index: Dict[str, Any] = {} def update_dict(tensor, arg_name, index=None): - base = tensor if tensor._base is None else tensor._base + base = tensor if get_base(tensor) is None else get_base(tensor) def set_result(base_index): if index is None: diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index e9cbb0759b4c10..7bfa16a295a5d2 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import contextlib import warnings +import weakref from abc import ABC, abstractmethod from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union @@ -111,7 +112,10 @@ class FunctionalTensor(torch.Tensor): torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type] ] - def __new__(cls, elem): + # Used by auto_functionalize to determine base of tensors during inference mode. + _inference_mode_base: Optional["FunctionalTensor"] = None + + def __new__(cls, elem, mode): assert torch._is_functional_tensor(elem) # In general, we'd like our functional tensor subclass to only be in charge of functionalization, @@ -142,9 +146,9 @@ def __new__(cls, elem): cls, elem.shape, # sizes elem.stride() if not is_sparse_any(elem) else None, # strides - elem.storage_offset() - if not is_sparse_any(elem) - else None, # storage_offset + ( + elem.storage_offset() if not is_sparse_any(elem) else None + ), # storage_offset None, # memory_format elem.dtype, # dtype elem.layout, # layout @@ -158,6 +162,21 @@ def __new__(cls, elem): ) torch._C._set_throw_on_mutable_data_ptr(out) out.elem = elem + + if ( + torch.is_inference_mode_enabled() + and torch._inductor.config.enable_auto_functionalized_v2 + ): + if out.is_base_tensor(): + out._inference_mode_base = None + # This assumes that the FunctionalTensor.elem does not change its storage after this point. + # Otherwise this would be invalid. + mode._storage_to_base[out.elem.untyped_storage()] = out + else: + out._inference_mode_base = mode._storage_to_base[ + out.elem.untyped_storage() + ] + assert out._inference_mode_base is not None return out def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -209,6 +228,7 @@ def __repr__(self): @staticmethod def to_functional(x): # We will do the wrapping for the user. + assert not torch._is_functional_tensor(x) # The only autograd metadata we care about on the FunctionalTensor is: # - requires_grad (so autograd runs) @@ -226,7 +246,7 @@ def to_functional(x): with functional_mode: torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined] - out = FunctionalTensor(x_functional) + out = FunctionalTensor(x_functional, functional_mode) torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined] return out @@ -234,6 +254,9 @@ def from_functional(self): torch._sync(self) return torch._from_functional_tensor(self.elem) + def is_base_tensor(self) -> bool: + return torch._is_functional_tensor_base(self.elem) + def replace_(self, output) -> None: torch._functionalize_replace(self.elem, output) @@ -316,6 +339,10 @@ def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=Fals # discovery. This flag distinguishes between the two stages. self._allow_token_discovery = _allow_token_discovery + self._storage_to_base: weakref.WeakKeyDictionary[ + torch.storage.UntypedStorage, Optional[FunctionalTensor] + ] = weakref.WeakKeyDictionary() + # No-op if FunctionalTensorMode is already in use def __enter__(self): def _get_prev_mode(): @@ -366,6 +393,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if not issubclass(t, torch._subclasses.FakeTensor) and t not in [torch.Tensor, FunctionalTensor] ] + if unrecognized_types: not_implemented_log.debug( "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types @@ -417,16 +445,13 @@ def _can_decompose(func): if r is not NotImplemented: return r - def assert_is_functional(x): - assert torch._is_functional_tensor(x) - def wrap(x): # Only wrap our outputs in subclasses if the inner functionalization call # also wrapped outputs into FunctionalTensorWrappers. # When can this happen? e.g. `torch.div(2, 2)` assert not isinstance(x, FunctionalTensor) if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): - return FunctionalTensor(x) + return FunctionalTensor(x, self) return x def unwrap(x): diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 1feb1f41cdd916..92890a1509e8ef 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -664,6 +664,10 @@ void initTorchFunctions(PyObject* module) { !at::functionalization::impl::isFunctionalTensor(o)); at::functionalization::impl::replace_(t, o); }); + py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + return at::functionalization::impl::isBaseTensor(t); + }); py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); From d297add86104b2e1abe964225773d5fe9e462750 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Fri, 6 Sep 2024 00:39:11 +0000 Subject: [PATCH 0510/1018] Add Percentages to Function Events (#135155) Summary: Users have recently asked that the profiler contains self/total CPU and device percentages to FunctionEvents so that teams can process the data procedurely. Some of it could be done mathematically via subroutines but since we already have the information in the _build_table, lets build it there. Test Plan: Check that we have the same table as before but also check that the parameters we check also have the expected values Differential Revision: D62210351 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135155 Approved by: https://github.com/shanw-meta, https://github.com/kit1980 --- torch/autograd/profiler_util.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 0cd81e4c208599..3957c044cbbca4 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -501,6 +501,9 @@ def __init__( self.is_legacy: bool = is_legacy self.flops: Optional[int] = flops self.is_user_annotation: Optional[bool] = is_user_annotation + self.self_cpu_percent = -1 + self.total_cpu_percent = -1 + self.total_device_percent = -1 def append_kernel(self, name, device, duration): assert self.device_type == DeviceType.CPU @@ -1017,26 +1020,33 @@ def trim_path(path, src_column_width): name = evt.key if max_name_column_width is not None and len(name) >= max_name_column_width - 3: name = name[: (max_name_column_width - 3)] + "..." + evt.self_cpu_percent = _format_time_share( + evt.self_cpu_time_total, sum_self_cpu_time_total + ) + evt.total_cpu_percent = ( + _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) + if not evt.is_async + else 0 + ) row_values = [ name, # Self CPU total %, 0 for async events. - _format_time_share(evt.self_cpu_time_total, sum_self_cpu_time_total), + evt.self_cpu_percent, evt.self_cpu_time_total_str, # Self CPU total # CPU total %, 0 for async events. - _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) - if not evt.is_async - else 0, + evt.total_cpu_percent, evt.cpu_time_total_str, # CPU total evt.cpu_time_str, # CPU time avg ] if has_device_time: + evt.total_device_percent = _format_time_share( + evt.self_device_time_total, sum_self_device_time_total + ) row_values.extend( [ evt.self_device_time_total_str, # device time total % - _format_time_share( - evt.self_device_time_total, sum_self_device_time_total - ), + evt.total_device_percent, evt.device_time_total_str, evt.device_time_str, # device time avg ] From cbd0b299eea6737646c1b247ba35a456fa87c183 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 5 Sep 2024 14:08:35 -0700 Subject: [PATCH 0511/1018] [AOTI] Support MKLDNN conv ops in cpp wrapper (#134475) Summary: Partially fix https://github.com/pytorch/pytorch/issues/123040. In the ABI-compatible mode, MKLDNN fallback ops do not have C shim implementations and thus need to go through the custom ops launch path. Other MLKDNN ops will be fixed in following PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134475 Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w, https://github.com/angelayi --- test/inductor/test_cpu_cpp_wrapper.py | 18 ++-- torch/_inductor/codegen/cpp_wrapper_cpu.py | 110 +++++++++++++++------ torch/_inductor/mkldnn_ir.py | 6 ++ 3 files changed, 99 insertions(+), 35 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 0d4d7cb6a1dbe2..4f8fd4645f2acf 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -88,8 +88,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ) if config.abi_compatible: xfail_list = [ - "test_conv2d_binary_inplace_fusion_failed_cpu", - "test_conv2d_binary_inplace_fusion_pass_cpu", "test_dynamic_qlinear_cpu", "test_dynamic_qlinear_qat_cpu", "test_lstm_packed_change_input_sizes_cpu", @@ -209,8 +207,12 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), func_inputs=[ - ["op_mkldnn__convolution_pointwise_binary.call"], - ["op_mkldnn__convolution_pointwise__binary.call"], + None + if config.abi_compatible + else ["op_mkldnn__convolution_pointwise_binary.call"], + None + if config.abi_compatible + else ["op_mkldnn__convolution_pointwise__binary.call"], ], ), BaseTest( @@ -219,8 +221,12 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), func_inputs=[ - ["op_mkldnn__convolution_pointwise__binary.call"], - ["op_mkldnn__convolution_pointwise_binary.call"], + None + if config.abi_compatible + else ["op_mkldnn__convolution_pointwise__binary.call"], + None + if config.abi_compatible + else ["op_mkldnn__convolution_pointwise_binary.call"], ], ), BaseTest( diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index af2bd731f5c324..86af726f58b3e7 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2126,8 +2126,10 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed( self.allow_stack_allocation = False def extract_output_name(out): - assert out is not None, "None, i.e. optional output is not supported" - if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): + if out is None: + # Because out is not a MultiOutput, we assume the kernel returns a single output + return [buf_name] + elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): return out.get_name() elif isinstance(out, (list, tuple)): return type(out)(extract_output_name(o) for o in out) @@ -2202,10 +2204,24 @@ def load_custom_op_wrapper(self): self.custom_op_wrapper_loaded = True def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): - def generate_py_arg_inner(raw_arg, arg_type): - if isinstance(arg_type, torch.TensorType): + def generate_py_arg_inner(lines, raw_arg, arg_type): + if raw_arg is None: + # Py_None is a singleton, so we have to explicitly incref it here + lines.append("Py_INCREF(Py_None);\n") + return "Py_None" + elif isinstance(arg_type, torch.TensorType): # Store AtenTensorHandle as void* - return f"PyCapsule_New(reinterpret_cast({raw_arg.codegen_reference()}.get()), NULL, NULL)" + base_handle = raw_arg.codegen_reference() + ( + tmp_raii_handle_var, + tmp_raii_handle_var_decl, + ) = self.create_tmp_raii_handle_var(base_handle) + if tmp_raii_handle_var: + lines.append(tmp_raii_handle_var_decl) + base_handle = tmp_raii_handle_var + return f"PyCapsule_New(reinterpret_cast({base_handle}.get()), NULL, NULL)" + elif isinstance(arg_type, torch.OptionalType): + return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType()) elif isinstance(arg_type, torch.IntType): # int return f"PyLong_FromLongLong({raw_arg})" @@ -2244,16 +2260,24 @@ def generate_py_arg_inner(raw_arg, arg_type): f"arg type {arg_type} is not yet supported by custom_op_wrapper" ) - lines = "" + lines = [] if isinstance(arg_type, torch.ListType): assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list" - lines += f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" + lines.append( + f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" + ) for i, elem in enumerate(raw_arg): - lines += f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(elem, arg_type.getElementType())});\n" - lines += f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" + lines.append( + f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n" + ) + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" + ) else: - lines += f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(raw_arg, arg_type)});\n" - return lines + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n" + ) + return "".join(lines) def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( self, @@ -2302,6 +2326,7 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( """ assert op_overload is not None, "op_overload should not be None" + for idx, (raw_arg, schema_arg) in enumerate( zip(raw_args, op_overload._schema.arguments) ): @@ -2372,12 +2397,12 @@ def generate_reset_kernel_saved_flags(self): def generate_save_uncompiled_kernels(self): pass - def c_type_for_prim_type(self, type_) -> str: + def c_type_for_prim_type(self, val, type_) -> str: assert ( config.abi_compatible ), "c_type_for_prim_type is only used in ABI compatible mode" if isinstance(type_, torch.OptionalType): - return f"{self.c_type_for_prim_type(type_.getElementType())}*" + return f"{self.c_type_for_prim_type(val, type_.getElementType())}*" elif isinstance(type_, torch.TensorType): return "AtenTensorHandle" elif isinstance(type_, (torch.IntType, torch.SymIntType)): @@ -2388,6 +2413,22 @@ def c_type_for_prim_type(self, type_) -> str: return "int32_t" elif isinstance(type_, torch.FloatType): return "double" + elif isinstance(type_, torch.NumberType): + if isinstance(val, bool): + return "int32_t" + elif isinstance(val, int): + return "int64_t" + elif isinstance(val, float): + return "double" + elif val is None: + # This could happen when val is an optional value + return "double" + else: + raise AssertionError( + f"Unexpected type in c_type_for_prim_type: {type_=}" + ) + elif isinstance(type_, torch.StringType): + return "const char*" else: raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") @@ -2476,7 +2517,7 @@ def val_to_arg_str(self, val, type_=None) -> str: return f"&{var_name}, {aux}" else: self.writeline( - f"{self.c_type_for_prim_type(element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" + f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" ) return f"&{var_name}" else: @@ -2487,19 +2528,13 @@ def val_to_arg_str(self, val, type_=None) -> str: base_handle = ( f"convert_arrayref_tensor_to_tensor({base_handle})" ) - if base_handle.startswith( - ( - "convert_arrayref_tensor_to_tensor", - "wrap_with_raii_handle_if_needed", - ) - ): - # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to - # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. - tmp_var_name = f"var_{next(self.arg_var_id)}" - self.writeline( - f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};" - ) - base_handle = tmp_var_name + ( + tmp_raii_handle_var, + tmp_raii_handle_var_decl, + ) = self.create_tmp_raii_handle_var(base_handle) + if tmp_raii_handle_var: + self.writeline(tmp_raii_handle_var_decl) + base_handle = tmp_raii_handle_var var_name = f"var_{next(self.arg_var_id)}" self.writeline( f"AtenTensorHandle {var_name} = {base_handle}.get();" @@ -2519,12 +2554,12 @@ def val_to_arg_str(self, val, type_=None) -> str: # Zero-size array is not supported in the C or C++ standard, so # we declare a null pointer for it. self.writeline( - f"const {self.c_type_for_prim_type(element_type)}* {var_name} = nullptr;" + f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;" ) else: result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" self.writeline( - f"const {self.c_type_for_prim_type(element_type)} {var_name}[] = {result};" + f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};" ) # Need to pass the array length because we can't use std::vector return f"{var_name}, {len(val)}" @@ -2532,3 +2567,20 @@ def val_to_arg_str(self, val, type_=None) -> str: return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" return self.val_to_arg_str_for_prim_type(val, type_) + + def create_tmp_raii_handle_var(self, base_handle): + if base_handle.startswith( + ( + "convert_arrayref_tensor_to_tensor", + "wrap_with_raii_handle_if_needed", + ) + ): + # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to + # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. + tmp_var_name = f"var_{next(self.arg_var_id)}" + return ( + tmp_var_name, + f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};\n", + ) + else: + return "", "" diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index eeab8fcc76c652..e5bd49625e3124 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -261,6 +261,8 @@ def codegen(self, wrapper): self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, + op_overload=self.op_overload, + raw_args=[*self.inputs, *self.constant_args], ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -335,6 +337,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + self.op_overload, + [*self.inputs, *self.constant_args], ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -428,6 +432,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + self.op_overload, + [*self.inputs, *self.constant_args], ) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: From 5a878989b3a17eab09c80ced8f3fb26423901aa2 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 5 Sep 2024 14:08:35 -0700 Subject: [PATCH 0512/1018] [AOTI] Support MKLDNN qlinear ops in cpp wrapper (#134783) Summary: Similar to https://github.com/pytorch/pytorch/pull/134475, support qlinear in the ABI-compatible mode for cpp-wrapper Inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134783 Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w, https://github.com/angelayi ghstack dependencies: #134475 --- test/inductor/test_cpu_cpp_wrapper.py | 14 +-- test/inductor/test_mkldnn_pattern_matcher.py | 7 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 9 ++ torch/_inductor/mkldnn_ir.py | 99 ++++++++++++++++++++ 4 files changed, 118 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 4f8fd4645f2acf..2020c5d98837bb 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -88,8 +88,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ) if config.abi_compatible: xfail_list = [ - "test_dynamic_qlinear_cpu", - "test_dynamic_qlinear_qat_cpu", "test_lstm_packed_change_input_sizes_cpu", "test_qconv2d_add_cpu", "test_qconv2d_add_relu_cpu", @@ -97,12 +95,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): "test_qconv2d_dequant_promotion_cpu", "test_qconv2d_maxpool2d_linear_dynamic_cpu", "test_qconv2d_relu_cpu", - "test_qlinear_cpu", - "test_qlinear_add_cpu", - "test_qlinear_add_relu_cpu", - "test_qlinear_dequant_promotion_cpu", - "test_qlinear_gelu_cpu", - "test_qlinear_relu_cpu", *[ func for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) @@ -325,11 +317,13 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestDynamicPatternMatcher(), condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, func_inputs=[ - [ + None + if config.abi_compatible + else [ "op_onednn_qconv2d_pointwise_.call", "op_quantized_max_pool2d_.call", "op_onednn_qlinear_pointwise_tensor.call", - ] + ], ], ), BaseTest( diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 3a98efae39b02c..418f743c650108 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1802,12 +1802,17 @@ def matcher_check_fn(): mod, (v,), [ + "torch.ops.onednn.qlinear_pointwise.tensor", + "torch.ops.onednn.qlinear_pointwise.binary", + ] + if config.abi_compatible + else [ "op_onednn_qlinear_pointwise_tensor.call", "op_onednn_qlinear_pointwise_binary_tensor.call", ], [], check_quantization=True, - num_include_ops=[2, 2], + num_include_ops=[4, 4] if config.abi_compatible else [2, 2], ) else: # For python wrapper diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 86af726f58b3e7..290787b99b9930 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -247,6 +247,11 @@ class RAIIPyObject { """ ) + @functools.lru_cache(None) # noqa: B019 + def include_extra_header(self, header: str): + # This is needed for cpp to python dtype conversion + self.header.splice(f"#include <{header}>") + def mark_output_type(self): # mark output type to unwrap tensor back to python scalar from ..ir import ShapeAsConstantBuffer @@ -2255,6 +2260,10 @@ def generate_py_arg_inner(lines, raw_arg, arg_type): raise NotImplementedError( f"arg type {arg_type} with raw_arg {raw_arg}, {type(raw_arg)} is not yet supported by custom_op_wrapper" ) + elif isinstance(raw_arg, torch.dtype): + # dtype + self.include_extra_header("torch/csrc/DynamicTypes.h") + return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))" else: raise NotImplementedError( f"arg type {arg_type} is not yet supported by custom_op_wrapper" diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index e5bd49625e3124..8562fae4bcf2a9 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -1233,17 +1233,23 @@ def __init__( def codegen(self, wrapper): # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation args = [x.codegen_reference() for x in self.inputs] const_args = [] const_args.extend(self.codegen_const_args()) x = args[0] + x_raw = self.inputs[0] packed_weight = args[1] + packed_weight_raw = self.inputs[1] bias = args[2] if self.has_bias else const_args[0] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] w_scale, w_zp = args[-2], args[-1] + w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] if self.x_scale_zp_are_tensors: assert len(args) >= 4 x_scale, x_zp = args[-4], args[-3] + x_scale_raw, x_zp_raw = self.inputs[-4], self.inputs[-3] ( o_scale, o_zp, @@ -1252,6 +1258,14 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-6:] + ( + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-6:] else: assert len(const_args) >= 8 ( @@ -1264,6 +1278,16 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-8:] + ( + x_scale_raw, + x_zp_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-8:] codegen_args = ( x, @@ -1280,6 +1304,21 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.python_kernel_name, @@ -1288,6 +1327,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + self.op_overload, + raw_args, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -1410,17 +1451,27 @@ def __init__( def codegen(self, wrapper): # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation args = [x.codegen_reference() for x in self.inputs] const_args = [] const_args.extend(self.codegen_const_args()) x = args[0] + x_raw = self.inputs[0] packed_weight = args[1] + packed_weight_raw = self.inputs[1] bias = args[2] if self.has_bias else const_args[0] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] w_scale, w_zp, other = args[-3], args[-2], args[-1] + w_scale_raw, w_zp_raw, other_raw = ( + self.inputs[-3], + self.inputs[-2], + self.inputs[-1], + ) if self.x_scale_zp_are_tensors: assert len(args) >= 5 x_scale, x_zp = args[-5], args[-4] + x_scale_raw, x_zp_raw = self.inputs[-5], self.inputs[-4] ( o_scale, o_zp, @@ -1433,6 +1484,18 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-10:] + ( + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-10:] else: assert len(const_args) >= 8 ( @@ -1449,6 +1512,20 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-12:] + ( + x_scale_raw, + x_zp_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-12:] codegen_args = ( x, @@ -1470,6 +1547,26 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + other_raw, + bias_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.python_kernel_name, @@ -1478,6 +1575,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + self.op_overload, + raw_args, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) From 3b3b9a6e4dccc5c0e8d7d36afefdf526693b3e3b Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 5 Sep 2024 14:08:36 -0700 Subject: [PATCH 0513/1018] [AOTI] Support MKLDNN qconv ops in cpp wrapper (#134795) Summary: Similar to https://github.com/pytorch/pytorch/pull/134475, support qconv in the ABI-compatible mode for cpp-wrapper Inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134795 Approved by: https://github.com/leslie-fang-intel, https://github.com/chunyuan-w, https://github.com/angelayi ghstack dependencies: #134475, #134783 --- test/inductor/test_cpu_cpp_wrapper.py | 6 -- torch/_inductor/mkldnn_ir.py | 98 ++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 2020c5d98837bb..4d18f64ffd6255 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -89,12 +89,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): if config.abi_compatible: xfail_list = [ "test_lstm_packed_change_input_sizes_cpu", - "test_qconv2d_add_cpu", - "test_qconv2d_add_relu_cpu", - "test_qconv2d_cpu", - "test_qconv2d_dequant_promotion_cpu", - "test_qconv2d_maxpool2d_linear_dynamic_cpu", - "test_qconv2d_relu_cpu", *[ func for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 8562fae4bcf2a9..12634c632b80a9 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -614,6 +614,7 @@ def __init__( def codegen(self, wrapper): # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation args = [x.codegen_reference() for x in self.inputs] const_arg_names = [ "x_scale", @@ -634,13 +635,21 @@ def codegen(self, wrapper): const_args = list(self.codegen_const_args(const_arg_names)) x = args[0] + x_raw = self.inputs[0] packed_weight = args[1] + packed_weight_raw = self.inputs[1] bias = args[2] if self.has_bias else const_args[2] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[2] w_scale, w_zp = args[-2], args[-1] + w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] ( x_scale, x_zp, ) = const_args[:2] + ( + x_scale_raw, + x_zp_raw, + ) = self.constant_args[:2] ( stride, padding, @@ -653,7 +662,18 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-10:] - + ( + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-10:] codegen_args = ( x, x_scale, @@ -673,6 +693,25 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.python_kernel_name, @@ -680,6 +719,8 @@ def codegen(self, wrapper): codegen_args, self.cpp_op_schema, self.cpp_kernel_key, + op_overload=self.op_overload, + raw_args=raw_args, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -812,6 +853,7 @@ def __init__( def codegen(self, wrapper): # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation args = [x.codegen_reference() for x in self.inputs] const_arg_names = [ "x_scale", @@ -836,15 +878,29 @@ def codegen(self, wrapper): const_args = list(self.codegen_const_args(const_arg_names)) x = args[0] + x_raw = self.inputs[0] packed_weight = args[1] + packed_weight_raw = self.inputs[1] bias = args[2] if self.has_bias else const_args[4] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[4] accum, w_scale, w_zp = args[-3], args[-2], args[-1] + accum_raw, w_scale_raw, w_zp_raw = ( + self.inputs[-3], + self.inputs[-2], + self.inputs[-1], + ) ( x_scale, x_zp, accum_scale, accum_zp, ) = const_args[:4] + ( + x_scale_raw, + x_zp_raw, + accum_scale_raw, + accum_zp_raw, + ) = self.constant_args[:4] ( stride, padding, @@ -859,6 +915,20 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-12:] + ( + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-12:] conv_args = ( x, x_scale, @@ -883,6 +953,30 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + accum_raw, + accum_scale_raw, + accum_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.python_kernel_name, @@ -891,6 +985,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + op_overload=self.op_overload, + raw_args=raw_args, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) From 26647f27dfcbb87930dc51c433836b02aea95469 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 6 Sep 2024 01:29:54 +0000 Subject: [PATCH 0514/1018] [ONNX] Enable experimental exporter logic to dynamo_export and support refine dynamic_shapes (#134976) (1) Enable experimental exporter logic to dynamo_export (2) Refine dynamic shapes and retry export in export strategies (3) Delete `torch_export_graph_extractor` and use the new export logic (4) Disable ExportedProgram test in `test_fx_onnx_with_onnxruntime.py`, as ONNXProgram is different now. Fixes https://github.com/pytorch/pytorch/issues/126479 Fixes #135183 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134976 Approved by: https://github.com/justinchuby --- test/onnx/exporter/test_api.py | 48 ++++++ test/onnx/pytorch_test_common.py | 22 +++ test/onnx/test_fx_to_onnx.py | 44 +---- test/onnx/test_fx_to_onnx_decomp_skip.py | 15 -- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 67 +------- .../test_torch_export_with_onnxruntime.py | 13 +- torch/onnx/__init__.py | 129 ++++++++++++++- torch/onnx/_internal/_exporter_legacy.py | 154 ++---------------- .../_internal/exporter/_capture_strategies.py | 34 +++- torch/onnx/_internal/exporter/_compat.py | 1 - torch/onnx/_internal/fx/_pass.py | 2 +- .../fx/torch_export_graph_extractor.py | 128 --------------- 12 files changed, 246 insertions(+), 411 deletions(-) delete mode 100644 torch/onnx/_internal/fx/torch_export_graph_extractor.py diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 6f2a1f942313dd..18330755510454 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -148,6 +148,54 @@ def test_partial_dynamic_shapes(self): }, ) + def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self): + os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1" + assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1" + + class Nested(torch.nn.Module): + def forward(self, x): + (a0, a1), (b0, b1), (c0, c1, c2) = x + return a0 + a1 + b0 + b1 + c0 + c1 + c2 + + inputs = ( + (1, 2), + ( + torch.randn(4, 4), + torch.randn(4, 4), + ), + ( + torch.randn(4, 4), + torch.randn(4, 4), + torch.randn(4, 4), + ), + ) + + onnx_program = torch.onnx.dynamo_export( + Nested(), + inputs, + export_options=torch.onnx.ExportOptions(dynamic_shapes=True), + ) + assert onnx_program is not None + onnx_testing.assert_onnx_program(onnx_program) + + def test_refine_dynamic_shapes_with_onnx_export(self): + # NOTE: From test/export/test_export.py + + # refine lower, upper bound + class TestRefineDynamicShapeModel(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] >= 6 and y.shape[0] <= 16: + return x * 2.0, y + 1 + + inps = (torch.randn(16), torch.randn(12)) + dynamic_shapes = { + "x": (torch.export.Dim("dx"),), + "y": (torch.export.Dim("dy"),), + } + self.assert_export( + TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes + ) + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 408168a9c711eb..6c90290a3e96d2 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -346,6 +346,28 @@ def wrapper(self, *args, **kwargs): return wrapper +def skip_if_fake_model_and_inititalizer(reason: Optional[str] = None): + """skip test with models using ExportedProgram as input. + + Args: + reason: The reason for skip the ONNX export test. + + Returns: + A decorator for skip tests. + """ + + def skip_dec(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if kwargs["use_fake_mode"] and kwargs["include_initializer"]: + return unittest.SkipTest(reason) + return func(self, *args, **kwargs) + + return wrapper + + return skip_dec + + def xfail_if_model_type_is_exportedprogram( error_message: str, reason: Optional[str] = None ): diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 5de711f181f4d3..460d9996cb9300 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -544,34 +544,6 @@ def test_fake_tensor_mode_huggingface_tiiuae_falcon(self): onnx.checker.check_model(onnx_program.model_proto) onnx.shape_inference.infer_shapes(onnx_program.model_proto) - def test_exported_program_input_with_custom_fx_tracer(self): - from torch.onnx._internal import _exporter_legacy - from torch.onnx._internal.fx import dynamo_graph_extractor - - class Model(torch.nn.Module): - def forward(self, x): - return x + 1 - - x = torch.randn(1, 1, 2) - exported_program = torch.export.export(Model(), args=(x,)) - - export_options = torch.onnx.ExportOptions() - export_options = _exporter_legacy.ResolvedExportOptions( - export_options, model=exported_program - ) - export_options.fx_tracer = ( - dynamo_graph_extractor.DynamoExport() - ) # Override fx_tracer to an unsupported tracer - with self.assertRaises(torch.onnx.OnnxExporterError): - onnx_program = torch.onnx.dynamo_export( - exported_program, - x, - export_options=export_options, - ) - self.assertTrue(onnx_program._export_exception is not None) - with self.assertRaises(torch.onnx.InvalidExportOptionsError): - raise self._export_exception - def test_exported_program_torch_distributions_normal_Normal(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -606,21 +578,6 @@ def forward(self, input): # with no Cast node in between. self.assertEqual(div_node.input[0], model_proto.graph.input[0].name) - def test_exported_program_as_input_with_model_signature(self): - class Model(torch.nn.Module): - def forward(self, x): - return x + 1.0 - - x = torch.randn(1, 1, 2, dtype=torch.float) - exported_program = torch.export.export(Model(), args=(x,)) - - onnx_program = torch.onnx.dynamo_export( - exported_program, - x, - ) - - self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature) - @common_utils.parametrize( "float8_type", [ @@ -707,6 +664,7 @@ def test_checkpoint_cast(self): onnx_program.save(tmp_onnx_file.name) onnx.checker.check_model(tmp_onnx_file.name, full_check=True) + @pytorch_test_common.skip_if_fake_model_and_inititalizer("segfault") @common_utils.parametrize( "include_initializer", [ diff --git a/test/onnx/test_fx_to_onnx_decomp_skip.py b/test/onnx/test_fx_to_onnx_decomp_skip.py index db8edce1425949..466ee4a0bb956f 100644 --- a/test/onnx/test_fx_to_onnx_decomp_skip.py +++ b/test/onnx/test_fx_to_onnx_decomp_skip.py @@ -19,12 +19,6 @@ def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str): class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase): - def _test_exported_program_forces_decomposition(self, model, input, op_type): - ep = torch.export.export(model, input) - onnx_program = torch.onnx.dynamo_export(ep, *input) - with self.assertRaises(AssertionError): - assert_op_in_onnx_model(onnx_program.model_proto, op_type) - def test_upsample_bilinear2d(self): class TestModel(torch.nn.Module): def __init__(self) -> None: @@ -37,9 +31,6 @@ def forward(self, x): onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2)) # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. assert_op_in_onnx_model(onnx_program.model_proto, "Resize") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2),), "Resize" - ) def test_upsample_bilinear2d_output_size(self): def func(x: torch.Tensor): @@ -61,9 +52,6 @@ def forward(self, x): onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3)) # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. assert_op_in_onnx_model(onnx_program.model_proto, "Resize") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2, 3),), "Resize" - ) def test_upsample_trilinear3d_output_size(self): def func(x: torch.Tensor): @@ -82,9 +70,6 @@ def forward(self, x): # If decomposition is skipped, the model will contain an InstanceNormalization op # instead of BatchNormalization op w/ training=True. assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2),), "InstanceNormalization" - ) if __name__ == "__main__": diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index dd53ded23cdbe3..ff4d3a91bd1afa 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -45,10 +45,7 @@ def _parameterized_class_attrs_and_values(): input_values.extend( itertools.product( (True, False), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), + (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,), ) ) return { @@ -912,10 +909,7 @@ def _parameterized_class_attrs_and_values_with_fake_options(): (True, False), (True, False), (True, False), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), + (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,), ) ) return { @@ -986,13 +980,6 @@ def _test_fake_tensor_mode_exporter( # Create the toy model with real weight. real_model = create_model() state_dict = real_model.state_dict() # concrete (non-fake) state_dict - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - real_model = torch.export.export( - real_model, args=create_args(), kwargs=create_kwargs() - ) with tempfile.NamedTemporaryFile( prefix=model_name, suffix=".pt" @@ -1015,13 +1002,6 @@ def _test_fake_tensor_mode_exporter( ) if export_within_fake_mode: - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - fake_model = torch.export.export( - fake_model, args=fake_args, kwargs=fake_kwargs - ) onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, @@ -1030,13 +1010,6 @@ def _test_fake_tensor_mode_exporter( ) if not export_within_fake_mode: - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - fake_model = torch.export.export( - fake_model, args=fake_args, kwargs=fake_kwargs - ) onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, **fake_kwargs, export_options=export_options ) @@ -1093,10 +1066,6 @@ def _test_fake_tensor_mode_exporter( for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close(ref_output, torch.tensor(ort_output)) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_simple(self): def create_model() -> nn.Module: class Model(torch.nn.Module): @@ -1126,10 +1095,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="!(it.GetName().empty())", reason="With after onnx==1.16, constant folding in optimizer causes this error.", @@ -1166,10 +1131,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_large_scale_exporter_with_toy_mlp(self): class MLPModel(nn.Module): def __init__(self) -> None: @@ -1208,10 +1169,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_huggingface_google_t5(self): config = transformers.T5Config( vocab_size=8096, d_model=64, num_layers=2, num_heads=2 @@ -1244,10 +1201,6 @@ def create_model(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", @@ -1310,10 +1263,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_huggingface_mosaicml_mpt(self): config = transformers.MptConfig( vocab_size=8096, d_model=64, n_heads=2, n_layers=3 @@ -1341,10 +1290,6 @@ def create_model(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="SymIntArrayRef expected to contain only concrete integers", model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, @@ -1374,10 +1319,6 @@ def create_model(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( error_message="Expected 5 inputs, got 3", reason="https://github.com/pytorch/pytorch/issues/115745", @@ -1417,10 +1358,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="SymIntArrayRef expected to contain only concrete integers", model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, diff --git a/test/onnx/torch_export/test_torch_export_with_onnxruntime.py b/test/onnx/torch_export/test_torch_export_with_onnxruntime.py index 7e3b24874e003b..7e7cf24a84e82a 100644 --- a/test/onnx/torch_export/test_torch_export_with_onnxruntime.py +++ b/test/onnx/torch_export/test_torch_export_with_onnxruntime.py @@ -36,14 +36,15 @@ def _compare_onnx_and_torch_exported_program( torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) else: torch_outputs = torch_exported_program(*input_args, **input_kwargs) - torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( - torch_outputs - ) - if len(torch_outputs_onnx_format) != len(onnx_outputs): + + if isinstance(torch_outputs, torch.Tensor): + torch_outputs = [torch_outputs] + + if len(torch_outputs) != len(onnx_outputs): raise AssertionError( - f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" + f"Expected {len(torch_outputs)} outputs, got {len(onnx_outputs)}" ) - for torch_output, onnx_output in zip(torch_outputs_onnx_format, onnx_outputs): + for torch_output, onnx_output in zip(torch_outputs, onnx_outputs): torch.testing.assert_close( torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol ) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 9d4084da1ac4b6..963766a33e3ba4 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -54,7 +54,7 @@ "is_onnxrt_backend_supported", ] -from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING +from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING import torch from torch import _C @@ -112,7 +112,6 @@ InvalidExportOptionsError, OnnxExporterError, OnnxRegistry, - dynamo_export, enable_fake_mode, ) @@ -126,7 +125,6 @@ ExportOptions.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" ONNXRuntimeOptions.__module__ = "torch.onnx" -dynamo_export.__module__ = "torch.onnx" InvalidExportOptionsError.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" enable_fake_mode.__module__ = "torch.onnx" @@ -393,6 +391,131 @@ def forward(self, x): return None +def dynamo_export( + model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined] + /, + *model_args, + export_options: ExportOptions | None = None, + **model_kwargs, +) -> ONNXProgram | Any: + """Export a torch.nn.Module to an ONNX graph. + + Args: + model: The PyTorch model to be exported to ONNX. + model_args: Positional inputs to ``model``. + model_kwargs: Keyword inputs to ``model``. + export_options: Options to influence the export to ONNX. + + Returns: + An in-memory representation of the exported ONNX model. + + **Example 1 - Simplest export** + :: + + class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x, bias=None): + out = self.linear(x) + out = out + bias + return out + + + model = MyModel() + kwargs = {"bias": 3.0} + args = (torch.randn(2, 2, 2),) + onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( + "my_simple_model.onnx" + ) + + **Example 2 - Exporting with dynamic shapes** + :: + + # The previous model can be exported with dynamic shapes + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + onnx_program = torch.onnx.dynamo_export( + model, *args, **kwargs, export_options=export_options + ) + onnx_program.save("my_dynamic_model.onnx") + """ + + # NOTE: The new exporter is experimental and is not enabled by default. + import warnings + + from torch.onnx import _flags + from torch.onnx._internal import exporter + from torch.utils import _pytree + + if isinstance(model, torch.export.ExportedProgram): + return exporter.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + elif _flags.USE_EXPERIMENTAL_LOGIC: + if export_options is not None: + warnings.warn( + "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. " + "For a more comprehensive set of export options, including advanced features, please consider using " + "`torch.onnx.export(..., dynamo=True)`. ", + category=FutureWarning, + ) + + if export_options is not None and export_options.dynamic_shapes: + # Make all shapes dynamic + def _to_dynamic_shapes_mapper(): + arg_order = 0 + + def _to_dynamic_shape(x): + nonlocal arg_order + if isinstance(x, torch.Tensor): + rank = len(x.shape) + dynamic_shape = {} + for i in range(rank): + dynamic_shape[i] = torch.export.Dim( + f"arg_{arg_order}_dim_{i}" + ) + arg_order += 1 + return dynamic_shape + else: + return None + + return _to_dynamic_shape + + # model_args could be nested + dynamic_shapes = _pytree.tree_map( + _to_dynamic_shapes_mapper(), + model_args, + ) + else: + dynamic_shapes = None + + return exporter.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + else: + from torch.onnx._internal._exporter_legacy import dynamo_export + + return dynamo_export( + model, *model_args, export_options=export_options, **model_kwargs + ) + + # TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module. # Returns True iff ONNX logging is turned on. diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 4d1830b3e9ce12..04347773abf01e 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -16,7 +16,6 @@ import torch import torch._ops -import torch.export as torch_export import torch.utils._pytree as pytree from torch.onnx._internal import io_adapter from torch.onnx._internal.diagnostics import infra @@ -304,27 +303,17 @@ class ResolvedExportOptions(ExportOptions): def __init__( self, options: ExportOptions | ResolvedExportOptions, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, # type: ignore[name-defined] + model: torch.nn.Module | Callable | None = None, # type: ignore[name-defined] ): from torch.onnx._internal.fx import ( # TODO: Prevent circular dep diagnostics, dynamo_graph_extractor, - torch_export_graph_extractor, ) if isinstance(options, ResolvedExportOptions): self.dynamic_shapes = options.dynamic_shapes self.diagnostic_options = options.diagnostic_options self.fake_context = options.fake_context - # private - if isinstance(model, torch_export.ExportedProgram) and not isinstance( - options.fx_tracer, torch_export_graph_extractor.TorchExport - ): - message = "'model' of type 'ExportedProgram' is only supported with 'TorchExport' FX Tracer" - e = InvalidExportOptionsError(message) - raise InvalidExportOptionsError( - ONNXProgram._from_failure(e, options.diagnostic_context), message - ) self.fx_tracer = options.fx_tracer self.onnx_registry = options.onnx_registry self.onnxfunction_dispatcher = options.onnxfunction_dispatcher @@ -345,10 +334,8 @@ def resolve(value: T | None, fallback: T | Callable[[], T]) -> T: self.diagnostic_options = resolve( options.diagnostic_options, DiagnosticOptions() ) - if isinstance(model, torch_export.ExportedProgram): - self.fx_tracer = torch_export_graph_extractor.TorchExport() - else: - self.fx_tracer = dynamo_graph_extractor.DynamoExport() + + self.fx_tracer = dynamo_graph_extractor.DynamoExport() self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type] self.diagnostic_context = diagnostics.DiagnosticContext( @@ -492,7 +479,6 @@ class ONNXProgram: diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata. fake_context: The fake context used for symbolic tracing. export_exception: The exception that occurred during export, if any. - model_signature: The model signature for the exported ONNX graph. """ _model_proto: Final[onnx.ModelProto] # type: ignore[name-defined, misc] @@ -501,9 +487,8 @@ class ONNXProgram: _diagnostic_context: Final[diagnostics.DiagnosticContext] # type: ignore[misc] _fake_context: Final[ONNXFakeContext | None] # type: ignore[misc] _export_exception: Final[Exception | None] # type: ignore[misc] - _model_signature: Final[torch.export.ExportGraphSignature | None] # type: ignore[misc] _model_torch: Final[ # type: ignore[misc] - torch.nn.Module | Callable | torch_export.ExportedProgram | None + torch.nn.Module | Callable | None ] def __init__( @@ -515,14 +500,9 @@ def __init__( *, fake_context: ONNXFakeContext | None = None, export_exception: Exception | None = None, - model_signature: torch.export.ExportGraphSignature | None = None, - model_torch: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_torch: torch.nn.Module | Callable | None = None, ): self._model_proto = model_proto - self._model_signature = model_signature self._model_torch = model_torch self._input_adapter = input_adapter self._output_adapter = output_adapter @@ -533,10 +513,7 @@ def __init__( def __call__( self, *args: Any, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, options: ONNXRuntimeOptions | None = None, **kwargs: Any, ) -> Any: @@ -571,8 +548,6 @@ def __call__( onnx_model = os.path.join(tmpdir_path, "model.onnx") if isinstance(model_with_state_dict, torch.nn.Module): model_state = model_with_state_dict.state_dict() - elif isinstance(model_with_state_dict, torch_export.ExportedProgram): - model_state = model_with_state_dict.state_dict else: model_state = None self.save( @@ -608,104 +583,6 @@ def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined] raise self._export_exception return self._model_proto - @property - def model_signature(self) -> torch.export.ExportGraphSignature | None: - """The model signature for the exported ONNX graph. - - This information is relevant because ONNX specification often differs from PyTorch's, resulting - in a ONNX graph with input and output schema different from the actual PyTorch model implementation. - By using the model signature, the users can understand the inputs and outputs differences - and properly execute the model in ONNX Runtime. - - NOTE: Model signature is only available when the ONNX graph was exported from a - :class:`torch.export.ExportedProgram` object. - - NOTE: Any transformation done to the model that changes the model signature must be accompanied - by updates to this model signature as well through :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`. - - Example: - - The following model produces different sets of inputs and outputs. - The first 4 inputs are model parameters (namely conv1.weight, conv2.weight, fc1.weight, fc2.weight), - and the next 2 inputs are registered buffers (namely my_buffer2, my_buffer1) and finally - the last 2 inputs are user inputs (namely x and b). - The first output is a buffer mutation (namely my_buffer2) and the last output is the actual model output. - - >>> import pprint - >>> class CustomModule(torch.nn.Module): - ... def __init__(self) -> None: - ... super().__init__() - ... self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) - ... self.register_buffer("my_buffer1", torch.tensor(3.0)) - ... self.register_buffer("my_buffer2", torch.tensor(4.0)) - ... self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, bias=False) - ... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False) - ... self.fc1 = torch.nn.Linear(9216, 128, bias=False) - ... self.fc2 = torch.nn.Linear(128, 10, bias=False) - ... - ... def forward(self, x, b): - ... tensor_x = self.conv1(x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = self.conv2(tensor_x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = torch.nn.functional.max_pool2d(tensor_x, 2) - ... tensor_x = torch.flatten(tensor_x, 1) - ... tensor_x = self.fc1(tensor_x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = self.fc2(tensor_x) - ... output = torch.nn.functional.log_softmax(tensor_x, dim=1) - ... ( - ... self.my_buffer2.add_(1.0) + self.my_buffer1 - ... ) # Mutate buffer through in-place addition - ... return output - >>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3)) - >>> exported_program = torch.export.export( - ... CustomModule(), args=inputs - ... ).run_decompositions({}) - >>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs) - >>> pprint.pprint(onnx_program.model_signature) - ExportGraphSignature(input_specs=[InputSpec(kind=, - arg=TensorArgument(name='p_conv1_weight'), - target='conv1.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_conv2_weight'), - target='conv2.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_fc1_weight'), - target='fc1.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_fc2_weight'), - target='fc2.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='b_my_buffer2'), - target='my_buffer2', - persistent=True), - InputSpec(kind=, - arg=TensorArgument(name='b_my_buffer1'), - target='my_buffer1', - persistent=True), - InputSpec(kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='b'), - target=None, - persistent=None)], - output_specs=[OutputSpec(kind=, - arg=TensorArgument(name='add'), - target='my_buffer2'), - OutputSpec(kind=, - arg=TensorArgument(name='_log_softmax'), - target=None)]) - """ - - return self._model_signature - @property def diagnostic_context(self) -> diagnostics.DiagnosticContext: """The diagnostic context associated with the export.""" @@ -721,10 +598,7 @@ def fake_context(self) -> ONNXFakeContext | None: def adapt_torch_inputs_to_onnx( self, *model_args, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, **model_kwargs, ) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]: """Converts the PyTorch model inputs to exported ONNX model inputs format. @@ -794,10 +668,7 @@ def adapt_torch_inputs_to_onnx( def adapt_torch_outputs_to_onnx( self, model_outputs: Any, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, ) -> Sequence[torch.Tensor | int | float | bool]: """Converts the PyTorch model outputs to exported ONNX model outputs format. @@ -1050,7 +921,7 @@ class Exporter: def __init__( self, options: ResolvedExportOptions, - model: torch.nn.Module | Callable | torch_export.ExportedProgram, + model: torch.nn.Module | Callable, model_args: Sequence[Any], model_kwargs: Mapping[str, Any], ): @@ -1138,9 +1009,6 @@ def export(self) -> ONNXProgram: self.options.fx_tracer.output_adapter, self.options.diagnostic_context, fake_context=self.options.fake_context, - model_signature=getattr( - self.model, "graph_signature", None - ), # Available for isinstance(self.model, ExportedProgram) only model_torch=self.model, ) @@ -1261,12 +1129,12 @@ def missing_opset(package_name: str): def dynamo_export( - model: torch.nn.Module | Callable | torch_export.ExportedProgram, # type: ignore[name-defined] + model: torch.nn.Module | Callable, /, *model_args, export_options: ExportOptions | None = None, **model_kwargs, -) -> ONNXProgram: +) -> ONNXProgram | Any: """Export a torch.nn.Module to an ONNX graph. Args: diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index dc511491d6b41b..8b908bc35e9346 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -120,9 +120,20 @@ class TorchExportStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes - ) + try: + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + new_shapes = ( + torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + ) + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes + ) def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) @@ -148,9 +159,20 @@ class TorchExportNonStrictStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False - ) + try: + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + new_shapes = ( + torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + ) + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False + ) def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 72d2a411c980cd..3fddef36b8b428 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -9,7 +9,6 @@ from typing import Any, Mapping, Sequence, TYPE_CHECKING import torch -import torch.export from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir from torch.onnx._internal.exporter import _core, _onnx_program diff --git a/torch/onnx/_internal/fx/_pass.py b/torch/onnx/_internal/fx/_pass.py index 388ae29cb699ed..5246788756f3ed 100644 --- a/torch/onnx/_internal/fx/_pass.py +++ b/torch/onnx/_internal/fx/_pass.py @@ -176,7 +176,7 @@ class Transform(abc.ABC): One important aspect to note is that if the transformation modifies the model input and/or output signature, (e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep` - are needed to reconcile :attr:`ONNXProgram.model_signature` and :attr:`ONNXProgram.model_proto`. + are needed to reconcile :attr:`ONNXProgram.model_proto`. That is, the model signature and the model representation must match. As an additional feature, this class provides builtin support for transformation recording using the diagnostics. diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py deleted file mode 100644 index aff9c154cd9e4f..00000000000000 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ /dev/null @@ -1,128 +0,0 @@ -# mypy: allow-untyped-defs -# NOTE: This file is referenced by name at -# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. -# introduced by https://github.com/pytorch/pytorch/pull/98894. -# If this file is renamed, moved, etc please update the reference there! - -from __future__ import annotations - -from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING - -import torch._dynamo -import torch.fx -from torch.onnx._internal import _exporter_legacy, io_adapter -from torch.onnx._internal.diagnostics import infra - - -if TYPE_CHECKING: - import torch.onnx - from torch.export.exported_program import ExportedProgram - - -class TorchExport(_exporter_legacy.FXGraphExtractor): - """Generates a FX GraphModule using torch.export API - Args: - aten_graph: If True, exports a graph with ATen operators. - If False, exports a graph with Python operators. - """ - - def __init__( - self, - aten_graph: bool | None = None, - ): - super().__init__() - self.aten_graph = aten_graph or True - - def generate_fx( - self, - options: _exporter_legacy.ResolvedExportOptions, - model: ExportedProgram, # type: ignore[override] - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - ) -> torch.fx.GraphModule: - # No need to translate callable to FX graph. - # This FX Graph extractor assumes `model` was obtained through - # exported_program = torch.export.export( - # model, - # args=model_args, # type: ignore[arg-type] - # kwargs=model_kwargs, # type: ignore[arg-type] - # ) - - # Export FX graph to ONNX ModelProto. - self.input_adapter.append_step( - io_adapter.FlattenInputWithTreeSpecValidationInputStep() - ) - self.input_adapter.append_step( - io_adapter.PrependParamsBuffersConstantAotAutogradInputStep() - ) - - # ONNX does not support None inputs. During graph building, all None inputs - # are removed. Here we register this step to input adapter. - options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep()) - - # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534 - # Dynamo doesn't support non-tensor inputs. - options.fx_tracer.input_adapter.append_step( - io_adapter.RemoveNonTensorInputStep() - ) - - # ONNX does not support complex inputs. During graph building, all complex inputs - # are converted to real representation inputs. Here we register this step to - # input/output adapter. - options.fx_tracer.input_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationInputStep() - ) - - updated_model_args = self.input_adapter.apply( - *model_args, model=model, **model_kwargs - ) - - # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of - # tensor, etc), we flatten the collection and register each element as output. - options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) - - # Output post-processing steps should happen after `FlattenOutputStep`. - options.fx_tracer.output_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationOutputStep() - ) - - options.fx_tracer.output_adapter.append_step( - io_adapter.PrependParamsAndBuffersAotAutogradOutputStep() - ) - - # run_decomposition generates a new graph module with decomposed ops. - # Thus, we need to run this step after io_adapters. - model = model.run_decompositions(options.decomposition_table) - - # Export FX graph to ONNX ModelProto. - return self.pre_export_passes( # type: ignore[return-value] - options, model, model.graph_module, updated_model_args - ) - - def pre_export_passes( - self, - options: _exporter_legacy.ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], - ): - # TODO: Import here to prevent circular dependency - from torch.onnx._internal.fx import analysis, passes - - diagnostic_context = options.diagnostic_context - - # ONNX does not support concept of (implicit) type promotion. - # Insert type casts explicitly where needed. - fx_module = passes.InsertTypePromotion(diagnostic_context, fx_module).run() - - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, fx_module, options.onnxfunction_dispatcher - ).analyze(infra.levels.ERROR) - - # This operation should be invoked as the last pre export pass. - # See [NOTE: Modularize pass ordering] - fx_module = passes.Modularize( - diagnostic_context, fx_module, is_exported_program=True - ).run() - - return fx_module From 004f1f1da66748136ba639b3489979de3ed0bae2 Mon Sep 17 00:00:00 2001 From: Sunita Nadampalli Date: Fri, 6 Sep 2024 01:40:48 +0000 Subject: [PATCH 0515/1018] aarch64: extend matmul heuristic checks to all neoverse platforms (#134548) for aarch64 neoverse platforms there are two gemm backends available for matmul operator on PyTorch: (1) Arm Compute Library and (2) OpenBLAS. While Arm Compute Library provides better performance over OpenBLAS, it has overhead for the kernel launch time, and hence we use OpenBLAS for smaller tensor compute. The heuristic was originally implemented for neoverse_v1. This commit extends the heuristic to other neoverse platforms Pull Request resolved: https://github.com/pytorch/pytorch/pull/134548 Approved by: https://github.com/malfet --- aten/src/ATen/native/LinearAlgebra.cpp | 19 +++++-------------- aten/src/ATen/native/mkldnn/Utils.h | 12 ++++++++++++ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 0611193357112e..2c4e000fce4421 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -1358,13 +1359,8 @@ static inline int64_t get_mkldnn_matmul_min_dim() { static auto value = [&] { const int64_t default_min_dim = [&] { // Minimum dimension requirement for MKLDNN; derived based on experiments. - // By default, it's only enabled on Neoverse V1. -#if !defined(__s390x__) && !defined(__powerpc__) - if (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 && cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1) { - return 8; - } -#endif - return 0; + //it's enabled on all Neoverse cpus. + return is_arm_neoverse() ? 8 : 0; }(); const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_DIM"); return ptr != nullptr ? std::atoi(ptr) : default_min_dim; @@ -1377,13 +1373,8 @@ static inline int64_t get_mkldnn_matmul_min_size() { static auto value = [&] { const int64_t default_min_size = [&] { // Minimum size requirement for MKLDNN; derived based on experiments. - // By default, it's only enabled on Neoverse V1. -#if !defined(__s390x__) && !defined(__powerpc__) - if (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 && cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1) { - return 8 * 1024; - } -#endif - return 0; + // it's enabled on all Neoverse cpus. + return is_arm_neoverse() ? 8 * 1024 : 0; }(); const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_SIZE"); return ptr != nullptr ? std::atoi(ptr) : default_min_size; diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h index a63d9ebfa2c150..2f3c791914e3db 100644 --- a/aten/src/ATen/native/mkldnn/Utils.h +++ b/aten/src/ATen/native/mkldnn/Utils.h @@ -89,10 +89,22 @@ const std::map& fusion_binary_alg_map(); inline bool mkldnn_bf16_device_check_arm() { return cpuinfo_initialize() && cpuinfo_has_arm_bf16(); } + +inline bool is_arm_neoverse() { + return (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 && + (cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1 || + cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v2 || + cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_n1 || + cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_n2)); +} #else constexpr bool mkldnn_bf16_device_check_arm() { return false; } + +constexpr bool is_arm_neoverse() { + return false; +} #endif #if AT_MKLDNN_ENABLED() From 30d92721e0ce1b644a89b7d72c9b8c44a31d85a7 Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 6 Sep 2024 02:27:22 +0000 Subject: [PATCH 0516/1018] Use Python 3.9 on all libtorch jobs (#135245) Part of the migration py3.8->3.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135245 Approved by: https://github.com/izaitsevfb --- .github/templates/upload.yml.j2 | 2 +- ...rm64-binary-libtorch-cxx11-abi-nightly.yml | 2 +- ...ted-windows-binary-libtorch-debug-main.yml | 4 ++-- ...-windows-binary-libtorch-debug-nightly.yml | 24 +++++++++---------- ...d-windows-binary-libtorch-release-main.yml | 4 ++-- ...indows-binary-libtorch-release-nightly.yml | 24 +++++++++---------- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/.github/templates/upload.yml.j2 b/.github/templates/upload.yml.j2 index f4d206d4d7539d..4494af7ac50b33 100644 --- a/.github/templates/upload.yml.j2 +++ b/.github/templates/upload.yml.j2 @@ -45,7 +45,7 @@ {%- if is_windows %} # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" {%- endif %} {%- else %} diff --git a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml index 0d0043cbe99352..4343ac3e75ac33 100644 --- a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml @@ -49,7 +49,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml index e11b4b9bb6961c..85e2564d612f45 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml @@ -51,7 +51,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -169,7 +169,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index 44178c422a1b80..215dbe681896e2 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -58,7 +58,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -176,7 +176,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -290,7 +290,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cpu-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -316,7 +316,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -435,7 +435,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -550,7 +550,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda11_8-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -576,7 +576,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -695,7 +695,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -810,7 +810,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda12_1-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -836,7 +836,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -955,7 +955,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -1070,7 +1070,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda12_4-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-windows-binary-libtorch-release-main.yml b/.github/workflows/generated-windows-binary-libtorch-release-main.yml index 26c35089427569..7fd315028bb702 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-main.yml @@ -51,7 +51,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -169,7 +169,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index 388c8cb5ed9379..c3ce65daff7097 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -58,7 +58,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -176,7 +176,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -290,7 +290,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cpu-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -316,7 +316,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -435,7 +435,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -550,7 +550,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda11_8-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -576,7 +576,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -695,7 +695,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -810,7 +810,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda12_1-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -836,7 +836,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -955,7 +955,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -1070,7 +1070,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda12_4-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} From c0049b7e7f314d9655c7eace163ad10475912a42 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 5 Sep 2024 12:05:15 -0700 Subject: [PATCH 0517/1018] Improve test_public_bindings import module error reporting (#135258) Error was hard to understand without message. Render it now. See https://github.com/pytorch/pytorch/pull/135259 for it in action. Example failure: ``` 2024-09-05T20:04:45.3022000Z FAILED [5.9524s] test_public_bindings.py::TestPublicBindings::test_modules_can_be_imported - AssertionError: String comparison failed: '' != "torch._logging.scribe failed to import w[112 chars].py)" 2024-09-05T20:04:45.3025413Z + torch._logging.scribe failed to import with error ImportError: cannot import name 'TypeAlias' from 'typing' (/opt/conda/envs/py_3.9/lib/python3.9/typing.py) 2024-09-05T20:04:45.3026990Z ``` Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135258 Approved by: https://github.com/albanD --- test/test_public_bindings.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 5433540aeb2b6a..94d5f7804b6b87 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -3,6 +3,7 @@ import importlib import inspect import json +import logging import os import pkgutil import unittest @@ -20,6 +21,9 @@ ) +log = logging.getLogger(__name__) + + class TestPublicBindings(TestCase): def test_no_new_reexport_callables(self): """ @@ -267,7 +271,9 @@ def test_modules_can_be_imported(self): failures = [] def onerror(modname): - failures.append((modname, ImportError)) + failures.append( + (modname, ImportError("exception occurred importing package")) + ) for mod in pkgutil.walk_packages(torch.__path__, "torch.", onerror=onerror): modname = mod.name @@ -279,8 +285,8 @@ def onerror(modname): importlib.import_module(modname) except Exception as e: # Some current failures are not ImportError - - failures.append((modname, type(e))) + log.exception("import_module failed") + failures.append((modname, e)) # It is ok to add new entries here but please be careful that these modules # do not get imported by public code. @@ -441,14 +447,16 @@ def onerror(modname): } errors = [] - for mod, excep_type in failures: + for mod, exc in failures: if mod in public_allowlist: # TODO: Ensure this is the right error type continue if mod in private_allowlist: continue - errors.append(f"{mod} failed to import with error {excep_type}") + errors.append( + f"{mod} failed to import with error {type(exc).__qualname__}: {str(exc)}" + ) self.assertEqual("", "\n".join(errors)) # AttributeError: module 'torch.distributed' has no attribute '_shard' From ce6b4e414a151e97a6f47bff275c39a7dcd11a4b Mon Sep 17 00:00:00 2001 From: Haibo Chen Date: Fri, 6 Sep 2024 02:48:54 +0000 Subject: [PATCH 0518/1018] add instrumentation of CCA stats for reserved and allocated memory size (#135231) As titled Pull Request resolved: https://github.com/pytorch/pytorch/pull/135231 Approved by: https://github.com/c-p-i-o --- c10/cuda/CUDACachingAllocator.cpp | 32 +++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index fa26a9f3e0cca7..63f709596d43bc 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -1463,6 +1464,12 @@ class DeviceCachingAllocator { if (block->size >= CUDAAllocatorConfig::max_split_size()) increase_stat(stats.oversize_allocations, 1); + auto allocated_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes); + allocated_bytes_gauge.record( + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current); + c10::reportMemoryUsageToProfiler( block->ptr, static_cast(block->size), @@ -1490,6 +1497,11 @@ class DeviceCachingAllocator { decrease_stat(stats.allocation[stat_type], 1); decrease_stat(stats.allocated_bytes[stat_type], block->size); }); + auto allocated_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes); + allocated_bytes_gauge.record( + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current); record_trace( TraceEntry::FREE_REQUESTED, @@ -2279,6 +2291,11 @@ class DeviceCachingAllocator { for_each_selected_stat_type(stat_types, [&](size_t stat_type) { increase_stat(stats.reserved_bytes[stat_type], mapped_range.size); }); + auto reserved_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); + reserved_bytes_gauge.record( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current); stats.num_device_alloc++; record_trace( @@ -2721,6 +2738,11 @@ class DeviceCachingAllocator { }); if (size >= CUDAAllocatorConfig::max_split_size()) increase_stat(stats.oversize_segments, 1); + auto reserved_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); + reserved_bytes_gauge.record( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current); // p.block came from new, not cudaMalloc. It should not be nullptr here. TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); @@ -2858,6 +2880,11 @@ class DeviceCachingAllocator { decrease_stat(stats.segment[stat_type], 1); decrease_stat(stats.reserved_bytes[stat_type], block->size); }); + auto reserved_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); + reserved_bytes_gauge.record( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current); if (block->size >= CUDAAllocatorConfig::max_split_size()) decrease_stat(stats.oversize_segments, 1); @@ -2914,6 +2941,11 @@ class DeviceCachingAllocator { for_each_selected_stat_type(stat_types, [&](size_t stat_type) { decrease_stat(stats.reserved_bytes[stat_type], unmapped.size); }); + auto reserved_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); + reserved_bytes_gauge.record( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current); if (block->pool->owner_PrivatePool) { // The cudaFreed block belonged to a CUDA graph's PrivatePool. From c4b495415e9f67e1e279dde0edc143da37c16672 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 5 Sep 2024 12:24:49 -0700 Subject: [PATCH 0519/1018] Consolidate raise and rewrap raise error branches (#135148) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135148 Approved by: https://github.com/anijain2305, https://github.com/albanD, https://github.com/yanboliang, https://github.com/malfet --- torch/_dynamo/convert_frame.py | 41 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 36c889a17d7cfc..284e36f3d4a970 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -914,24 +914,6 @@ def format_guard_failures() -> str: try: guarded_code = compile_inner(code, one_graph, hooks, transform) return guarded_code - except ( - Unsupported, - TorchRuntimeError, - BackendCompilerFailed, - AssertionError, - ConstraintViolationError, - GuardOnDataDependentSymNode, - ValidationException, - UncapturedHigherOrderOpError, - BisectValidationException, - ) as e: - fail_type = str(type(e)) - fail_reason = str(e) - exception_handler(e, code, frame, export=export) - fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( - e, compile_id - ) - raise except Exception as e: fail_type = str(type(e)) fail_reason = str(e) @@ -939,9 +921,26 @@ def format_guard_failures() -> str: fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( e, compile_id ) - raise InternalTorchDynamoError(str(e)).with_traceback( - e.__traceback__ - ) from None + if isinstance( + e, + ( + Unsupported, + TorchRuntimeError, + BackendCompilerFailed, + AssertionError, + ConstraintViolationError, + GuardOnDataDependentSymNode, + ValidationException, + UncapturedHigherOrderOpError, + BisectValidationException, + ), + ): + raise + else: + # Rewrap for clarity + raise InternalTorchDynamoError(str(e)).with_traceback( + e.__traceback__ + ) from None finally: if tracer: tracer.output.local_scope = {} From 2a551368ed2518132cc26c838c11d77296357922 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 5 Sep 2024 12:24:50 -0700 Subject: [PATCH 0520/1018] Include exception type qualname when rewrapping InternalTorchDynamoError (#135145) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135145 Approved by: https://github.com/drisspg, https://github.com/anijain2305 ghstack dependencies: #135148 --- test/dynamo/test_exc.py | 2 +- torch/_dynamo/convert_frame.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index 11d240924eae37..0f137776b11a71 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -107,7 +107,7 @@ def f(ctx): Traceback (most recent call last): File "test_exc.py", line N, in f raise NotImplementedError -torch._dynamo.exc.InternalTorchDynamoError: +torch._dynamo.exc.InternalTorchDynamoError: NotImplementedError: from user code: File "test_exc.py", line N, in fn001 diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 284e36f3d4a970..4a28b140e151aa 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -917,6 +917,8 @@ def format_guard_failures() -> str: except Exception as e: fail_type = str(type(e)) fail_reason = str(e) + # NB: e's msg is mutated here to add user stack, but we DON'T want + # that stack in the Scuba logged fail_reason exception_handler(e, code, frame, export=export) fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( e, compile_id @@ -938,9 +940,9 @@ def format_guard_failures() -> str: raise else: # Rewrap for clarity - raise InternalTorchDynamoError(str(e)).with_traceback( - e.__traceback__ - ) from None + raise InternalTorchDynamoError( + f"{type(e).__qualname__}: {str(e)}" + ).with_traceback(e.__traceback__) from None finally: if tracer: tracer.output.local_scope = {} From ebc7e3626458f18bb1594fc7ef3aef0a90b4442e Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 5 Sep 2024 12:24:50 -0700 Subject: [PATCH 0521/1018] Report qualname of exception type rather than (#135146) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135146 Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/yanboliang ghstack dependencies: #135148, #135145 --- torch/_dynamo/convert_frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 4a28b140e151aa..070b6feacbb369 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -915,7 +915,7 @@ def format_guard_failures() -> str: guarded_code = compile_inner(code, one_graph, hooks, transform) return guarded_code except Exception as e: - fail_type = str(type(e)) + fail_type = type(e).__qualname__ fail_reason = str(e) # NB: e's msg is mutated here to add user stack, but we DON'T want # that stack in the Scuba logged fail_reason From c4e66791edf1aa37decdb921416d5901b7ab6c8c Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 5 Sep 2024 14:49:14 -0400 Subject: [PATCH 0522/1018] Remove dead expect_rational (#135105) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135105 Approved by: https://github.com/malfet --- torch/fx/experimental/sym_node.py | 4 +--- torch/fx/experimental/symbolic_shapes.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index add6cedb428a45..268f56e3214fb9 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -436,9 +436,7 @@ def guard_int(self, file, line): def guard_float(self, file, line): # TODO: use the file/line for some useful diagnostic on why a # guard occurred - r = self.shape_env.evaluate_expr( - self.expr, self.hint, fx_node=self.fx_node, expect_rational=False - ) + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) try: return float(r) except Exception: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index bc6645c617a479..c6a228f5256e71 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4428,7 +4428,7 @@ def add_expr(expr): @_lru_cache def _maybe_evaluate_static( self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False, - expect_rational=True, size_oblivious: bool = False, axioms: Optional[Tuple[sympy.Expr]] = None, + size_oblivious: bool = False, axioms: Optional[Tuple[sympy.Expr]] = None, var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None ) -> "Optional[sympy.Expr]": """ @@ -5121,18 +5121,18 @@ def _log_guard(self, prefix: str, g, forcing_spec: bool): @lru_cache(256) @record_shapeenv_event(save_tracked_fakes=True) def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, - expect_rational=True, size_oblivious: bool = False, *, forcing_spec: bool = False): + size_oblivious: bool = False, *, forcing_spec: bool = False): try: - return self._evaluate_expr(orig_expr, hint, fx_node, expect_rational, size_oblivious, forcing_spec=forcing_spec) + return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec) except Exception: self.log.warning( - "failed during evaluate_expr(%s, hint=%s, expect_rational=%s, size_oblivious=%s, forcing_spec=%s", - orig_expr, hint, expect_rational, size_oblivious, forcing_spec + "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s", + orig_expr, hint, size_oblivious, forcing_spec ) raise def _evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, - expect_rational=True, size_oblivious: bool = False, *, forcing_spec: bool = False): + size_oblivious: bool = False, *, forcing_spec: bool = False): """ Given an expression, evaluates it, adding guards if necessary """ @@ -5200,7 +5200,6 @@ def compute_concrete_val(): expr = orig_expr static_expr = self._maybe_evaluate_static(expr, - expect_rational=expect_rational, size_oblivious=size_oblivious) if static_expr is not None: self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr) @@ -5220,7 +5219,6 @@ def compute_concrete_val(): if not size_oblivious: size_oblivious_result = self._maybe_evaluate_static( expr, - expect_rational=expect_rational, size_oblivious=True ) From 8d63c8d80ad6962a2410ce7f0e222eb1caf86e5b Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 5 Sep 2024 06:18:48 -0700 Subject: [PATCH 0523/1018] [Inductor] Fix AOT weight alignment issue on CPU (#135205) **Summary** Fix issue: https://github.com/pytorch/pytorch/issues/135027. On CPU, the `consts_size` used to generate `_binary_constants_bin_start` is not padded to `ALIGN_BYTES`, while `serialized_weights` is, causing a failure in the 16K alignment check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135205 Approved by: https://github.com/jgong5, https://github.com/desertfire --- ...ctor_amp_freezing_torchbench_inference.csv | 2 +- ...inductor_freezing_torchbench_inference.csv | 2 +- ...ctor_amp_freezing_torchbench_inference.csv | 2 +- ...inductor_freezing_torchbench_inference.csv | 2 +- torch/_inductor/codecache.py | 26 +++++++++++++------ 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv index e3772d03264f73..13762422e2a83e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv @@ -90,7 +90,7 @@ detectron2_maskrcnn_r_50_fpn,fail_to_run,0 -dlrm,fail_to_run,0 +dlrm,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv index 36064e7361d0eb..1349ce2537da16 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv @@ -90,7 +90,7 @@ detectron2_maskrcnn_r_50_fpn,fail_to_run,0 -dlrm,fail_to_run,0 +dlrm,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv index 21cee0dc41ca15..085ea372032622 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,fail_to_run,0 -dlrm,fail_to_run,0 +dlrm,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv index fb0c5ec473135a..2f4d4d0a4fa91b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,fail_to_run,0 -dlrm,fail_to_run,0 +dlrm,pass,0 diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index e3c4faedcb9f73..aaac1edc821397 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -62,6 +62,8 @@ ) from torch._utils_internal import log_cache_bypass +from .utils import _align + T = TypeVar("T") @@ -1728,10 +1730,23 @@ def _compile_consts_darwin(consts: bytes) -> str: ) output_o = os.path.splitext(input_path)[0] + ".o" + + all_cuda = all( + graph.get_original_value_of_constant(name).is_cuda + for name in graph.constants.keys() + if name not in graph.folded_constants + ) + + def get_nbytes_of_tensor(tensor: torch.Tensor, all_cuda: bool) -> int: + n_bytes = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + return n_bytes if all_cuda else _align(n_bytes) + consts_size = sum( - torch.ops.mkldnn._nbytes(tensor) - if tensor.is_mkldnn - else tensor.untyped_storage().nbytes() + get_nbytes_of_tensor(tensor, all_cuda) for (name, tensor) in graph.constants.items() if name not in graph.folded_constants ) @@ -1825,11 +1840,6 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: raw_bytes = bytes(raw_array.contents) return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) - all_cuda = all( - graph.get_original_value_of_constant(name).is_cuda - for name in graph.constants.keys() - if name not in graph.folded_constants - ) serialized_weights = b"".join( _to_bytes(graph.get_original_value_of_constant(name), all_cuda) for name in graph.constants.keys() From 3ee74d778a7099755fa8426559e7a8501eb93b7e Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 4 Sep 2024 19:50:59 -0700 Subject: [PATCH 0524/1018] [Inductor][CPP] Avoid mistake wgt tensor delete (#135100) **Summary** Fix issue: https://github.com/pytorch/pytorch/issues/134998: Previously, we only checked if the `get_attr` FX node for the weight had a single user node. However, two `get_attr` nodes may share the same tensor and should not be deleted in such cases. In this PR, we add the count of users for tensor along with the num of users for nodes to decide whether this tensor can be deleted or not. **TestPlan** ``` python test/inductor/test_cpu_select_algorithm.py -k test_linear_wgt_multi_users ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135100 Approved by: https://github.com/jgong5 --- test/inductor/test_cpu_select_algorithm.py | 29 +++++++++++++++ torch/_inductor/codegen/cpp_gemm_template.py | 39 +++++++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 6cebd99704da2f..3dd5400e68eb53 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -168,6 +168,35 @@ def forward(self, x): else: self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("in_features", (1000,)) + @parametrize("out_features", (1024,)) + @parametrize("bias", (True,)) + @dtypes( + torch.float, + ) + def test_linear_wgt_multi_users(self, in_features, out_features, bias, dtype): + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.embeddings = torch.nn.Embedding(out_features, in_features) + self.linear = torch.nn.Linear(in_features, out_features, bias) + self.linear.weight = self.embeddings.weight + + def forward(self, x): + x = self.embeddings(x) + return self.linear(x) + + counters.clear() + mod = M(bias=bias).to(dtype=dtype).eval() + v = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) + with verify(dtype) as (atol, rtol): + self.common(mod, (v,), atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @inductor_config.patch({"freezing": True}) @patches @torch.no_grad diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index a0856a15973349..b4b33b28216b4b 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -680,8 +680,45 @@ def postprocessor(output): # non-retraceable. To support retracing, we can add a repack node to the # FX graph. For example: # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template + W_tensor_users = 0 for node in reversed(V.graph.graph.nodes): - if node.name == W_node.get_name() and len(node.users) == 1: + # Case may happen when the wgt tensor is used by more than 1 get_attr node + # https://github.com/pytorch/pytorch/issues/134998 + if node.op == "get_attr" and hasattr( + V.graph.module, node.name + ): # wgt might already be deleted + comp_tensor = getattr(V.graph.module, node.name) + if ( + W.is_mkldnn == comp_tensor.is_mkldnn + and W.dtype == comp_tensor.dtype + and W.device == comp_tensor.device + and ( + ( + not W.is_mkldnn + and ( + W.untyped_storage().data_ptr() + == comp_tensor.untyped_storage().data_ptr() + ) + ) + or ( + W.is_mkldnn + and ( + torch.ops.mkldnn.data_ptr(W) + == torch.ops.mkldnn.data_ptr(comp_tensor) + ) + ) + ) + ): + W_tensor_users += 1 + + for node in reversed(V.graph.graph.nodes): + # The wgt tensor has been used by only 1 get_attr node + # The get_attr node has only 1 user fx node + if ( + node.name == W_node.get_name() + and len(node.users) == 1 + and W_tensor_users == 1 + ): del V.graph.constants[node.name] delattr(V.graph.module, node.name) delattr(V.graph.graph.owning_module, node.name) From 18cb2eb9f66d2df1b218f0cffb581f765836c468 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Fri, 6 Sep 2024 03:21:07 +0000 Subject: [PATCH 0525/1018] [inductor] check intel compiler minimal version (#135209) On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor. Add minimal version check for Intel compiler. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135209 Approved by: https://github.com/ezyang --- torch/_inductor/cpp_builder.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index dade21d4303e58..95a0bff86fd8a4 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -24,6 +24,7 @@ from torch._inductor import config, exc from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA from torch._inductor.runtime.runtime_utils import cache_dir +from torch.torch_version import TorchVersion if config.is_fbcode(): @@ -200,6 +201,16 @@ def _is_msvc_cl(cpp_compiler: str) -> bool: @functools.lru_cache(None) def _is_intel_compiler(cpp_compiler: str) -> bool: + def _check_minimal_version(compiler_version: TorchVersion) -> None: + """ + On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor. + """ + min_version = "2024.2.1" if _IS_WINDOWS else "0.0.0" + if compiler_version < TorchVersion(min_version): + raise RuntimeError( + f"Intel Compiler error: less than minimal version {min_version}." + ) + try: output_msg = ( subprocess.check_output( @@ -215,6 +226,13 @@ def _is_intel_compiler(cpp_compiler: str) -> bool: raise RuntimeError( "Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." ) + + # Version check + icx_ver_search = re.search(r"(\d+[.]\d+[.]\d+[.]\d+)", output_msg) + if icx_ver_search is not None: + icx_ver = icx_ver_search.group(1) + _check_minimal_version(TorchVersion(icx_ver)) + return is_intel_compiler except FileNotFoundError as exc: return False From eadc9a5c56f97a9be260813c8290617a6ce292b3 Mon Sep 17 00:00:00 2001 From: hippocookie Date: Fri, 6 Sep 2024 03:25:07 +0000 Subject: [PATCH 0526/1018] [Distributed] Change function call in test to non-deprecated to eliminate warning (#134938) Migrate function call in test to eliminate warning message in below and reduce the chance of test fail when methods removed - from deprecated `save_state_dict` change to `save` - from deprecated `load_state_dict` change to `load` Warning message: ```bash pytorch/test/distributed/checkpoint/test_fsdp_model_state.py:37: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134938 Approved by: https://github.com/wz337, https://github.com/fegin --- test/distributed/checkpoint/test_fsdp_model_state.py | 4 ++-- .../checkpoint/test_fsdp_tp_checkpoint_conversion.py | 2 +- test/distributed/checkpoint/test_hsdp_checkpoint.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/distributed/checkpoint/test_fsdp_model_state.py b/test/distributed/checkpoint/test_fsdp_model_state.py index c90e5c6f267ef5..12d013d5e776e6 100644 --- a/test/distributed/checkpoint/test_fsdp_model_state.py +++ b/test/distributed/checkpoint/test_fsdp_model_state.py @@ -34,7 +34,7 @@ def _test_fsdp_model_state(self, process_group) -> None: "model": model.state_dict(), } - dist_cp.save_state_dict( + dist_cp.save( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), planner=DefaultSavePlanner(), @@ -55,7 +55,7 @@ def _test_fsdp_model_state(self, process_group) -> None: "model": model_2.state_dict(), } - dist_cp.load_state_dict( + dist_cp.load( state_dict=state_dict, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), planner=DefaultLoadPlanner(), diff --git a/test/distributed/checkpoint/test_fsdp_tp_checkpoint_conversion.py b/test/distributed/checkpoint/test_fsdp_tp_checkpoint_conversion.py index adf2a6d3496f49..728524b4011d6b 100644 --- a/test/distributed/checkpoint/test_fsdp_tp_checkpoint_conversion.py +++ b/test/distributed/checkpoint/test_fsdp_tp_checkpoint_conversion.py @@ -40,7 +40,7 @@ def test_fsdp_to_tp(self): fsdp_state_dict = fsdp_model.state_dict() # save fsdp_state_dict to storage - dist_cp.save_state_dict( + dist_cp.save( state_dict=fsdp_state_dict, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), ) diff --git a/test/distributed/checkpoint/test_hsdp_checkpoint.py b/test/distributed/checkpoint/test_hsdp_checkpoint.py index 02904baf6a7977..23ca7c9463be7f 100644 --- a/test/distributed/checkpoint/test_hsdp_checkpoint.py +++ b/test/distributed/checkpoint/test_hsdp_checkpoint.py @@ -94,7 +94,7 @@ def test_hsdp_checkpoint(self, is_even_sharded_model) -> None: state_dict = {"model": model.state_dict()} state_dict_to_save = deepcopy(state_dict) - dist_cp.save_state_dict( + dist_cp.save( state_dict=state_dict_to_save, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), planner=DefaultSavePlanner(), @@ -113,7 +113,7 @@ def test_hsdp_checkpoint(self, is_even_sharded_model) -> None: self.assertEqual(v1.placements, v2.placements) self.assertNotEqual(v1.to_local(), v2.to_local()) - dist_cp.load_state_dict( + dist_cp.load( state_dict=state_dict_to_save, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), planner=DefaultLoadPlanner(), From 4d1d1a7f49eeefc52ed11c04fafbeb7615832cce Mon Sep 17 00:00:00 2001 From: Xuan Zhang Date: Fri, 6 Sep 2024 04:12:39 +0000 Subject: [PATCH 0527/1018] [inductor][debug] fix draw_buffers (#135266) **Before:** ![image](https://github.com/user-attachments/assets/aac756f3-1349-4647-9da3-87cf105cf647) **After:** image Pull Request resolved: https://github.com/pytorch/pytorch/pull/135266 Approved by: https://github.com/yf225 --- torch/_inductor/debug.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index df26bba94190d0..868833a425be4b 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -111,6 +111,7 @@ def func1(*args: Any) -> int: FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) buf_to_fx_node = {} + node_to_fx_node = {} graph = torch.fx.Graph() first_node = None @@ -162,10 +163,9 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) - if isinstance(snode, FusedSchedulerNode): - for x in snode.snodes: - buf_to_fx_node[x.get_name()] = fx_node - buf_to_fx_node[name] = fx_node + node_to_fx_node[name] = fx_node + for buf in snode.get_outputs(): + buf_to_fx_node[buf.get_name()] = fx_node if first_node is None: first_node = fx_node @@ -175,7 +175,7 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: name = snode.get_name() deps = snode.read_writes.reads - fx_node = buf_to_fx_node[name] + fx_node = node_to_fx_node[name] new_args = [] for dep in deps: if dep.name in buf_to_fx_node: @@ -184,6 +184,8 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: with graph.inserting_before(first_node): dep_node = graph.placeholder(dep.name) buf_to_fx_node[dep.name] = dep_node + if dep_node == fx_node: # to avoid cycles + continue new_args.append(dep_node) fx_node.args = tuple(new_args) From 43c700f2fab644489b89fcfd5f173c27a944db35 Mon Sep 17 00:00:00 2001 From: chilli Date: Thu, 5 Sep 2024 18:58:23 -0700 Subject: [PATCH 0528/1018] Add randomness checking for sdpa vmap (#135176) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135176 Approved by: https://github.com/zou3519 --- .../functorch/BatchRulesDecompositions.cpp | 10 ++++- .../functorch/BatchRulesLinearAlgebra.cpp | 15 ++++++++ test/functorch/test_vmap.py | 38 +++++++++++++++++++ test/functorch/test_vmap_registrations.py | 1 - 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 5739e88d5ddcc1..21f17cff1f18f3 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -23,6 +23,9 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { OP_DECOMPOSE(dropout_); OP_DECOMPOSE(feature_alpha_dropout_); OP_DECOMPOSE(feature_dropout_); + OP_DECOMPOSE(dropout); + OP_DECOMPOSE(_scaled_dot_product_attention_math); + OP_DECOMPOSE(scaled_dot_product_attention); } static void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) { @@ -235,7 +238,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(relu6_); OP_DECOMPOSE(prelu); OP_DECOMPOSE2(softmax, int); - OP_DECOMPOSE(scaled_dot_product_attention); OP_DECOMPOSE(special_gammainc); OP_DECOMPOSE(special_gammaincc); OP_DECOMPOSE(special_logit); @@ -261,7 +263,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(special_xlogy); OP_DECOMPOSE2(special_xlogy, other_scalar); OP_DECOMPOSE2(special_xlogy, self_scalar); - OP_DECOMPOSE(_scaled_dot_product_attention_math); m.impl("split.sizes", native::split_symint); @@ -386,6 +387,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE2(to, dtype); OP_DECOMPOSE2(to, dtype_layout); OP_DECOMPOSE2(to, other); + + // Random ops that are also registered here + OP_DECOMPOSE(dropout); + OP_DECOMPOSE(_scaled_dot_product_attention_math); + OP_DECOMPOSE(scaled_dot_product_attention); } } // namespace at::functorch diff --git a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp index 6047a6eddb65d8..fed7fecc217b98 100644 --- a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp +++ b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp @@ -496,6 +496,11 @@ _scaled_dot_product_flash_attention_batch_rule( bool return_debug_mask, c10::optional scale ) { + if (dropout_p > 0) { + auto maybe_layer = maybeCurrentDynamicLayer(); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value()); + } auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); auto query_ = moveBatchDimToFront(query, query_bdim); auto key_ = moveBatchDimToFront(key, key_bdim); @@ -540,6 +545,11 @@ fourOutputs _scaled_dot_product_efficient_attention_batch_rule( bool is_causal, c10::optional scale ) { + if (dropout_p > 0) { + auto maybe_layer = maybeCurrentDynamicLayer(); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value()); + } auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); auto query_ = moveBatchDimToFront(query, query_bdim); auto key_ = moveBatchDimToFront(key, key_bdim); @@ -577,6 +587,11 @@ _scaled_dot_product_cudnn_attention_batch_rule( bool return_debug_mask, c10::optional scale ) { + if (dropout_p > 0) { + auto maybe_layer = maybeCurrentDynamicLayer(); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value()); + } auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); auto query_ = moveBatchDimToFront(query, query_bdim); auto key_ = moveBatchDimToFront(key, key_bdim); diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index a8358e62796278..7b7ae117f95c63 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3923,6 +3923,44 @@ def T(*args): in_dims=(2, 1, None), ) + @parametrize("backend", PLATFORM_SPECIFIC_SDPA) + @parametrize("randomness", ["error", "same", "different"]) + def test_randomness(self, device, randomness, backend): + if device == "cpu": + raise unittest.SkipTest("This test is only for CUDA for now") + backend_ctx = sdpa_kernel([backend]) + with backend_ctx: + B = 4 + query = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) + key = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) + value = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) + + def f(q, k, v, dropout): + return F.scaled_dot_product_attention(q, k, v, dropout_p=dropout) + + # No matter the randomness mode, dropout=0.0 should pass + vmap( + functools.partial(f, dropout=0.0), + in_dims=(0, 0, 0), + randomness=randomness, + )(query, key, value) + + fail_with_randomness = randomness == "error" + if backend != SDPBackend.MATH: + fail_with_randomness |= randomness == "same" + context = ( + self.assertRaises(RuntimeError) + # We currently don't support randomness == "same", and "error" should always error with randomness + if fail_with_randomness + else contextlib.nullcontext() + ) + with context: + vmap( + functools.partial(f, dropout=0.5), + in_dims=(0, 0, 0), + randomness=randomness, + )(query, key, value) + @allowVmapFallbackUsage def test_inplace_view(self, device): leaf = torch.randn(4, 5, requires_grad=True) diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index bcf5bbc42c8329..1bff959e3c4f82 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -63,7 +63,6 @@ "aten::diagflat", "aten::divide.out_mode", "aten::divide_.Scalar", - "aten::dropout", "aten::dropout_", "aten::embedding_bag", "aten::embedding_bag.padding_idx", From e4a5e2c5b81dcaaad4b73585388a64b04c14f431 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 6 Sep 2024 04:52:30 +0000 Subject: [PATCH 0529/1018] =?UTF-8?q?[RDP]=20Fix=20"No=20module=20named=20?= =?UTF-8?q?'libfb=E2=80=99"=20(#135244)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: D62215095 Introduced an import error to arvr pipelines as the is_fbcode() function does not work as intended. This changes is_fbcode() to be a much stricter check. Test Plan: ``` buck2 run arvr/mode/platform010/opt-stripped //arvr/libraries/depthlink/clients/mr_replay:pipeline_runner -c bolt.use_eva3_sim=True -- --config_file arvr/libraries/depthlink/clients/mr_replay/configs/runner_config.yaml --features DEPTH ``` Differential Revision: D62237502 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135244 Approved by: https://github.com/aorenste --- torch/_inductor/config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 2332bebf2c7fda..684728ea48908a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -573,17 +573,18 @@ def decide_compile_threads() -> int: # gemm autotuning global cache dir if is_fbcode(): - from libfb.py import parutil - try: + from libfb.py import parutil + if __package__: global_cache_dir = parutil.get_dir_path( os.path.join(__package__.replace(".", os.sep), "fb/cache") ) else: global_cache_dir = parutil.get_dir_path("fb/cache") - except ValueError: + except (ValueError, ModuleNotFoundError): global_cache_dir = None + else: global_cache_dir = None From 2e8d508e8a8f53ed99f5cf1b274d76f7f71b60db Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 5 Sep 2024 20:22:29 -0700 Subject: [PATCH 0530/1018] [fx] Bypass custom __setattr__ in Node.__init__ (#135079) Before: ![image](https://github.com/user-attachments/assets/5f0a6ae6-6049-44d0-b5f2-a549a23ad97f) After: ![image](https://github.com/user-attachments/assets/51c9f91b-f8a0-4043-8362-65813feec823) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135079 Approved by: https://github.com/oulgen ghstack dependencies: #135070, #135076, #135082, #135084 --- torch/fx/graph.py | 4 +-- torch/fx/node.py | 77 +++++++++++++++++++++++++++++------------------ 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 62daf414ec4f4a..cb3fbeb33d83b9 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -816,10 +816,10 @@ def remove(self, node: Node) -> None: def find_nodes(self, *, op: str, target: Optional['Target'] = None): if op == "call_function": assert target is not None - return dict(self.table[(op, target)]).keys() + return [*self.table[(op, target)].keys()] if target is None: - return dict(self.table[(op, None)]).keys() + return [*self.table[(op, None)].keys()] # op is call_method, get_attr, call_module return [node for node in self.table[(op, None)].keys() if node.target == target] diff --git a/torch/fx/node.py b/torch/fx/node.py index f84b23e29ddf85..323f954d1a3e2e 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -168,6 +168,16 @@ class Node(_NodeBase): """ _args: Tuple['Argument', ...] _kwargs: Dict[str, 'Argument'] + graph: 'Graph' + name: str + op: str + target: 'Target' + _input_nodes: Dict['Node', None] + users: Dict['Node', None] + type: Optional[Any] + _sort_key: Any + _repr_fn: Optional[Callable[['Node'], str]] + meta: Dict[str, Any] @compatibility(is_backward_compatible=True) def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', @@ -199,11 +209,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', annotation of values in the generated code or for other types of analyses. """ - super().__init__() - self.graph = graph - self.name = name # unique name of value being created assert op in _legal_ops - self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr if op == 'call_function': if not callable(target): raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' @@ -212,13 +218,22 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', if not isinstance(target, str): raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' 'but a str is expected') - self.target = target # for method/module/function, the name of the method/module/function/attr + super().__init__() + + # bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects + assign = object.__setattr__ + + assign(self, "graph", graph) + assign(self, "name", name) # unique name of value being created + assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + + assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add # All `Node`-valued inputs. Key is the Node, value is don't-care. # The public API for this is `all_input_nodes`, this private attribute # should not be accessed directly. - self._input_nodes : Dict[Node, None] = {} + assign(self, "_input_nodes", {}) self.__update_args_kwargs(args, kwargs) # All of the nodes that use the value produced by this Node @@ -226,7 +241,8 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # would appear once here, but represents two uses. # # Is a dict to act as an "ordered set". Keys are significant, value dont-care - self.users : Dict[Node, None] = {} + assign(self, "users", {}) + # Type expression representing the output value of this node. # This should contain the same class of Type objects that would appear # as type annotations for function inputs/outputs. @@ -237,15 +253,15 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # generated function return type. (Note this is a special case. ``return`` # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. - self.type : Optional[Any] = return_type - self._sort_key: Any = () + assign(self, "type", return_type) + assign(self, "_sort_key", ()) # If set, use this fn to print this node - self._repr_fn : Optional[Callable[[Node], str]] = None + assign(self, "_repr_fn", None) # Dictionary to store metadata passes need to do their # transformations. This metadata is preserved across node copies - self.meta : Dict[str, Any] = {} + assign(self, "meta", {}) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -490,14 +506,14 @@ def update_users_and_input_nodes(n: Any) -> Any: # Clear prior users and input_nodes for old_use in self._input_nodes.keys(): old_use.users.pop(self) - self._input_nodes = {} + object.__setattr__(self, "_input_nodes", {}) # bypass Node.__setattr__ # We do three things in a single pass of the args # - Normalize list->immutable_list, dict->immutable_dict, etc # - Populate self._input_nodes # - Populate arg.users[self] for each arg - self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment] - self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment] + object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes)) + object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)) def __repr__(self) -> str: if self._repr_fn: @@ -740,23 +756,24 @@ def _rename(self, candidate: str) -> None: self.graph._graph_namespace._rename_object(self, name) def __setattr__(self, name: str, value: Any) -> None: - if name == 'name' and hasattr(self, "name"): - m = self.graph.owning_module - if getattr(m, "_replace_hook", None): - assert isinstance(value, str) - for user in self.users: - m._replace_hook(old=self, new=value, user=user) - update = False - if ( - hasattr(self, name) and - hasattr(self.graph, "_find_nodes_lookup_table") and - self in self.graph._find_nodes_lookup_table - ): - update = True - self.graph._find_nodes_lookup_table.remove(self) + if name in _setattr_custom_handling: + if name == "name": + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + assert isinstance(value, str) + for user in self.users: + m._replace_hook(old=self, new=value, user=user) + elif name in ("op", "target"): + table = getattr(self.graph, "_find_nodes_lookup_table", None) + if table and self in table: + self.graph._find_nodes_lookup_table.remove(self) + object.__setattr__(self, name, value) + self.graph._find_nodes_lookup_table.insert(self) + return + object.__setattr__(self, name, value) - if update: - self.graph._find_nodes_lookup_table.insert(self) + +_setattr_custom_handling = dict.fromkeys(["name", "op", "target"]) @compatibility(is_backward_compatible=True) def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: From 5e9afc133d11ddfe0d4f65dafc73332e85e95fff Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 5 Sep 2024 20:22:29 -0700 Subject: [PATCH 0531/1018] [inductor] Skip retracing an existing LoopBody (#135235) This is roughly a 7% speedup in inductor compile time for hf_Bert_large. The time spent in `LoopBody.__init__` improves from 15% to 8% of `fx_codegen_and_compile`. Before ![image](https://github.com/user-attachments/assets/7de0f28e-35bd-472f-b4be-b52733d2a85c) After ![image](https://github.com/user-attachments/assets/5f0cf11a-43c5-43ae-b13c-f32383a75a7f) Overall ![image](https://github.com/user-attachments/assets/6a369d8c-fb5e-4ad2-9504-0fc745ad6568) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135235 Approved by: https://github.com/oulgen ghstack dependencies: #135070, #135076, #135082, #135084, #135079 --- torch/_inductor/ir.py | 122 +++++++++++++++++++++++++++++++++++------- 1 file changed, 102 insertions(+), 20 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 229751410345be..ff999a18fdff71 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6802,6 +6802,19 @@ class LoopBody: indexing simplifications and makes it easier to analyze loop bodies. """ + indexing_exprs: Dict[str, sympy.Expr] + indexing_exprs_name: Dict[sympy.Expr, str] + reads: List[sympy.Expr] + writes: List[sympy.Expr] + other: List[sympy.Expr] + reads_name2expr: Dict[str, sympy.Expr] + writes_name2expr: Dict[str, sympy.Expr] + submodules: Dict[str, Any] + subblocks: Dict[str, LoopBodyBlock] + indirect_vars: List[str] + indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] + root_block: LoopBodyBlock + def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): super().__init__() @@ -6815,6 +6828,15 @@ def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): self.reduce_vars = reduce_vars self.var_ranges = var_ranges + if isinstance(fn, LoopBody): + self._init_with_copy(fn, args) + else: + self._init_with_tracing(fn, args) + + self.indexing = None + + def _init_with_tracing(self, fn, args): + """Do an FX trace of an arbitrary callable to construct self""" self.indexing_exprs = {} self.indexing_exprs_name = {} self.reads = [] @@ -6826,8 +6848,44 @@ def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): self.subblocks = {} self.indirect_vars = [] self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} - self.root_block = LoopBodyBlock(self, fn, args) - self.indexing = None + self.root_block = LoopBodyBlock(self, fn, args) # traces + del self.indexing_exprs_name # not used after _init_with_tracing + + def _init_with_copy(self, other: LoopBody, args): + """ + _init_with_tracing() is slow, so this is a fast path in the case + where we are just reordering/merging/splitting the args of an + existing LoopBody. + """ + indexing_exprs = other.indexing_from_args(args) + update_index = { + expr: indexing_exprs[name] for name, expr in other.indexing_exprs.items() + } + assert indexing_exprs.keys() == other.indexing_exprs.keys() and len( + update_index + ) == len(indexing_exprs) + + self.indexing_exprs = indexing_exprs + self.reads = [update_index[x] for x in other.reads] + self.writes = [update_index[x] for x in other.writes] + self.other = [update_index[x] for x in other.other] + self.reads_name2expr = { + k: update_index[v] for k, v in other.reads_name2expr.items() + } + self.writes_name2expr = { + k: update_index[v] for k, v in other.writes_name2expr.items() + } + self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} + self.indirect_vars = [*other.indirect_vars] + self.indirect_var_ranges = {**other.indirect_var_ranges} + self.root_block = other.root_block.clone(self) + + submodules = {**other.submodules} + submodules.pop("get_index") + self.submodules = { + "get_index": self.get_index, + **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] + } def merge_loops(self) -> LoopBody: """ @@ -7020,6 +7078,35 @@ def __call__(self, *indices): self.indexing = None return result + def bind_set_indirect_shim(self, var, size, check, wrap_neg): + def set_indirect(new_var): + self.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) + ) + + set_indirect.clone = functools.partial( # type: ignore[attr-defined] + LoopBody.bind_set_indirect_shim, + var=var, + size=size, + check=check, + wrap_neg=wrap_neg, + ) + return set_indirect + + def bind_scan_shim(self, combine_fn): + def shim(dtypes, values): + return V.ops.scan(dtypes, combine_fn, values) + + shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] + return shim + + def bind_masked_shim(self, name): + def shim(mask, other): + return V.ops.masked(mask, self.subblocks[name], other) + + shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] + return shim + class LoopBodyBlock: """ @@ -7090,15 +7177,9 @@ def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): """ Recursively capture the masked out body in another LoopBodyBlock """ - - subblock: LoopBodyBlock - - def shim(mask, other): - return V.ops.masked(mask, subblock, other) - - name = self.body.add_submodule(shim, "masked_subblock") - subblock = LoopBodyBlock(self.body, masked_body, []) - self.body.subblocks[name] = subblock + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) return tracer.create_proxy( "call_module", name, (mask_proxy, other_proxy), {} ) @@ -7111,9 +7192,7 @@ def scan( ], value_proxy, ): - def shim(dtypes, values): - return V.ops.scan(dtypes, combine_fn, values) - + shim = self.body.bind_scan_shim(combine_fn) name = self.body.add_submodule(shim, "scan") result = tracer.create_proxy( "call_module", @@ -7142,12 +7221,9 @@ def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): """ var = self.body.add_indirect(size) - - def set_indirect(new_var): - self.body.replace_indirect( - var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) - ) - + set_indirect = self.body.bind_set_indirect_shim( + var, size, check, wrap_neg + ) tracer.create_proxy( "call_module", self.body.add_submodule(set_indirect, f"set_{var}"), @@ -7196,6 +7272,12 @@ def debug_str(self, name="block"): code.strip().replace("def forward(", f"def {name}("), ) + def clone(self, body: LoopBody): + """Shallow copy with a new parent LoopBody""" + copy = LoopBodyBlock.__new__(LoopBodyBlock) + copy.__dict__.update({**self.__dict__, "body": body}) + return copy + class _CollectiveKernel(FallbackKernel): def should_allocate(self): From 5eaffc458727f5c1a7ffdada91d2daaa9ddec6e7 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 5 Sep 2024 20:22:29 -0700 Subject: [PATCH 0532/1018] [inductor] Remove LoopBody.reads,writes,other (#135256) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135256 Approved by: https://github.com/oulgen ghstack dependencies: #135070, #135076, #135082, #135084, #135079, #135235 --- torch/_inductor/ir.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index ff999a18fdff71..96c171814056de 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6804,9 +6804,6 @@ class LoopBody: indexing_exprs: Dict[str, sympy.Expr] indexing_exprs_name: Dict[sympy.Expr, str] - reads: List[sympy.Expr] - writes: List[sympy.Expr] - other: List[sympy.Expr] reads_name2expr: Dict[str, sympy.Expr] writes_name2expr: Dict[str, sympy.Expr] submodules: Dict[str, Any] @@ -6839,11 +6836,8 @@ def _init_with_tracing(self, fn, args): """Do an FX trace of an arbitrary callable to construct self""" self.indexing_exprs = {} self.indexing_exprs_name = {} - self.reads = [] - self.writes = [] self.reads_name2expr = {} self.writes_name2expr = {} - self.other = [] self.submodules = {"get_index": self.get_index} self.subblocks = {} self.indirect_vars = [] @@ -6866,9 +6860,6 @@ def _init_with_copy(self, other: LoopBody, args): ) == len(indexing_exprs) self.indexing_exprs = indexing_exprs - self.reads = [update_index[x] for x in other.reads] - self.writes = [update_index[x] for x in other.writes] - self.other = [update_index[x] for x in other.other] self.reads_name2expr = { k: update_index[v] for k, v in other.reads_name2expr.items() } @@ -7024,7 +7015,6 @@ def debug_str(self): __repr__ = debug_str def add_index_expr(self, expr: sympy.Expr, category, buf_name): - getattr(self, category).append(expr) if buf_name is not None: getattr(self, f"{category}_name2expr")[buf_name] = expr if expr not in self.indexing_exprs_name: From 927dcbda0ccb4b4fb5b272523e61993d53d2371d Mon Sep 17 00:00:00 2001 From: fduwjj Date: Thu, 5 Sep 2024 20:42:04 -0700 Subject: [PATCH 0533/1018] [Fix][FR][ez] Remove debugging logs (#135308) Removing the print added during debugging process. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135308 Approved by: https://github.com/wz337 --- tools/flight_recorder/components/builder.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index d9eaa4c27368ab..37a9534860f189 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -205,12 +205,6 @@ def build_collectives( profiling_name = entries[0]["profiling_name"] pg_name = _pg_guids[(pg_name, first_rank)] collective_seq_id = entries[0]["collective_seq_id"] - print( - "collective_seq_id ", - collective_seq_id, - " p2p_seq_id ", - entries[0]["p2p_seq_id"], - ) record_id = entries[0]["record_id"] input_sizes = entries[0]["input_sizes"] output_sizes = entries[0]["output_sizes"] From 3d6996dfe27b3cc260ab1a7c57563a80676c8a6d Mon Sep 17 00:00:00 2001 From: wz337 Date: Thu, 5 Sep 2024 14:55:10 -0700 Subject: [PATCH 0534/1018] [DeviceMesh][Easy] Make RuntimeError a bit more descriptive by including the actual world_size (#135271) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135271 Approved by: https://github.com/fduwjj --- torch/distributed/device_mesh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 0ad17536ef1832..d8d33ab7ebfbea 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -469,7 +469,7 @@ def _get_or_create_default_group(self): world_size = get_world_size() if self.mesh.numel() > world_size: raise RuntimeError( - f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!" + f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" ) device_handle = _get_device_handle(self.device_type) From 93c10d1ab252056575b6c7f94d0ebaae8be4725a Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 5 Sep 2024 16:05:56 -0700 Subject: [PATCH 0535/1018] [FlexAttention] Specify padding_value for boundary checked loads (#134573) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134573 Approved by: https://github.com/Chillee --- torch/_inductor/kernel/flex_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 5c87cc40a8e32a..24cd489c162dbd 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -231,7 +231,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK q = tl.load(Q_block_ptr) else: # boundary check is not free, so we only do it when necessary. - q = tl.load(Q_block_ptr, boundary_check=(0,)) + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero") # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We don't know anything "special" about these blocks, so we need to apply @@ -427,7 +427,7 @@ def forward_block_mn( if IS_DIVISIBLE: k = tl.load(K_block_ptr) else: - k = tl.load(K_block_ptr, boundary_check=(1,)) + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = "zero") # -- compute qk --- qk = tl.dot(q, k) # TODO: use cuda matmul when q_len <= 2. if not PRESCALE_QK: @@ -500,7 +500,7 @@ def forward_block_mn( if IS_DIVISIBLE: v = tl.load(V_block_ptr) else: - v = tl.load(V_block_ptr, boundary_check=(0,)) + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = "zero") acc = tl.dot(p.to(MATMUL_PRECISION), v, acc) # -- update m_i From dbefb26ef61ecdd84063aed4259975a2a2696167 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 6 Sep 2024 07:06:06 +0000 Subject: [PATCH 0536/1018] [export][training ir migration] quantized_decomposed.quantize_per_tensor decomposition (#134525) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: In graph of TestXNNPACKQuantizer.test_dynamic_linear_with_con test, some quantized_decomposed.quantize_per_tensor.default ops are becoming quantized_decomposed.dequantize_per_tensor.tensor ops when using the new training ir. This is because we lift params/buffers before calling make_fx. So previously, for the graph that’s passed to make_fx,`graph.L__self___linear1.weight` is a tensor now in training ir, graph.L__self___linear1.weight is a FakeTensor. This caused the node overload to be different. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_dynamic_linear_with_conv ``` Differential Revision: D61364547 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134525 Approved by: https://github.com/tugsbayasgalan, https://github.com/jerryzh168 --- test/quantization/pt2e/test_xnnpack_quantizer.py | 16 ++++++++++++++++ torch/testing/_internal/common_quantization.py | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index 0e88a7793ce475..12962c8f3b0099 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -4,6 +4,7 @@ import torch import torch._dynamo as torchdynamo +from torch._utils_internal import capture_pre_autograd_graph_using_training_ir from torch.ao.ns.fx.utils import compute_sqnr from torch.ao.quantization import ( default_dynamic_fake_quant, @@ -680,6 +681,20 @@ def test_dynamic_linear_with_conv(self): torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, } + + capture_pre_autograd_graph_node_occurrence = None + if capture_pre_autograd_graph_using_training_ir(): + capture_pre_autograd_graph_node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # In training IR, the decomposition is different. + # `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes + # `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes. + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } act_affine_quant_obs = observer.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, @@ -703,6 +718,7 @@ def test_dynamic_linear_with_conv(self): [], True, qconfig_mapping, + capture_pre_autograd_graph_node_occurrence=capture_pre_autograd_graph_node_occurrence, ) def test_gru(self): diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 87891bcefdfc1a..ce990cd0aaf8ef 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -1247,6 +1247,7 @@ def _test_quantizer( export_with_dynamic_shape=False, is_qat=False, is_debug_mode=False, + capture_pre_autograd_graph_node_occurrence=None, ): # resetting dynamo cache torch._dynamo.reset() @@ -1305,6 +1306,10 @@ def _test_quantizer( for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): if k in expected_node_occurrence: node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] + if capture_pre_autograd_graph_node_occurrence is not None: + node_occurrence = { + ns.call_function(k): v for k, v in capture_pre_autograd_graph_node_occurrence.items() + } self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) fx_quant_output = m_fx(*example_inputs) self.assertEqual(fx_quant_output, pt2_quant_output) From 9a521bd28840c0f82160d40cda1b0ea596948251 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Fri, 6 Sep 2024 07:30:09 +0000 Subject: [PATCH 0537/1018] Update torch-xpu-ops pin (ATen XPU implementation) (#135300) Release cycle for PyTorch 2.5 1. Bugfixing: correct reduction logic in cdist kernel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135300 Approved by: https://github.com/EikanWang --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 42f3d8ca513447..5043495ce9680b 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -17f4bd5b7ad63de408cdacee961d03c51cc75ab4 +2326220aa49235ff87c382a85eb7f7cf115b4338 From 549812583a696cd667ea250d449bc0e56461ac8a Mon Sep 17 00:00:00 2001 From: CaoE Date: Thu, 5 Sep 2024 18:11:00 -0700 Subject: [PATCH 0538/1018] [Inductor][CPP] Select tiling factor for lower precision data types (#133830) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133830 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_cpu_repro.py | 10 ++++++++-- torch/_inductor/codegen/cpp.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 100b5fa35361e1..e463f3e60b8545 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4098,13 +4098,19 @@ def func2(arg0, arg1): funcs.append(func2) + # test small shapes + funcs.append(func2) + small_size = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.bfloat16) // 2 + example_shapes = [ [(10, 32, 20, 20), (10, 32, 20, 20)], [(10, 32, 20, 20)], [(10, 32, 20, 20), (10, 32, 20, 20)], + # test small shapes + [(small_size), (small_size)], ] - mixed_types = [False, False, True] - check_vecns = [True, True, True] + mixed_types = [False, False, True, False] + check_vecns = [True, True, True, False] for dtype in [torch.bfloat16, torch.float16]: for func, shapes, mixed, check_vecn in zip( diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e91aae7e2de813..aeaf1273bf7031 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3260,7 +3260,13 @@ def select_tiling( tiling_indices = self._select_tiling_indices( fn_list, var_sizes_list, tiling_factor ) + if tiling_indices: + group, reduction_group = max( + var_sizes_list, key=lambda sizes: len(sizes[1]) + ) + call_ranges = tuple(group) + tuple(reduction_group) + if config.cpp.enable_tiling_heuristics: def _try_get_stride( @@ -3296,10 +3302,6 @@ def _is_valid_indices( < len(itervars) ) - group, reduction_group = max( - var_sizes_list, key=lambda sizes: len(sizes[1]) - ) - call_ranges = tuple(group) + tuple(reduction_group) itervars = [ sympy_index_symbol_with_prefix(SymT.XBLOCK, n) for n in range(len(call_ranges)) @@ -3376,6 +3378,27 @@ def _is_valid_indices( # when needed. return [], [] + if dtype in DTYPE_LOWP_FP: + # For lower precision data type, if the call_range is not long enough, + # use tiling_factor // 2 for better performance + factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + for tiling_indice in tiling_indices: + if tiling_indice < 0: + tiling_indice = tiling_indice + len(call_ranges) + if tiling_indice < 0 or tiling_indice >= len(call_ranges): + continue + if has_free_symbols(call_ranges): + call_range = V.graph.sizevars.size_hint( + call_ranges[tiling_indice], fallback=0 + ) + if call_range < factor_lowp: + V.graph.sizevars.guard_lt(call_range, factor_lowp) + tiling_factor = factor_lowp // 2 + break + elif call_ranges[tiling_indice] < factor_lowp: + tiling_factor = factor_lowp // 2 + break + if len(tiling_indices) == 1: return [tiling_factor], tiling_indices if len(tiling_indices) == 2: From f37ae0f1e30b9dd5bace8e00bea9ea4bd4596af5 Mon Sep 17 00:00:00 2001 From: penguin-wwy <940375606@qq.com> Date: Fri, 6 Sep 2024 08:18:35 +0000 Subject: [PATCH 0539/1018] [Docs] Update FileCheck doc (#135199) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135199 Approved by: https://github.com/soulitzer --- test/HowToWriteTestsUsingFileCheck.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/HowToWriteTestsUsingFileCheck.md b/test/HowToWriteTestsUsingFileCheck.md index 0795c23002a162..3a9b28d3574678 100644 --- a/test/HowToWriteTestsUsingFileCheck.md +++ b/test/HowToWriteTestsUsingFileCheck.md @@ -93,7 +93,7 @@ annotations from the example above one would write: * `CHECK-COUNT-EXACTLY-: ` Scans the input and succeeds when a line containing exactly `NUM` entries of `PATTERN` is found. -* `CHECK-DAG: pattern` +* `CHECK-DAG: ` Works similar to the usual `CHECK` pragma, but also matches if there exists a way to reorder the CHECK-DAG pragmas to satisfy all patterns. For example the following pattern: @@ -110,3 +110,18 @@ annotations from the example above one would write: bar end ``` +* `CHECK-SOURCE-HIGHLIGHTED: ` + Check for highlighted source ranges. This is useful when writing tests regarding generated error messages that require source code highlighting. + For example the following pattern: + ``` + # CHECK-SOURCE-HIGHLIGHTED: raise Exception("raised exception + ``` + would match the following input: + ``` + def method_that_raises() -> torch.Tensor: + raise Exception("raised exception") # noqa: TRY002 + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + builtins.Exception: raised exception + ``` +* `CHECK-REGEX: ` + Scans the input until `PATTERN` is matched, accepts RE syntax for std::regex. From 3538a8d581482c7b6d5fd745698523e92334a5fd Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Thu, 5 Sep 2024 20:31:05 -0700 Subject: [PATCH 0540/1018] [inductor] Fix gen_transposed_tile_load_store (#135307) Recent PR: https://github.com/pytorch/pytorch/pull/131745 bring new VLA logical in cpp codegen. And it will raise build fail error on MSVC and error code is `Compiler Error C2131`: https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2131?view=msvc-170 reproduce UT: ```cmd pytest test\inductor\test_torchinductor_dynamic_shapes.py -v -k test_large_block_sizes_dynamic_shapes_cpu ``` Original generated code: ```c++ alignas(16) float tmp1[static_cast(((-256LL)*(c10::div_floor_integer(static_cast(ks1), static_cast(16LL)))) + (16LL*ks1))]; ``` Changes: allocate a large-enough fixed-sized buffer. New genarated code: ```c++ alignas(16) float tmp1[16*16]; ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135307 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/codegen/cpp.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index aeaf1273bf7031..af892c20a0a44b 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3086,10 +3086,7 @@ def gen_transposed_tile_load_store(self, name, var, index, is_store): tile_var = self.cse.cache[load_or_store] if need_define: - define_line = ( - f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}" - f"[{cexpr_index(self.outer_num_elems * self.inner_num_elems)}];" - ) + define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}];" self.preloads.writeline(define_line) load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) From 3ba879376871146e92514c7da3b7e423f5129496 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 5 Sep 2024 16:42:55 -0700 Subject: [PATCH 0541/1018] [Dynamo] Automatically in-graph traceable tensor subclass ctors (#135151) Fixes https://github.com/pytorch/pytorch/issues/114389 Previously, dynamo would attempt to trace through the `__init__` of traceable tensor subclasses, since their constructors are AOT dispatcher traceable by definition, dynamo should automatically put these in the graph like we do for any other tensors. Not doing this is difficult because dynamo would need to apply mutations post tensor subclass creation in the graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135151 Approved by: https://github.com/bdhirsh --- test/dynamo/test_subclasses.py | 94 +++++++++++++++++++++++++ torch/_dynamo/variables/torch.py | 4 ++ torch/_dynamo/variables/user_defined.py | 6 +- torch/utils/_python_dispatch.py | 7 +- 4 files changed, 109 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index fdc74592ec4d23..5379405bfbe58e 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -30,6 +30,7 @@ ) from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.two_tensor import TwoTensor +from torch.utils._python_dispatch import return_and_correct_aliasing def traceable_subclass(c): @@ -1427,6 +1428,99 @@ def __metadata_guard__(self, x, y): lambda: torch.compile(lambda x: x * x)(x), ) + def test_subclass_constructor_proxying(self): + import dataclasses + from collections import namedtuple + from typing import Any + + @dataclasses.dataclass(frozen=True) + class SubclassTensorArgs: + original_shape: torch.Size + device: torch.device + inner_meta: Any + + SubclassTensorArgs2 = namedtuple( + "SubclassTensorArgs2", + [ + "original_shape", + "device", + "inner_meta", + ], + ) + + class SubclassTensor(torch.Tensor): + @staticmethod + def __new__(cls, a, meta): + shape = a.shape + kwargs = {} + kwargs["strides"] = a.stride() + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return out + + def __init__(self, a, meta): + self.a = a + self.meta = meta + + def __repr__(self): + a_repr = repr(self.a) + return f"SubclassTensor({a_repr})" + + def __tensor_flatten__(self): + return ["a"], self.meta + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, _, __): + a = inner_tensors["a"] + return SubclassTensor(a, meta) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_a = pytree.tree_map( + lambda x: x.a if isinstance(x, SubclassTensor) else x, args + ) + kwargs_a = pytree.tree_map( + lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs + ) + out_a = func(*args_a, **kwargs_a) + out = pytree.tree_map( + lambda x: SubclassTensor( + x, SubclassTensorArgs2(x.shape, x.device, None) + ) + if isinstance(x, torch.Tensor) + else x, + out_a, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + @torch.compile(fullgraph=True) + def f1(x): + meta = SubclassTensorArgs( + x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None) + ) + out = SubclassTensor(x, meta) + return out * out + + x = torch.randn(3, 3) + f1(x) + + @torch.compile(fullgraph=True) + def f1(x): + meta = SubclassTensorArgs2( + x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None) + ) + out = SubclassTensor(x, meta) + return out * out + + x = torch.randn(3, 3) + f1(x) + def test_torch_function_subclass_survives_into_aot_autograd(self): # If you have a tensor subclass that relies on dispatch into the same op # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(), diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f4d4eb900c1735..caab550afb82b5 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -15,6 +15,7 @@ from torch._guards import TracingContext from torch._logging import warning_once from torch._streambase import _StreamBase +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type from .. import config, polyfills, variables from ..codegen import PyCodegen @@ -1025,6 +1026,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + if is_traceable_wrapper_subclass_type(data.class_type): + unimplemented("Parameter constructor with tensor subclass NYI") + if not can_convert_to_tracable_parameter(): unimplemented("Workaround for issues with nn_parameter construction") diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 54dc9119d267b4..0d53d6d21f5502 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -16,6 +16,7 @@ import torch._dynamo.config import torch.nn from torch._guards import TracingContext +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type from .. import polyfills, variables from ..bytecode_transformation import create_call_function @@ -509,7 +510,10 @@ def call_function( user_cls_source=self.source, mutable_local=MutableLocal(), ) - elif self.value in self._in_graph_classes(): + elif ( + self.value in self._in_graph_classes() + or is_traceable_wrapper_subclass_type(self.value) + ): # torch.LongTensor cannot accept a list of FakeTensors. # So we stack the list of FakeTensors instead. if ( diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index eea7f058bfc5e0..70c65a69071733 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -3,7 +3,7 @@ import warnings from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque +from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type from typing_extensions import TypeGuard from collections import deque @@ -391,6 +391,11 @@ def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]: and hasattr(t, "__tensor_unflatten__") ) +def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]: + """Same as above, but takes a type argument instead of an instance.""" + return (issubclass(t, torch.Tensor) and t != torch.Tensor + and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")) + def transform_subclass(t, callback, outer_size=None, outer_stride=None): """ From ea08e7bf29c21c4b9b0fe4e234a93f94be513d56 Mon Sep 17 00:00:00 2001 From: Yan Zhiwei Date: Fri, 6 Sep 2024 07:45:04 +0000 Subject: [PATCH 0542/1018] check compilation status before query cudnn version in conv (#135332) This PR is created for fixing the https://github.com/pytorch/pytorch/issues/135322. The cudnn compilation status should be check firstly before querying version, otherwise, conv may trigger runtimeerror before any check in other non-cuda backends. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135332 Approved by: https://github.com/EikanWang, https://github.com/atalman --- aten/src/ATen/native/Convolution.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 4a72c102777d15..2f85c9724372a8 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -421,6 +421,9 @@ struct ConvParams { // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) #if !defined(C10_MOBILE) + if (!detail::getCUDAHooks().compiledWithCuDNN()) { + return false; + } if (needs_64bit_indexing_no_split(input, weight)) { static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { @@ -430,9 +433,6 @@ struct ConvParams { return false; } } - if (!detail::getCUDAHooks().compiledWithCuDNN()) { - return false; - } if (!input.is_cuda() || !cudnn_enabled) { return false; } From 95bd5a1b0f0afea0b9a2b5f9f47cd7e223a94296 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 5 Sep 2024 19:35:16 -0700 Subject: [PATCH 0543/1018] Ignore fresh unbacked when doing recursive make_fx inside HOPs (#135053) Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7705964779531357/ This now also incorporates a test from https://github.com/pytorch/pytorch/pull/133585 (which it fixes) and the prep PR https://github.com/pytorch/pytorch/pull/134407 Including the PR desc from that: I am trying to fix a problem reported by user in [fb.workplace.com/groups/6829516587176185/permalink/7705964779531357](https://fb.workplace.com/groups/6829516587176185/permalink/7705964779531357/) The summary of this problem is that when we do collect metadata analysis in AOTAutograd, we accumulate pending unbacked symbols which are going to be discarded at the end of the trace. However, if we do a recursive make_fx inside tracing, as occurs with torch.cond, we end up seeing that there are pending unbacked symbols that aren't associated with a binding, even though it's spurious (they've leaked into the inner make_fx call from the outer AOTAutograd analysis). In https://github.com/pytorch/pytorch/pull/133588 I tried to just prevent adding the symbols to the pending list at all in the first place. But this itself caused some problems which were fixed in https://github.com/pytorch/pytorch/pull/124785 . The problem fixed in that PR is that when we allocate tangents that have unbacked size, something prevented them from having correct unbacked SymInts when ignore fresh unbacked SymInts was enabled. So I had patched it at the time by just not suppressing pending symbols and clearing them out some other way. I think... I was wrong in that PR? That is to say, it was OK to avoid putting the fresh unbacked symbols in the pending list; the real problem was suppressing unbacked renamings. But there doesn't seem to be a good reason to suppress these; this PR shows that it doesn't actually fail any tests if you do these anyway. Intuitively, this makes sense, because you can't trigger renamings unless you're actually adding unbacked symbols to the pending set. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135053 Approved by: https://github.com/ydwu4 --- test/export/test_export.py | 55 +++++++++++++++++++ .../test_torchinductor_dynamic_shapes.py | 24 +++++++- .../collect_metadata_analysis.py | 10 +++- torch/fx/experimental/proxy_tensor.py | 15 ++++- torch/fx/experimental/symbolic_shapes.py | 4 -- 5 files changed, 99 insertions(+), 9 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 4220f21ecb14d0..0d05b57b444e4e 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -715,6 +715,61 @@ def forward(self, x, c): foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False ) + def test_unbacked_to_cond(self): + class M(torch.nn.Module): + def forward(self, a): + az = a.nonzero() + + def true_fn(x): + return (x + 1).sum() + + def false_fn(x): + return (x + 3).sum() + + r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) + return r * 2 + + M()(torch.randn(7)) + torch.export.export(M(), (torch.randn(7),)) + + def test_unbacked_to_cond_passthrough(self): + class M(torch.nn.Module): + def forward(self, a): + az = a.nonzero() + + def true_fn(x): + return x + 1 + + def false_fn(x): + return x + 3 + + r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) + return r * 2 + + M()(torch.randn(7)) + torch.export.export(M(), (torch.randn(7),)) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_cond_contains_unbacked_no_escape(self): + class M(torch.nn.Module): + def forward(self, a, b1, b2, c): + def true_fn(x): + return x * b1.item() + + def false_fn(x): + return x * b2.item() + + r = torch.cond(a, true_fn, false_fn, (c,)) + return r * 2 + + args = ( + torch.tensor(True), + torch.tensor([4]), + torch.tensor([4]), + torch.randn(10, requires_grad=True), + ) + torch.export.export(M(), args) + def test_state_tensors(self): class M(torch.nn.Module): # simple with register buffer def __init__(self) -> None: diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 34c71f44332572..2e631445919ac2 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -7,7 +7,7 @@ import sys import unittest from functools import partial -from typing import List +from typing import List, Tuple import torch import torch.library @@ -561,6 +561,28 @@ def f(x): f(torch.tensor([3], device=device)) + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + @torch._inductor.config.patch(implicit_fallbacks=True) + def test_multi_output_unbacked_custom_op(self, device): + @torch.library.custom_op("test::foo", mutates_args=()) + def foo(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty(2, device=x.device), torch.empty(3, device=x.device) + + @foo.register_fake + def _(x: torch.Tensor) -> torch.Tensor: + ctx = torch.library.get_ctx() + u0 = ctx.new_dynamic_size() + return torch.empty(u0, device=x.device), torch.empty(3, device=x.device) + + @torch.compile(fullgraph=True) + def f(x): + a, b = torch.ops.test.foo(x) + return a.sum() + b.sum() + + f(torch.tensor([3], device=device)) + @torch._inductor.config.patch(disable_cpp_codegen=True) def test_floor(self): # `int(n * 0.2)` will be generated as `floor(0.2*s0)` of torch.SymInt type. diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 820c7d288cf2c6..51a0aeb24ad496 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -9,6 +9,7 @@ """ import collections +import contextlib import logging from functools import wraps from typing import Callable, DefaultDict, Dict, List, Optional @@ -162,15 +163,18 @@ def inner(*flat_args): # It doesn't matter if we run this under predispatch or not because it is # only for figuring out metadata mode = FunctionalTensorMode(_allow_token_discovery=True) - with disable_above, mode: + suppress_pending = contextlib.nullcontext() + fake_mode = detect_fake_mode() + if fake_mode and (shape_env := fake_mode.shape_env): + suppress_pending = shape_env.ignore_fresh_unbacked_symbols() + with disable_above, mode, suppress_pending: # precondition: The passed in function already handles unflattening inputs + flattening outputs flat_f_args = pytree.tree_map(_to_fun, flat_args) flat_f_outs = f(*flat_f_args) # We didn't do any tracing, so we don't need to process the # unbacked symbols, they will just disappear into the ether. # Also, prevent memoization from applying. - if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env): - shape_env.pending_fresh_unbacked_symbols.clear() + if fake_mode: fake_mode.epoch += 1 fake_mode.reset_nt_tensor_id_counter() diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 7a3ece886506c3..5a5771d2ae8f7a 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -596,6 +596,19 @@ def track_tensor_tree( constant: Optional[_NestedTensors], tracer: _ProxyTracer, ) -> T: + # NB: We call set_unbacked_bindings only on the *topmost* call to + # track_tensor_tree, not recursive calls. This is because there must + # be only ONE unbacked_binding proxy call, and it should be the one + # where all of the unbacked SymInts actually first come into existence. + # If you call this again on the inner proxies for the tuple projections, + # you will have multiple unbacked_bindings for the same symbol, but + # they're not going to show up anywhere. + # + # I was briefly deceived into setting unbacked bindings recursively when + # working on https://github.com/pytorch/pytorch/pull/133585 because I + # observed that some extra unbacked bindings were needed to handle some + # higher order operator code. But actually it looks like this was + # just an unrelated bug that needed to be fixed separately. _set_unbacked_bindings(inner_res, proxy_res) def wrap_with_proxy( @@ -2198,5 +2211,5 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) if fake_mode and fake_mode.shape_env: if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): - assert isinstance(out_proxy, Proxy) + assert isinstance(out_proxy, Proxy), out_proxy out_proxy.node.meta["unbacked_bindings"] = symbol_to_path diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index c6a228f5256e71..369708d2c2dbc0 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -597,8 +597,6 @@ def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, """ if shape_env is None: return - if shape_env._ignore_fresh_unbacked_symbols_tls(): - return fs = shape_env.pending_fresh_unbacked_symbols pending = set(fs) if pending: @@ -2722,8 +2720,6 @@ def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol): assert isinstance(new_s, sympy.Symbol), new_s assert free_unbacked_symbols(new_s), new_s assert free_unbacked_symbols(orig_s), orig_s - if self._ignore_fresh_unbacked_symbols_tls(): - return dest = self.replacements.get(orig_s) assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" self._set_replacement(orig_s, new_s, "rename_unbacked_to") From 2a6f2c3d3726bece625eb9d27a06099b4dfc7782 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 4 Sep 2024 14:19:23 -0700 Subject: [PATCH 0544/1018] Also handle compiler collective when input variable doesn't exist on all ranks (#135147) Internal xref: https://fb.workplace.com/groups/3095840833991792/permalink/3810738595835342/ Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135147 Approved by: https://github.com/jansel --- test/distributed/test_dynamo_distributed.py | 25 ++++++++++++++++++--- torch/_dynamo/variables/builder.py | 8 ++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 5ba20a7c948d41..52943afe984971 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -914,9 +914,6 @@ def test_compiler_collectives_dim_mismatch(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() - # TODO: This should be possible to do inside the function, but - device = f"cuda:{self.rank}" - @torch.compile() def f(x, y): zx = x.shape @@ -940,6 +937,28 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @config.patch(enable_compiler_collectives=True) + def test_compiler_collectives_missing_source(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(rank, xs): + return xs[rank].sum() + + xs = [] + for _ in range(self.world_size): + xs.append(torch.randn(10, device=self.rank)) + + f(self.rank, xs) + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", False) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index c54c38f9bf1f83..5644f45e098ed1 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2522,9 +2522,11 @@ def update_frame_state(size, stride): else: # Apply the updates for sub_state in st.all_states: - update_frame_state( - sub_state.input_sizes[name], sub_state.input_strides[name] - ) + # Not all inputs are necessarily present on all ranks + if name in sub_state.input_sizes and name in sub_state.input_strides: + update_frame_state( + sub_state.input_sizes[name], sub_state.input_strides[name] + ) frame_state_entry = tx.output.frame_state[name] # TODO: index export_constraints ahead of time so we don't have to From 89cea2ea09624b664fe61fced41b351ee3bfde2a Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Thu, 5 Sep 2024 22:35:58 -0700 Subject: [PATCH 0545/1018] [inductor][cpp][gemm] fix autotune runtime error from linear_binary fusion (#135275) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135275 Approved by: https://github.com/leslie-fang-intel --- test/inductor/test_cpu_select_algorithm.py | 34 ++++++++++++++++++++ torch/_inductor/codegen/cpp_gemm_template.py | 2 +- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 3dd5400e68eb53..ca1d4f12d6a4f4 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -490,6 +490,40 @@ def forward(self, x): self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (384,)) + @parametrize("in_features", (196,)) + @parametrize("out_features", (384,)) + @parametrize("bias", (True, False)) + @parametrize( + "binary", + ("add",), + ) + @dtypes(torch.float, torch.bfloat16, torch.half) + def test_linear_with_binary_input_3d( + self, batch_size, in_features, out_features, bias, binary, dtype + ): + class M(torch.nn.Module): + def __init__(self, bias, binary, other): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias) + self.binary = _get_epilogue(binary, other) + + def forward(self, x): + return self.binary(self.linear(x)) + + counters.clear() + B = (2, batch_size) + v = torch.randn(*B, in_features).to(dtype=dtype) + u = torch.randn(*B, out_features).to(dtype=dtype) + mod = M(bias=bias, binary=binary, other=u).to(dtype=dtype).eval() + with verify(dtype) as (atol, rtol): + self.common(mod, (v,), atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @inductor_config.patch({"freezing": True}) @patches @torch.no_grad diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index b4b33b28216b4b..09a339cd8824fd 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -531,7 +531,7 @@ def normalize_shapes(inputs, layout_or_out): new_inputs = list(inputs) X = inputs[0] W = inputs[1] - B = inputs[2] if len(inputs) > 2 else None + B = inputs[2] if has_bias else None if isinstance(W, ir.IRNode): if trans_w: if not isinstance(W, ir.TensorBox): From cee3ed19168eaba64e548f18bc691d6c05051973 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 5 Sep 2024 22:36:00 -0700 Subject: [PATCH 0546/1018] [inductor] [cpp] improve cache blocking for is_dynamic_M (#131306) ## Performance Models with >= 3% performance speedup are listed below: ### AMP single-thread dynamic shape (measured on CPU with AMX support) No regressions | Model Family | Model Name | Speedup | |--------------|------------|---------| torchbench | soft_actor_critic| 3% Pull Request resolved: https://github.com/pytorch/pytorch/pull/131306 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel ghstack dependencies: #135275 Co-authored-by: Jiong Gong --- torch/_inductor/codegen/cpp_gemm_template.py | 33 ++++++++++-- torch/_inductor/codegen/cpp_prefix.h | 56 ++++++++++++++++++++ 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 09a339cd8824fd..7bed1503900787 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -64,11 +64,28 @@ const auto Nt_blocks = Nr_blocks; const auto Kt_blocks = Kr_blocks; {%- endif %} - const int64_t Mc_blocks = Mt_blocks; - const int64_t Nc_blocks = 1; - const int64_t Kc_blocks = Kt_blocks; + int64_t Mc_blocks, Nc_blocks, Kc_blocks; + uint32_t L1_cache_size = {{L1_cache_size}}; + uint32_t L2_cache_size = {{L2_cache_size}}; + mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>( + num_threads, + M, + N, + K, + Mr, + Nr, + Kr, + Mt_blocks, + Nt_blocks, + Kt_blocks, + Mc_blocks, + Nc_blocks, + Kc_blocks, + L1_cache_size, + L2_cache_size + ); const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; - const int64_t num_Nc_blocks = Nr_blocks; + const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; {%- else %} constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; @@ -979,6 +996,12 @@ def get_reindexer(epilogue_node): if isinstance(micro_gemm, CppMicroGemmAMX): counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + options = dict( X=X, W=W, @@ -1008,6 +1031,8 @@ def get_reindexer(epilogue_node): w_zp=w_zp, acc_buf_dtype=torch.int32 if int8_gemm else torch.float, DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, ) with contextlib.ExitStack() as stack: for buf in fake_buffers: diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index d3cf6270a3ec44..ace9a182472a0a 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -734,6 +734,62 @@ void mm_get_thread_blocking( assert(Mt != 0); } +template +void mm_get_cache_blocking( + int num_threads, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t Mt_blocks, + int64_t Nt_blocks, + int64_t Kt_blocks, + int64_t& Mc_blocks, + int64_t& Nc_blocks, + int64_t& Kc_blocks, + uint32_t L1_cache_size, + uint32_t L2_cache_size) { + // See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking algorithm. + // TODO(jgong5): cache cache blocking results + // TODO: tune the factor here + float L1_limit_factor = 1.0; + float L2_limit_factor = 0.5; + + auto L1 = L1_cache_size * L1_limit_factor; + auto L2 = L2_cache_size * L2_limit_factor; + + constexpr size_t num_byte_A = sizeof(X_t); + constexpr size_t num_byte_B = sizeof(W_t); + + int64_t size_cache_B = Kr * Kt_blocks * Nr * num_byte_B; + Kc_blocks = Kt_blocks; + if (size_cache_B > L1) { + Kc_blocks = (int64_t)std::floor(L1 / (Kr * Nr * num_byte_B)); + } + + float min_Mc_ratio = 2; + int64_t min_Mc_blocks = std::ceil(min_Mc_ratio * Mr / Nr); + auto Kt_bytes = Kt_blocks * Kr * num_byte_A; + if (min_Mc_blocks * Mr * Kt_bytes < L2) { + Mc_blocks = std::min(Mt_blocks, (int64_t)std::floor(L2 / (Mr * Kt_bytes))); + Nc_blocks = 1; + } else { + Mc_blocks = Mt_blocks; + Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks); + auto Nc_bytes = Nc_blocks * Nr * 4; + auto Kc_bytes = Kc_blocks * Kr * num_byte_A; + if (Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2) { + auto M_max = (std::sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8; + if (M_max < Mc_blocks * Mr) { + Mc_blocks = (int64_t)std::floor(M_max / Mr); + Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks); + } + } + } +} + inline void mm_get_thread_blocks( int thread_id, int64_t M_blocks, From 57209e2049e0b443d528c105aeb09cb44b429eed Mon Sep 17 00:00:00 2001 From: Julia Guo <153684546+juliagmt-google@users.noreply.github.com> Date: Fri, 6 Sep 2024 13:27:24 +0000 Subject: [PATCH 0547/1018] [CD] Update binary_linux_test.sh to include calling builder smoke test (#133869) Run smoke test Fixes #1969 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133869 Approved by: https://github.com/atalman Co-authored-by: Andrey Talman --- .circleci/scripts/binary_linux_test.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index c30814982ef9e6..81d7bf2c511a05 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -119,6 +119,11 @@ fi # Test the package /builder/check_binary.sh +if [[ "\$GPU_ARCH_TYPE" != *s390x* && "\$GPU_ARCH_TYPE" != *xpu* && "\$GPU_ARCH_TYPE" != *rocm* && "$PACKAGE_TYPE" != libtorch ]]; then + # Exclude s390, xpu, rocm and libtorch builds from smoke testing + python /builder/test/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled +fi + # Clean temp files cd /builder && git clean -ffdx From d5796caaaf109b720aa89df6c0a60a80453fcf0d Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 5 Sep 2024 20:19:44 -0700 Subject: [PATCH 0548/1018] AOTDispatcher: limit cases when we detach() graph inputs to non-leaves (#134193) This PR is slightly a revival / update to the discussion from https://github.com/pytorch/pytorch/pull/98960: Part of FSDP2's tracing strategy right now is that: (1) it is painful/difficult to handle the case where we have multiple graph input tensors that are aliased to each other and at least one of them is duplicated (2) we already have longstanding in logic to remove duplicate input tensors from the graph in dynamo. Morally, FSDP2 gives us duplicate input tensors in the backward graph for every `unsharded_param`, because we have (a) the `unsharded_param` being closed over by the backward hook to resize/allgather, and (b) the same `unsharded_param` being saved for backward by autograd (we now guarantee in the partitioner that we will always save the base tensor for backward and recompute views) (3) However, we were still seeing cases where the `unsharded_param` showed up twice in the backward graph inputs, as distinct tensor objects (with different python ids) instead of being true duplicates that dynamo can de-dup. It turns on that this was because we were `.detach()`ing the `unsharded_param` in AOTDispatcher before plumbing it through the compiled forward (and so autograd would save a detach'd version of the `unsharded_param`). This is precisely because of the logic from https://github.com/pytorch/pytorch/pull/98960. However, re-reading the detailed comments, it seems unnecessary to do a detach() on a graph input that is a (leaf) `nn.Parameter`, even if it happens to get no gradients in the backward. Since it is a leaf, we don't have to worry about the autograd engine "continuing to backprop through the graph beyond the current tensor" (the leaf has no other grad_fn for autograd to backprop through). So this PR makes us a bit less aggressive about calling detach() on inputs: we only do it when: (1) our graph input statically will get a `None` gradient (and also has no metadata mutations, the existing state) (2) **and** our graph input is a non-leaf tensor (so detach()ing is actually required to prevent autograd from incorrectly backpropping past the non-leaf. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134193 Approved by: https://github.com/yf225 Co-authored-by: Will Feng --- test/distributed/test_inductor_collectives.py | 27 ++++++++----------- test/inductor/test_distributed_patterns.py | 9 +++---- test/inductor/test_torchinductor.py | 8 +++++- .../jit_compile_runtime_wrappers.py | 6 ++++- .../testing/_internal/optests/aot_autograd.py | 4 ++- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index af59f6188581d8..a19183425d3907 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1013,22 +1013,17 @@ def func(inp): return ar input = torch.ones(4, 4, device="cuda", requires_grad=True) - # TODO implement backwards - with self.assertRaisesRegex( - RuntimeError, - "element 0 of tensors does not require grad and does not have a grad_fn", - ): - compiled = torch.compile( - func, backend="aot_eager" - ) # inductor bug with single-op allreduce graph - out = compiled(input) - out.sum().backward() - - correct_input = input.clone().detach().requires_grad_() - correct = func(correct_input) - correct.sum().backward() - self.assertTrue(same(out, correct)) - self.assertTrue(same(input.grad, correct_input.grad)) + compiled = torch.compile( + func, backend="aot_eager" + ) # inductor bug with single-op allreduce graph + out = compiled(input) + out.sum().backward() + + correct_input = input.clone().detach().requires_grad_() + correct = func(correct_input) + correct.sum().backward() + self.assertTrue(same(out, correct)) + self.assertTrue(same(input.grad, correct_input.grad)) def test_meta(self): x = torch.rand((2, 3, 4), device="meta") diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index a647a16c421da4..9aa070e979a0ae 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -42,7 +42,7 @@ def fw_pre_hook(mod, inp): torch.ops.fsdp.set_( mod.unsharded_weight, all_gather(mod.sharded_weight) ) - mod.weight = mod.unsharded_weight + mod._parameters["weight"] = mod.unsharded_weight # Forward: # mod.sharded_weight = local_shard (always) @@ -54,7 +54,7 @@ def fw_pre_hook(mod, inp): # mod.unsharded_weight = zero-sized allgather def fw_post_hook(mod, inp, out): - mod.weight = mod.sharded_weight + mod._parameters["weight"] = mod.sharded_weight mod.unsharded_weight.untyped_storage().resize_(0) def bw_pre_hook(mod, gO): @@ -74,7 +74,7 @@ def bw_pre_hook(mod, gO): torch.ops.fsdp.set_( mod.unsharded_weight, all_gather(mod.sharded_weight) ) - mod.weight = mod.unsharded_weight + mod._parameters["weight"] = mod.unsharded_weight # Backward: # mod.sharded_weight = local_shard (always) @@ -88,7 +88,7 @@ def bw_pre_hook(mod, gO): def bw_post_hook(mod, gI, gO): grad = mod.weight.grad new_grad = reduce_scatter(grad) - mod.weight = mod.sharded_weight + mod._parameters["weight"] = mod.sharded_weight mod.weight.grad = new_grad mod.unsharded_weight.untyped_storage().resize_(0) @@ -99,7 +99,6 @@ def bw_post_hook(mod, gI, gO): m.sharded_weight = nn.Parameter(reduce_scatter(m.weight)) m.unsharded_weight = nn.Parameter(all_gather(m.sharded_weight)) m.unsharded_weight.untyped_storage().resize_(0) - del m.weight m.register_full_backward_pre_hook(bw_pre_hook) m.register_full_backward_hook(bw_post_hook) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ffd4dcb0ebc3e0..2cf1253215ad0f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -516,7 +516,13 @@ def reference_to_expect(actual_flat, correct_flat): correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads) all_none_grads = all(x is None for x in correct_grad) - if all_none_grads: + tensor_args = [ + x + for x in pytree.tree_flatten(example_inputs)[0] + if isinstance(x, torch.Tensor) + ] + any_non_leaves = any(x.grad_fn is not None for x in tensor_args) + if all_none_grads and any_non_leaves: # See Note [Detaching inputs that never need gradients] # There are a handful of ops that can return None gradients, into of zero gradients. # If all inputs to an AOTAutograd graph are supposed to get None gradients, diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index f01e83d37b1c50..5dc236f314b079 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -512,7 +512,11 @@ def aot_dispatch_autograd( == MutationType.MUTATED_IN_GRAPH and fw_metadata.input_info[i].mutates_storage_metadata ) - if bw_out is None and not metadata_mutation_in_graph: + is_non_leaf = ( + fw_metadata.input_info[i].requires_grad + and not fw_metadata.input_info[i].is_leaf + ) + if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: _indices_of_inps_to_detach.append(i) if aot_config.enable_log: diff --git a/torch/testing/_internal/optests/aot_autograd.py b/torch/testing/_internal/optests/aot_autograd.py index 4f281c7771757c..a5552e23c8a467 100644 --- a/torch/testing/_internal/optests/aot_autograd.py +++ b/torch/testing/_internal/optests/aot_autograd.py @@ -118,7 +118,9 @@ def check(args, ignore_failure=False): raise # See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215 - if all(x is None for x in orig_grad): + tensor_args = [x for x in pytree.tree_flatten(args)[0] if isinstance(x, torch.Tensor)] + any_non_leaves = any(x.grad_fn is not None for x in tensor_args) + if all(x is None for x in orig_grad) and any_non_leaves: with assert_raises_regex_fn(RuntimeError, 'does not require grad and does not have a grad_fn'): call_forwards_backwards(compiled_f, args) return From c806da52c91de8e4d7595e936c7a75e0b5c9d090 Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 5 Sep 2024 16:39:07 -0700 Subject: [PATCH 0549/1018] [inductor][triton] mark workspace args as mutated (#134648) SplitScan makes use of a workspace arg that needs to be zeroed before it is used - then, it is used to communicate between thread blocks during the triton kernel implementation. It is mutated during during the execution of the kernel, so it should be marked as such. Before this PR, it is not marked as mutated; AFAIK this is fine during normal execution, but during autotuning it causes problems. The workspace starts off zeroed (as expected), but during autotuning the kernel will be executed multiple times and the workspace does not get re-set between executions, resulting in incorrect data. If the data is used for indexing, then you can fail device-side asserts (and the results after the initial run (with autotuning) could be wrong). The test added in this PR repros the issue when the fix is removed. When we mark the arg as mutated, then the arg gets cloned before autotuning, so that the arg passed to the kernel during autotuning will always be zeroed as expected. https://github.com/pytorch/pytorch/blob/804852c1f998e6c7c58a5c7cbcb5e9063d02d54a/torch/_inductor/runtime/triton_heuristics.py#L685-L689 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134648 Approved by: https://github.com/peterbell10, https://github.com/jansel --- test/inductor/test_torchinductor.py | 15 +++++++++++++++ torch/_inductor/codegen/triton.py | 14 ++++++++++++++ torch/_inductor/runtime/triton_heuristics.py | 3 ++- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 2cf1253215ad0f..5bc9a7050686d2 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1884,6 +1884,21 @@ def fn(a, b): b = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float64, device=self.device) self.common(fn, (a, b), rtol=1e-4, atol=1e-5, check_lowp=False) + @config.patch(max_autotune_pointwise=True) + def test_split_cumsum_index(self): + # Split scan uses a workspace that needs to be zeroed before use. + # data[index] does indirect indexing that should catch issues if the + # workspace is not zeroed. + def fn(lengths, data): + offsets = torch.cumsum(lengths, 0) + return data[offsets] + + lengths = torch.full((2**14,), 2**2, dtype=torch.int64, device=self.device) + lengths[-2] = 3 + lengths[-1] = 3 + data = make_tensor((2**16,), dtype=torch.float32, device=self.device) + self.common(fn, (lengths, data)) + def test_split_cumprod(self): def fn(a): return torch.cumprod(a, -1) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 65dd3608d2a841..2596f3d59793e9 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2624,6 +2624,20 @@ def codegen_kernel(self, name=None): mutated_args.add(self.args.inplace_buffers[mutation].inner_name) if mutation in self.args.output_buffers: mutated_args.add(self.args.output_buffers[mutation]) + + # workspace arguments are mutated, but are not marked as mutations in self.mutations + # because their buffers are added during codegen, and aren't tracked during + # lowering/scheduling. So we add them as mutated_args explicitly below. + # + # In the logic below, we only mark the workspaces a mutated if they are marked with + # zero_fill: that's because, if we don't expect the buffer to be pre-filled with + # zeros, then, although we still mutate the data, we don't care about those + # mutations because we don't make any assumptions about the contents of the + # workspace buffer. + for argname, arg in zip(argdefs, signature): + if isinstance(arg, WorkspaceArg) and arg.zero_fill: + mutated_args.add(argname) + mutated_args = sorted(mutated_args) triton_meta_signature = signature_to_meta( diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c5a6bd69fd7263..a3b8e611f3852b 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -172,7 +172,7 @@ def __init__( triton_meta, # passed directly to triton configs, save_cache_hook, - mutated_arg_names, + mutated_arg_names: List[str], # see [Note: clone mutated buffers] heuristic_type, size_hints=None, inductor_meta=None, # metadata not relevant to triton @@ -677,6 +677,7 @@ def kernel_call(): def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: from ..compile_fx import clone_preserve_strides + # [Note: clone mutated buffers] # clone inplace buffers to avoid autotune contaminating them if # the kernel does in-place stores. avoid cloning other buffers because # it leads to increase memory use From 81fdbb935b0cc8e54f11192ddf4da471ec842196 Mon Sep 17 00:00:00 2001 From: Yiwen Shi Date: Fri, 6 Sep 2024 14:47:03 +0000 Subject: [PATCH 0550/1018] [torchelastic] Don't do signal handling when off the main thread (#135088) Summary: In multiprocessing, signal handling is not possible if the thread is not the main thread. This resulted in the following error: > "ValueError('signal only works in main thread of the main interpreter')" To address this issue, the diff checks whether the thread is the main thread and, if not, skips signal handling. Test Plan: Before this change, MAST job failed: https://fburl.com/mlhub/iq2m10v8 With this change, MAST job succeeded: https://fburl.com/mlhub/q6kb8343 Differential Revision: D62166943 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135088 Approved by: https://github.com/d4l3k --- .../elastic/multiprocessing/api_test.py | 4 ++++ .../distributed/elastic/multiprocessing/api.py | 17 ++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 9f55785810ca86..98ff8f1a309749 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -6,6 +6,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import ctypes import multiprocessing import os @@ -362,6 +363,9 @@ def test_pcontext_wait(self): self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped()) + def test_pcontext_wait_on_a_child_thread(self): + asyncio.run(asyncio.to_thread(self.test_pcontext_wait)) + def test_multiprocess_context_close(self): pc = start_processes( name="sleep", diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 0b3a9a2ce29087..5befbffe49573d 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -16,6 +16,7 @@ import subprocess import sys import tempfile +import threading import time from abc import ABC, abstractmethod from contextlib import nullcontext @@ -470,11 +471,17 @@ def __init__( def start(self) -> None: """Start processes using parameters defined in the constructor.""" - signal.signal(signal.SIGTERM, _terminate_process_handler) - signal.signal(signal.SIGINT, _terminate_process_handler) - if not IS_WINDOWS: - signal.signal(signal.SIGHUP, _terminate_process_handler) - signal.signal(signal.SIGQUIT, _terminate_process_handler) + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGTERM, _terminate_process_handler) + signal.signal(signal.SIGINT, _terminate_process_handler) + if not IS_WINDOWS: + signal.signal(signal.SIGHUP, _terminate_process_handler) + signal.signal(signal.SIGQUIT, _terminate_process_handler) + else: + logger.warning( + "Failed to register signal handlers since torchelastic is running on a child thread. " + "This could lead to orphaned worker processes if the torchrun is terminated." + ) self._start() self._stdout_tail.start() self._stderr_tail.start() From bc00fa125c2a608060a9e6be3f5a67ac23ac78df Mon Sep 17 00:00:00 2001 From: rzou Date: Thu, 5 Sep 2024 10:41:29 -0700 Subject: [PATCH 0551/1018] Add Inductor config for default stride behavior (#135238) By default, Inductor is allowed to manipulate the layout (strides+storage offset) of input tensors to custom operators. We want to change it so that the default is that Inductor should respect the stride order of input tensors to custom operators. This PR adds a config to toggle the behavior, in the next PR up we'll change the default. We also make the following changes: - We add a new operator Tag (flexible_layout), which means that inductor is allowed to manipulate the layout. When we flip the default, users can specify they want the old behavior by using this tag. This is a reland of https://github.com/pytorch/pytorch/pull/126986, which was previously reverted due to silent incorrectness. We've since fixed the silent incorrectness (https://github.com/pytorch/pytorch/pull/133639) Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/135238 Approved by: https://github.com/albanD --- aten/src/ATen/native/tags.yaml | 9 +++++++++ test/test_custom_ops.py | 20 ++++++++++++++++++++ torch/_inductor/config.py | 7 +++++++ torch/_inductor/lowering.py | 29 ++++++++++++++++++++++++----- 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index c31721729036ff..3544a3cf0b16c6 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -46,6 +46,15 @@ desc: | This tag indicates that the operator should be passed Tensors following the same stride permutation as observed in eager when compiled in inductor. + Only one of {needs_fixed_stride_order, flexible_layout} can apply; if + multiple are assigned then we assume the most restrictive one. +- tag: flexible_layout + desc: | + This tag indicates that the custom operator can accept inputs with varying + strides/storage_offset and that when compiled, Inductor is allowed to change + the strides/storage_offset of inputs to the custom operator. + Only one of {needs_fixed_stride_order, flexible_layout} can apply; if + multiple are assigned then we assume the most restrictive one. # NOTE [Core ATen Ops] - tag: core diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index f2a51a03476b4d..05afe8fec38a38 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3428,6 +3428,26 @@ def vmap(info, in_dims, w, x=2, *, y=3, z): self.assertTrue(called) self.assertEqual(result, w * 2 * 3 * 42) + def test_layout_constraint_tags(self): + needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order + flexible_layout = torch._C.Tag.flexible_layout + # (tags, the result of the tag inference) + tests = [ + ({needs_fixed_stride_order}, needs_fixed_stride_order), + ({flexible_layout}, flexible_layout), + # If no tags are provided, then the following is the default + (set(), flexible_layout), + # If multiple tags are provided, then we use the most constrained tag. + ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), + ] + from torch._inductor.lowering import get_layout_constraint_tag + + for tags, expected in tests: + with torch.library._scoped_library("mylib", "FRAGMENT") as m: + m.define("foobar(Tensor x) -> Tensor", tags=tags) + result = get_layout_constraint_tag(torch.ops.mylib.foobar.default) + self.assertEqual(result, expected) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap(self): for mode in ["function", "qualname", "opoverload", "c_opdef"]: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 684728ea48908a..db778544d88f0e 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -65,6 +65,13 @@ def autotune_remote_cache_default() -> Optional[bool]: # sleep in inductor for testing sleep_sec_TESTING_ONLY: Optional[int] = None +# The default layout constraint for custom operators. +# This must be the name of one of the layout constraint tags +# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}), +# If the custom op does not have a layout constraint tag already +# then we assume the following applies. +custom_op_default_layout_constraint = "flexible_layout" + # use cpp wrapper instead of python wrapper cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 3219e3a427ee66..84120d9f9933f0 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -101,11 +101,30 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A _maybe_layout_constraints[fn] = None return None # We lazily register tag-based layout constraints. - if torch._C.Tag.needs_fixed_stride_order in fn.tags: - _maybe_layout_constraints[fn] = constrain_to_fx_strides - return _maybe_layout_constraints[fn] - _maybe_layout_constraints[fn] = None - return None + + def handle_layout_constraint_tag(tag): + if tag is torch._C.Tag.needs_fixed_stride_order: + _maybe_layout_constraints[fn] = constrain_to_fx_strides + return _maybe_layout_constraints[fn] + elif tag is torch._C.Tag.flexible_layout: + _maybe_layout_constraints[fn] = None + return None + else: + raise AssertionError(f"Unknown layout constraint tag: {tag}") + + tag = get_layout_constraint_tag(fn) + return handle_layout_constraint_tag(tag) + + +def get_layout_constraint_tag(fn): + tags_by_priority = [ + torch._C.Tag.needs_fixed_stride_order, + torch._C.Tag.flexible_layout, + ] + for tag in tags_by_priority: + if tag in fn.tags: + return tag + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) def assert_nyi(cond, msg): From 0a073b1a432a5eb794c97c452f4aa6859715024b Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Fri, 6 Sep 2024 15:12:40 +0000 Subject: [PATCH 0552/1018] error on exporting ScriptModule (#135302) Test Plan: added test Differential Revision: D62279179 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135302 Approved by: https://github.com/yushangdi --- test/export/test_export.py | 19 +++++++++++++++++++ torch/export/__init__.py | 12 ++++++++++++ 2 files changed, 31 insertions(+) diff --git a/test/export/test_export.py b/test/export/test_export.py index 0d05b57b444e4e..8df3dc7be1ee24 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -892,6 +892,25 @@ def forward(self, x): torch.allclose(ep.module()(torch.zeros(2, 3)), torch.ones(2, 3) * 21) ) + def test_export_script_module(self): + class Foo(torch.nn.Module): + def forward(self, rv: torch.Tensor, t: torch.Tensor): + i = t.item() + return rv + i + + foo = Foo() + foo_script = torch.jit.script(foo) + inp = (torch.zeros(3, 4), torch.tensor(7)) + + with self.assertRaisesRegex( + ValueError, "Exporting a ScriptModule is not supported" + ): + export(foo_script, inp) + + from torch._export.converter import TS2EPConverter + + TS2EPConverter(foo_script, inp).convert() + def test_torch_fn(self): class M1(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 2cbbc472968a83..367b3f127a16bb 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -145,6 +145,12 @@ def export_for_training( raise ValueError( f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) return _export_for_training( mod, args, @@ -255,6 +261,12 @@ def export( raise ValueError( f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) return _export( mod, args, From 2269e782093af889a28869bfe4af1e192721b2bf Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Fri, 6 Sep 2024 15:59:32 +0000 Subject: [PATCH 0553/1018] [AOTI][Tooling][6/n] Fix long dtype input tensors calling `mean()` in `aoti_torch_print_tensor_handle` (#135072) Differential Revision: D61635232 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135072 Approved by: https://github.com/hl475, https://github.com/ColinPeppler --- test/inductor/test_aot_inductor.py | 4 ---- .../csrc/inductor/aoti_torch/shim_common.cpp | 24 +++++++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index a39ac7347424ff..71e0a96dae17f0 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3400,17 +3400,13 @@ def forward(self, x, y): ("add_kernel_0", 3), ] - # test the default debug printing codegen with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): result, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, Model(), example_inputs ) - # check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected self.assertEqual("aoti_torch_print_tensor_handle" in code, True) - # check the codegen for debug printing around the actual kernel call is expected - for kernel_call, count in kernel_calls: FileCheck().check_count( f"before_launch - {kernel_call}", diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 5763470449e664..4cf5dd3ae228a6 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1051,15 +1051,31 @@ AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( // Print summary stats of the tensor std::cout << "Number of elements: " << numel << std::endl; + std::cout << "Dtype: " << t->dtype() << std::endl; if (numel > 0) { - std::cout << "Mean value: " << t->mean().item() << std::endl; - std::cout << "Min value: " << t->min().item() << std::endl; - std::cout << "Max value: " << t->max().item() << std::endl; + // torch/aten `mean()` function only supports float and complex dtypes + // See: + // https://github.com/pytorch/pytorch/blob/a0e062c6f1a03ec93e87413e42c4d0b336518131/aten/src/ATen/native/ReduceOps.cpp#L304-L309 + auto mean_value = [t](at::ScalarType dtype) { + return t->to(dtype).mean().item(); + }; + bool is_complex_type = + at::isComplexType(at::typeMetaToScalarType(t->dtype())); + at::ScalarType float_dtype = + is_complex_type ? at::kComplexFloat : at::kFloat; + std::cout << "Mean value: " << mean_value(float_dtype) << std::endl; + if (!is_complex_type) { + // "min_all_cuda" function is not implemented for 'ComplexFloat' type. + // (similar for max) Skip printing min/max value for complex type tensors + // here If encountered complex dtypes (rare occasions), suggest to print + // out the whole value of the tensor. + std::cout << "Min value: " << t->min().item() << std::endl; + std::cout << "Max value: " << t->max().item() << std::endl; + } } std::cout << "Device: " << t->device() << std::endl; std::cout << "Size: " << t->sizes() << std::endl; std::cout << "Stride: " << t->strides() << std::endl; - std::cout << "Dtype: " << t->dtype() << std::endl; std::cout << "Layout: " << t->layout() << std::endl; std::cout << "Is contiguous: " << t->is_contiguous() << std::endl; std::cout << "Requires grad: " << t->requires_grad() << std::endl; From e15a7f308ddd7fd963bb6ebe3119dbd73d769040 Mon Sep 17 00:00:00 2001 From: Alfredo Tupone Date: Fri, 6 Sep 2024 16:16:51 +0000 Subject: [PATCH 0554/1018] Porting to GCC 15 (#135188) uint8_t is found on cstdint header Pull Request resolved: https://github.com/pytorch/pytorch/pull/135188 Approved by: https://github.com/Skylion007 --- caffe2/utils/string_utils.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/caffe2/utils/string_utils.cc b/caffe2/utils/string_utils.cc index 640cac70edbbad..ba763738f7bb04 100644 --- a/caffe2/utils/string_utils.cc +++ b/caffe2/utils/string_utils.cc @@ -3,6 +3,7 @@ #include #include #include +#include namespace caffe2 { From 59efb4f437954262f8036b439920ac7ea0a1b64e Mon Sep 17 00:00:00 2001 From: yanbing-j Date: Fri, 6 Sep 2024 16:40:26 +0000 Subject: [PATCH 0555/1018] Update submodule ideep to include aarch64 change (#134897) This PR is per ARM request, which is in https://github.com/intel/ideep/issues/334. Context for the request is: Arm team has upstreamed the dynamic quantization changes, all the PRs were merged (torch, ideep, oneDNN), but without this ideep submodule update, the feature will not work. The change is isolated to only matmul operator and quantization path alone. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134897 Approved by: https://github.com/jgong5, https://github.com/atalman, https://github.com/snadampal --- third_party/ideep | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/ideep b/third_party/ideep index 383e9238c1c118..41d636c2bbcea6 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 383e9238c1c118c1ff72662f9ed7c8610150325d +Subproject commit 41d636c2bbcea6bff0faf97cdb65a48cdde987af From d653b4206ddbb6f9901e306a879aac3328f437da Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Fri, 6 Sep 2024 16:41:58 +0000 Subject: [PATCH 0556/1018] [torch][fx] Set maximum warning count during fx.Graph.lint (#135069) Summary: resnet152 spent about 15 minutes writing warning messages in _unlift during `to_executorch` because they're all written to unbuffered stderr by the `warnings` module. These warnings are almost always about get_attr nodes referencing a non-existent name: ```lang=py warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' 'not reference an nn.Module, nn.Parameter, or buffer, which is ' 'what \'get_attr\' Nodes typically target' ) ``` I'm not aware of a way to configure the warnings module to write this out at most once, so I'm just going to disable the lint for now. Test Plan: Re-ran resnet152 with Executorch and the XNNPackBackend, it is much faster now Differential Revision: D62156090 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135069 Approved by: https://github.com/yushangdi --- torch/fx/graph.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index cb3fbeb33d83b9..1601eee07f2717 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1551,6 +1551,8 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: # Check targets are legit if self.owning_module: + num_warnings = 0 + MAX_WARNINGS = 5 for node in self.nodes: if node.op == 'call_function': if not callable(node.target): @@ -1577,11 +1579,21 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: and not isinstance(new_m_itr, torch.nn.Module) and not isinstance(new_m_itr, torch.nn.Parameter) and atom not in m_itr._buffers): - warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module, nn.Parameter, or buffer, which is ' - 'what \'get_attr\' Nodes typically target') + if num_warnings < MAX_WARNINGS: + # Don't emit this warning too frequently, + # for very large graphs this can become very expensive + # from a performance perspective. + warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' + 'not reference an nn.Module, nn.Parameter, or buffer, which is ' + 'what \'get_attr\' Nodes typically target') + num_warnings += 1 else: m_itr = new_m_itr + if num_warnings > MAX_WARNINGS: + warnings.warn( + f'Additional {num_warnings - MAX_WARNINGS} warnings ' + 'suppressed about get_attr references' + ) @compatibility(is_backward_compatible=True) def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None): From e94eb575c557867e12709314c0b84a78958931c7 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Fri, 6 Sep 2024 17:02:06 +0000 Subject: [PATCH 0557/1018] [export] Record the global torch version in serialization. (#135243) Summary: In general I think it will be useful to also record the global torch version in the EP, so that we can track them in the logging in addition to the schema version. Test Plan: CI Reviewed By: henryoier Differential Revision: D62252626 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135243 Approved by: https://github.com/yushangdi --- torch/_export/serde/schema.py | 3 ++- torch/_export/serde/schema.yaml | 7 +++++-- torch/_export/serde/serialize.py | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index c12997e049c1de..ce102b39367ad0 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -8,7 +8,7 @@ from torch._export.serde.union import _Union # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (7, 2) +SCHEMA_VERSION = (7, 3) TREESPEC_VERSION = 1 @@ -378,3 +378,4 @@ class ExportedProgram: range_constraints: Dict[str, RangeConstraint] schema_version: SchemaVersion verifiers: List[str] = field(default_factory=list) + torch_version: str = "<=2.4" diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index f8eda2b5c3b02c..25a9a295ad0b97 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<8c1eb1b5b62b21e3725448a11d3a807cfe006e7558870a2c1ccc455d3def71c4>> +# checksum<<923abf371a1f8802cacb037d409d28273867777a98f6542fba28616c2b92b639>> Argument: kind: union fields: @@ -105,6 +105,9 @@ ExportedProgram: verifiers: type: List[str] default: '[]' + torch_version: + type: str + default: <=2.4 GradientToParameterSpec: kind: struct fields: @@ -430,5 +433,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 7 -- 2 +- 3 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 3b8803b776f6f9..44153ccc78eb42 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1394,6 +1394,7 @@ def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: minor=SCHEMA_VERSION[1], ), verifiers=[v.dialect for v in exported_program.verifiers], + torch_version=torch.__version__, ) # Test canonical form is well defined. @@ -2905,6 +2906,7 @@ def replace_output(out): range_constraints=range_constraints, schema_version=ep.schema_version, verifiers=ep.verifiers, + torch_version=ep.torch_version, ) From 516bc00259232c7181d5433a6dfcde4dd5413e46 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Fri, 6 Sep 2024 17:25:05 +0000 Subject: [PATCH 0558/1018] [aoti][easy] remove breakpoint() in wrapper.py (#134807) Differential Revision: D61687146 Remove an unintended breakpoint in code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134807 Approved by: https://github.com/YUNQIUGUO --- torch/_inductor/codegen/wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 30b073ca5f7015..ddc422f4c294cb 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1627,7 +1627,6 @@ def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): elif isinstance(arg, list): return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]" else: - breakpoint() raise NotImplementedError(f"Unsupported type {type(arg)}") def _grid_dim_str(self, grid_per_dim): From a6203513abb8a8cd46ec705a68997dcd74f41d30 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 6 Sep 2024 17:31:25 +0000 Subject: [PATCH 0559/1018] remove _check call on item() for torch.istft (#135234) Fixes #135014 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135234 Approved by: https://github.com/tugsbayasgalan --- test/export/test_export.py | 20 ++++++++++++++++++++ torch/_refs/__init__.py | 6 ------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 8df3dc7be1ee24..3081baaca35221 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6995,6 +6995,26 @@ def forward(self, x): if node.op == "call_function": self.assertTrue(False) + @testing.expectedFailureTrainingIRToRunDecomp # T200904004 + @testing.expectedFailureTrainingIRToRunDecompNonStrict + def test_istft_op(self): + class istft_class(torch.nn.Module): + def forward(self, spec): + window = torch.hann_window(1024).type(torch.FloatTensor) + return torch.istft( + spec, + n_fft=1024, + hop_length=512, + window=window, + length=144000, + ) + + model = istft_class() + real_part = torch.randn(1, 513, 282, dtype=torch.float32) + imaginary_part = torch.randn(1, 513, 282, dtype=torch.float32) + spec = torch.complex(real_part, imaginary_part) + export(model, (spec,)) + def test_automatic_dynamic_shapes_simple_equality(self): # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism # leads to replacement symbols being set for equalities, and inferred relationships being checked diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 673f5f1dea8295..15b18fefacb777 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3549,12 +3549,6 @@ def istft( y = y.narrow(dim=1, start=start, length=length) window_envelop = window_envelop.narrow(dim=1, start=start, length=length) - window_envelop_lowest = window_envelop.abs().min().lt(1e-11) - torch._check( - not window_envelop_lowest.item(), - lambda: "window overlap add min less than 1e-11", - ) - y = y / window_envelop if original_ndim == 2: y = y.squeeze(0) From b11d6e11fd92e9028b0ead645c1902d1a016a759 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 6 Sep 2024 17:55:43 +0000 Subject: [PATCH 0560/1018] [elastic] support local_addr across all rendezvous impls (#135262) Summary: There was a regression introduced in https://github.com/pytorch/pytorch/pull/125743 that made `local_addr` no longer used. This fixes that by passing `local_addr` to `RendezvousStoreInfo.build` everywhere it's used. This also fixes a number of tests allowing them to be run in parallel which hugely sped up the testing cycle as this change touches many different rendezvous implementations. This required a few fixes in unrelated tests. Test Plan: Added tests for the common rendezvous implementations that `local_addr` to prevent future regressions. ``` buck2 test @//mode/dev-nosan fbcode//caffe2/test/distributed/elastic/... fbcode//caffe2/torch/distributed/elastic/... -- --stress-runs 3 ``` To vet the parallelism changes I also ran with 3 stress runs each to identify flakiness caused by parallelism. Differential Revision: D62256407 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135262 Approved by: https://github.com/fduwjj, https://github.com/wz337 --- .../c10d_rendezvous_backend_test.py | 18 ++++++++++------ .../rendezvous/dynamic_rendezvous_test.py | 21 +++++++++++++++++-- .../timer/file_based_local_timer_test.py | 5 +++-- torch/distributed/elastic/rendezvous/api.py | 11 ++++++++-- .../rendezvous/c10d_rendezvous_backend.py | 5 ++++- .../elastic/rendezvous/dynamic_rendezvous.py | 8 +++++-- .../elastic/rendezvous/etcd_rendezvous.py | 18 +++++++++++++--- .../elastic/timer/file_based_local_timer.py | 3 ++- 8 files changed, 70 insertions(+), 19 deletions(-) diff --git a/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py b/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py index 2175e8b17e9522..89329c380f391f 100644 --- a/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py +++ b/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py @@ -25,6 +25,7 @@ C10dRendezvousBackend, create_backend, ) +from torch.distributed.elastic.utils.distributed import get_free_port class TCPStoreBackendTest(TestCase, RendezvousBackendTestMixin): @@ -69,9 +70,11 @@ def setUp(self) -> None: # For testing, the default parameters used are for tcp. If a test # uses parameters for file store, we set the self._params to # self._params_filestore. + + port = get_free_port() self._params = RendezvousParameters( backend="dummy_backend", - endpoint="localhost:29300", + endpoint=f"localhost:{port}", run_id="dummy_run_id", min_nodes=1, max_nodes=1, @@ -95,7 +98,7 @@ def setUp(self) -> None: self._expected_temp_dir = tempfile.gettempdir() self._expected_endpoint_host = "localhost" - self._expected_endpoint_port = 29300 + self._expected_endpoint_port = port self._expected_store_type = TCPStore self._expected_read_timeout = timedelta(seconds=10) @@ -173,11 +176,14 @@ def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_al def test_create_backend_returns_backend_if_endpoint_port_is_not_specified( self, ) -> None: - self._params.endpoint = self._expected_endpoint_host - - self._expected_endpoint_port = 29400 + # patch default port and pass endpoint with no port specified + with mock.patch( + "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.DEFAULT_PORT", + self._expected_endpoint_port, + ): + self._params.endpoint = self._expected_endpoint_host - self._assert_create_backend_returns_backend() + self._assert_create_backend_returns_backend() def test_create_backend_returns_backend_if_endpoint_file_is_not_specified( self, diff --git a/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py b/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py index 67cc97267a28ed..90ea76a48fadc7 100644 --- a/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py +++ b/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py @@ -1597,6 +1597,23 @@ def test_create_handler_records_and_raises_exceptions(self, record_mock) -> None create_handler(self._store, self._backend, self._params) record_mock.assert_called_once() + def test_create_handler_rdzv_local_addr(self) -> None: + params = RendezvousParameters( + backend=self._backend.name, + endpoint="dummy_endpoint", + run_id="dummy_run_id", + min_nodes=1, + max_nodes=1, + join_timeout="50", + last_call_timeout="60", + close_timeout="70", + local_addr="127.0.0.2", + ) + store = HashStore() + handler = create_handler(store, self._backend, params) + rdzv_info = handler.next_rendezvous() + self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "127.0.0.2") + def _ignore_exception(exception_type: Exception, fn: Callable): try: @@ -1656,7 +1673,7 @@ def _create_handler(self, **kwargs) -> DynamicRendezvousHandler: "min_nodes": 2, "max_nodes": 2, "join_timeout": "5", - "local_addr": f"address_{len(self._handlers)}", + "local_addr": f"127.0.0.{len(self._handlers)}", } params.update(**kwargs) @@ -1714,7 +1731,7 @@ def test_redundancy_list(self) -> None: state_and_token = self._backend.get_state() state = pickle.loads(state_and_token[0]) addresses = [node.addr for node in state.redundancy_list] - self.assertListEqual(addresses, ["address_2"]) + self.assertListEqual(addresses, ["127.0.0.2"]) def test_redundancy_transition_to_wait_list_then_join_rendezvous(self) -> None: handler1 = self._create_handler( diff --git a/test/distributed/elastic/timer/file_based_local_timer_test.py b/test/distributed/elastic/timer/file_based_local_timer_test.py index 490e4a9ce37a7c..c06f3520bac853 100644 --- a/test/distributed/elastic/timer/file_based_local_timer_test.py +++ b/test/distributed/elastic/timer/file_based_local_timer_test.py @@ -6,6 +6,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import multiprocessing as mp +import os import signal import time import unittest @@ -37,7 +38,7 @@ class FileTimerTest(TestCase): def setUp(self): super().setUp() self.max_interval = 0.01 - self.file_path = "/tmp/test_file_path_" + str(uuid.uuid4()) + self.file_path = f"/tmp/test_file_path_{os.getpid()}_{uuid.uuid4()}" self.server = timer.FileTimerServer( self.file_path, "test", self.max_interval ) @@ -204,7 +205,7 @@ def test_send_request_without_server(self): class FileTimerServerTest(TestCase): def setUp(self): super().setUp() - self.file_path = "/tmp/test_file_path_" + str(uuid.uuid4()) + self.file_path = f"/tmp/test_file_path_{os.getpid()}_{uuid.uuid4()}" self.max_interval = 0.01 self.server = timer.FileTimerServer( self.file_path, "test", self.max_interval diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index 9cde6758981abf..cbfc5532c76a6a 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -68,14 +68,21 @@ class RendezvousStoreInfo: master_port: int @staticmethod - def build(rank: int, store: Store) -> "RendezvousStoreInfo": + def build( + rank: int, store: Store, local_addr: Optional[str] + ) -> "RendezvousStoreInfo": """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor. + + Args: + rank: rank of the current node + store: store to use for rendezvous + local_addr: address of the current node, if not provided will be resolved from hostname """ # TODO swap to collectives comms API if rank == 0: - addr = socket.getfqdn() + addr = local_addr or socket.getfqdn() port = _get_free_port() store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type] store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] diff --git a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py index d49f4724a0ea76..427a53bc3276dd 100644 --- a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -28,6 +28,9 @@ logger = logging.getLogger(__name__) +# default port for the TCP store +DEFAULT_PORT = 29400 + class C10dRendezvousBackend(RendezvousBackend): """Represents a C10d-backed rendezvous backend. @@ -132,7 +135,7 @@ def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]: def _create_tcp_store(params: RendezvousParameters) -> TCPStore: - host, port = parse_rendezvous_endpoint(params.endpoint, default_port=29400) + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT) cfg_is_host = params.get_as_bool("is_host") # If the user has explicitly specified whether our process should host the diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index a2d2011f8173ea..c8e294604501db 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -1178,7 +1178,9 @@ def next_rendezvous(self) -> RendezvousInfo: # opt-out option of TCP store sharing if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1": - bootstrap_store_info = RendezvousStoreInfo.build(rank, store) + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._this_node.addr + ) return RendezvousInfo( store, rank, @@ -1198,7 +1200,9 @@ def next_rendezvous(self) -> RendezvousInfo: else: # If the store is not type of TCPStore start TCPStore server, which requries # bootstrapping info across ranks - self._bootstrap_store_info = RendezvousStoreInfo.build(rank, store) + self._bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._this_node.addr + ) if rank == 0: self._shared_tcp_store_server = self._create_tcp_store_server( self._bootstrap_store_info diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index fe6170ede0159e..f0aa3c8d8887d8 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -149,8 +149,15 @@ class EtcdRendezvousHandler(RendezvousHandler): +--------------------------------------------+--------------------------+ """ - def __init__(self, rdzv_impl): + def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]): + """ + Args: + rdzv_impl: the implementation of the rendezvous + local_addr: the local address of the current node + """ + self._rdzv_impl = rdzv_impl + self._local_addr = local_addr def __del__(self): # TODO: look into using weakref here instead. @@ -165,7 +172,9 @@ def next_rendezvous(self): logger.info("Creating EtcdStore as the c10d::Store implementation") store = self._rdzv_impl.setup_kv_store(rdzv_version) - bootstrap_store_info = RendezvousStoreInfo.build(rank, store) + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._local_addr + ) return RendezvousInfo(store, rank, world_size, bootstrap_store_info) def is_closed(self): @@ -1062,4 +1071,7 @@ def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT ), ) - return EtcdRendezvousHandler(rdzv_impl=rdzv) + return EtcdRendezvousHandler( + rdzv_impl=rdzv, + local_addr=params.local_addr, + ) diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 74da756d58c99a..1c6d0d9a33064a 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -193,10 +193,11 @@ def __init__( def start(self) -> None: logger.info( - "Starting %s..." " max_interval=%s," " daemon=%s", + "Starting %s... max_interval=%s, daemon=%s, file_path=%s", type(self).__name__, self._max_interval, self._daemon, + self._file_path, ) self._watchdog_thread = threading.Thread( target=self._watchdog_loop, daemon=self._daemon From 97f71f508343576fd980d1e5913f031447c7b27b Mon Sep 17 00:00:00 2001 From: wdziurdz <99736820+wdziurdz@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:19:52 +0000 Subject: [PATCH 0561/1018] Fix incorrect trace of post-accumulate grad hook on tensor with zero dims (#135226) Fix incorrect trace of post-accumulate grad hook on tensor with zero dimensions Fixes #135207 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135226 Approved by: https://github.com/xmfan --- torch/_dynamo/compiled_autograd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index e68581762372e3..e7c5d2414f6e2b 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -237,14 +237,14 @@ def post_acc_grad_hook(self, input, hook_id): assert isinstance(input, torch.Tensor) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] - proxies = self.proxy_call_hook( + proxy = self.proxy_call_hook( hook, input, hook_type="post_acc_grad_hook", ) with disable_proxy_modes_tracing(): input = [maybe_clone(input)] - self.bind_tensors_to_proxies(input, proxies) + self.bind_tensors_to_proxies(input, [proxy]) return input # Note: [Compiled autograd and cudagraphs] From a18adfa2600f4af200802a1be92475804397a348 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Wed, 4 Sep 2024 20:14:03 -0700 Subject: [PATCH 0562/1018] Run bypassed graph compile outside the except block to avoid chaining of exceptions (#135175) Fixes #135172 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135175 Approved by: https://github.com/masnesral, https://github.com/ezyang --- torch/_inductor/codecache.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index aaac1edc821397..59cc47ac06c94e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1365,10 +1365,11 @@ def load( # type: ignore[no-untyped-def] if remote: log_cache_bypass("bypass_fx_graph", str(e)) cache_event_time = time_ns() - if not compiled_graph: - compiled_graph = compile_fx_fn( - gm, example_inputs, inputs_to_check, fx_kwargs - ) + + if not compiled_graph: + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) assert compiled_graph is not None cache_info["cache_state"] = cache_state chromium_log = get_chromium_event_logger() From 710645df07f7f7193e09f81b02c58d7d43716cfe Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 6 Sep 2024 19:04:58 +0000 Subject: [PATCH 0563/1018] [Inductor] Use argument names as the key for the `constants` dict and the `signature` dict (#135170) Referencing how triton constructs these dictionaries https://github.com/triton-lang/triton/blob/ca3fb5f6fa6365445d2d982d137aa5cb2a80f547/python/triton/runtime/jit.py#L639 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135170 Approved by: https://github.com/htyu --- torch/_higher_order_ops/triton_kernel_wrap.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 37f7f05c641486..5a1ad4405c5ec3 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -158,15 +158,13 @@ def generate_ttir(kernel, kwargs): ] specialization = kernel._get_config(*ordered_args.values()) constants = { - i: arg - for i, arg in enumerate(ordered_args.values()) - if not isinstance(arg, Tensor) + name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor) } # Build kernel signature -- doesn't include constexpr arguments. signature = { - i: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(ordered_args.values()) + name: kernel._type_of(kernel._key_of(arg)) + for i, (name, arg) in enumerate(ordered_args.items()) if i not in kernel.constexprs } From 701a8c05512a6943406b1633bf04d2a029f8c186 Mon Sep 17 00:00:00 2001 From: Nowtryz <44437613+nowtryz@users.noreply.github.com> Date: Fri, 6 Sep 2024 19:06:20 +0000 Subject: [PATCH 0564/1018] Add MaskedTensor passthrough: unfold, F.Unfold, F.Fold, stack (#125262) Hi, I noticed the `unfold` operator was missing on MaskedTensor. I tested that my change works when calling unfold and backward on a `MaskedTensor` but I didn't find the tests for the dispatch of such operation. Where is it? Pull Request resolved: https://github.com/pytorch/pytorch/pull/125262 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/Col2Im.cpp | 2 +- aten/src/ATen/native/Im2Col.cpp | 2 +- aten/src/ATen/native/cuda/Col2Im.cu | 2 +- aten/src/ATen/native/cuda/Im2Col.cu | 2 +- docs/source/masked.rst | 3 ++ test/test_maskedtensor.py | 50 ++++++++++++++++--- torch/masked/maskedtensor/passthrough.py | 5 ++ .../_internal/common_methods_invocations.py | 4 +- 8 files changed, 56 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/Col2Im.cpp b/aten/src/ATen/native/Col2Im.cpp index 0a98ee5c46ca77..51e005c2901b93 100644 --- a/aten/src/ATen/native/Col2Im.cpp +++ b/aten/src/ATen/native/Col2Im.cpp @@ -144,7 +144,7 @@ static void col2im_out_cpu_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, input.scalar_type(), "col2im_out_cpu", [&] { Tensor input_n = Tensor(); Tensor output_n = Tensor(); diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp index dac2ee6e3f103a..25eb4d6787240a 100644 --- a/aten/src/ATen/native/Im2Col.cpp +++ b/aten/src/ATen/native/Im2Col.cpp @@ -94,7 +94,7 @@ static void im2col_out_cpu_template( output.resize_({batch_size, n_output_plane, output_length}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, input.scalar_type(), "im2col_out_cpu", [&] { Tensor input_n; Tensor output_n; diff --git a/aten/src/ATen/native/cuda/Col2Im.cu b/aten/src/ATen/native/cuda/Col2Im.cu index bb6d4748deb1f5..e390666b307aa1 100644 --- a/aten/src/ATen/native/cuda/Col2Im.cu +++ b/aten/src/ATen/native/cuda/Col2Im.cu @@ -102,7 +102,7 @@ void col2im_out_cuda_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); int64_t output_batch_stride = output.stride(0); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, input.scalar_type(), "col2im_out_cuda", [&] { int64_t height_col = (output_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / diff --git a/aten/src/ATen/native/cuda/Im2Col.cu b/aten/src/ATen/native/cuda/Im2Col.cu index 312ad893c0d81e..796eda97b37331 100644 --- a/aten/src/ATen/native/cuda/Im2Col.cu +++ b/aten/src/ATen/native/cuda/Im2Col.cu @@ -103,7 +103,7 @@ static void im2col_out_cuda_template( output.resize_({batch_size, n_output_plane, output_length}); // Launch kernel - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, input.scalar_type(), "im2col_out_cuda", [&] { Tensor input_n; Tensor output_n; diff --git a/docs/source/masked.rst b/docs/source/masked.rst index 60dd67f643b890..8177b91a9c15c6 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -283,9 +283,11 @@ The following ops are currently supported: kron meshgrid narrow + nn.functional.unfold ravel select split + stack t transpose vsplit @@ -294,6 +296,7 @@ The following ops are currently supported: Tensor.expand_as Tensor.reshape Tensor.reshape_as + Tensor.unfold Tensor.view .. This module needs to be documented. Adding here in the meantime diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index bf42cf086376ec..580f1301d40605 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -66,6 +66,18 @@ def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08): if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match") +def _compare_forward_backward(data, mask, fn): + mt = masked_tensor(data, mask, requires_grad=True) + masked_res = fn(mt) + masked_res.sum().backward() + + t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() + tensor_res = fn(t) + tensor_res.sum().backward() + + _compare_mt_t(masked_res, tensor_res) + _compare_mt_t(mt.grad, t.grad, atol=1e-06) + def _create_random_mask(shape, device): return make_tensor(shape, device=device, dtype=torch.bool) @@ -164,15 +176,8 @@ def test_softmax(self, device): ], device=device ) - mt = masked_tensor(data, mask, requires_grad=True) - masked_res = torch.softmax(mt, -1) - masked_res.sum().backward() - xinf = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() - tensor_res = torch.softmax(xinf, -1) - tensor_res.sum().backward() - _compare_mt_t(masked_res, tensor_res) - _compare_mt_t(mt.grad, xinf.grad, atol=1e-06) + _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1)) def test_where(self, device): data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device) @@ -192,6 +197,35 @@ def test_where(self, device): _compare_mt_t(mx.grad, x.grad) _compare_mt_t(my.grad, y.grad) + def test_unfold(self, device): + data = torch.rand(5, 5, device=device) + mask = torch.rand(5, 5, device=device) > 0.5 + _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2)) + + def test_nn_unfold(self, device): + data = torch.rand(2, 5, 3, 4, device=device) + mask = torch.rand(2, 5, 3, 4, device=device) > 0.5 + _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3))) + + def test_stack(self, device): + masked_tensors = [ + masked_tensor( + torch.rand(2, 5, 3, 4, device=device), + torch.rand(2, 5, 3, 4, device=device) > 0.5, + requires_grad=True, + ) for _ in range(3) + ] + + data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors] + masked_res = torch.stack(masked_tensors) + tensor_res = torch.stack(data_tensors) + + masked_res.sum().backward() + tensor_res.sum().backward() + _compare_mt_t(masked_res, tensor_res) + for mt, t in zip(masked_tensors, data_tensors): + _compare_mt_t(mt.grad, t.grad, atol=1e-06) + def test_to_sparse(self, device): for sample in _generate_sample_data(device=device): data = sample.input diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py index 4a2e79456c86db..ba13f50c1fee9c 100644 --- a/torch/masked/maskedtensor/passthrough.py +++ b/torch/masked/maskedtensor/passthrough.py @@ -30,6 +30,11 @@ torch.ops.aten._reshape_alias, torch.ops.aten.cat, torch.ops.aten.unsqueeze, + torch.ops.aten.unfold, + torch.ops.aten.unfold_backward, + torch.ops.aten.im2col, + torch.ops.aten.col2im, + torch.ops.aten.stack, ] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e1dc92edb6302a..15102a2fba28eb 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15289,8 +15289,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): autodiff_nonfusible_nodes=["aten::hardswish"]), OpInfo('nn.functional.unfold', aten_name='im2col', - dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), sample_inputs_func=sample_inputs_nn_unfold, # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, From 34679a69add4b62ac5e04c66ce01b4d657cfeba8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 6 Sep 2024 19:10:54 +0000 Subject: [PATCH 0565/1018] [ONNX] Refactor exporter errors (#135180) Refactor exporter errors to combine old errors and new errors for API consistency. This PR also 1. Removes the `_C._check_onnx_proto(proto)` call in the old exporter. We don't need the ONNX checker because it is limited. 2. Removes the `OnnxExporterError` defined in the dynamo module. This class unnecessarily stores the onnx program object, making it very bulky. Instead, we revert to use the plain OnnxExporterError defined in the `errors` module and use it as the base class for all errors. 3. Continues to expose `OnnxExporterError` in `torch.onnx` and the rest of the errors in `torch.onnx.errors`. 4. Removes the `CheckerError` and `InvalidExportOptionsError` from `torch.onnx`. This is BC breaking but should have low impact. 5. I did not rename existing errors out of compatibility considerations, even though `ExporterError` would have been more succinct. Fixes https://github.com/pytorch/pytorch/issues/135125 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135180 Approved by: https://github.com/titaiwangms --- docs/source/onnx.rst | 5 +- docs/source/onnx_dynamo.rst | 3 -- test/onnx/dynamo/test_exporter_api.py | 41 --------------- test/onnx/onnx_test_common.py | 37 ++++--------- test/onnx/test_fx_to_onnx.py | 14 ----- test/onnx/test_pytorch_onnx_no_runtime.py | 27 ---------- test/onnx/test_pytorch_onnx_onnxruntime.py | 2 +- torch/onnx/__init__.py | 23 ++++----- torch/onnx/_internal/_exporter_legacy.py | 60 ++++++---------------- torch/onnx/_internal/exporter/_building.py | 3 +- torch/onnx/_internal/exporter/_core.py | 10 ++-- torch/onnx/_internal/exporter/errors.py | 30 ----------- torch/onnx/errors.py | 48 ++++++++++++----- torch/onnx/utils.py | 12 ----- 14 files changed, 82 insertions(+), 233 deletions(-) delete mode 100644 torch/onnx/_internal/exporter/errors.py diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index ffaa8ef836b787..a07664aaf69473 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -63,9 +63,12 @@ also be interested in reading our `development wiki Date: Fri, 6 Sep 2024 19:35:20 +0000 Subject: [PATCH 0566/1018] [ONNX] Properly handle Attributes in traceable functions (#135367) Previously the attributes were sent in as Attr objects even when we call the function as a plain Python function. Turning them into python objects. From https://github.com/justinchuby/torch-onnx/pull/186 Related https://github.com/microsoft/onnxscript/issues/1846 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135367 Approved by: https://github.com/justinchuby --- torch/onnx/_internal/exporter/_building.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 297896d126e976..274d977b88b465 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -495,6 +495,11 @@ def eval_function( # type: ignore[override] # call because it will filter out the unexpected kwargs for us. if function.traceable: # Trace the function call instead of adding the function as a node + # Turn the ir.Attr objects into Python constants first + named_attrs = { + name: attr.value if isinstance(attr, ir.Attr) else attr + for name, attr in named_attrs.items() + } return function.function(**named_inputs, **named_attrs) outputs = self._call_op(op_signature, named_inputs, named_attrs) From d52cd6d6f023d391a76bc3c89c32b43fba1a7ad0 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 6 Sep 2024 10:20:09 -0700 Subject: [PATCH 0567/1018] Run all autograd node post hooks (#134728) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134728 Approved by: https://github.com/albanD, https://github.com/soulitzer --- test/inductor/test_compiled_autograd.py | 8 ++-- test/test_autograd.py | 52 +++++++++++++++++++++++-- torch/autograd/graph.py | 7 ++++ torch/csrc/autograd/engine.cpp | 17 ++++---- 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index ad83833ea3386e..ea6675276e212c 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1609,6 +1609,7 @@ def fn(): ), compiled_autograd.enable(compiler_fn): fn() + @unittest.skip("Flaky, cache from test ordering affects test. #135369") def test_autograd_cpp_node(self): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { @@ -2741,7 +2742,7 @@ def wrap_test_class(orig_cls): "test_tensor_hooks_inplace_multiple_outputs", # uses assert in hook "test_hooks", # uses assert in hook "test_accumulate_grad_posthooks_can_observe_tensor_prehook", # allclose - "test_save_tensor_hook_version_counter_not_shared", # assertEqual + "test_saved_tensors_hook_version_counter_not_shared", # assertEqual "test_post_accumulate_grad_hook_returns_not_None", # throws "test_custom_function_cycle", # assertEqual "test_mark_non_differentiable_mixed", # assertTrue @@ -2808,7 +2809,8 @@ def wrap_test_class(orig_cls): "test_custom_function_forward_mode_inplace_checks", # forward AD "test_custom_function_forward_mode_view_checks", # forward AD "test_custom_function_forward_mode_wrong_formula", # forward AD - "test_default_saved_variable_hooks_double_backward", # create_graph + "test_default_saved_tensors_hooks_double_backward", # create_graph + "test_node_post_hook_registered_during_unpack_hook", # 'NoneType' object has no attribute 'register_hook' "test_full_backward_hook_double_backward", # create_graph "test_function", # create_graph "test_grad", # create_graph @@ -2865,7 +2867,7 @@ def wrap_test_class(orig_cls): "test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173 # Category: FakeTensor "test_saving_variable_to_disk", # torch.save should no-op and be recorded in the graph - "test_wrapped_number_saved_variable_hooks", # Proxy tensor should carryover is_wrapped_number_ of its original + "test_wrapped_number_saved_tensors_hooks", # Proxy tensor should carryover is_wrapped_number_ of its original "test_grad_batched_grad", # torch._subclasses.fake_tensor.UnsupportedFakeTensorException: meta converter nyi "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads # Category: Divergence from eager diff --git a/test/test_autograd.py b/test/test_autograd.py index 6c3ac3723042e2..c7fabb725082c1 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1936,6 +1936,50 @@ def prehook2(grad_output): self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) self.assertEqual(counter[0], 1) + def test_node_post_hook_registered_during_unpack_hook(self): + """ + Test that post hooks registered during one of the node's + unpack hooks are properly restricted and will run properly. + """ + test_case = self + + class RegisterPostNodeHook(torch.autograd.graph.saved_tensors_hooks): + def __init__(self) -> None: + def pack_tensor(tensor: torch.Tensor) -> torch.Tensor: + return tensor + + def unpack_tensor(tensor: torch.Tensor) -> torch.Tensor: + node = torch._C._current_autograd_node() + + def hook(outputs, inputs): + # Assert that inputs passed in are None + test_case.assertTrue(all(i is None for i in inputs)) + halved_outputs = tuple( + o / 2.0 if o is not None else None for o in outputs + ) + return halved_outputs + + node.register_hook(hook) + return tensor + + super().__init__(pack_tensor, unpack_tensor) + + a = torch.rand(3, 3, requires_grad=True) + + def model(): + var, mean = torch.var_mean(a, dim=0) + loss = (var + mean).sum() + loss.backward() + + model() + ref_grad = a.grad.clone() + + with RegisterPostNodeHook(): + model() + + # Verify that the post hook got called and the grad propagation worked + self.assertEqual(ref_grad / 2.0 + ref_grad, a.grad) + def test_hooks_cpp(self): # Tests hooks for autograd function implemented in C++ bn = torch.nn.BatchNorm1d(5, affine=False) @@ -9463,7 +9507,7 @@ def test_disabling_saved_tensor_hooks_nested(self): self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled()) - def test_saved_tensor_hooks_custom_error_propagaation(self): + def test_saved_tensor_hooks_custom_error_propagation(self): class CustomError(Exception): pass @@ -9541,7 +9585,7 @@ def unpack_hook(x): self.assertEqual(pack_count, 1) self.assertEqual(unpack_count, 1) - def test_save_tensor_hook_version_counter_not_shared(self): + def test_saved_tensors_hook_version_counter_not_shared(self): class Test(torch.autograd.Function): @staticmethod def forward(ctx, x): @@ -9638,7 +9682,7 @@ def unpack(name): y.sum().backward() self.assertEqual(2 * a, a.grad) - def test_default_saved_variable_hooks_double_backward(self): + def test_default_saved_tensors_hooks_double_backward(self): with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): a = torch.randn(5, requires_grad=True) y = a**3 @@ -9676,7 +9720,7 @@ def test_default_saved_variable_hooks_double_backward(self): # note that in that sense, a is saved twice self.assertEqual(6 * 8 * a, a.grad) - def test_wrapped_number_saved_variable_hooks(self): + def test_wrapped_number_saved_tensors_hooks(self): def err_hook(x): raise RuntimeError("this hook should not be called") diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 6fbf797d26889f..4792ced32b2703 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -108,6 +108,13 @@ def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle: See :ref:`backward-hooks-execution` for more information on how when this hook is executed, and how its execution is ordered relative to other hooks. + .. note:: + In the rare case where the hook is registered while the Node has already + begun execution, there is no longer any guarantee on :attr:`grad_outputs` + content (it might be as usual or empty depending on other factors). The + hook can still optionally return a new gradient to be used in place of + :attr:`grad_inputs` independent of :attr:`grad_outputs`. + Example:: >>> import torch diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index de13aac33b38e8..1be6242909af71 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -819,9 +819,15 @@ static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) { static variable_list call_post_hooks( Node& fn, variable_list outputs, - const variable_list& inputs) { + const variable_list& inputs, + const bool had_post_hooks) { for (const auto& hook : fn.post_hooks()) { - outputs = (*hook)(outputs, inputs); + if (had_post_hooks) { + outputs = (*hook)(outputs, inputs); + } else { + variable_list null_inputs; + outputs = (*hook)(outputs, null_inputs); + } } return outputs; } @@ -976,11 +982,8 @@ static variable_list call_function( return ss.str(); }); - if (has_post_hooks) { - // NOLINTNEXTLINE(bugprone-use-after-move) - return call_post_hooks(fn, std::move(outputs), inputs); - } - return outputs; + // NOLINTNEXTLINE(bugprone-use-after-move) + return call_post_hooks(fn, std::move(outputs), inputs, has_post_hooks); } void Engine::evaluate_function( From 7b98a37907cf4c9fb8252fd5ef16b0263918b259 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 6 Sep 2024 19:54:23 +0000 Subject: [PATCH 0568/1018] [TCPStore] use wait counters (#135283) This replaces the existing TCPStore counters with the new shared wait counters. There's no users of the tcpstore counters so should be completely safe to remove. Test plan: Existing tests + build There's no OSS backend for wait counters so can't write any tests with them currently. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135283 Approved by: https://github.com/c-p-i-o --- torch/csrc/distributed/c10d/TCPStore.cpp | 81 ++++-------------------- torch/csrc/distributed/c10d/TCPStore.hpp | 28 -------- torch/csrc/distributed/c10d/init.cpp | 5 -- 3 files changed, 14 insertions(+), 100 deletions(-) diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 4bd3b28ec1e24c..2855ced299201c 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -33,53 +34,6 @@ namespace c10d { namespace detail { -class timing_guard { - Counter& counter_; - typedef std::chrono::time_point - time_point; - time_point start_; - - public: - timing_guard(Counter& counter) - : counter_(counter), start_(std::chrono::high_resolution_clock::now()) {} - - ~timing_guard() { - stop(); - } - - void stop() { - if (start_ != time_point()) { - auto diff = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start_) - .count(); - counter_.update(diff); - start_ = time_point(); - } - } -}; - -void Counter::update(double val) { - count_ += 1; - - auto delta = val - mean_; - mean_ += delta / count_; - - auto delta2 = val - mean_; - m2_ += delta2 * delta2; -} - -std::unordered_map Counter::observe() const { - std::unordered_map res; - res["count"] = (double)count_; - res["mean"] = mean_; - if (count_ >= 2) { - res["sample_variance"] = m2_ / (count_ - 1); - } else { - res["sample_variance"] = std::nan("1"); - } - return res; -} - // Manages the lifecycle of a server daemon. class TCPServer { public: @@ -317,6 +271,8 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) addr_{std::move(host)}, numWorkers_{opts.numWorkers}, usingLibUv_{opts.useLibUV} { + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__init); + if (opts.useLibUV) { TORCH_CHECK( ::c10d::detail::is_libuv_tcpstore_backend_available(), @@ -421,7 +377,7 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) TCPStore::~TCPStore() = default; void TCPStore::waitForWorkers() { - detail::timing_guard tguard(clientCounters_["waitForWorkers"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__waitForWorkers); if (numWorkers_ == std::nullopt) { return; } @@ -491,7 +447,7 @@ void TCPStore::_splitSet( } void TCPStore::set(const std::string& key, const std::vector& data) { - detail::timing_guard tguard(clientCounters_["set"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__set); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::SET); buffer.appendString(keyPrefix_ + key); @@ -503,7 +459,7 @@ std::vector TCPStore::compareSet( const std::string& key, const std::vector& expectedValue, const std::vector& desiredValue) { - detail::timing_guard tguard(clientCounters_["compareSet"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__compareSet); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::COMPARE_SET); buffer.appendString(keyPrefix_ + key); @@ -515,7 +471,7 @@ std::vector TCPStore::compareSet( } std::vector TCPStore::get(const std::string& key) { - detail::timing_guard tguard(clientCounters_["get"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__get); const std::lock_guard lock(activeOpLock_); return doGet(keyPrefix_ + key); } @@ -530,13 +486,13 @@ std::vector TCPStore::doGet(const std::string& key) { } int64_t TCPStore::add(const std::string& key, int64_t value) { - detail::timing_guard tguard(clientCounters_["add"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__add); const std::lock_guard lock(activeOpLock_); return incrementValueBy(keyPrefix_ + key, value); } bool TCPStore::deleteKey(const std::string& key) { - detail::timing_guard tguard(clientCounters_["deleteKey"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__delete); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::DELETE_KEY); buffer.appendString(keyPrefix_ + key); @@ -564,7 +520,7 @@ int64_t TCPStore::getNumKeys() { } bool TCPStore::check(const std::vector& keys) { - detail::timing_guard tguard(clientCounters_["check"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__check); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::CHECK); buffer.appendValue(keys.size()); @@ -591,7 +547,7 @@ void TCPStore::wait(const std::vector& keys) { void TCPStore::wait( const std::vector& keys, const std::chrono::milliseconds& timeout) { - detail::timing_guard tguard(clientCounters_["wait"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__wait); const std::lock_guard lock(activeOpLock_); std::vector prefixedKeys{}; prefixedKeys.reserve(keys.size()); @@ -652,7 +608,7 @@ void TCPStore::doWait( void TCPStore::append( const std::string& key, const std::vector& data) { - detail::timing_guard tguard(clientCounters_["append"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__append); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::APPEND); buffer.appendString(keyPrefix_ + key); @@ -662,7 +618,7 @@ void TCPStore::append( std::vector> TCPStore::multiGet( const std::vector& keys) { - detail::timing_guard tguard(clientCounters_["multiGet"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiGet); const std::lock_guard lock(activeOpLock_); std::vector prefixedKeys; prefixedKeys.reserve(keys.size()); @@ -689,7 +645,7 @@ std::vector> TCPStore::multiGet( void TCPStore::multiSet( const std::vector& keys, const std::vector>& values) { - detail::timing_guard tguard(clientCounters_["multiSet"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiSet); TORCH_CHECK( keys.size() == values.size(), "multiSet keys and values vectors must be of same size"); @@ -708,15 +664,6 @@ bool TCPStore::hasExtendedApi() const { return true; } -std::unordered_map> -TCPStore::collectClientCounters() const noexcept { - std::unordered_map> res; - for (const auto& kv : clientCounters_) { - res[kv.first] = kv.second.observe(); - } - return res; -} - std::string TCPStore::repr() const { auto clientRepr = client_ ? client_->repr() : ""; auto serverRepr = server_ ? server_->repr() : ""; diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 015d134e983fe3..5ac48d8205e5da 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -18,30 +18,6 @@ struct SocketAddress { std::uint16_t port{}; }; -class Counter { - public: - void update(double val); - std::unordered_map observe() const; - - double mean() const noexcept { - return mean_; - } - int64_t count() const noexcept { - return count_; - } - double variance() const noexcept { - return m2_ / static_cast(count_); - } - double sample_variance() const noexcept { - return m2_ / static_cast(count_ - 1); - } - - private: - int64_t count_ = 0; - double mean_ = 0; - double m2_ = 0; -}; - } // namespace detail struct TCPStoreOptions { @@ -130,9 +106,6 @@ class TORCH_API TCPStore : public Store { return addr_.port; } - std::unordered_map> - collectClientCounters() const noexcept; - bool isLibUvBackend() const noexcept { return usingLibUv_; } @@ -162,7 +135,6 @@ class TORCH_API TCPStore : public Store { const std::string initKey_ = "init/"; const std::string keyPrefix_ = "/"; std::mutex activeOpLock_; - std::unordered_map clientCounters_; bool usingLibUv_ = true; }; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index eced2e685fcf85..60d2a9680e814e 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1556,15 +1556,10 @@ Example:: py::arg("master_listen_fd") = py::none(), py::arg("use_libuv") = true, py::call_guard()) - .def( - "collect_client_counters", - &::c10d::TCPStore::collectClientCounters, - "Return a dict of counters for tcp store client") .def_property_readonly( "host", &::c10d::TCPStore::getHost, R"(Gets the hostname on which the store listens for requests.)") - .def_property_readonly( "port", &::c10d::TCPStore::getPort, From 183ca1b0cad09d4751c3d12fef0948ba0398fe76 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 6 Sep 2024 20:27:50 +0000 Subject: [PATCH 0569/1018] [ONNX] Clean up the missed lines from previous PRs (#135368) Some missed deleted lines Pull Request resolved: https://github.com/pytorch/pytorch/pull/135368 Approved by: https://github.com/justinchuby --- test/onnx/test_fx_to_onnx.py | 1 - torch/onnx/_internal/_exporter_legacy.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index d7a9b372d57384..1946ba586ac0d2 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -736,7 +736,6 @@ def forward(self, tensor_x: torch.Tensor): onnx_program.save( tmp_onnx_file.name, include_initializers=include_initializer, - model_state=state_dict if include_initializer else None, ) onnx_model = onnx.load(tmp_onnx_file.name) self.assertEqual( diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index cf5ee7d1bbd0a6..3c05b350e8c64f 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -821,10 +821,6 @@ def save( ) else: try: - if not isinstance(self.model_proto, onnx.ModelProto): # type: ignore[attr-defined] - raise ValueError( - "onnx_program.ModelProto is not an onnx.ModelProto" - ) destination.write(self.model_proto.SerializeToString()) except ValueError as exc: raise ValueError( From 09e429af97866f5529cbd8963535ab50f4eb93bd Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Fri, 6 Sep 2024 20:30:47 +0000 Subject: [PATCH 0570/1018] [test][easy] Add debug utils for cpu select algorithm test (#135038) Summary: Add debug utils to debug a flaky test in fbcode ci. Some context: https://github.com/pytorch/pytorch/pull/126545 Test Plan: ci Differential Revision: D62005445 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135038 Approved by: https://github.com/jgong5, https://github.com/XuehaiPan --- test/inductor/test_cpu_select_algorithm.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index ca1d4f12d6a4f4..813be61c951954 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -1,6 +1,8 @@ # Owner(s): ["oncall: cpu inductor"] import contextlib import functools +import logging +import os import sys import unittest from typing import Optional @@ -25,6 +27,9 @@ from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL +log = logging.getLogger(__name__) + + try: try: from . import test_cpu_repro, test_torchinductor @@ -264,6 +269,19 @@ def __init__(self, bias, epilogue, other): def forward(self, x): return self.epilogue(self.linear(x)) + # TODO: debug utils, safe to remove in Oct 2024 + if inductor_config.is_fbcode(): + log.warning( + f"DEBUG: torch.backends.mkl.is_available() is {torch.backends.mkl.is_available()}, " # noqa: G004 + f"torch.ops.mkldnn._is_mkldnn_fp16_supported() is {torch.ops.mkldnn._is_mkldnn_fp16_supported()}, " + f"torch.ops.mkldnn._is_mkldnn_bf16_supported() is {torch.ops.mkldnn._is_mkldnn_bf16_supported()}, " + f"inductor_config.freezing is {inductor_config.freezing}, " + f"mkldnn._is_mkldnn_acl_supported() is {torch.ops.mkldnn._is_mkldnn_acl_supported()}, " + f"torch._C.has_mkl is {torch._C.has_mkl}, " + f"PYTORCH_TEST_FBCODE is {os.getenv('PYTORCH_TEST_FBCODE')}, " + f"PYTORCH_TEST_REMOTE_GPU is {os.getenv('PYTORCH_TEST_REMOTE_GPU')}, " + ) + counters.clear() v = torch.randn(batch_size, in_features).to(dtype=dtype) u = torch.randn(batch_size, out_features).to(dtype=dtype) From f70651474f297e047ab48d56ceaabbe800c3dff7 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 3 Sep 2024 16:54:04 -0700 Subject: [PATCH 0571/1018] [dynamo] reland map/zip iterator related changes (#135074) Differential Revision: [D62211019](https://our.internmc.facebook.com/intern/diff/D62211019) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135074 Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/mlazos --- test/dynamo/test_functions.py | 218 +++++++++++++++++++++++- test/dynamo/test_repros.py | 8 +- torch/_dynamo/symbolic_convert.py | 18 +- torch/_dynamo/utils.py | 8 +- torch/_dynamo/variables/__init__.py | 2 + torch/_dynamo/variables/base.py | 13 ++ torch/_dynamo/variables/builtin.py | 105 +++++++----- torch/_dynamo/variables/constant.py | 8 + torch/_dynamo/variables/dicts.py | 1 + torch/_dynamo/variables/iter.py | 212 ++++++++++++++++++++++- torch/_dynamo/variables/lists.py | 33 +++- torch/_dynamo/variables/user_defined.py | 6 +- 12 files changed, 554 insertions(+), 78 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index a0e3369ed73c7e..e936e66f23f8ab 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -239,6 +239,22 @@ def test_itertools_chain_from_iterable(a, b): v = v + x return v + def test_itertools_reconstruct(self): + def fn(a): + it1 = itertools.repeat(1) + it2 = itertools.count(2) + for _ in range(3): + a += next(it1) + a += next(it2) + return it1, it2, a + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + i1, i2, a = fn(torch.ones(3, 3)) + it1, it2, b = opt_fn(torch.ones(3, 3)) + self.assertEqual(next(i1), next(it1)) + self.assertEqual(next(i2), next(it2)) + self.assertEqual(a, b) + @make_test def test_obj_eq(a, b): v = a + b @@ -507,8 +523,7 @@ def test_deque(a, b): empty = collections.deque() d.extend(empty) - # dynamo same() util doesn't support deque so just return a list - return list(d) + return d @make_test def test_slice1(a): @@ -3115,6 +3130,199 @@ def fn(a, ind, val): fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]]) ) + def test_map_return(self): + def fn(a, b): + return map(lambda x: x + 1, [a, b]) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.randn(3, 3), torch.randn(3, 3)) + self.assertIsInstance(m, map) + + @make_test + def test_map_max(a, b): + return max(map(lambda x: x.sum(), [a, b])) + + # max(map(...)) graph breaks + @unittest.expectedFailure + @make_test + def test_map_max_const(a): + return max(map(lambda x: x, [1, 2, 3])), a + 1 + + @make_test + def test_map_list(a, b): + return list(map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_tuple(a, b): + return tuple(map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_iter(a, b): + it = iter(map(lambda x: x + 1, [a, b])) + return next(it) + + @make_test + def test_map_zip_dict(a): + d = dict( + zip( + map(lambda x: x + 1, [0, 1, 2]), + [map(lambda x: x - 1, [y]) for y in [3, 4, 5]], + ) + ) + return list(d[3])[0], a + 1 # noqa: RUF015 + + @make_test + def test_map_dict_fromkeys(a): + return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1 + + @make_test + def test_map_set(a): + return set(map(lambda x: x + 1, [0, 1])), a + 1 + + # test_map_sum defined earlier + + @make_test + def test_map_reduce(a, b): + return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_sorted(a): + return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1 + + @make_test + def test_map_list_extend(a, b, c): + l = [a] + l.extend(map(lambda x: x + 1, [b, c])) + return l + + @make_test + def test_map_list_slice_assign(a, b, c, d, e): + l = [a, b, c] + l[1:2] = map(lambda x: x + 1, [d, e]) + return l + + @make_test + def test_map_deque_extendleft(a, b, c): + d = collections.deque([a]) + d.extendleft(map(lambda x: x + 1, [b, c])) + return d + + @make_test + def test_map_str_join(a): + return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1 + + def test_map_with_graph_break(self): + def f(a): + a += 1 + + def g(x): + nonlocal a + a += 1 + return x + 1 + + m = map(g, [1, 2, 3, 4, 5]) + a += next(m) # won't graph break + torch._dynamo.graph_break() + a += next(m) # will graph break + return a + + cnts = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=cnts) + self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3))) + self.assertEqual(cnts.frame_count, 3) + + def test_map_reconstruct(self): + def fn(a): + return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1 + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.ones(3, 3))[0] + self.assertIsInstance(m, map) + self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) + + def test_zip_reconstruct(self): + def fn(a): + return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1 + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.ones(3, 3))[0] + self.assertIsInstance(m, zip) + self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) + + @make_test + def test_map_partial_unpack(a, b): + y = 1 + + def f(x): + nonlocal y + y += 1 + return x + + l = list(zip([a, b], map(f, [1, 2, 3, 4]))) + return a + y + + @make_test + def test_map_call_function_ex(a, b): + def f(x, y): + return x + y + + return f(*map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_unpack_twice(a, b): + m = map(lambda x: x + 1, [a, b]) + l1 = list(m) + l2 = list(m) + return l1, l2 + + @make_test + def test_enumerate(a, b): + return list(enumerate([a, b], start=1)), a + 1 + + @make_test + def test_map_enumerate(a, b): + return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1 + + @make_test + def test_map_infinite(a, b): + return list(map(lambda x, y: x + y, [a, b], itertools.count(3))) + + @make_test + def test_map_unpack_vars(a, b): + x, y = map(lambda x: x + 1, [a, b]) + return x + y + + def test_enumerate_custom(self): + class MyClass: + def __iter__(self): + self.a = 1 + return self + + def __next__(self): + if self.a > 3: + raise StopIteration + self.a += 1 + return self.a + + def fn(x): + for i, it in enumerate(MyClass()): + x += i + it + return x + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3))) + + def test_enumerate_reconstruct(self): + def fn(a, b): + return enumerate([a, b], start=1) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + inps = (torch.randn(3, 3), torch.randn(3, 3)) + it1 = fn(*inps) + it2 = opt_fn(*inps) + self.assertIsInstance(it2, enumerate) + self.assertEqual(list(it1), list(it2)) + def udf_mul(x, y): return x * y @@ -3670,10 +3878,16 @@ def fn(x, ys, zs): with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): nopython_fn(x, ys[:1], zs) + with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): + nopython_fn(x, ys, zs[:1]) + # Should cause fallback if allow graph break with self.assertRaisesRegex(ValueError, "zip()"): opt_fn(x, ys[:1], zs) + with self.assertRaisesRegex(ValueError, "zip()"): + opt_fn(x, ys, zs[:1]) + def test_fn_with_attr(self): def fn(x): if fn.pred: diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 19e9f7d9617240..a3c4226ed6a9f8 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5476,15 +5476,17 @@ def gen_inps(len_x, len_y): return x, y def g(x, y): - return tuple(map(f, x, y)) + return map(f, x, y) opt_g = torch.compile(g, fullgraph=True, backend="eager") inps = gen_inps(3, 3) - self.assertEqual(g(*inps), opt_g(*inps)) + self.assertEqual(type(g(*inps)), type(opt_g(*inps))) + self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) inps = gen_inps(3, 5) - self.assertEqual(g(*inps), opt_g(*inps)) + self.assertEqual(type(g(*inps)), type(opt_g(*inps))) + self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) def test_staticmethod_allow_in_graph(self): class MyClass: diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index e605e57b6aa49c..ab92f82aa0f630 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1663,8 +1663,8 @@ def CALL_FUNCTION_EX(self, inst): if not isinstance( argsvars, BaseListVariable - ) and argsvars.has_unpack_var_sequence(self): - argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) + ) and argsvars.has_force_unpack_var_sequence(self): + argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) # Unpack for cases like fn(**obj) where obj is a map if isinstance(kwargsvars, UserDefinedObjectVariable): @@ -1833,7 +1833,7 @@ def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): items = [] for seq in seqs: try: - items.extend(seq.unpack_var_sequence(self)) + items.extend(seq.force_unpack_var_sequence(self)) except NotImplementedError: unimplemented(f"BUILD_LIST_UNPACK {seq}") self.push(cls(items, mutable_local=MutableLocal())) @@ -1871,7 +1871,7 @@ def BUILD_CONST_KEY_MAP(self, inst): assert isinstance(keys, TupleVariable) assert keys.is_python_constant() - keys = keys.unpack_var_sequence(self) + keys = keys.force_unpack_var_sequence(self) assert len(keys) == len(values) self.push( @@ -1961,8 +1961,8 @@ def UNPACK_SEQUENCE(self, inst): # x, y = a.shape proxy = getattr(seq.obj.as_proxy(), seq.name) val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] - elif seq.has_unpack_var_sequence(self): - val = seq.unpack_var_sequence(self) + elif seq.has_force_unpack_var_sequence(self): + val = seq.force_unpack_var_sequence(self) else: unimplemented(f"UNPACK_SEQUENCE {seq}") if len(val) != inst.argval: @@ -1975,8 +1975,8 @@ def UNPACK_EX(self, inst): prefix = inst.argval & 0xFF # low byte suffix = inst.argval >> 8 # high byte seq = self.pop() - if seq.has_unpack_var_sequence(self): - vals = list(seq.unpack_var_sequence(self)) + if seq.has_force_unpack_var_sequence(self): + vals = list(seq.force_unpack_var_sequence(self)) assert len(vals) >= prefix + suffix vals_prefix = vals[:prefix] vals_list = vals[prefix : len(vals) - suffix] @@ -2400,7 +2400,7 @@ def CALL_INTRINSIC_1(self, inst): self.UNARY_POSITIVE(inst) elif inst.argval == 6: # INTRINSIC_LIST_TO_TUPLE - self.push(TupleVariable(self.pop().unpack_var_sequence(self))) + self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) else: unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 627ada1ff88a1d..0ec548d788f868 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1608,8 +1608,12 @@ def same( """Check correctness to see if ref and res match""" if fp64_ref is None: fp64_ref = ref - if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): - assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" + if isinstance( + ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size) + ): + assert isinstance( + res, (list, tuple, collections.deque) + ), f"type mismatch {type(ref)} {type(res)}" if len(ref) != len(res): log_error("Length mismatch") return False diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 021233cd6e235f..d522d773e6e865 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -46,7 +46,9 @@ CycleIteratorVariable, IteratorVariable, ItertoolsVariable, + MapVariable, RepeatIteratorVariable, + ZipVariable, ) from .lazy import LazyVariableTracker from .lists import ( diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 25b4ffad99e7b3..723c5a90c66ac6 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -289,6 +289,15 @@ def can_reconstruct(self, tx): def unpack_var_sequence(self, tx) -> List["VariableTracker"]: raise NotImplementedError + def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]: + # like unpack_var_sequence, but should only be used when it is + # safe to eagerly (vs. lazily) unpack this variable. + # e.g. map(f, x) is normally evaluated lazily but sometimes + # we want to force eager unpacking, e.g. when converting to a list. + # NOTE: this method is allowed to mutate the VariableTracker, so + # it should only be called once. + return self.unpack_var_sequence(tx) + def has_unpack_var_sequence(self, tx) -> bool: try: self.unpack_var_sequence(tx) @@ -296,6 +305,10 @@ def has_unpack_var_sequence(self, tx) -> bool: except NotImplementedError: return False + # NB: don't call force_unpack_var_sequence, especially if it mutates! + def has_force_unpack_var_sequence(self, tx) -> bool: + return self.has_unpack_var_sequence(tx) + def inspect_parameter_names(self) -> List[str]: unimplemented(f"inspect_parameter_names: {self}") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 6ffd820c3f1320..b6ff05e429d12c 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1058,9 +1058,8 @@ def call_str(self, tx: "InstructionTranslator", arg): return tx.inline_user_function_return(user_func_variable, [arg], {}) def _call_min_max(self, tx: "InstructionTranslator", *args): - if len(args) == 1 and args[0].has_unpack_var_sequence(tx): - # expand iterable - items = args[0].unpack_var_sequence(tx) + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) return self._call_min_max_seq(tx, items) elif len(args) == 2: return self._call_min_max_binary(tx, args[0], args[1]) @@ -1075,6 +1074,10 @@ def _call_min_max_seq(self, tx: "InstructionTranslator", items): return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) def _call_min_max_binary(self, tx: "InstructionTranslator", a, b): + if a is None or b is None: + # a or b could be none if we reduce and _call_min_max_binary failed + # to return something + return if self.tensor_args(a, b): if not isinstance(a, variables.TensorVariable): a, b = b, a @@ -1223,17 +1226,15 @@ def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs): ), ) + # NOTE must handle IteratorVariable separately! def _call_iter_tuple_list( self, tx: "InstructionTranslator", obj=None, *args, **kwargs ): + assert not isinstance(obj, variables.IteratorVariable) + if self._dynamic_args(*args, **kwargs): return self._dyn_proxy(tx, *args, **kwargs) - if isinstance(obj, variables.IteratorVariable): - # For non-list iterators, we will guard on vars that - # determine the control flow - return obj - cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( @@ -1261,9 +1262,22 @@ def _call_iter_tuple_list( mutable_local=MutableLocal(), ) + def _call_tuple_list(self, tx, obj=None, *args, **kwargs): + if isinstance(obj, variables.IteratorVariable): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), + mutable_local=MutableLocal(), + ) + else: + return self._call_iter_tuple_list(tx, obj, *args, **kwargs) + def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): - # Handle the case where we are iterating over a tuple, list or iterator - ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) + if isinstance(obj, variables.IteratorVariable): + ret = obj + else: + # Handle the case where we are iterating over a tuple, list or iterator + ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) if ret is None: # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. @@ -1272,8 +1286,8 @@ def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): return obj.call_method(tx, "__iter__", args, kwargs) return ret - call_tuple = _call_iter_tuple_list - call_list = _call_iter_tuple_list + call_tuple = _call_tuple_list + call_list = _call_tuple_list def call_callable(self, tx: "InstructionTranslator", arg): from .functions import BaseUserFunctionVariable @@ -1331,10 +1345,12 @@ def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): ListVariable, TupleVariable, ListIteratorVariable, + variables.IteratorVariable, ), ): items = dict( - x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx) + x.force_unpack_var_sequence(tx) + for x in arg.force_unpack_var_sequence(tx) ) return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) elif isinstance(arg, variables.MutableMappingVariable): @@ -1391,13 +1407,12 @@ def call_custom_dict_fromkeys( return DictVariableType( dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal() ) - elif arg.has_unpack_var_sequence(tx) and all( - is_hashable(v) for v in arg.unpack_var_sequence(tx) - ): - keys = arg.unpack_var_sequence(tx) - return DictVariableType( - dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() - ) + elif arg.has_force_unpack_var_sequence(tx): + keys = arg.force_unpack_var_sequence(tx) + if all(is_hashable(v) for v in keys): + return DictVariableType( + dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() + ) unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}") def call_set(self, tx: "InstructionTranslator", *args, **kwargs): @@ -1409,8 +1424,8 @@ def call_set(self, tx: "InstructionTranslator", *args, **kwargs): arg = args[0] if isinstance(arg, variables.SetVariable): return arg.clone(mutable_local=MutableLocal()) - elif arg.has_unpack_var_sequence(tx): - items = arg.unpack_var_sequence(tx) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) return SetVariable(items, mutable_local=MutableLocal()) elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( arg.value, KeysView @@ -1443,16 +1458,12 @@ def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs): def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): if kwargs: assert len(kwargs) == 1 and "strict" in kwargs - if all(x.has_unpack_var_sequence(tx) for x in args): - unpacked = [arg.unpack_var_sequence(tx) for arg in args] - if kwargs.pop("strict", False) and len(unpacked) > 0: - if not all(len(u) == len(unpacked[0]) for u in unpacked): - raise UserError( - ValueError, - "zip() has one argument of len differing from others", - ) - items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] - return variables.TupleVariable(items) + strict = kwargs.pop("strict", False) + args = [ + arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg + for arg in args + ] + return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal()) def call_len(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) @@ -1553,10 +1564,11 @@ def call_hasattr(self, tx: "InstructionTranslator", obj, attr): return obj.call_hasattr(tx, name) def call_map(self, tx: "InstructionTranslator", fn, *seqs): - if all(seq.has_unpack_var_sequence(tx) for seq in seqs): - unpacked = [seq.unpack_var_sequence(tx) for seq in seqs] - items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)] - return variables.TupleVariable(items) + seqs = [ + seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + for seq in seqs + ] + return variables.MapVariable(fn, seqs, mutable_local=MutableLocal()) def call_filter(self, tx: "InstructionTranslator", fn, seq): if seq.has_unpack_var_sequence(tx): @@ -1589,10 +1601,10 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): return variables.ConstantVariable.create( sum((x.value for x in seq.items), start=start.value), ) - if seq.has_unpack_var_sequence(tx): + if seq.has_force_unpack_var_sequence(tx): if start is self._SENTINEL: start = variables.ConstantVariable.create(0) - items = seq.unpack_var_sequence(tx) + items = seq.force_unpack_var_sequence(tx) return BuiltinVariable(functools.reduce).call_function( tx, [ @@ -1606,8 +1618,8 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): def call_reduce( self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL ): - if iterable.has_unpack_var_sequence(tx): - items = iterable.unpack_var_sequence(tx) + if iterable.has_force_unpack_var_sequence(tx): + items = iterable.force_unpack_var_sequence(tx) if initial is self._SENTINEL: value, items = items[0], items[1:] else: @@ -1903,11 +1915,12 @@ def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker): return variables.TupleVariable(items) def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs): - if ( - obj.has_unpack_var_sequence(tx) - and not isinstance(obj, variables.TensorVariable) - and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx)) + if obj.has_force_unpack_var_sequence(tx) and not isinstance( + obj, variables.TensorVariable ): + unpacked = obj.force_unpack_var_sequence(tx) + if not all(x.is_python_constant() for x in unpacked): + return function = kwargs.pop("key", None) reverse = kwargs.pop( "reverse", ConstantVariable.create(False) @@ -1915,7 +1928,7 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg assert len(kwargs) == 0 if function: items = sorted( - obj.unpack_var_sequence(tx), + unpacked, key=lambda x: function.call_function( tx, [x], {} ).as_python_constant(), @@ -1923,7 +1936,7 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg ) else: items = sorted( - obj.unpack_var_sequence(tx), + unpacked, key=lambda x: x.as_python_constant(), reverse=reverse, ) diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 0028ecd81decda..65de0aab6ef3c7 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -145,6 +145,14 @@ def call_method( return variables.BuiltinVariable(str.format).call_function( tx, [self, *args], kwargs ) + elif name == "join" and istype(self.value, str): + assert len(args) == 1 and len(kwargs) == 0 + arg_unpacked = args[0].force_unpack_var_sequence(tx) + try: + arg_const = [x.as_python_constant() for x in arg_unpacked] + return ConstantVariable.create(self.value.join(arg_const)) + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) if any(isinstance(x, SymNodeVariable) for x in args): # Promote to SymNodeVariable for operations involving dynamic shapes. diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 6c69364ca57ee9..10b55c1b7e96eb 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -314,6 +314,7 @@ def call_method( ListVariable, TupleVariable, ListIteratorVariable, + variables.IteratorVariable, UserDefinedObjectVariable, ), ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 192f25a0c6b677..1f8dac8811f53a 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -2,14 +2,17 @@ import itertools import operator -from typing import Dict, List, Optional, TYPE_CHECKING +import sys +from typing import Dict, List, Optional, TYPE_CHECKING, Union from .. import polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction from ..exc import ( handle_observed_exception, ObservedUserStopIteration, raise_observed_exception, unimplemented, + UserError, ) from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -197,6 +200,25 @@ def __init__(self, **kwargs) -> None: def next_variable(self, tx): unimplemented("abstract method, must implement") + # NOTE: only call when unpacking this iterator safely done eagerly! + # Normally, iterators are accessed lazily. + # Example of safe eager unpacking: list(map(f, seq)) + # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) + def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + result = [] + while True: + try: + result.append(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + return result + + # don't call force_unpack_var_sequence since it can mutate + # IteratorVariable state! + def has_force_unpack_var_sequence(self, tx) -> bool: + return True + class RepeatIteratorVariable(IteratorVariable): def __init__(self, item: VariableTracker, **kwargs) -> None: @@ -207,6 +229,18 @@ def __init__(self, item: VariableTracker, **kwargs) -> None: def next_variable(self, tx): return self.item + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("repeat"), + ] + ) + ) + codegen(self.item) + codegen.extend_output(create_call_function(1, False)) + class CountIteratorVariable(IteratorVariable): def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: @@ -220,10 +254,23 @@ def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: def next_variable(self, tx): assert self.mutable_local + old_item = self.item tx.output.side_effects.mutation(self) - next_item = self.item.call_method(tx, "__add__", [self.step], {}) - self.item = next_item - return self.item + self.item = self.item.call_method(tx, "__add__", [self.step], {}) + return old_item + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("count"), + ] + ) + ) + codegen(self.item) + codegen(self.step) + codegen.extend_output(create_call_function(2, False)) class CycleIteratorVariable(IteratorVariable): @@ -269,3 +316,160 @@ def next_variable(self, tx): return self.item else: raise_observed_exception(StopIteration, tx, self) + + +class ZipVariable(IteratorVariable): + """ + Represents zip(*iterables) + """ + + _nonvar_fields = { + "index", + "strict", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + iterables: List[Union[List[VariableTracker], VariableTracker]], + strict: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + assert isinstance(iterables, list) + # can be list[Variable] or VariableTracker (with next_variable implemented) + self.iterables = iterables + self.index = 0 + self.strict = strict + + def python_type(self): + return zip + + def has_unpack_var_sequence(self, tx) -> bool: + return all( + isinstance(it, list) or it.has_unpack_var_sequence(tx) + for it in self.iterables + ) + + def unpack_var_sequence(self, tx) -> List["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + iterables = [] + for it in self.iterables: + if isinstance(it, list): + iterables.append(it[self.index :]) + else: + iterables.append(it.unpack_var_sequence(tx)) + kwargs = {"strict": self.strict} if self.strict else {} + zipped = zip(*iterables, **kwargs) + return [variables.TupleVariable(list(var)) for var in zipped] + + def next_variable(self, tx): + assert self.mutable_local + old_index = self.index + args = [] + + def get_item(it): + if isinstance(it, list): + if old_index >= len(it): + raise_observed_exception(StopIteration, tx, self) + return it[old_index] + else: + return it.next_variable(tx) + + try: + for idx, it in enumerate(self.iterables): + args.append(get_item(it)) + except ObservedUserStopIteration: + if self.strict: + if idx == 0: + # all other iterables should be exhausted + for it in self.iterables: + try: + get_item(it) + except ObservedUserStopIteration: + handle_observed_exception(tx) + continue + # no ObservedUserStopIteration - fall through to UserError + break + else: + # all iterables exhausted, raise original error + raise + handle_observed_exception(tx) + raise UserError( + ValueError, + "zip() has one argument of len differing from others", + ) from None + raise + + tx.output.side_effects.mutation(self) + self.index += 1 + return variables.TupleVariable(args) + + def reconstruct_items(self, codegen): + for it in self.iterables: + if isinstance(it, list): + remaining_items = it[self.index :] + codegen.foreach(remaining_items) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(remaining_items)) + ) + else: + codegen(it) + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True + ) + self.reconstruct_items(codegen) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(self.iterables)) + ) + if sys.version_info >= (3, 10): + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + create_instruction("CALL_FUNCTION_EX", arg=1), + ] + ) + else: + codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0)) + + +class MapVariable(ZipVariable): + """ + Represents map(fn, *iterables) + """ + + def __init__( + self, + fn: VariableTracker, + iterables: List[Union[List[VariableTracker], VariableTracker]], + **kwargs, + ) -> None: + super().__init__(iterables, **kwargs) + self.fn = fn + + def python_type(self): + return map + + def has_unpack_var_sequence(self, tx) -> bool: + return False + + def next_variable(self, tx): + args = super().next_variable(tx) + return self.fn.call_function(tx, args.items, {}) + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True + ) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.extend_output( + [ + create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1), + create_instruction("CALL_FUNCTION_EX", arg=0), + ] + ) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 1fed456a8f7d62..30916e0b699698 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -29,6 +29,7 @@ from .base import MutableLocal, VariableTracker from .constant import ConstantVariable from .functions import UserFunctionVariable, UserMethodVariable +from .iter import IteratorVariable if TYPE_CHECKING: @@ -334,11 +335,11 @@ def call_method( name == "extend" and self.mutable_local and args - and args[0].has_unpack_var_sequence(tx) + and args[0].has_force_unpack_var_sequence(tx) ): assert not kwargs (arg,) = args - seq = arg.unpack_var_sequence(tx) + seq = arg.force_unpack_var_sequence(tx) tx.output.side_effects.mutation(self) self.items.extend(seq) return ConstantVariable.create(None) @@ -422,11 +423,13 @@ def call_method( key, value = args tx.output.side_effects.mutation(self) if isinstance(key, SliceVariable): - if not value.has_unpack_var_sequence(tx): + if not value.has_force_unpack_var_sequence(tx): unimplemented( f"Missing dynamo support for expanding {value} into a list for slice assignment." ) - self.items[key.as_python_constant()] = value.unpack_var_sequence(tx) + self.items[key.as_python_constant()] = value.force_unpack_var_sequence( + tx + ) else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) @@ -464,7 +467,12 @@ def reconstruct(self, codegen): ) ) codegen.foreach(self.items) - codegen.extend_output(create_call_function(len(self.items), False)) + codegen.extend_output( + [ + create_instruction("BUILD_LIST", arg=len(self.items)), + *create_call_function(1, False), + ] + ) def call_method( self, @@ -487,11 +495,15 @@ def call_method( tx.output.side_effects.mutation(self) self.items[key.as_python_constant()] = value return ConstantVariable.create(None) - elif name == "extendleft" and self.mutable_local: + elif ( + name == "extendleft" + and self.mutable_local + and args[0].has_force_unpack_var_sequence(tx) + ): assert not kwargs (arg,) = args - prefix = arg.unpack_var_sequence(tx) + prefix = arg.force_unpack_var_sequence(tx) prefix.reverse() tx.output.side_effects.mutation(self) self.items = prefix + list(self.items) @@ -802,10 +814,10 @@ def var_getattr(self, tx: "InstructionTranslator", name): return self.items[fields.index(name)] -class ListIteratorVariable(VariableTracker): +class ListIteratorVariable(IteratorVariable): _nonvar_fields = { "index", - *VariableTracker._nonvar_fields, + *IteratorVariable._nonvar_fields, } def __init__(self, items, index: int = 0, **kwargs) -> None: @@ -856,6 +868,9 @@ def as_python_constant(self): def unpack_var_sequence(self, tx): return list(self.items[self.index :]) + def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + return self.unpack_var_sequence(tx) + def reconstruct(self, codegen): remaining_items = self.items[self.index :] codegen.foreach(remaining_items) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0d53d6d21f5502..34e1d0d10c9f72 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -379,8 +379,8 @@ def call_function( elif self.value is collections.deque and not kwargs: if len(args) == 0: items = [] - elif len(args) == 1 and args[0].has_unpack_var_sequence(tx): - items = args[0].unpack_var_sequence(tx) + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) else: unimplemented("deque() with more than 1 arg not supported") return variables.lists.DequeVariable(items, mutable_local=MutableLocal()) @@ -749,7 +749,7 @@ def call_method( assert not (args or kwargs) items = [] keys = self.call_method(tx, "keys", [], {}) - for key in keys.unpack_var_sequence(tx): + for key in keys.force_unpack_var_sequence(tx): items.append( TupleVariable( [key, self.odict_getitem(tx, key)], From ce28d4c0e7427554c50434b9d4f60056100a2aaf Mon Sep 17 00:00:00 2001 From: Vadym Khortiuk Date: Fri, 6 Sep 2024 20:51:55 +0000 Subject: [PATCH 0572/1018] Revert expectFailureIf condition on tests with torch.compile on Windows (#134759) Fixes #134716 This PR reverts some changes introduced in https://github.com/pytorch/pytorch/commit/6eae5695466a9d698de303f82b7cdced63c30ea3 (#133987) torch.compile is not available on Windows, tests should be expected to fail. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134759 Approved by: https://github.com/malfet --- test/functorch/test_eager_transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index ba046be3d99324..a1bd52a2fbb808 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -5141,8 +5141,9 @@ def wrapper(*args, **kwargs): @markDynamoStrictTest class TestCompileTransforms(TestCase): @skipIfRocm(msg="test leaks memory on ROCm") + # torch.compile is not supported on Windows CUDA. # Triton only supports GPU with SM70 or later. - @expectedFailureIf(TEST_CUDA and not SM70OrLater) + @expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater)) def test_compile_vmap_hessian(self, device): # The model and inputs are a smaller version # of code at benchmark repo: From 07673619583f6562e9ea4804180ecabdd1be7bd0 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 6 Sep 2024 21:02:01 +0000 Subject: [PATCH 0573/1018] [ez][TD] Fix request for issue body returns None (#135389) I assumed it would be empty string if the body is empty, but its just None Pull Request resolved: https://github.com/pytorch/pytorch/pull/135389 Approved by: https://github.com/malfet --- tools/testing/target_determination/heuristics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/testing/target_determination/heuristics/utils.py b/tools/testing/target_determination/heuristics/utils.py index 17259756533d5c..7c408277559b1e 100644 --- a/tools/testing/target_determination/heuristics/utils.py +++ b/tools/testing/target_determination/heuristics/utils.py @@ -116,7 +116,7 @@ def get_issue_or_pr_body(number: int) -> str: # Despite the 'issues' in the link, this also works for PRs url = f"https://api.github.com/repos/pytorch/pytorch/issues/{number}" with urlopen(Request(url, headers=headers)) as conn: - body: str = json.loads(conn.read().decode())["body"] + body: str = json.loads(conn.read().decode())["body"] or "" return body From e51d942a8dca4203490717406ec96793eab4a8cd Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 6 Sep 2024 11:15:00 -0700 Subject: [PATCH 0574/1018] [dynamo] recursively skip frames when Dynamo cache limit is hit (#135144) Fixes https://github.com/pytorch/pytorch/pull/135144 and [T197117723](https://www.internalfb.com/intern/tasks/?t=197117723). In general, adds `SkipCodeRecursiveException` to Dynamo - when raised in Dynamo, convert_frame will return a `skip_code_recursive_flag` back to C Dynamo, signaling it to skip the current frame and all recursive calls. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135144 Approved by: https://github.com/jansel, https://github.com/anijain2305 --- test/dynamo/test_modules.py | 5 ++--- test/dynamo/test_recompiles.py | 19 +++++++++++++++++++ torch/_C/_dynamo/eval_frame.pyi | 5 +++++ torch/_dynamo/config.py | 3 +++ torch/_dynamo/convert_frame.py | 21 +++++++++++++++++++-- torch/_dynamo/exc.py | 8 ++++++++ torch/csrc/dynamo/eval_frame.c | 29 +++++++++++++++++++++++++++-- torch/csrc/dynamo/extra_state.cpp | 19 ++++++++++++------- torch/csrc/dynamo/extra_state.h | 4 +++- 9 files changed, 98 insertions(+), 15 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index d390e0bb2d857b..329b04fd7d810c 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1994,9 +1994,8 @@ def fn(x, mod): mod = Mod() opt_mod(torch.randn(5, 5), mod) - # fn compiles twice, and forward twice - # (since forward is inlined when fn is compiled) - self.assertEqual(cnts.frame_count, 4) + # fn compiles twice + self.assertEqual(cnts.frame_count, 2) @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True) def test_inline_inbuilt_nn_modules(self): diff --git a/test/dynamo/test_recompiles.py b/test/dynamo/test_recompiles.py index bab95c799a1263..f0cba5132cf3a5 100644 --- a/test/dynamo/test_recompiles.py +++ b/test/dynamo/test_recompiles.py @@ -315,6 +315,25 @@ def forward(self, x): model(x) self.assertEqual(counter.frame_count, 2) + @patch.object(torch._dynamo.config, "cache_size_limit", 2) + def test_no_recursive_compile_after_cache_limit_hit(self): + def f(x, n): + x = x + n + return g(x, n) + + def g(x, n): + x = x + n + return h(x, n) + + def h(x, n): + return x + n + + counter = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=counter, dynamic=False) + for i in range(10): + opt_f(torch.ones(3), i) + self.assertEqual(counter.frame_count, 2) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 6c4e4be3be6378..548fc1f59e0ff1 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -8,6 +8,11 @@ from torch._dynamo.types import DynamoCallback, DynamoGuardHook # exposes the same interface. _PyInterpreterFrame = NewType("_PyInterpreterFrame", types.FrameType) +# For typechecking +SkipCodeRecursiveFlag = NewType("SkipCodeRecursiveFlag", object) +# Flag returned by Dynamo tracer to indicate to Dynamo eval frame that we should skip frames recursively. +skip_code_recursive_flag: SkipCodeRecursiveFlag + def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def reset_code(code: types.CodeType) -> None: ... def unsupported(obj1: object, obj2: object) -> object: ... diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 7cec9d153d69a5..2ba29961af36e9 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -48,6 +48,9 @@ def is_fbcode(): # [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps accumulated_cache_size_limit = 256 +# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit +skip_code_recursive_on_cache_limit_hit = True + # whether or not to specialize on int inputs. This only has an effect with # dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int # inputs. Note that assume_static_by_default will also cause ints to get diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 070b6feacbb369..d15f5c4aa43dcf 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -33,6 +33,7 @@ from torch._logging import structured from torch._utils_internal import ( compile_time_strobelight_meta, + justknobs_check, maybe_upload_prof_stats_to_manifold, signpost_event, ) @@ -68,8 +69,10 @@ from .exc import ( augment_exc_message, BackendCompilerFailed, + CacheLimitExceeded, format_error_msg, InternalTorchDynamoError, + SkipCodeRecursiveException, TorchRuntimeError, UncapturedHigherOrderOpError, unimplemented, @@ -850,7 +853,13 @@ def format_guard_failures() -> str: format_guard_failures(), troubleshooting_url, ) - unimplemented(f"{limit_type} reached") + if config.skip_code_recursive_on_cache_limit_hit and justknobs_check( + "pytorch/compiler:skip_code_recursive_on_cache_limit_hit" + ): + raise CacheLimitExceeded(f"{limit_type} reached") + else: + # do not recursively skip frames + unimplemented(f"{limit_type} reached") log.debug( "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s", @@ -1047,7 +1056,9 @@ def __call__( hooks: Hooks, frame_state: Dict[str, Union[int, FrameStateSizeEntry]], skip: int = 0, - ) -> Optional[GuardedCode]: + ) -> Optional[ + Union[GuardedCode, torch._C._dynamo.eval_frame.SkipCodeRecursiveFlag] + ]: counters["frames"]["total"] += 1 try: result = self._inner_convert( @@ -1112,6 +1123,12 @@ def __call__( log.info(error_msg, exc_info=True) else: log.warning(error_msg, exc_info=True) + + # If we encounter SkipCodeRecursiveException, return skip_code_recursive_flag + # to signal to Dynamo eval frame to skip the current frame and any recursive calls. + if isinstance(e, SkipCodeRecursiveException): + return torch._C._dynamo.eval_frame.skip_code_recursive_flag + return None diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index c281a9d68c4868..0d2108ada9e10e 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -168,6 +168,14 @@ def __init__(self, error_type: UserErrorType, msg, case_name=None) -> None: self.message = msg +class SkipCodeRecursiveException(TorchDynamoException): + pass + + +class CacheLimitExceeded(SkipCodeRecursiveException, Unsupported): + pass + + class UnsafeScriptObjectError(TorchDynamoException): pass diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 1803a0e57341cb..181acbdf4946a6 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -136,7 +136,7 @@ static struct PyGetSetDef THPPyInterpreterFrame_properties[] = { static PyTypeObject THPPyInterpreterFrameType = { PyVarObject_HEAD_INIT(NULL, 0) - .tp_name = "torch._C.dynamo.eval_frame._PyInterpreterFrame", + .tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame", .tp_basicsize = sizeof(THPPyInterpreterFrame), .tp_flags = Py_TPFLAGS_DEFAULT, .tp_getset = THPPyInterpreterFrame_properties, @@ -521,6 +521,8 @@ static PyObject* _custom_eval_frame_shim( return result; } +static PyObject* skip_code_recursive_flag; + // NOTE: In 3.12+, the frame evaluation function (callee) is responsible for clearing/popping // the frame, meaning that unless we default evaluate the original frame, // we are responsible for clearing it - via clear_old_frame_if_python_312_plus. @@ -580,6 +582,13 @@ static PyObject* _custom_eval_frame( DEBUG_TRACE("skip %s", get_frame_name(frame)); return eval_frame_default(tstate, frame, throw_flag); } + if (extra == SKIP_CODE_RECURSIVE) { + DEBUG_TRACE("skip recursive %s", get_frame_name(frame)); + eval_frame_callback_set(Py_None); + PyObject* result = eval_frame_default(tstate, frame, throw_flag); + eval_frame_callback_set(callback); + return result; + } if (extra == NULL) { extra = init_and_set_extra_state(F_CODE(frame)); @@ -668,6 +677,14 @@ static PyObject* _custom_eval_frame( // inside the torch.compile block we won't try to Dynamo anything else. *should_clear_frame = 1; return NULL; + } else if (result == skip_code_recursive_flag) { + // Dynamo returned skip_code_recursive_flag, so we should recursively skip code. + DEBUG_TRACE("create skip recursive %s", get_frame_name(frame)); + set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE); + PyObject* r = eval_frame_default(tstate, frame, throw_flag); + // Re-enable custom behavior + eval_frame_callback_set(callback); + return r; } else if (result != Py_None) { DEBUG_TRACE("create cache %s", get_frame_name(frame)); @@ -712,7 +729,7 @@ static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL}; static PyTypeObject THPPyInterpreterFrameType = { PyVarObject_HEAD_INIT(NULL, 0) - .tp_name = "torch._C.dynamo.eval_frame._PyInterpreterFrame", + .tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame", .tp_basicsize = sizeof(THPPyInterpreterFrame), .tp_flags = Py_TPFLAGS_DEFAULT, .tp_getset = THPPyInterpreterFrame_properties, @@ -883,5 +900,13 @@ PyObject* torch_c_dynamo_eval_frame_init(void) { } #endif + skip_code_recursive_flag = PyObject_New(PyObject, &PyBaseObject_Type); + if (skip_code_recursive_flag == NULL) { + return NULL; + } + if (PyModule_AddObject(module, "skip_code_recursive_flag", skip_code_recursive_flag) != 0) { + return NULL; + } + return module; } diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index abc6951ab554c6..d661f57d2cf32b 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -38,15 +38,20 @@ void ExtraState::invalidate(CacheEntry* cache_entry) { this->cache_entry_list.erase(cache_entry->_owner_loc); } +static bool is_extra_state_unset(ExtraState* extra_state) { + return extra_state == nullptr || extra_state == SKIP_CODE || + extra_state == SKIP_CODE_RECURSIVE; +} + CacheEntry* extract_cache_entry(ExtraState* extra_state) { - if (extra_state == nullptr || extra_state == SKIP_CODE) { + if (is_extra_state_unset(extra_state)) { return nullptr; } return extra_state->get_first_entry(); } FrameState* extract_frame_state(ExtraState* extra_state) { - if (extra_state == nullptr || extra_state == SKIP_CODE) { + if (is_extra_state_unset(extra_state)) { return nullptr; } return (FrameState*)extra_state->frame_state.ptr(); @@ -60,16 +65,14 @@ ExtraState* get_extra_state(PyCodeObject* code) { void destroy_extra_state(void* obj) { ExtraState* extra = (ExtraState*)obj; - if (extra != nullptr && extra != SKIP_CODE) { + if (!is_extra_state_unset(extra)) { delete extra; } } void set_extra_state(PyCodeObject* code, ExtraState* extra_state) { ExtraState* old_extra_state = get_extra_state(code); - CHECK( - old_extra_state == nullptr || old_extra_state == SKIP_CODE || - old_extra_state != extra_state); + CHECK(is_extra_state_unset(extra_state) || old_extra_state != extra_state); _PyCode_SetExtra((PyObject*)code, extra_index, extra_state); } @@ -80,6 +83,8 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code) { ExtraState* extra_state = new ExtraState(); NULL_CHECK(extra_state); set_extra_state(code, extra_state); + // freed by destroy_extra_state (since we need to pass these objects to C) + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return extra_state; } @@ -173,7 +178,7 @@ py::list _debug_get_cache_entry_list(const py::handle& code_obj) { PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; - if (extra && extra != SKIP_CODE) { + if (!is_extra_state_unset(extra)) { for (CacheEntry& e : extra->cache_entry_list) { result.append(py::cast(e, py::return_value_policy::reference)); } diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h index 22997d0e1a252c..1f6ccc7061a0c5 100644 --- a/torch/csrc/dynamo/extra_state.h +++ b/torch/csrc/dynamo/extra_state.h @@ -16,6 +16,8 @@ extern "C" { // Flag to just run a frame normally #define SKIP_CODE ((void*)0x1) +// Flag to run a frame and any recursive calls normally +#define SKIP_CODE_RECURSIVE ((void*)0x2) // Points to the extra scratch space on the code object extern Py_ssize_t extra_index; @@ -97,7 +99,7 @@ void destroy_extra_state(void* obj); // - there is no return, but the extra_state is stolen, so it becomes // set_extra_state responsibility to clean it up. It will be deleted during // the reset_code/skip, when the set_extra_state is called with -// NULL/SKIP_CODE. +// NULL/SKIP_CODE/SKIP_CODE_RECURSIVE. // Invariant - Dont set the extra state for the extra state that is already on // the code object. Otherwise, we will first free up the old extra state From 327531aff882c11db83438942d491da0d91f2743 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 6 Sep 2024 21:39:17 +0000 Subject: [PATCH 0575/1018] Revert "[ONNX] Refactor exporter errors (#135180)" This reverts commit 5eebd9315a72422d59b6f8d8ca8e4e573e231d5c. Reverted https://github.com/pytorch/pytorch/pull/135180 on behalf of https://github.com/clee2000 due to I think this broke test_public_bindings.py::TestPublicBindings::test_correct_module_names [GH job link](https://github.com/pytorch/pytorch/actions/runs/10743909338/job/29800779403) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/5eebd9315a72422d59b6f8d8ca8e4e573e231d5c), possibly a landrace with the PR that landed before it ([comment](https://github.com/pytorch/pytorch/pull/135180#issuecomment-2334844191)) --- docs/source/onnx.rst | 5 +- docs/source/onnx_dynamo.rst | 3 ++ test/onnx/dynamo/test_exporter_api.py | 41 +++++++++++++++ test/onnx/onnx_test_common.py | 37 +++++++++---- test/onnx/test_fx_to_onnx.py | 14 +++++ test/onnx/test_pytorch_onnx_no_runtime.py | 27 ++++++++++ test/onnx/test_pytorch_onnx_onnxruntime.py | 2 +- torch/onnx/__init__.py | 23 +++++---- torch/onnx/_internal/_exporter_legacy.py | 60 ++++++++++++++++------ torch/onnx/_internal/exporter/_building.py | 3 +- torch/onnx/_internal/exporter/_core.py | 10 ++-- torch/onnx/_internal/exporter/errors.py | 30 +++++++++++ torch/onnx/errors.py | 48 +++++------------ torch/onnx/utils.py | 12 +++++ 14 files changed, 233 insertions(+), 82 deletions(-) create mode 100644 torch/onnx/_internal/exporter/errors.py diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index a07664aaf69473..ffaa8ef836b787 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -63,12 +63,9 @@ also be interested in reading our `development wiki Date: Thu, 5 Sep 2024 18:31:59 -0700 Subject: [PATCH 0576/1018] [FR] Make trace_dir a required argument (#135157) Ensures users get a clean error if they forget to specify the dir, and improves the help message. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135157 Approved by: https://github.com/c-p-i-o, https://github.com/fduwjj --- tools/flight_recorder/components/config_manager.py | 4 ++-- tools/flight_recorder/fr_trace.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index 1f4511d2979192..47de3ef3ee052d 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -17,9 +17,9 @@ def __init__(self: "JobConfig"): self.parser = argparse.ArgumentParser( description="PyTorch Flight recorder analyzing script." ) - self.parser.add_argument( - "-d", "--dir", required=True, help="Directory with flight recorder dumps" + "trace_dir", + help="Directory containing one trace file per rank, named with _.", ) self.parser.add_argument( "--selected-ranks", diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py index 298852d9bb4b5d..c2b8b81a9fa27e 100644 --- a/tools/flight_recorder/fr_trace.py +++ b/tools/flight_recorder/fr_trace.py @@ -40,7 +40,7 @@ def main(args: Optional[Sequence[str]] = None) -> None: config = JobConfig() args = config.parse_args(args) - details, version = read_dir(args.prefix, args.dir) + details, version = read_dir(args.prefix, args.trace_dir) db = build_db(details, args, version) if args.output: with open(args.output, "wb") as f: From 8fb9bf19182079a58f62f1c7d425d7877463df69 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 5 Sep 2024 18:32:00 -0700 Subject: [PATCH 0577/1018] [FR] Automatically infer a common filename prefix (#135158) Save the annoyance of specifying this on the command line each time Pull Request resolved: https://github.com/pytorch/pytorch/pull/135158 Approved by: https://github.com/fduwjj, https://github.com/c-p-i-o ghstack dependencies: #135157 --- .../components/config_manager.py | 7 +++- tools/flight_recorder/components/loader.py | 42 ++++++++++++++++--- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index 47de3ef3ee052d..618aa40b55be58 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -39,8 +39,11 @@ def __init__(self: "JobConfig"): self.parser.add_argument( "-p", "--prefix", - help="prefix to strip such that rank can be extracted", - default="rank_", + help=( + "Common filename prefix to strip such that rank can be extracted. " + "If not specified, will attempt to infer a common prefix." + ), + default=None, ) self.parser.add_argument("-j", "--just_print_entries", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index 19987ddc74d004..451f14df37f205 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -7,15 +7,16 @@ import gc import os import pickle +import re import time -from typing import Any, Dict, List, Tuple, Union +import typing +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple, Union def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]: basename = os.path.basename(filename) - assert ( - basename.find(prefix) == 0 - ), f"args.prefix ({prefix}) must match the beginning of each filename ({basename})" + rank = int(basename[len(prefix) :]) host_name = f"host_rank{rank}" @@ -35,13 +36,44 @@ def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any] } -def read_dir(prefix: str, folder: str) -> Tuple[Dict[str, Dict[str, Any]], str]: +exp = re.compile(r"([\w\-\_]*?)(\d+)$") + + +def _determine_prefix(files: List[str]) -> str: + """If the user doesn't specify a prefix, but does pass a dir full of similarly-prefixed files, we should be able to + infer the common prefix most of the time. But if we can't confidently infer, just fall back to requring the user + to specify it + """ + possible_prefixes: typing.DefaultDict[str, Set[int]] = defaultdict(set) + for f in files: + m = exp.search(f) + if m: + p, r = m.groups() + possible_prefixes[p].add(int(r)) + if len(possible_prefixes) == 1: + prefix = next(iter(possible_prefixes)) + print(f"Inferred common prefix {prefix}") + return prefix + else: + raise ValueError( + "Unable to automatically determine the common prefix for the trace file names. " + "Please specify --prefix argument manually" + ) + + +def read_dir( + prefix: Optional[str], folder: str +) -> Tuple[Dict[str, Dict[str, Any]], str]: gc.disable() details = {} t0 = time.time() version = "" for root, _, files in os.walk(folder): + if prefix is None: + prefix = _determine_prefix(files) for f in files: + if f.find(prefix) != 0: + continue details[f] = read_dump(prefix, os.path.join(root, f)) if not version: version = str(details[f]["version"]) From 5e5e53838ff3ec1022201d6fb9c5662415d85407 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 6 Sep 2024 11:55:40 -0700 Subject: [PATCH 0578/1018] [BE] Clarify defaulting behavior in optimizer (#135384) Fixes #135340 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135384 Approved by: https://github.com/drisspg, https://github.com/jainapurva --- torch/optim/optimizer.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index a417a44817ebfd..8f7993842c1009 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -245,17 +245,13 @@ def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: are supported. (default: None) .. note:: The foreach and fused implementations are typically faster than the for-loop, - single-tensor implementation. Thus, if the user has not specified BOTH flags - (i.e., when foreach = fused = None), we will attempt defaulting to the foreach - implementation when the tensors are all on CUDA. For example, if the user specifies - True for fused but nothing for foreach, we will run the fused implementation. If - the user specifies False for foreach but nothing for fused (or False for fused but - nothing for foreach), we will run the for-loop implementation. If the user specifies - True for both foreach and fused, we will prioritize fused over foreach, as it is - typically faster. We attempt to use the fastest, so the hierarchy goes fused -> - foreach -> for-loop. HOWEVER, since the fused implementation is relatively new, - we want to give it sufficient bake-in time, so we default to foreach and NOT - fused when the user has not specified either flag.""" + single-tensor implementation, with fused being theoretically fastest with both + vertical and horizontal fusion. As such, if the user has not specified either + flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach + implementation when the tensors are all on CUDA. Why not fused? Since the fused + implementation is relatively new, we want to give it sufficient bake-in time. + To specify fused, pass True for fused. To force running the for-loop + implementation, pass False for either foreach or fused. """ _capturable_doc = r"""capturable (bool, optional): whether this instance is safe to capture in a CUDA graph. Passing True can impair ungraphed performance, From 5c745db96a9d23060074e1820ebe9eba61a6a14f Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 6 Sep 2024 21:53:39 +0000 Subject: [PATCH 0579/1018] Require tlparse for failing tests in test_structured_trace.py (#135376) Summary: These tests are currently failing internally. Per discussion, skip if tlparse is unavailable Test Plan: ``` feature remove tlparse buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --run-disabled --regex test_structured_trace.py feature install tlparse buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --run-disabled --regex test_structured_trace.py ``` Differential Revision: D62310342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135376 Approved by: https://github.com/ezyang --- test/dynamo/test_structured_trace.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 91089220a82a52..2ec2310acf3274 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -23,6 +23,8 @@ from torch.testing._internal.inductor_utils import HAS_CUDA +HAS_TLPARSE = shutil.which("tlparse") is not None +requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" @@ -213,6 +215,7 @@ def test_cudagraphs(self): self.assertParses() + @requires_tlparse def test_recompiles(self): def fn(x, y): return torch.add(x, y) @@ -257,6 +260,7 @@ def fn(x, y): self.assertParses() + @requires_tlparse def test_example_fn(self): fn_opt = torch._dynamo.optimize("inductor")(example_fn) fn_opt(torch.ones(1000, 1000)) @@ -280,6 +284,7 @@ def test_example_fn(self): self.assertParses() + @requires_tlparse def test_dynamo_error(self): try: fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn) @@ -299,6 +304,7 @@ def test_dynamo_error(self): self.assertParses() + @requires_tlparse def test_inductor_error(self): import torch._inductor.lowering @@ -463,6 +469,7 @@ def forward(self, x): self.assertParses() + @requires_tlparse def test_graph_breaks(self): @torch._dynamo.optimize("inductor") def fn(x): @@ -496,6 +503,7 @@ def fn(x): # TODO: bring in the trace_source tests once we start emitting bytecode + @requires_tlparse def test_graph_sizes_dynamic(self): def fn(a, b): return a @ b @@ -535,6 +543,7 @@ def fn(a, b): self.assertParses() + @requires_tlparse def test_guards_recompiles(self): def fn(x, ys, zs): return inner(x, ys, zs) @@ -606,6 +615,7 @@ def forward(self, x, y): """, # noqa: B950 ) + @requires_tlparse @torch._inductor.config.patch("fx_graph_cache", True) def test_codecache(self): def fn(a): @@ -648,6 +658,7 @@ def fn(a): ) self.assertParses() + @requires_tlparse @torch._inductor.config.patch("fx_graph_cache", True) @show_chrome_events def test_chromium_event(self): From 35b3a705f4cb0e982c38ea4b71cb9e13038fbed9 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Fri, 6 Sep 2024 09:51:25 -0700 Subject: [PATCH 0580/1018] [export] fix placeholder name collision tests by removing map call (#135366) The current test is failing because of the current unstable state of map. torch.compile and non-strict export are taking two seperate routes unlike cond and while_loop. This pr fix the test it self. We'll fix map in follow up PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135366 Approved by: https://github.com/angelayi --- test/export/test_export.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 3081baaca35221..0e26947f5c3026 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6313,7 +6313,6 @@ def forward(self, mul, add, add_1): real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes] self.assertEqual(expected_names_and_ops, real_names_and_ops) - @testing.expectedFailureRetraceability def test_placeholder_naming_collisions_hoo_subgraphs(self): # test collisions between user inputs, top-level nodes, and HOO subgraph nodes class Foo(torch.nn.Module): @@ -6369,11 +6368,7 @@ def forward(self, x, mul, mul_1): # (please never do this) class Foo(torch.nn.Module): def forward(self, input, true_graph, body_graph): - def map_body(x, y): - return x + y - - x = map(map_body, input, body_graph[0]) - x = x + true_graph[0] + true_graph[1] + x = input + true_graph[0] + true_graph[1] x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x]) x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x]) return x @@ -6385,7 +6380,6 @@ def map_body(x, y): ) ep = export(Foo(), inputs) expected_getattr_names = [ - "body_graph_1", "true_graph_2", "false_graph_0", "true_graph_3", From 9445b1ffa3bd1c57ca5888e67545c20e1fe110b1 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Fri, 6 Sep 2024 23:05:44 +0000 Subject: [PATCH 0581/1018] [aoti test] Disable FP8 funz dtypes in fp8 runtime check test (#135373) Fixing https://github.com/pytorch/pytorch/issues/126734 Key is the funz FP8 types are for AMD only. source: https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md Pull Request resolved: https://github.com/pytorch/pytorch/pull/135373 Approved by: https://github.com/chenyang78 --- test/inductor/test_aot_inductor.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 71e0a96dae17f0..60cffb05d63e9f 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -2958,29 +2958,24 @@ class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, x0, x1, x2, x3): - t = ( - x0.to(torch.float) - + x1.to(torch.float) - + x2.to(torch.float) - + x3.to(torch.float) - ) + def forward(self, x0, x1): + t = x0.to(torch.float) + x1.to(torch.float) return t inputs = [] for dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, - torch.float8_e4m3fnuz, - torch.float8_e5m2fnuz, + # FP8 funz are for AMD + # see https://github.com/pytorch/pytorch/issues/126734 + # torch.float8_e4m3fnuz, + # torch.float8_e5m2fnuz, ): inputs.append(torch.ones(8, 8, 8, dtype=dtype, device=self.device)) dim0 = Dim("s0", min=2, max=1024) dynamic_shapes = { "x0": {0: dim0}, "x1": {0: dim0}, - "x2": {0: dim0}, - "x3": {0: dim0}, } with torch.no_grad(), config.patch( { From 826df27b07a94b91f149b8b345b494d1f223438e Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Fri, 6 Sep 2024 00:57:27 -0700 Subject: [PATCH 0582/1018] [inductor] Fix loop split optimization (#135303) Fix https://github.com/pytorch/pytorch/issues/135274. Improve the check whether the div expr matches: add a check whether `split_var` is in `original_body.iter_vars`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135303 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel --- torch/_inductor/codegen/cpp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index af892c20a0a44b..d8a3146429ba76 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4207,6 +4207,7 @@ def try_loop_split(self, nodes: List[SchedulerNode]): if ( isinstance(split_number, sympy.core.numbers.Integer) and isinstance(split_var, sympy.core.symbol.Symbol) + and split_var in original_body.iter_vars and divide_index_name is not None and all( stride_at_vec_range(expr, split_var) == 1 From a1dfb8541548fb084944a3d6e971ce1e06458138 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 5 Sep 2024 19:57:14 -0700 Subject: [PATCH 0583/1018] [Inductor][CPP] Fix the issue of view dtype (#135301) **Summary** Fix issue: https://github.com/pytorch/pytorch/issues/135160, it's a regression introduced by https://github.com/pytorch/pytorch/pull/134569, where the dtype of `to_dtype_bitcast` was incorrectly handled when using the scalarize implementation. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_view_dtype ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135301 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_cpu_repro.py | 10 ++++++++++ torch/_inductor/codegen/common.py | 2 ++ torch/_inductor/codegen/cpp.py | 9 +++++++-- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index e463f3e60b8545..669d8ad598b1bd 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1907,6 +1907,16 @@ def test_bitwise_right_shift(self): res = cfn(x, bit_num) self.assertEqual(res_aten_eager, res) + def test_view_dtype(self): + def f(x): + return x.view(torch.int32) >> 2 + + input = torch.ones(16, 16) + res_aten_eager = f(input) + cfn = torch.compile(f) + res = cfn(input) + self.assertEqual(res_aten_eager, res) + @patch("torch.cuda.is_available", lambda: False) def test_scatter_using_atomic_add(self): def fn(a, dim, index, b): diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e27b847d0f4ca3..a703f7d66c8ee0 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -344,6 +344,8 @@ def deduce_output_dtype_by_name( ): buf_name = args[1] return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + elif op_name == "to_dtype_bitcast": + return kwargs["dtype"] if "dtype" in kwargs else args[-2] return None diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index d8a3146429ba76..e1addfbbebb8c3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1640,6 +1640,11 @@ def inner(*args, **kwargs): "signbit", ) octype = "bool" if output_mask else cdtype + octype = ( + DTYPE_TO_CPP[args[-2]] + if (scalar_func.__name__ == "to_dtype_bitcast") + else octype + ) with code.indent(): for argidx, arg in enumerate(args): if isinstance(arg, CppCSEVariable): @@ -1668,9 +1673,9 @@ def inner(*args, **kwargs): else: load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" if n_vec == 1: - load_fn = f"at::vec::Vectorized<{cdtype}>::loadu" + load_fn = f"at::vec::Vectorized<{octype}>::loadu" else: - load_fn = f" at::vec::VectorizedN<{cdtype}, {n_vec}>::loadu" + load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" code.writeline(f"return {load_fn}({load_args});") code.writeline("()") return code From 09c2543192dfb98355aeedad679109d1a5e8eb13 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Fri, 6 Sep 2024 23:57:56 +0000 Subject: [PATCH 0584/1018] [Split Build] Refactor split build binary builds into their own workflows and move split build binary builds to periodic (#134624) As we need to move split build binary tests from trunk to periodic this pr, refactors those jobs out into its own workflow to achieve this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134624 Approved by: https://github.com/malfet --- .../scripts/generate_binary_build_matrix.py | 47 +- .github/scripts/generate_ci_workflows.py | 35 + ...linux-aarch64-binary-manywheel-nightly.yml | 20 + ...ated-linux-binary-manywheel-main-split.yml | 182 ++ .../generated-linux-binary-manywheel-main.yml | 147 +- ...d-linux-binary-manywheel-nightly-split.yml | 1516 +++++++++++++++++ ...nerated-linux-binary-manywheel-nightly.yml | 1329 ++------------- ...d-linux-s390x-binary-manywheel-nightly.yml | 15 + 8 files changed, 1982 insertions(+), 1309 deletions(-) create mode 100644 .github/workflows/generated-linux-binary-manywheel-main-split.yml create mode 100644 .github/workflows/generated-linux-binary-manywheel-nightly-split.yml diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 6b33924a0e5546..365f750e1ad7bb 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -325,6 +325,7 @@ def generate_wheels_matrix( os: str, arches: Optional[List[str]] = None, python_versions: Optional[List[str]] = None, + use_split_build: bool = False, ) -> List[Dict[str, str]]: package_type = "wheel" if os == "linux" or os == "linux-aarch64" or os == "linux-s390x": @@ -371,7 +372,17 @@ def generate_wheels_matrix( ) and python_version == "3.13": continue + if use_split_build and ( + arch_version not in ["12.4", "12.1", "11.8", "cpu"] or os != "linux" + ): + raise RuntimeError( + "Split build is only supported on linux with cuda 12.4, 12.1, 11.8, and cpu.\n" + f"Currently attempting to build on arch version {arch_version} and os {os}.\n" + "Please modify the matrix generation to exclude this combination." + ) + # 12.1 linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install + if ( arch_version in ["12.4", "12.1", "11.8"] and os == "linux" @@ -385,6 +396,7 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), + "use_split_build": "True" if use_split_build else "False", "devtoolset": ( "cxx11-abi" if arch_version == "cuda-aarch64" else "" ), @@ -400,7 +412,8 @@ def generate_wheels_matrix( ), } ) - if arch_version != "cuda-aarch64": + # Special build building to use on Colab. PyThon 3.10 for 12.1 CUDA + if python_version == "3.10" and arch_version == "12.1": ret.append( { "python_version": python_version, @@ -409,40 +422,16 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), - "use_split_build": "True", + "use_split_build": "True" if use_split_build else "False", "devtoolset": "", "container_image": WHEEL_CONTAINER_IMAGES[arch_version], "package_type": package_type, - "pytorch_extra_install_requirements": ( - PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version] # fmt: skip - if os != "linux-aarch64" - else "" - ), - "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-split".replace( # noqa: B950 + "pytorch_extra_install_requirements": "", + "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-full".replace( # noqa: B950 ".", "_" ), } ) - # Special build building to use on Colab. PyThon 3.10 for 12.1 CUDA - if python_version == "3.10" and arch_version == "12.1": - ret.append( - { - "python_version": python_version, - "gpu_arch_type": gpu_arch_type, - "gpu_arch_version": gpu_arch_version, - "desired_cuda": translate_desired_cuda( - gpu_arch_type, gpu_arch_version - ), - "use_split_build": "False", - "devtoolset": "", - "container_image": WHEEL_CONTAINER_IMAGES[arch_version], - "package_type": package_type, - "pytorch_extra_install_requirements": "", - "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-full".replace( # noqa: B950 - ".", "_" - ), - } - ) else: ret.append( { @@ -452,6 +441,7 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), + "use_split_build": "True" if use_split_build else "False", "devtoolset": ( "cxx11-abi" if arch_version == "cpu-cxx11-abi" else "" ), @@ -467,6 +457,7 @@ def generate_wheels_matrix( ), } ) + return ret diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 8a66ebe5f59b5f..9f3536e2ede50f 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -61,6 +61,7 @@ class BinaryBuildWorkflow: # Mainly for macos cross_compile_arm64: bool = False macos_runner: str = "macos-14-xlarge" + use_split_build: bool = False def __post_init__(self) -> None: if self.abi_version: @@ -75,6 +76,11 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: GITHUB_DIR / f"workflows/generated-{self.build_environment}-{self.branches}.yml" ) + if self.use_split_build: + output_file_path = ( + GITHUB_DIR + / f"workflows/generated-{self.build_environment}-{self.branches}-split.yml" + ) with open(output_file_path, "w") as output_file: GENERATED = "generated" # Note that please keep the variable GENERATED otherwise phabricator will hide the whole file output_file.writelines([f"# @{GENERATED} DO NOT EDIT MANUALLY\n"]) @@ -110,6 +116,20 @@ class OperatingSystem: isolated_workflow=True, ), ), + BinaryBuildWorkflow( + os=OperatingSystem.LINUX, + package_type="manywheel", + build_configs=generate_binary_build_matrix.generate_wheels_matrix( + OperatingSystem.LINUX, + use_split_build=True, + arches=["11.8", "12.1", "12.4", "cpu"], + ), + ciflow_config=CIFlowConfig( + labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, + isolated_workflow=True, + ), + use_split_build=True, + ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="conda", @@ -162,6 +182,21 @@ class OperatingSystem: ), branches="main", ), + BinaryBuildWorkflow( + os=OperatingSystem.LINUX, + package_type="manywheel", + build_configs=generate_binary_build_matrix.generate_wheels_matrix( + OperatingSystem.LINUX, + arches=["11.8", "12.1", "12.4"], + python_versions=["3.9"], + use_split_build=True, + ), + ciflow_config=CIFlowConfig( + labels={LABEL_CIFLOW_PERIODIC}, + ), + branches="main", + use_split_build=True, + ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 3ce438c9bc5f09..aeb85c90feb78f 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -58,6 +58,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.9" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" @@ -81,6 +82,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -103,6 +105,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 secrets: @@ -125,6 +128,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" @@ -149,6 +153,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda-aarch64 secrets: @@ -170,6 +175,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.10" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" @@ -193,6 +199,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -215,6 +222,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 secrets: @@ -237,6 +245,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" @@ -261,6 +270,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda-aarch64 secrets: @@ -282,6 +292,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.11" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" @@ -305,6 +316,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -327,6 +339,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 secrets: @@ -349,6 +362,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" @@ -373,6 +387,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda-aarch64 secrets: @@ -394,6 +409,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.12" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" @@ -417,6 +433,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -439,6 +456,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 secrets: @@ -461,6 +479,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" @@ -485,6 +504,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda-aarch64 secrets: diff --git a/.github/workflows/generated-linux-binary-manywheel-main-split.yml b/.github/workflows/generated-linux-binary-manywheel-main-split.yml new file mode 100644 index 00000000000000..61d6df214c5a2c --- /dev/null +++ b/.github/workflows/generated-linux-binary-manywheel-main-split.yml @@ -0,0 +1,182 @@ +# @generated DO NOT EDIT MANUALLY + +# Template is at: .github/templates/linux_binary_build_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: linux-binary-manywheel + + +on: + push: + branches: + - main + tags: + - 'ciflow/periodic/*' + workflow_dispatch: + +env: + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + ANACONDA_USER: pytorch + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + BUILD_ENVIRONMENT: linux-binary-manywheel + BUILDER_ROOT: /builder + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + PYTORCH_ROOT: /pytorch + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + SKIP_ALL_TESTS: 0 +concurrency: + group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + manywheel-py3_9-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda11_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda11_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_9-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_1 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_1 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_9-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_4 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_4 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 298c0b5e105ccb..d87b832bf03cb0 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -54,6 +54,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8 @@ -77,6 +78,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel @@ -85,53 +87,6 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -146,6 +101,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1 @@ -169,6 +125,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel @@ -177,53 +134,6 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -238,6 +148,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 @@ -261,6 +172,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel @@ -268,50 +180,3 @@ jobs: runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly-split.yml b/.github/workflows/generated-linux-binary-manywheel-nightly-split.yml new file mode 100644 index 00000000000000..8e57081422f1a9 --- /dev/null +++ b/.github/workflows/generated-linux-binary-manywheel-nightly-split.yml @@ -0,0 +1,1516 @@ +# @generated DO NOT EDIT MANUALLY + +# Template is at: .github/templates/linux_binary_build_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: linux-binary-manywheel + + +on: + push: + # NOTE: Meta Employees can trigger new nightlies using: https://fburl.com/trigger_pytorch_nightly_build + branches: + - nightly + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + - 'ciflow/binaries/*' + - 'ciflow/binaries_wheel/*' + workflow_dispatch: + +env: + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + ANACONDA_USER: pytorch + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + BUILD_ENVIRONMENT: linux-binary-manywheel + BUILDER_ROOT: /builder + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + PYTORCH_ROOT: /pytorch + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + SKIP_ALL_TESTS: 0 +concurrency: + group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + manywheel-py3_9-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda11_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda11_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_1 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_1 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_4 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_4 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cpu + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda11_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda11_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda12_1 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_1 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cuda12_1-full-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda12_1-full + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_1-full-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda12_1-full-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_1-full + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_1-full-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cuda12_1-full-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_1-full + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda12_4 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_4 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cpu + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cpu + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda11_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda11_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_1 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_4 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_4 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cpu + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cpu + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_12-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda11_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda11_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_12-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_12-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda12_1 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_1 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_12-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_12-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda12_4 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_4 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_12-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_12-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cpu + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cpu + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_12-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_13-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cuda11_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda11_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_13-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_13-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cuda12_1 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_1 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_13-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_13-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cuda12_4 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_4 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_13-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_13-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cpu + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cpu + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_13-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 4fff1e0bacd2bf..f817e5a009122d 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -58,6 +58,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu @@ -79,6 +80,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel @@ -101,6 +103,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu secrets: @@ -123,6 +126,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu-cxx11-abi @@ -145,6 +149,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-cxx11-abi build_environment: linux-binary-manywheel @@ -168,6 +173,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-cxx11-abi secrets: @@ -190,6 +196,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8 @@ -213,6 +220,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel @@ -236,6 +244,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 secrets: @@ -244,77 +253,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda11_8-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -329,6 +267,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1 @@ -352,6 +291,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel @@ -375,6 +315,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 secrets: @@ -383,77 +324,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -468,6 +338,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 @@ -491,6 +362,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel @@ -514,6 +386,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 secrets: @@ -522,77 +395,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda12_4-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-rocm6_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -607,6 +409,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_1 @@ -631,6 +434,7 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm @@ -692,6 +496,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_1 secrets: @@ -714,6 +519,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_2 @@ -738,6 +544,7 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm @@ -799,6 +606,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_2 secrets: @@ -820,6 +628,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-xpu @@ -843,6 +652,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.9" permissions: id-token: write @@ -912,6 +722,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-xpu secrets: @@ -933,6 +744,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu @@ -954,6 +766,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel @@ -976,6 +789,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu secrets: @@ -998,6 +812,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu-cxx11-abi @@ -1020,6 +835,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-cxx11-abi build_environment: linux-binary-manywheel @@ -1043,6 +859,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-cxx11-abi secrets: @@ -1065,6 +882,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda11_8 @@ -1088,6 +906,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel @@ -1111,6 +930,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 secrets: @@ -1119,7 +939,7 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda11_8-split-build: + manywheel-py3_10-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1129,22 +949,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda11_8-split + build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda11_8-split-test: # Testing + manywheel-py3_10-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda11_8-split-build + - manywheel-py3_10-cuda12_1-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1153,44 +973,44 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda11_8-split + build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda11_8-split-upload: # Uploading + manywheel-py3_10-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda11_8-split-test + needs: manywheel-py3_10-cuda12_1-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda11_8-split + build_name: manywheel-py3_10-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_1-build: + manywheel-py3_10-cuda12_1-full-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1204,17 +1024,17 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_1 + build_name: manywheel-py3_10-cuda12_1-full build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-test: # Testing + manywheel-py3_10-cuda12_1-full-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda12_1-build + - manywheel-py3_10-cuda12_1-full-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1227,146 +1047,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda12_1-full-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: False - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-full-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_1-full-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: False + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1-full build_environment: linux-binary-manywheel @@ -1413,6 +1094,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_4 @@ -1436,6 +1118,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel @@ -1459,6 +1142,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 secrets: @@ -1467,77 +1151,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_4-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_4-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_4-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-rocm6_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -1552,6 +1165,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_1 @@ -1576,6 +1190,7 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -1637,6 +1252,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_1 secrets: @@ -1659,6 +1275,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_2 @@ -1683,6 +1300,7 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -1744,6 +1362,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_2 secrets: @@ -1765,6 +1384,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-xpu @@ -1788,6 +1408,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.10" permissions: id-token: write @@ -1857,6 +1478,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-xpu secrets: @@ -1878,6 +1500,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu @@ -1899,6 +1522,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel @@ -1921,6 +1545,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu secrets: @@ -1943,6 +1568,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu-cxx11-abi @@ -1965,6 +1591,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-cxx11-abi build_environment: linux-binary-manywheel @@ -1988,6 +1615,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-cxx11-abi secrets: @@ -2010,6 +1638,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda11_8 @@ -2033,6 +1662,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel @@ -2056,6 +1686,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 secrets: @@ -2064,7 +1695,7 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda11_8-split-build: + manywheel-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2074,22 +1705,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda11_8-split + build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda11_8-split-test: # Testing + manywheel-py3_11-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda11_8-split-build + - manywheel-py3_11-cuda12_1-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2098,44 +1729,44 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8-split + build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda11_8-split-upload: # Uploading + manywheel-py3_11-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda11_8-split-test + needs: manywheel-py3_11-cuda12_1-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8-split + build_name: manywheel-py3_11-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_1-build: + manywheel-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2145,157 +1776,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_1 + build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_4 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-test: # Testing + manywheel-py3_11-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - manywheel-py3_11-cuda12_4-build @@ -2311,6 +1804,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel @@ -2334,6 +1828,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 secrets: @@ -2342,77 +1837,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda12_4-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_4-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-rocm6_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -2427,6 +1851,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_1 @@ -2451,6 +1876,7 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -2512,6 +1938,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_1 secrets: @@ -2534,6 +1961,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_2 @@ -2558,6 +1986,7 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -2619,6 +2048,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_2 secrets: @@ -2640,6 +2070,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-xpu @@ -2663,6 +2094,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.11" permissions: id-token: write @@ -2732,6 +2164,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-xpu secrets: @@ -2753,6 +2186,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu @@ -2774,6 +2208,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel @@ -2796,6 +2231,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu secrets: @@ -2818,6 +2254,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu-cxx11-abi @@ -2840,6 +2277,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-cxx11-abi build_environment: linux-binary-manywheel @@ -2863,6 +2301,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-cxx11-abi secrets: @@ -2885,6 +2324,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda11_8 @@ -2908,6 +2348,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel @@ -2931,6 +2372,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 secrets: @@ -2939,217 +2381,7 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda11_8-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda11_8-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_1 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda12_4-build: + manywheel-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3159,21 +2391,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_4 + build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-test: # Testing + manywheel-py3_12-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda12_4-build + - manywheel-py3_12-cuda12_1-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3182,42 +2415,44 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4 + build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-upload: # Uploading + manywheel-py3_12-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda12_4-test + needs: manywheel-py3_12-cuda12_1-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4 + build_name: manywheel-py3_12-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda12_4-split-build: + manywheel-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3231,18 +2466,18 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_4-split + build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-split-test: # Testing + manywheel-py3_12-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda12_4-split-build + - manywheel-py3_12-cuda12_4-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3255,20 +2490,20 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4-split + build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-split-upload: # Uploading + manywheel-py3_12-cuda12_4-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda12_4-split-test + needs: manywheel-py3_12-cuda12_4-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder @@ -3279,9 +2514,9 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4-split + build_name: manywheel-py3_12-cuda12_4 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} @@ -3302,6 +2537,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_1 @@ -3326,6 +2562,7 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm @@ -3387,6 +2624,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_1 secrets: @@ -3409,6 +2647,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_2 @@ -3433,6 +2672,7 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm @@ -3494,6 +2734,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_2 secrets: @@ -3515,6 +2756,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-xpu @@ -3538,6 +2780,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.12" permissions: id-token: write @@ -3607,6 +2850,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-xpu secrets: @@ -3628,6 +2872,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu @@ -3649,6 +2894,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu build_environment: linux-binary-manywheel @@ -3671,6 +2917,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu secrets: @@ -3693,6 +2940,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu-cxx11-abi @@ -3715,6 +2963,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-cxx11-abi build_environment: linux-binary-manywheel @@ -3738,6 +2987,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-cxx11-abi secrets: @@ -3760,6 +3010,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda11_8 @@ -3783,6 +3034,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda11_8 build_environment: linux-binary-manywheel @@ -3806,6 +3058,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda11_8 secrets: @@ -3814,77 +3067,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda11_8-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda11_8-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda11_8-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -3899,6 +3081,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_1 @@ -3922,6 +3105,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_1 build_environment: linux-binary-manywheel @@ -3945,6 +3129,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_1 secrets: @@ -3953,77 +3138,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_1-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -4038,6 +3152,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_4 @@ -4061,6 +3176,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_4 build_environment: linux-binary-manywheel @@ -4084,6 +3200,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_4 secrets: @@ -4092,77 +3209,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_4-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda12_4-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_4-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-xpu-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -4176,6 +3222,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-xpu @@ -4199,6 +3246,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.13" permissions: id-token: write @@ -4268,6 +3316,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-xpu secrets: diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 574a2f87c014cc..22468055434e86 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -58,6 +58,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.9" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -81,6 +82,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -103,6 +105,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x secrets: @@ -124,6 +127,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.10" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -147,6 +151,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -169,6 +174,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x secrets: @@ -190,6 +196,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.11" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -213,6 +220,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -235,6 +243,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x secrets: @@ -256,6 +265,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.12" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -279,6 +289,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -301,6 +312,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x secrets: @@ -322,6 +334,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.13" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -345,6 +358,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -367,6 +381,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x secrets: From dc8aa60ca9e23922be14a3ef3ac97378e9aa09fa Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 7 Sep 2024 00:02:35 +0000 Subject: [PATCH 0585/1018] Change test_constant_prop_preserve_metadata (#135268) Summary: In new export_for_training, "stack_trace" does not exist in node meta anymore. Test Plan: ``` buck run fbcode//mode/dev-nosan fbcode//caffe2/test:quantization_pt2e -- -r test_constant_prop_preserve_metadata ``` Reviewed By: angelayi Differential Revision: D62219974 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135268 Approved by: https://github.com/angelayi --- test/quantization/pt2e/test_quantize_pt2e.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index eef06f19cc8ed2..62fbbf7162747e 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -1453,7 +1453,6 @@ def forward(self, x): for n in m.graph.nodes: if n.op == "get_attr" and "frozen_param" in n.target: - self.assertIn("stack_trace", n.meta) for key in n.meta: self.assertEqual(n.meta[key], weight_meta[key]) From 98a6753231575b835a6fdef8f10af24c7af3c69c Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 6 Sep 2024 13:49:52 -0700 Subject: [PATCH 0586/1018] [Dynamo][DTensor] Fixes SymNodeVariable() is not a constant error in Compiled DDP + TP unit test (#135315) Before the fix, the unit test will fail at forward Dynamo tracing: ``` File "/data/users/willfeng/pytorch/test/distributed/_composable/test_replicate_with_compiler.py", line 415, in test_ddp_tp loss = compiled_replicate_model(data).sum() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ... torch._dynamo.exc.InternalTorchDynamoError: SymNodeVariable() is not a constant from user code: File "/data/users/willfeng/pytorch/torch/distributed/tensor/parallel/_data_parallel_utils.py", line 34, in _unflatten_tensor result = DTensor.from_local( ``` After the fix, the compilation fails at a later step (Compiled Autograd tracing), due to needing "pre-dispatch tracing of backward graph" feature (see details at https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474). I believe this PR is a net improvement, because it should also fix the 1D Traceable FSDP2 failure case on internal models (https://github.com/pytorch/pytorch/issues/130978#issuecomment-2319476690), which is much harder to build a minimal unit test for. Fixes https://github.com/pytorch/pytorch/issues/130978. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135315 Approved by: https://github.com/bdhirsh --- .../test_replicate_with_compiler.py | 62 +++++++++--------- .../_tensor/test_dtensor_compile.py | 64 +++++++++++++++++++ torch/_dynamo/variables/torch.py | 22 +++++-- 3 files changed, 113 insertions(+), 35 deletions(-) diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 41ce99090018b0..14d9cc47271f59 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -32,6 +32,7 @@ skip_if_rocm, ) from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed.fake_pg import FakeStore from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint @@ -367,35 +368,28 @@ def test_bucketing_concat_op(self): fc.run(code) -class DDP_TP_Test(MultiProcessInductorTestCase): - @property - def world_size(self) -> int: - return min(4, torch.cuda.device_count()) +class DDP_TP_Test(InductorTestCase): + def setUp(self): + self.rank = 0 + self.world_size = 4 + torch.cuda.set_device("cuda:0") - def setUp(self) -> None: - super().setUp() - self._spawn_processes() + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) def tearDown(self): - super().tearDown() - try: - os.remove(self.file_name) - except OSError: - pass + dist.destroy_process_group() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm - @skip_if_lt_x_gpu(4) def test_ddp_tp(self): - torch.cuda.set_device(f"cuda:{self.rank}") - dist.init_process_group( - backend="nccl", - rank=self.rank, - world_size=self.world_size, - store=dist.FileStore(self.file_name, self.world_size), - ) - model = Net().cuda() - compiled_replicate_model = deepcopy(model) + ref_model = Net() + compiled_replicate_model = deepcopy(ref_model) mesh_2d = init_device_mesh( "cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp") ) @@ -407,8 +401,8 @@ def test_ddp_tp(self): "fc3": ColwiseParallel(), "fc4": RowwiseParallel(), } - model = parallelize_module(model, tp_mesh, parallelize_plan) - model = replicate(model, device_mesh=dp_mesh) + ref_model = parallelize_module(ref_model, tp_mesh, parallelize_plan) + ref_model = replicate(ref_model, device_mesh=dp_mesh) compiled_replicate_model = parallelize_module( compiled_replicate_model, tp_mesh, parallelize_plan ) @@ -416,15 +410,23 @@ def test_ddp_tp(self): compiled_replicate_model, device_mesh=dp_mesh ) compiled_replicate_model = torch.compile(compiled_replicate_model) - data = torch.randn([1, DIM]).cuda() + data = torch.randn([1, DIM]) with compiled_autograd.enable(compiler_fn()): loss = compiled_replicate_model(data).sum() - loss.backward() + # TODO: We need "pre-dispatch tracing of backward graph" to make this work: + # https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474 + with self.assertRaisesRegex( + AssertionError, + "Expected ProxyTensor, got ", + ): + loss.backward() - loss = model(data).sum() - loss.backward() - for p1, p2 in zip(model.parameters(), compiled_replicate_model.parameters()): - self.assertEqual(p1.grad, p2.grad) + # ref_loss = ref_model(data).sum() + # ref_loss.backward() + # for p1, p2 in zip( + # ref_model.parameters(), compiled_replicate_model.parameters() + # ): + # self.assertEqual(p1.grad, p2.grad) if __name__ == "__main__": diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index a58ab232748c83..eaef0716ba30dc 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -299,6 +299,70 @@ def from_local_kwargs_fn(x): self.assertEqual(res, ref) self.assertEqual(cnt.frame_count, 2) + def test_dynamo_dtensor_from_local_dynamic_shapes(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # Case 1: all dims dynamic + def fn(x): + dt = DTensor.from_local( + x, + mesh, + [Replicate()], + run_check=False, + shape=x.shape, + stride=x.stride(), + ) + return dt.to_local() + 2 + + inp = torch.randn(4, 6, requires_grad=True) + ref = fn(inp) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + res = torch.compile(fn, backend=cnt, fullgraph=True, dynamic=True)(inp) + res.sum().backward() + + self.assertEqual(res, ref) + self.assertEqual(cnt.frame_count, 1) + + # Case 2: only sizes are dynamic, strides are static + def fn(x): + dt = DTensor.from_local( + x, mesh, [Replicate()], run_check=False, shape=x.shape, stride=(1,) + ) + return dt.to_local() + 2 + + inp = torch.randn(4, requires_grad=True) + torch._dynamo.mark_dynamic(inp, 0) + ref = fn(inp) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + res = torch.compile(fn, backend=cnt, fullgraph=True)(inp) + res.sum().backward() + + self.assertEqual(res, ref) + self.assertEqual(cnt.frame_count, 1) + + # Case 3: both sizes and strides have a mix of dynamic and static dims + def fn(x): + dt = DTensor.from_local( + x, + mesh, + [Replicate()], + run_check=False, + shape=(x.shape[0], x.shape[1], 2), + stride=(x.stride()[0], 2, 1), + ) + return dt.to_local() + 2 + + inp = torch.randn(4, 6, 2, requires_grad=True) + torch._dynamo.mark_dynamic(inp, 0) + torch._dynamo.mark_dynamic(inp, 1) + ref = fn(inp) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + res = torch.compile(fn, backend=cnt, fullgraph=True)(inp) + res.sum().backward() + + self.assertEqual(res, ref) + self.assertEqual(cnt.frame_count, 1) + def test_dynamo_dtensor_recompile(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index caab550afb82b5..34f8166c7bbcde 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -671,10 +671,19 @@ def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function args_as_value = [x.as_python_constant() for x in args[1:]] - kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} - - def fn_with_prim_types(x): - return self.value(x, *args_as_value, **kwargs_as_value) + kwargs_as_value = { + k: v.as_python_constant() + for k, v in kwargs.items() + if k not in ["shape", "stride"] + } + kwargs_to_be_proxied = { + k: kwargs[k] for k in ["shape", "stride"] if k in kwargs + } + + def fn_with_prim_types(x, shape=None, stride=None): + return self.value( + x, *args_as_value, **kwargs_as_value, shape=shape, stride=stride + ) # attach the same function name for better debugging fn_with_prim_types.__name__ = "prim " + self.value.__name__ @@ -684,7 +693,10 @@ def fn_with_prim_types(x): proxy=tx.output.create_proxy( "call_function", fn_with_prim_types, - *proxy_args_kwargs([args[0]], {}), + *proxy_args_kwargs( + [args[0]], + kwargs_to_be_proxied, + ), ), ) From 644be2655a0172a0caa7ff4b80a455f06a0c25a1 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Sat, 7 Sep 2024 00:49:53 +0000 Subject: [PATCH 0587/1018] Add release matrix for 2.5 (#135383) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135383 Approved by: https://github.com/huydhn --- RELEASE.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASE.md b/RELEASE.md index 476cf199fdbbeb..59a3336b225331 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -50,6 +50,7 @@ Following is the Release Compatibility Matrix for PyTorch releases: | PyTorch version | Python | Stable CUDA | Experimental CUDA | Stable ROCm | | --- | --- | --- | --- | --- | +| 2.5 | >=3.9, <=3.12, (3.13 experimental) | CUDA 11.8, CUDA 12.1, CUDA 12.4, CUDNN 9.1.0.70 | None | ROCm 6.2 | | 2.4 | >=3.8, <=3.12 | CUDA 11.8, CUDA 12.1, CUDNN 9.1.0.70 | CUDA 12.4, CUDNN 9.1.0.70 | ROCm 6.1 | | 2.3 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 6.0 | | 2.2 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 5.7 | From a581119d3f6e3ff918cf77c6f93241053f7eeece Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 7 Sep 2024 00:50:13 +0000 Subject: [PATCH 0588/1018] [ONNX] Refactor exporter errors (#135180) Refactor exporter errors to combine old errors and new errors for API consistency. This PR also 1. Removes the `_C._check_onnx_proto(proto)` call in the old exporter. We don't need the ONNX checker because it is limited. 2. Removes the `OnnxExporterError` defined in the dynamo module. This class unnecessarily stores the onnx program object, making it very bulky. Instead, we revert to use the plain OnnxExporterError defined in the `errors` module and use it as the base class for all errors. 3. Continues to expose `OnnxExporterError` in `torch.onnx` and the rest of the errors in `torch.onnx.errors`. 4. Removes the `CheckerError` and `InvalidExportOptionsError` from `torch.onnx`. This is BC breaking but should have low impact. 5. I did not rename existing errors out of compatibility considerations, even though `ExporterError` would have been more succinct. Fixes https://github.com/pytorch/pytorch/issues/135125 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135180 Approved by: https://github.com/titaiwangms --- docs/source/onnx_dynamo.rst | 3 -- test/onnx/dynamo/test_exporter_api.py | 41 --------------- test/onnx/onnx_test_common.py | 37 ++++--------- test/onnx/test_fx_to_onnx.py | 14 ----- test/onnx/test_pytorch_onnx_no_runtime.py | 27 ---------- test/onnx/test_pytorch_onnx_onnxruntime.py | 2 +- torch/onnx/__init__.py | 23 ++++----- torch/onnx/_internal/_exporter_legacy.py | 60 ++++++---------------- torch/onnx/_internal/exporter/_building.py | 10 ++-- torch/onnx/_internal/exporter/_core.py | 16 +++--- torch/onnx/_internal/exporter/_errors.py | 21 ++++++++ torch/onnx/_internal/exporter/errors.py | 30 ----------- torch/onnx/errors.py | 33 ++++++------ torch/onnx/utils.py | 12 ----- 14 files changed, 87 insertions(+), 242 deletions(-) create mode 100644 torch/onnx/_internal/exporter/_errors.py delete mode 100644 torch/onnx/_internal/exporter/errors.py diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst index 5694c40d83c4f5..89b995687dab8e 100644 --- a/docs/source/onnx_dynamo.rst +++ b/docs/source/onnx_dynamo.rst @@ -146,9 +146,6 @@ API Reference .. autoclass:: torch.onnx.ONNXRuntimeOptions :members: -.. autoclass:: torch.onnx.InvalidExportOptionsError - :members: - .. autoclass:: torch.onnx.OnnxExporterError :members: diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index 17d55f0374c59a..e2c8cf3df0b6c3 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -1,12 +1,10 @@ # Owner(s): ["module: onnx"] import io -import os import onnx import torch from torch.onnx import dynamo_export, ExportOptions, ONNXProgram -from torch.onnx._internal import _exporter_legacy from torch.onnx._internal._exporter_legacy import ResolvedExportOptions from torch.testing._internal import common_utils @@ -72,45 +70,6 @@ def test_save_to_existing_buffer_default_serializer(self): dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer) onnx.load(buffer) - def test_save_sarif_log_to_file_with_successful_export(self): - with common_utils.TemporaryFileName(suffix=".sarif") as path: - dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path) - self.assertTrue(os.path.exists(path)) - - def test_save_sarif_log_to_file_with_failed_export(self): - class ModelWithExportError(torch.nn.Module): - def forward(self, x): - raise RuntimeError("Export error") - - with self.assertRaises(RuntimeError): - dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) - self.assertTrue( - os.path.exists(_exporter_legacy._DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH) - ) - - def test_onnx_program_accessible_from_exception_when_export_failed(self): - class ModelWithExportError(torch.nn.Module): - def forward(self, x): - raise RuntimeError("Export error") - - with self.assertRaises(torch.onnx.OnnxExporterError) as cm: - dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) - self.assertIsInstance(cm.exception, torch.onnx.OnnxExporterError) - self.assertIsInstance(cm.exception.onnx_program, ONNXProgram) - - def test_access_onnx_program_model_proto_raises_when_onnx_program_is_emitted_from_failed_export( - self, - ): - class ModelWithExportError(torch.nn.Module): - def forward(self, x): - raise RuntimeError("Export error") - - with self.assertRaises(torch.onnx.OnnxExporterError) as cm: - dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) - onnx_program = cm.exception.onnx_program - with self.assertRaises(RuntimeError): - onnx_program.model_proto - def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true( self, ): diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index 79d59a1e816816..69c35e44d57555 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -32,7 +32,6 @@ import torch from torch import export as torch_export from torch.onnx import _constants, verification -from torch.onnx._internal.fx import diagnostics from torch.testing._internal import common_utils from torch.testing._internal.opinfo import core as opinfo_core from torch.types import Number @@ -286,35 +285,19 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( # Feed args and kwargs into exporter. # Note that exporter should flatten kwargs into positional args the exported model; # since ONNX doesn't represent kwargs. - export_error: Optional[torch.onnx.OnnxExporterError] = None - try: - with _dynamo_config.patch(do_not_emit_runtime_asserts=True): - onnx_program = torch.onnx.dynamo_export( - ref_model, - *ref_input_args, - **ref_input_kwargs, - export_options=torch.onnx.ExportOptions( - dynamic_shapes=self.dynamic_shapes, - diagnostic_options=torch.onnx.DiagnosticOptions( - verbosity_level=logging.DEBUG - ), + with _dynamo_config.patch(do_not_emit_runtime_asserts=True): + onnx_program = torch.onnx.dynamo_export( + ref_model, + *ref_input_args, + **ref_input_kwargs, + export_options=torch.onnx.ExportOptions( + dynamic_shapes=self.dynamic_shapes, + diagnostic_options=torch.onnx.DiagnosticOptions( + verbosity_level=logging.DEBUG ), - ) - except torch.onnx.OnnxExporterError as e: - export_error = e - onnx_program = e.onnx_program - - if diagnostics.is_onnx_diagnostics_log_artifact_enabled(): - onnx_program.save_diagnostics( - f"test_report_{self._testMethodName}" - f"_dynamic_axes_{self.dynamic_shapes}" - f"_model_type_{self.model_type}" - ".sarif" + ), ) - if export_error is not None: - raise export_error - if not skip_dynamic_shapes_check: assert_dynamic_shapes(onnx_program, self.dynamic_shapes) diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 3c973d80203ee2..1946ba586ac0d2 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -242,20 +242,6 @@ def forward(self, input): export_options=torch.onnx.ExportOptions(onnx_registry=registry), ) - try: - torch.onnx.dynamo_export( - TraceModel(), - x, - export_options=torch.onnx.ExportOptions(onnx_registry=registry), - ) - except torch.onnx.OnnxExporterError as e: - assert_has_diagnostics( - e.onnx_program.diagnostic_context, - diagnostics.rules.no_symbolic_function_for_call_function, - diagnostics.levels.ERROR, - expected_node="aten.mul.Tensor", - ) - def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info( self, ): diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 64cbba6fc15fc5..37ca3836e53876 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -600,33 +600,6 @@ def forward(self, x, y): + ")." ) - def test_onnx_checker_invalid_graph(self): - class CustomAddModule(torch.nn.Module): - def forward(self, x, y): - return torch.add(x, y) - - def symbolic_custom_invalid_add(g, input, other, alpha=None): - return g.op("Add", input, other, invalid_attr_i=1) - - torch.onnx.register_custom_op_symbolic( - "::add", symbolic_custom_invalid_add, opset_version=9 - ) - - x = torch.randn(2, 3, 4) - y = torch.randn(2, 3, 4) - - test_model = CustomAddModule() - f = io.BytesIO() - - try: - with self.assertRaises(torch.onnx.errors.CheckerError): - torch.onnx.export(test_model, (x, y), f, opset_version=9) - finally: - torch.onnx.unregister_custom_op_symbolic("::add", 9) - - self.assertTrue(f.getvalue(), "ONNX graph was not exported.") - loaded_model = onnx.load_from_string(f.getvalue()) - def test_shape_value_map(self): class RSoftMax(torch.nn.Module): def __init__(self, radix, cardinality): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 5d2dc4bdf29f8a..091209a8a4f697 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -13582,7 +13582,7 @@ def forward(self, input, grid): ): if self.opset_version < 20: with self.assertRaises( - torch.onnx.errors.OnnxExporterError, + torch.onnx.OnnxExporterError, ): self.run_test( GridSampleModule(mode, padding_mode, align_corners), diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 963766a33e3ba4..d4d5d2c71255d8 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -38,15 +38,13 @@ "unregister_custom_op_symbolic", "disable_log", "enable_log", - # Errors - "CheckerError", # Backwards compatibility + # Base error + "OnnxExporterError", # Dynamo Exporter "DiagnosticOptions", "ExportOptions", "ONNXProgram", "ONNXRuntimeOptions", - "InvalidExportOptionsError", - "OnnxExporterError", "OnnxRegistry", "dynamo_export", "enable_fake_mode", @@ -69,7 +67,7 @@ OrtExecutionProvider as _OrtExecutionProvider, ) from ._type_utils import JitScalarType -from .errors import CheckerError # Backwards compatibility +from .errors import OnnxExporterError from .utils import ( _optimize_graph, _run_symbolic_function, @@ -109,8 +107,6 @@ ExportOptions, ONNXProgram, ONNXRuntimeOptions, - InvalidExportOptionsError, - OnnxExporterError, OnnxRegistry, enable_fake_mode, ) @@ -120,20 +116,19 @@ import os # Set namespace for exposed private names +DiagnosticOptions.__module__ = "torch.onnx" +ExportOptions.__module__ = "torch.onnx" ExportTypes.__module__ = "torch.onnx" JitScalarType.__module__ = "torch.onnx" -ExportOptions.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" ONNXRuntimeOptions.__module__ = "torch.onnx" -InvalidExportOptionsError.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" -enable_fake_mode.__module__ = "torch.onnx" OnnxRegistry.__module__ = "torch.onnx" -DiagnosticOptions.__module__ = "torch.onnx" -is_onnxrt_backend_supported.__module__ = "torch.onnx" -_OrtExecutionProvider.__module__ = "torch.onnx" -_OrtBackendOptions.__module__ = "torch.onnx" _OrtBackend.__module__ = "torch.onnx" +_OrtBackendOptions.__module__ = "torch.onnx" +_OrtExecutionProvider.__module__ = "torch.onnx" +enable_fake_mode.__module__ = "torch.onnx" +is_onnxrt_backend_supported.__module__ = "torch.onnx" producer_name = "pytorch" producer_version = _C_onnx.PRODUCER_VERSION diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 31e3ae4b5a338e..3c05b350e8c64f 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -1,7 +1,19 @@ # mypy: allow-untyped-defs -from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions) - annotations, -) +from __future__ import annotations + + +__all__ = [ + "DiagnosticOptions", + "ExportOptions", + "ONNXProgram", + "ONNXRuntimeOptions", + "InvalidExportOptionsError", + "OnnxRegistry", + "UnsatisfiedDependencyError", + "dynamo_export", + "enable_fake_mode", +] + import abc import contextlib @@ -17,6 +29,7 @@ import torch import torch._ops import torch.utils._pytree as pytree +from torch.onnx import errors from torch.onnx._internal import io_adapter from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.fx import ( @@ -1057,28 +1070,6 @@ def __init__(self, package_name: str, message: str): self.package_name = package_name -class OnnxExporterError(RuntimeError): - """Raised when an ONNX exporter error occurs. - - This exception is thrown when there's an error during the ONNX export process. - It encapsulates the :class:`ONNXProgram` object generated until the failure, allowing - access to the partial export results and associated metadata. - """ - - onnx_program: Final[ONNXProgram] # type: ignore[misc] - - def __init__(self, onnx_program: ONNXProgram, message: str): - """ - Initializes the OnnxExporterError with the given ONNX program and message. - - Args: - onnx_program (ONNXProgram): The partial results of the ONNX export. - message (str): The error message to be displayed. - """ - super().__init__(message) - self.onnx_program = onnx_program - - class InvalidExportOptionsError(RuntimeError): """Raised when user specified an invalid value for the :class:`ExportOptions`.""" @@ -1228,10 +1219,7 @@ def forward(self, x, bias=None): "or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). " f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}" ) - raise OnnxExporterError( - ONNXProgram._from_failure(e, resolved_export_options.diagnostic_context), - message, - ) from e + raise errors.OnnxExporterError(message) from e def common_pre_export_passes( @@ -1309,17 +1297,3 @@ def common_pre_export_passes( ) return module - - -__all__ = [ - "DiagnosticOptions", - "ExportOptions", - "ONNXProgram", - "ONNXRuntimeOptions", - "InvalidExportOptionsError", - "OnnxExporterError", - "OnnxRegistry", - "UnsatisfiedDependencyError", - "dynamo_export", - "enable_fake_mode", -] diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 4c5a7e32f244cb..48d5728c1cfb66 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -20,7 +20,7 @@ from onnxscript.ir import convenience as ir_convenience import torch -from torch.onnx._internal.exporter import _schemas, _tensors, errors +from torch.onnx._internal.exporter import _errors, _schemas, _tensors if TYPE_CHECKING: @@ -372,7 +372,7 @@ def _call_op( op_signature, named_inputs, type_binding, self.constant_farm, self.opset ) except Exception as e: - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. " f"named_inputs={named_inputs}, named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." ) from e @@ -384,7 +384,7 @@ def _call_op( ) ) except Exception as e: - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error constructing node for operator '{op_signature.domain}::{op_signature.name}'. " f"named_inputs={named_inputs}, converted_named_inputs={converted_named_inputs}, " f"named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." @@ -428,7 +428,7 @@ def eval( return outputs[0] return outputs except Exception as e: - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}." ) from e @@ -513,7 +513,7 @@ def eval_function( # type: ignore[override] _, lineno = inspect.getsourcelines(function.function) except Exception: source_file = lineno = None - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error calling function '{function.name}' with args {args} and kwargs {kwargs}." + f" The function is defined at '{source_file}:{lineno}'." if source_file diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 5994db2c4e0e55..2daa37eb8cd500 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -30,6 +30,7 @@ _building, _capture_strategies, _dispatching, + _errors, _fx_passes, _ir_passes, _onnx_program, @@ -37,7 +38,6 @@ _reporting, _tensors, _verification, - errors, ) @@ -426,7 +426,7 @@ def _handle_call_function_node_with_lowering( if onnx_function is None: # TODO(justinchuby): Fall back to ATen op or do something else? - raise errors.DispatchError( + raise _errors.DispatchError( f"No ONNX function found for {node.target!r}. Failure message: {message}" ) @@ -453,7 +453,7 @@ def _handle_call_function_node_with_lowering( try: outputs = onnx_function(*onnx_args, **onnx_kwargs) except Exception as e: - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'" ) from e @@ -547,7 +547,7 @@ def _add_nodes( # No lowering _handle_call_function_node(model.graph, node, node_name_to_values) except Exception as e: - raise errors.OnnxConversionError( + raise _errors.ConversionError( f"Error when translating node {node.format_node()}. See the stack trace for more information." ) from e return node_name_to_values @@ -958,7 +958,7 @@ def export( Raises: TorchExportError: If the export process fails with torch.export. - OnnxConversionError: If the ExportedProgram to ONNX translation fails. + ConversionError: If the ExportedProgram to ONNX translation fails. """ # Set up the error reporting facilities timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f") @@ -1039,7 +1039,7 @@ def export( # focus on the torch.export.export error. Errors from other strategies like # torch.jit.trace is due to the fallback and can be confusing to users. # We save all errors in the error report. - raise errors.TorchExportError( + raise _errors.TorchExportError( _STEP_ONE_ERROR_MESSAGE + ( f"\nError report has been saved to '{report_path}'." @@ -1099,7 +1099,7 @@ def export( else: report_path = None - raise errors.OnnxConversionError( + raise _errors.ConversionError( _STEP_TWO_ERROR_MESSAGE + (f"\nError report has been saved to '{report_path}'." if report else "") + _summarize_exception_stack(e) @@ -1163,7 +1163,7 @@ def export( else: report_path = None - raise errors.OnnxConversionError( + raise _errors.ConversionError( _STEP_TWO_ERROR_MESSAGE + (f"\nError report has been saved to '{report_path}'." if report else "") + _summarize_exception_stack(e) diff --git a/torch/onnx/_internal/exporter/_errors.py b/torch/onnx/_internal/exporter/_errors.py new file mode 100644 index 00000000000000..ff41bbe695fe7d --- /dev/null +++ b/torch/onnx/_internal/exporter/_errors.py @@ -0,0 +1,21 @@ +"""Error classes for the ONNX exporter.""" + +from __future__ import annotations + +import torch.onnx.errors + + +class TorchExportError(torch.onnx.errors.OnnxExporterError): + """Error during graph capturing using torch.export.""" + + +class ConversionError(torch.onnx.errors.OnnxExporterError): + """Error during ExportedProgram to ONNX conversion.""" + + +class DispatchError(ConversionError): + """Error during ONNX Function dispatching.""" + + +class GraphConstructionError(ConversionError): + """Error during ONNX graph construction.""" diff --git a/torch/onnx/_internal/exporter/errors.py b/torch/onnx/_internal/exporter/errors.py deleted file mode 100644 index a70eccf3a5633e..00000000000000 --- a/torch/onnx/_internal/exporter/errors.py +++ /dev/null @@ -1,30 +0,0 @@ -class ExporterError(RuntimeError): - """Error during export.""" - - -class TorchExportError(ExporterError): - """Error during torch.export.export.""" - - -class OnnxConversionError(ExporterError): - """Error during ONNX conversion.""" - - -class DispatchError(OnnxConversionError): - """Error during ONNX Funtion dispatching.""" - - -class GraphConstructionError(OnnxConversionError): - """Error during graph construction.""" - - -class OnnxCheckerError(ExporterError): - """Error during ONNX model checking.""" - - -class OnnxRuntimeError(ExporterError): - """Error during ONNX Runtime execution.""" - - -class OnnxValidationError(ExporterError): - """Output value mismatch.""" diff --git a/torch/onnx/errors.py b/torch/onnx/errors.py index 2f5271596b5c28..f6e035a8a85f1d 100644 --- a/torch/onnx/errors.py +++ b/torch/onnx/errors.py @@ -2,41 +2,38 @@ from __future__ import annotations -import textwrap -from typing import TYPE_CHECKING - -from torch.onnx import _constants -from torch.onnx._internal import diagnostics - - -if TYPE_CHECKING: - from torch import _C __all__ = [ - "OnnxExporterError", "OnnxExporterWarning", - "CheckerError", "SymbolicValueError", "UnsupportedOperatorError", ] +import textwrap +from typing import TYPE_CHECKING -class OnnxExporterWarning(UserWarning): - """Base class for all warnings in the ONNX exporter.""" +if TYPE_CHECKING: + from torch import _C -class OnnxExporterError(RuntimeError): - """Errors raised by the ONNX exporter.""" + +class OnnxExporterWarning(UserWarning): + """Warnings in the ONNX exporter.""" -class CheckerError(OnnxExporterError): - """Raised when ONNX checker detects an invalid model.""" +class OnnxExporterError(RuntimeError): + """Errors raised by the ONNX exporter. This is the base class for all exporter errors.""" class UnsupportedOperatorError(OnnxExporterError): """Raised when an operator is unsupported by the exporter.""" + # NOTE: This is legacy and is only used by the torchscript exporter + # Clean up when the torchscript exporter is removed def __init__(self, name: str, version: int, supported_version: int | None): + from torch.onnx import _constants + from torch.onnx._internal import diagnostics + if supported_version is not None: diagnostic_rule: diagnostics.infra.Rule = ( diagnostics.rules.operator_supported_in_newer_opset_version @@ -60,6 +57,8 @@ def __init__(self, name: str, version: int, supported_version: int | None): class SymbolicValueError(OnnxExporterError): """Errors around TorchScript values and nodes.""" + # NOTE: This is legacy and is only used by the torchscript exporter + # Clean up when the torchscript exporter is removed def __init__(self, msg: str, value: _C.Value): message = ( f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 9398b92cd9a43b..2b7848e5805845 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1644,18 +1644,6 @@ def _export( if verbose: torch.onnx.log("Exported graph: ", graph) onnx_proto_utils._export_file(proto, f, export_type, export_map) - # The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX, - # we can skip this check. - # If large model format export is enabled, proto will only contain data location instead of - # raw data and _check_onnx_proto() will fail because it can only handle the raw ONNX proto - # string in memory. - if (operator_export_type is _C_onnx.OperatorExportTypes.ONNX) and ( - not val_use_external_data_format - ): - try: - _C._check_onnx_proto(proto) - except RuntimeError as e: - raise errors.CheckerError(e) from e finally: assert GLOBALS.in_onnx_export GLOBALS.in_onnx_export = False From 8fcfdead5711276431a87658144d18d64a01b85b Mon Sep 17 00:00:00 2001 From: blaine-rister <145300525+blaine-rister@users.noreply.github.com> Date: Sat, 7 Sep 2024 02:19:14 +0000 Subject: [PATCH 0589/1018] [Inductor] Optionally allow padding on non-GPU devices (#135280) This is the OSS component of a larger MTIA diff. Currently, Inductor disables padding for non-GPU devices. We need to change this behavior to enable padding on MTIA. This PR adds a config option to enable padding on the CPU, or any other non-GPU device. In the future, we might want to enable padding on all devices by default. However, that might require supporting device-dependent padding defaults, since CPUs will likely use different settings than H100 GPUs. Differential Revision: D61038114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135280 Approved by: https://github.com/jfix71, https://github.com/shunting314 --- torch/_inductor/compile_fx.py | 2 +- torch/_inductor/config.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index fd8d6626deeb95..db4c1e0eddfece 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -382,7 +382,7 @@ def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor) ) - if config.comprehensive_padding and not has_gpu: + if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu: perf_hint_log.info("Skip comprehensive padding on CPU") return config.patch(comprehensive_padding=False) else: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index db778544d88f0e..80b65fbc31af91 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -608,6 +608,9 @@ def decide_compile_threads() -> int: ) pad_channels_last = False +# Disable comprehensive padding on the CPU +disable_padding_cpu = True + # The width of comprehensive padding, in bytes. # CUDA max memory transaction size is 128 bytes for a warp. padding_alignment_bytes = 128 From 530f7eadaebd3eec18634cc8e752917f26d59d23 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Sat, 7 Sep 2024 02:31:43 +0000 Subject: [PATCH 0590/1018] [quant][pt2e] fix placeholder typo and related quantization tests (#135379) A previous typo on "placeholder" and related tests in quantization are fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135379 Approved by: https://github.com/jerryzh168 --- test/quantization/pt2e/test_numeric_debugger.py | 6 +++--- torch/ao/quantization/pt2e/_numeric_debugger.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 094193f6cf00b4..e6cc4dc3252f44 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -75,7 +75,7 @@ def test_quantize_pt2e_preserve_handle(self): m = prepare_pt2e(m, quantizer) debug_handle_map = _extract_debug_handles(m) res_counter = Counter(debug_handle_map.values()) - repeated_debug_handle_ids = [3, 4, 7] + repeated_debug_handle_ids = [2, 3, 6] # 3 ids were repeated because we copy over the id from node to its output observer # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default for dh_id in repeated_debug_handle_ids: @@ -87,7 +87,7 @@ def test_quantize_pt2e_preserve_handle(self): res_counter = Counter(debug_handle_map.values()) # same set of ids where repeated, because we copy over the id from observer/fake_quant to # dequantize node - repeated_debug_handle_ids = [3, 4, 7] + repeated_debug_handle_ids = [2, 3, 6] for dh_id in repeated_debug_handle_ids: self.assertEqual(res_counter[dh_id], 2) @@ -161,7 +161,7 @@ def test_prepare_for_propagation_comparison(self): from torch.ao.quantization.pt2e._numeric_debugger import OutputLogger loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)] - self.assertEqual(len(loggers), 8) + self.assertEqual(len(loggers), 7) self.assertTrue("conv2d" in [logger.node_name for logger in loggers]) self.assertEqual(res, ref) diff --git a/torch/ao/quantization/pt2e/_numeric_debugger.py b/torch/ao/quantization/pt2e/_numeric_debugger.py index cff3b6d2aa3216..fedcf470a18a1d 100644 --- a/torch/ao/quantization/pt2e/_numeric_debugger.py +++ b/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -21,7 +21,7 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None: """ unique_id = 0 for node in graph_module.graph.nodes: - if node.op in ["output", "placehodler"]: + if node.op in ["output", "placeholder"]: continue if CUSTOM_KEY not in node.meta: From bd4d37a1a14072b069468407969f6582ad6258ac Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 7 Sep 2024 03:07:39 +0000 Subject: [PATCH 0591/1018] [ONNX] Handle mixed sequence inputs properly (#135378) Previously, when an input contains a mixture of `Value` and python constants like `[SymbolicTensor('sym_size_int_3', type=Tensor(INT64), shape=[], producer=node_Shape_0, index=0), 512]`, we get errors like ```pytb Traceback (most recent call last): File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_building.py", line 367, in _call_op converted_named_inputs = _process_python_constants_and_sequences( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_building.py", line 275, in _process_python_constants_and_sequences raise TypeError( TypeError: Constant input '[SymbolicTensor('sym_size_int_3', type=Tensor(INT64), shape=[], producer=node_Shape_0, index=0), 512]' of type '' is not supported ``` This PR updates Sequence handling to support this case, as well as variadic inputs and ONNX Sequence inputs. Synced from https://github.com/justinchuby/torch-onnx/pull/187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135378 Approved by: https://github.com/titaiwangms --- torch/onnx/_internal/exporter/_building.py | 277 ++++++++++++++++----- 1 file changed, 212 insertions(+), 65 deletions(-) diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 48d5728c1cfb66..3e5b48394c6579 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -13,7 +13,7 @@ import copy import inspect import logging -from typing import Any, Mapping, Sequence, TYPE_CHECKING, Union +from typing import Any, Iterable, Mapping, Sequence, TYPE_CHECKING, Union import onnxscript from onnxscript import evaluator, ir @@ -29,12 +29,13 @@ logger = logging.getLogger(__name__) -# TODO(justinchuby): Update ValidAttributeType to ir_convenience.SupportedAttrTypes ValidAttributeType = Union[ ir.TensorProtocol, int, float, bool, str, Sequence[int], Sequence[float], None ] -AllowedArgType = Union[ir.Value, Sequence[ir.Value], ValidAttributeType] +AllowedArgType = Union[ + ir.Value, Sequence[Union[ir.Value, ValidAttributeType]], ValidAttributeType +] # Logic for adapting inputs from general Python or PyTorch inputs to ONNX ir.Value @@ -181,13 +182,103 @@ def _resolve_parameter_dtypes( return type_binding -def _process_python_constants_and_sequences( +def _determine_input_dtype( + param: _schemas.Parameter, + arg: AllowedArgType, + type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], +) -> ir.DataType: + """Determine the dtype of the input that is a mix of Python constants and ir.Value.""" + if param.type_constraint in type_binding: + # A known dtype is available because it was resolved + return type_binding[param.type_constraint].dtype + if len(param.type_constraint.allowed_types) == 1: + # Only one type is allowed by the type constraint + return next(iter(param.type_constraint.allowed_types)).dtype + + # No dtype information available. Infer from the Python constant or (in the Sequence case) + # from a mix of Python constants and ir.Value + if isinstance(arg, bool): + return ir.DataType.BOOL + if isinstance(arg, float): + return ir.DataType.FLOAT + if isinstance(arg, int): + return ir.DataType.INT64 + if isinstance(arg, str): + return ir.DataType.STRING + if isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + return arg.dtype + if arg is None: + return ir.DataType.UNDEFINED + + # Handle sequences + if isinstance(arg, (tuple, list)): + if len(arg) == 0: + # Special case: Treat empty sequence as INT64 as they are typically used for shape + return ir.DataType.INT64 + + # Try to obtain the dtype from one of the values + for val in arg: + if isinstance(val, ir.Value) and val.dtype is not None: + return val.dtype + + if any(isinstance(val, float) for val in arg): + # If any float is present, the dtype is float + return ir.DataType.FLOAT + elif any(isinstance(val, int) for val in arg): + # Otherwise if any int is present, the dtype is int + return ir.DataType.INT64 + + raise ValueError( + f"Could not determine the dtype for the input '{param.name}'. " + f"param={param}, arg={arg}, param_type_constraint={param.type_constraint}, " + f"type_binding={type_binding}" + ) + + +def _allowed_types_are_sequence_types(allowed_types: Iterable[ir.TypeProtocol]) -> bool: + """Check if all allowed types are Sequence types.""" + return all(isinstance(t, ir.SequenceType) for t in allowed_types) + + +def _get_or_create_constant( + constant_farm: dict[ + tuple[ + bool | int | float | str | tuple[int] | tuple[float], + ir.DataType, + ], + ir.Value, + ], + arg: bool + | int + | float + | str + | tuple[int] + | tuple[float] + | tuple[bool] + | list[int] + | list[float] + | list[bool], + dtype: ir.DataType, + opset: onnxscript.values.Opset, +) -> ir.Value: + if isinstance(arg, list): + # Make the arg hashable + arg = tuple(arg) # type: ignore[assignment] + constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type] + if constant_value is None: + constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type] + constant_value = opset.Constant(value=constant_tensor) + constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index] + return constant_value + + +def _process_python_constants( signature: _schemas.OpSignature, named_inputs: dict[str, AllowedArgType], type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], constant_farm: dict[ tuple[ - bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float], + bool | int | float | str | tuple[int] | tuple[float], ir.DataType, ], ir.Value, @@ -206,7 +297,7 @@ def _process_python_constants_and_sequences( opset: The Opset to use for creating Constant nodes. Returns: - None + A mapping of parameter names to Python constants converted to constant Nodes. """ # 3. Convert Python constants to Constant nodes based on the dtype information; # construct sequences @@ -225,80 +316,128 @@ def _process_python_constants_and_sequences( if isinstance(arg, ir.Value): # TODO(justinchuby): Cast the ir.Value here if needed continue + if ( isinstance(arg, Sequence) and len(arg) > 0 - and all(isinstance(val, ir.Value) for val in arg) + and any(isinstance(val, ir.Value) for val in arg) ): # Skip the sequence of ir.Value. This is a variadic input or a Sequence input - # NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants - # like `Max(0, ir.Value())` - # We need to convert the Python constants to Constant nodes - # NOTE: Important to check that arg is not empty because we need to treat it as list[int] or list[float] + # It will be handled by _process_python_sequences + continue + if param.variadic: + # Handled by _process_python_sequences + continue + if _allowed_types_are_sequence_types(param.type_constraint.allowed_types): + # Handled by _process_python_sequences continue - # if param.variadic: - # # FXIME: Handle variadic inputs and sequence inputs differently - # raise NotImplementedError - # TODO: Find a way to recursively build constants. Maybe extract the logic out. - # FIXME: I am here - - assert isinstance( - param, _schemas.Parameter - ), f"Expected Parameter, got {type(param)}" - if param.type_constraint in type_binding: - # A known dtype is available - dtype = type_binding[param.type_constraint].dtype - elif len(param.type_constraint.allowed_types) == 1: - # Only one type is allowed - dtype = next(iter(param.type_constraint.allowed_types)).dtype - else: - # No dtype information available. Infer from the Python constant - if isinstance(arg, bool): - dtype = ir.DataType.BOOL - elif isinstance(arg, float): - dtype = ir.DataType.FLOAT - elif isinstance(arg, int): - dtype = ir.DataType.INT64 - elif isinstance(arg, str): - dtype = ir.DataType.STRING - elif isinstance(arg, (tuple, list)) and all( - isinstance(val, int) for val in arg - ): - dtype = ir.DataType.INT64 - elif isinstance(arg, (tuple, list)) and any( - isinstance(val, float) for val in arg - ): - # NOTE: if any float is present, the dtype is float - dtype = ir.DataType.FLOAT - elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): - dtype = arg.dtype - elif arg is None: - dtype = ir.DataType.UNDEFINED - else: - raise TypeError( - f"Constant input '{arg}' of type '{type(arg)}' is not supported" - ) + dtype = _determine_input_dtype(param, arg, type_binding) if arg is None: constant_value = None - elif not isinstance(arg, (ir.Tensor, ir.TensorProtocol)): - # Deduplicate the constants - if isinstance(arg, (tuple, list)): - # Make the arg hashable - arg = tuple(arg) # noqa: PLW2901 - constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type] - if constant_value is None: - constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type] - constant_value = opset.Constant(value=constant_tensor) - constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index] - else: + elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): constant_value = opset.Constant(value=arg) + else: + # Deduplicate the constants + constant_value = _get_or_create_constant(constant_farm, arg, dtype, opset) # type: ignore[arg-type] named_inputs[param.name] = constant_value return named_inputs # type: ignore[return-value] +def _process_python_sequences( + signature: _schemas.OpSignature, + named_inputs: dict[str, AllowedArgType], + type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], + constant_farm: dict[ + tuple[ + bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float], + ir.DataType, + ], + ir.Value, + ], + opset: onnxscript.values.Opset, +): + """Handle three types of sequences. + + 1. Variadic inputs + 2. Sequence input of ir.Value, + 3. Sequence of Python constants that contains ir.Value + """ + for name, arg in named_inputs.items(): + param = signature.params_map[name] + assert isinstance( + param, _schemas.Parameter + ), f"Expected Parameter, got {type(param)}" + + if not isinstance(arg, (tuple, list)): + continue + + if len(arg) == 0: + # Skip empty sequences + continue + + # 1. Sequence input of ir.Value + if _allowed_types_are_sequence_types(param.type_constraint.allowed_types): + # Turn the list into a Sequence node + # Constant op creation will be handled by the variadic case below when calling + # the SequenceConstruct op. + named_inputs[name] = opset.SequenceConstruct(*arg) + continue + + # 2. Variadic inputs + # NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants + # like `Max(0, ir.Value())` + # We need to convert the Python constants to Constant nodes + if param.variadic: + if all(isinstance(val, ir.Value) for val in arg): + # Skip the variadic input if all values are ir.Value + continue + + dtype = _determine_input_dtype(param, arg, type_binding) + new_args = [] + for val in arg: + if isinstance(val, ir.Value): + new_args.append(val) + else: + constant_tensor = ir.tensor(value=val, dtype=dtype) # type: ignore[arg-type] + constant_value = opset.Constant(value=constant_tensor) + new_args.append(constant_value) + named_inputs[name] = new_args + continue + else: + # 3. Concat the list as a single input + # E.g. [Value, 42] should be converted to op.Concat(Value, Constant(42)) + # when the expected input type is INT64 + # We assume this only happens for 1D cases + if all(isinstance(val, ir.Value) for val in arg): + named_inputs[name] = opset.Concat(*arg) + continue + + dtype = _determine_input_dtype(param, arg, type_binding) + new_args = [] + for val in arg: + if isinstance(val, ir.Value): + new_args.append(val) + elif val is None: + # Skip None values + continue + elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + new_args.append(opset.Constant(value=val)) + else: + # Turn the Python constant into 1D tensor for the constant + assert isinstance( + val, (bool, int, float) + ), f"Expected int or float, got {type(val)}" + new_args.append( + _get_or_create_constant(constant_farm, [arg], dtype, opset) # type: ignore[arg-type] + ) + named_inputs[name] = opset.Concat(*new_args) + continue + return named_inputs + + def _construct_node( signature: _schemas.OpSignature, named_inputs: Mapping[str, ir.Value | None], @@ -368,9 +507,17 @@ def _call_op( """ type_binding = _resolve_parameter_dtypes(op_signature, named_inputs) try: - converted_named_inputs = _process_python_constants_and_sequences( + converted_named_inputs = _process_python_constants( op_signature, named_inputs, type_binding, self.constant_farm, self.opset ) + converted_named_inputs = _process_python_sequences( + op_signature, + converted_named_inputs, # type: ignore[arg-type] + type_binding, + self.constant_farm, + self.opset, + ) + except Exception as e: raise _errors.GraphConstructionError( f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. " From 61a570b55e6e0c7608e48d533f5ae466fbf5ccfd Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 6 Sep 2024 17:44:20 -0700 Subject: [PATCH 0592/1018] [FlexAttention] Skip very small block size unit tests on H100 due to Triton bug (#135393) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135393 Approved by: https://github.com/BoyuanFeng --- test/inductor/test_flex_attention.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 0ba2d1ca08152f..9b52048f1dbf37 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -41,6 +41,15 @@ "Requires CUDA and Triton", ) +# Use this decorator only when hitting Triton bugs on H100 +running_on_a100_only = skipUnless( + torch.cuda.is_available() + and torch.version.hip is None + and has_triton() + and torch.cuda.get_device_capability() == (8, 0), + "Requires A100 and Triton", +) + Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) torch.set_float32_matmul_precision("high") @@ -526,7 +535,7 @@ def run_automatic_dynamic_test( def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable): self.run_test(score_mod, dtype) - @supported_platform + @running_on_a100_only @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( @@ -540,7 +549,7 @@ def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( ) self.run_test_with_call(attention, dtype, B, H, 64, D, B, H, 64, D) - @supported_platform + @running_on_a100_only @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_custom_sparse_block_size( From 909fb16f8b4f9fbaa0eb999d9aeb2a8f188cc2e1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 7 Sep 2024 04:48:18 +0000 Subject: [PATCH 0593/1018] [ONNX] Support FakeTensor in ONNXProgram (#135399) Sync with https://github.com/justinchuby/torch-onnx/compare/v0.1.20...v0.1.21 to support FakeTensors in ONNXProgram. Specifically, this PR implements the `apply_weights` method to allow users to supply a dictionary of concrete tensors to replace FakeTensors in the exported model weights. An error is raised when users try to serialize a FakeTensor to avoid segfaults. Also fixed a bug in `.save()` when `keep_initializers_as_inputs` is True and `include_initializers` is False. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135399 Approved by: https://github.com/titaiwangms --- test/onnx/pytorch_test_common.py | 22 ---------- test/onnx/test_fx_to_onnx.py | 2 +- torch/onnx/_internal/_exporter_legacy.py | 13 +++++- torch/onnx/_internal/exporter/_core.py | 10 ++++- .../onnx/_internal/exporter/_onnx_program.py | 41 ++++++++++++++++++- 5 files changed, 61 insertions(+), 27 deletions(-) diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 6c90290a3e96d2..408168a9c711eb 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -346,28 +346,6 @@ def wrapper(self, *args, **kwargs): return wrapper -def skip_if_fake_model_and_inititalizer(reason: Optional[str] = None): - """skip test with models using ExportedProgram as input. - - Args: - reason: The reason for skip the ONNX export test. - - Returns: - A decorator for skip tests. - """ - - def skip_dec(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - if kwargs["use_fake_mode"] and kwargs["include_initializer"]: - return unittest.SkipTest(reason) - return func(self, *args, **kwargs) - - return wrapper - - return skip_dec - - def xfail_if_model_type_is_exportedprogram( error_message: str, reason: Optional[str] = None ): diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 1946ba586ac0d2..47778a689d8330 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -650,7 +650,6 @@ def test_checkpoint_cast(self): onnx_program.save(tmp_onnx_file.name) onnx.checker.check_model(tmp_onnx_file.name, full_check=True) - @pytorch_test_common.skip_if_fake_model_and_inititalizer("segfault") @common_utils.parametrize( "include_initializer", [ @@ -732,6 +731,7 @@ def forward(self, tensor_x: torch.Tensor): onnx_program = torch.onnx.dynamo_export( model, tensor_x, export_options=export_options ) + onnx_program.apply_weights(state_dict) with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: onnx_program.save( tmp_onnx_file.name, diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 3c05b350e8c64f..71b828b9bceb7a 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -522,6 +522,7 @@ def __init__( self._diagnostic_context = diagnostic_context self._fake_context = fake_context self._export_exception = export_exception + self._state_dict: dict[str, torch.Tensor] = {} def __call__( self, @@ -562,7 +563,7 @@ def __call__( if isinstance(model_with_state_dict, torch.nn.Module): model_state = model_with_state_dict.state_dict() else: - model_state = None + model_state = self._state_dict self.save( onnx_model, model_state=model_state, @@ -736,6 +737,13 @@ def adapt_torch_outputs_to_onnx( ), "model_with_state_dict must be specified." return self._output_adapter.apply(model_outputs, model=model_with_state_dict) # type: ignore[return-value] + def apply_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Apply the weights from the specified state dict to the ONNX model. + Args: + state_dict: The state dict containing the weights to apply to the ONNX model. + """ + self._state_dict = state_dict + def save( self, destination: str | io.BufferedIOBase, @@ -764,6 +772,9 @@ def save( include_initializers is True or model_state is None ), "Cannot specify both `include_initializers=False` and `model_state`." + if self._state_dict and model_state is None: + model_state = self._state_dict + # Add initializers when symbolic tracing is enabled _model_state_files: list[str | io.BytesIO | dict[str, Any]] = [] if include_initializers: diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 2daa37eb8cd500..ae216c61f13c93 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -16,8 +16,6 @@ import onnxscript import onnxscript.evaluator -import onnxscript.function_libs -import onnxscript.function_libs.torch_lib from onnxscript import ir from onnxscript.ir import convenience as ir_convenience @@ -117,6 +115,14 @@ def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because # it avoids copying to a NumPy array + import torch._subclasses.fake_tensor + + if isinstance(self.raw, torch._subclasses.fake_tensor.FakeTensor): + raise TypeError( + f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " + "with a tensor backed by real data using ONNXProgram.apply_weights() " + "or save the model without initializers by setting include_initializers=False." + ) tensor = self.raw.detach().cpu().contiguous() return bytes( (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 49172e772c8a1c..6fb62a5961f875 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -11,6 +11,7 @@ import os import tempfile import textwrap +import warnings from typing import Callable, Sequence, TYPE_CHECKING import torch @@ -18,6 +19,9 @@ from torch.utils import _pytree +# NOTE: DO NOT import module from torch.onnx._internal to this module in the global scope +# because ONNXProgram is exposed to the public API + if TYPE_CHECKING: import onnxruntime as ort @@ -124,6 +128,22 @@ def save( When `external_data` is `True` or the model is larger than 2GB, the weights are saved as external data in a separate file. + Initializer (model weights) serialization behaviors: + - include_initializers=True, keep_initializers_as_inputs=False (default): + The initializers are included in the saved model. + - include_initializers=True, keep_initializers_as_inputs=True: + The initializers are included in the saved model and kept as model inputs. + Choose this option if you want the ability to override the model weights + during inference. + - include_initializers=False, keep_initializers_as_inputs=False: + The initializers are not included in the saved model and are not listed + as model inputs. Choose this option if you want to attach the initializers + to the ONNX model in a separate, post-processing, step. + - include_initializers=False, keep_initializers_as_inputs=True: + The initializers are not included in the saved model but are listed as model + inputs. Choose this option if you want to supply the initializers during + inference and want to minimize the size of the saved model. + Args: destination: The path to save the ONNX model to. include_initializers: Whether to include the initializers in the saved model. @@ -142,7 +162,7 @@ def save( if not include_initializers: self.model.graph.initializers.clear() if keep_initializers_as_inputs: - self.model.graph.inputs.extend(self.model.graph.initializers.values()) # type: ignore[arg-type] + self.model.graph.inputs.extend(original_initializers.values()) # type: ignore[arg-type] # Save the model to disk if ( @@ -160,6 +180,25 @@ def save( self.model.graph.inputs.clear() self.model.graph.inputs.extend(original_inputs) + def apply_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Apply the weights from the specified state dict to the ONNX model. + Args: + state_dict: The state dict containing the weights to apply to the ONNX model. + """ + from torch.onnx._internal.exporter import _core + + for name, tensor in state_dict.items(): + if name in self.model.graph.initializers: + self.model.graph.initializers[name].const_value = _core.TorchTensor( + tensor, name + ) + else: + warnings.warn( + f"Weight '{name}' not found in the model. Skipped applying.", + category=torch.onnx.errors.OnnxExporterWarning, + stacklevel=1, + ) + def initialize_inference_session( self, initializer: Callable[ From 1647ed12173b2e09c228445a66ce503c6fb157b2 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sat, 7 Sep 2024 00:29:29 -0700 Subject: [PATCH 0594/1018] [inductor][cpp][gemm] reduce memory alloc overhead by allocating local acc once per thread (#135277) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135277 Approved by: https://github.com/leslie-fang-intel Co-authored-by: Wu, Chunyuan --- torch/_inductor/codegen/cpp_gemm_template.py | 14 +++++++------- torch/_inductor/codegen/cpp_prefix.h | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 7bed1503900787..8e102677d56173 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -137,14 +137,14 @@ const int64_t k_block_end = Kr_blocks; {%- endif %} {{ micro_gemm.codegen_init(kernel) }} +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} +{%- endif %} for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; -{%- if use_local_acc %} - {%- set acc_buf_name = "local_acc_buf" %} - {{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "Nc_blocks*Nr"], acc_buf_dtype) }} -{%- endif %} for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); @@ -162,7 +162,7 @@ int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} for (int64_t nci = nc; nci < nc_block_end; nci++) { -{%- set acc_slice = kernel.slice_nd(acc, [(), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} +{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} {%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} if (kc == k_block_start) { @@ -180,7 +180,7 @@ {%- endif %} { {%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %} -{%- set tile_acc = kernel.slice_nd(acc, [(), ("0", "n_end - n_start")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %} {{ kernel.store_output( tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers )|indent(20, false) @@ -394,7 +394,7 @@ def get_cache_blocking(register_blocking, thread_blocking): # The ratios below are empirically determined to decide # the effective sizes of L1 and L2. # TODO: tune the factor here - L1_limit_factor = 1 + L1_limit_factor = 0.8 L2_limit_factor = 0.5 L1_cache_size = ( diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index ace9a182472a0a..6054473fdd5bc9 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -754,7 +754,7 @@ void mm_get_cache_blocking( // See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking algorithm. // TODO(jgong5): cache cache blocking results // TODO: tune the factor here - float L1_limit_factor = 1.0; + float L1_limit_factor = 0.8; float L2_limit_factor = 0.5; auto L1 = L1_cache_size * L1_limit_factor; From 310df1100eb9710385684e6204f65aea0239f2ee Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sat, 7 Sep 2024 00:29:30 -0700 Subject: [PATCH 0595/1018] [inductor][cpp][gemm] enable dynamic M for k-slicing (#133447) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133447 Approved by: https://github.com/leslie-fang-intel ghstack dependencies: #135277 Co-authored-by: Wu, Chunyuan --- torch/_inductor/codegen/cpp_gemm_template.py | 3 +- torch/_inductor/codegen/cpp_prefix.h | 54 ++++++++++++++------ 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 8e102677d56173..b2a2b1c41576d0 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -58,7 +58,7 @@ const int64_t Mr_blocks = (M + Mr - 1) / Mr; {%- if num_threads > 1 %} int64_t Mt_blocks, Nt_blocks, Kt_blocks; - mm_get_thread_blocking(num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); + mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); {%- else %} const auto Mt_blocks = Mr_blocks; const auto Nt_blocks = Nr_blocks; @@ -1033,6 +1033,7 @@ def get_reindexer(epilogue_node): DTYPE_TO_CPP=DTYPE_TO_CPP, L1_cache_size=L1_cache_size, L2_cache_size=L2_cache_size, + config=config, ) with contextlib.ExitStack() as stack: for buf in fake_buffers: diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 6054473fdd5bc9..a9db2d346ca65b 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -645,6 +645,7 @@ void atomic_add_vec(T *addr, at::vec::VectorizedN index, at::vec::V void mm_get_thread_blocking( int num_threads, + int max_k_slices, int64_t M, int64_t N, int64_t K, @@ -676,15 +677,16 @@ void mm_get_thread_blocking( return std::make_tuple(std::move(factors), count); }; - auto get_blocking = [](int64_t num_threads, - int64_t factor, + auto get_blocking = [](int64_t m_factor, + int64_t n_factor, + int64_t k_factor, int64_t m_blocks, int64_t n_blocks, int64_t k_blocks) { - int64_t thread_block_n = (n_blocks + factor - 1) / factor; - int64_t cofactor = num_threads / factor; - int64_t thread_block_m = (m_blocks + cofactor - 1) / cofactor; - return std::make_tuple(thread_block_m, thread_block_n, k_blocks); + int64_t thread_block_k = (k_blocks + k_factor - 1) / k_factor; + int64_t thread_block_n = (n_blocks + n_factor - 1) / n_factor; + int64_t thread_block_m = (m_blocks + m_factor - 1) / m_factor; + return std::make_tuple(thread_block_m, thread_block_n, thread_block_k); }; auto is_better_blocking = [=](int64_t Mt_, @@ -693,7 +695,7 @@ void mm_get_thread_blocking( int64_t Mt, int64_t Nt, int64_t Kt) { - return Mt == 0 || Mt_ * Mr + Nt_ * Nr < Mt * Mr + Nt * Nr; + return Mt == 0 || Kt_ < Kt || Mt_ * Mr + Nt_ * Nr < Mt * Mr + Nt * Nr; }; int64_t m_blocks = (M + Mr - 1) / Mr; @@ -704,11 +706,11 @@ void mm_get_thread_blocking( assert(count > 0); for (int i = 0; i < count; ++i) { - int64_t factor = factors[i]; - if (n_blocks >= factor && - m_blocks >= num_threads / factor) { + int64_t n_factor = factors[i]; + int64_t m_factor = num_threads / n_factor; + if (n_blocks >= n_factor && m_blocks >= m_factor) { auto [Mt_, Nt_, Kt_] = get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks); if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); } @@ -720,11 +722,33 @@ void mm_get_thread_blocking( } for (int i = 0; i < count; ++i) { - int64_t factor = factors[i]; - int64_t cofactor = num_threads / factor; - if (n_blocks >= factor || m_blocks >= cofactor) { + int64_t k_factor = factors[i]; + if (k_blocks >= k_factor && (max_k_slices == 0 || k_factor <= max_k_slices)) { + auto [mxn_factors, mxn_count] = get_factors(num_threads / k_factor); + for (int j = 0; j < mxn_count; ++j) { + int64_t n_factor = mxn_factors[j]; + int64_t m_factor = num_threads / (k_factor * n_factor); + if (n_blocks >= n_factor && m_blocks >= m_factor) { + auto [Mt_, Nt_, Kt_] = get_blocking( + m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks); + if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { + std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); + } + } + } + } + } + + if (Mt != 0) { + return; + } + + for (int i = 0; i < count; ++i) { + int64_t n_factor = factors[i]; + int64_t m_factor = num_threads / n_factor; + if (n_blocks >= n_factor || m_blocks >= m_factor) { auto [Mt_, Nt_, Kt_] = get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks); if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); } From 308e1f746729d829b5fcdfe04dc272f31bf3a204 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sat, 7 Sep 2024 00:29:32 -0700 Subject: [PATCH 0596/1018] [inductor][cpp][gemm] cache blocking config for dynamic shapes (#133538) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133538 Approved by: https://github.com/leslie-fang-intel ghstack dependencies: #135277, #133447 Co-authored-by: Wu, Chunyuan --- test/inductor/test_cpu_select_algorithm.py | 7 ++ torch/_inductor/codegen/cpp_prefix.h | 112 +++++++++++++++++---- 2 files changed, 98 insertions(+), 21 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 813be61c951954..a3ff2b3590aaae 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -1400,6 +1400,13 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): test_quantized_linear_amx_dynamic_shapes = ( TestSelectAlgorithm.test_quantized_linear_amx ) + test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing + test_linear_cache_blocking_dynamic_shapes = ( + TestSelectAlgorithm.test_linear_cache_blocking + ) + test_linear_thread_factors_dynamic_shapes = ( + TestSelectAlgorithm.test_linear_thread_factors + ) instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index a9db2d346ca65b..b46e772cd6ce11 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -7,6 +7,7 @@ #include #include #include +#include #include // WARNING: be extra careful when including more ATen/c10 header files here! @@ -643,7 +644,37 @@ void atomic_add_vec(T *addr, at::vec::VectorizedN index, at::vec::V } #endif -void mm_get_thread_blocking( +std::tuple, int> _get_factors(int64_t number) { + int count = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + count += 2; + } + } + auto factors = std::shared_ptr(new int64_t[count]); + int index = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + factors[index++] = number / i; + factors[index++] = i; + } + } + return std::make_tuple(factors, count); +} + +std::tuple, int> get_factors(int64_t number) { + thread_local std::map, int>> cache; + auto it = cache.find(number); + if (it != cache.end()) { + return it->second; + } else { + auto factors = _get_factors(number); + cache[number] = factors; + return factors; + } +} + +void _mm_get_thread_blocking( int num_threads, int max_k_slices, int64_t M, @@ -656,27 +687,8 @@ void mm_get_thread_blocking( int64_t& Nt, int64_t& Kt) { // see NOTE [Thread blocking in Cpp GEMM] for heuristics - // TODO(jgong5): cache thread blocking results Mt = Nt = Kt = 0; - auto get_factors = [](int64_t number) { - int count = 0; - for (int64_t i = std::sqrt(number); i > 0; --i) { - if (number % i == 0) { - count += 2; - } - } - auto factors = std::make_unique(count); - int index = 0; - for (int64_t i = std::sqrt(number); i > 0; --i) { - if (number % i == 0) { - factors[index++] = number / i; - factors[index++] = i; - } - } - return std::make_tuple(std::move(factors), count); - }; - auto get_blocking = [](int64_t m_factor, int64_t n_factor, int64_t k_factor, @@ -758,8 +770,34 @@ void mm_get_thread_blocking( assert(Mt != 0); } +void mm_get_thread_blocking( + int num_threads, + int max_k_slices, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t& Mt, + int64_t& Nt, + int64_t& Kt) { + thread_local std::map< + std::tuple, + std::tuple> cache; + auto key = std::make_tuple(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr); + auto it = cache.find(key); + if (it != cache.end()) { + std::tie(Mt, Nt, Kt) = it->second; + return; + } else { + _mm_get_thread_blocking(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr, Mt, Nt, Kt); + cache[key] = std::make_tuple(Mt, Nt, Kt); + } +} + template -void mm_get_cache_blocking( +void _mm_get_cache_blocking( int num_threads, int64_t M, int64_t N, @@ -814,6 +852,38 @@ void mm_get_cache_blocking( } } +template +void mm_get_cache_blocking( + int num_threads, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t Mt_blocks, + int64_t Nt_blocks, + int64_t Kt_blocks, + int64_t& Mc_blocks, + int64_t& Nc_blocks, + int64_t& Kc_blocks, + uint32_t L1_cache_size, + uint32_t L2_cache_size) { + thread_local std::map< + std::tuple, + std::tuple> cache; + auto key = std::make_tuple(num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, L1_cache_size, L2_cache_size); + auto it = cache.find(key); + if (it != cache.end()) { + std::tie(Mc_blocks, Nc_blocks, Kc_blocks) = it->second; + return; + } else { + _mm_get_cache_blocking( + num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, Mc_blocks, Nc_blocks, Kc_blocks, L1_cache_size, L2_cache_size); + cache[key] = std::make_tuple(Mc_blocks, Nc_blocks, Kc_blocks); + } +} + inline void mm_get_thread_blocks( int thread_id, int64_t M_blocks, From bc3297d608a69a2bf16209ee385f6fb0c4d8f8d8 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Sat, 7 Sep 2024 10:07:00 +0000 Subject: [PATCH 0597/1018] [Reland] Refactor caching device allocator utils (#130923) # Motivation Following [[RFC] Intel GPU Runtime Upstreaming for Allocator ](https://github.com/pytorch/pytorch/issues/116322), this PR aims to refactor caching device allocator utils to improve code reuse usage. This is the first PR, we could prepare some follow-up PRs continuing to refactor the device caching allocator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130923 Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD, https://github.com/eqy --- c10/core/CachingDeviceAllocator.h | 131 +++++++++++++++++ c10/cuda/CUDACachingAllocator.cpp | 163 +++++++-------------- c10/cuda/CUDACachingAllocator.h | 82 ++--------- c10/cuda/CUDAMallocAsyncAllocator.cpp | 2 + torch/csrc/cuda/CUDAPluggableAllocator.cpp | 4 +- torch/csrc/cuda/CUDAPluggableAllocator.h | 2 +- torch/csrc/cuda/Module.cpp | 8 +- 7 files changed, 202 insertions(+), 190 deletions(-) create mode 100644 c10/core/CachingDeviceAllocator.h diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h new file mode 100644 index 00000000000000..8724ecf88ae0f9 --- /dev/null +++ b/c10/core/CachingDeviceAllocator.h @@ -0,0 +1,131 @@ +#pragma once + +#include +#include + +#include + +namespace c10::CachingDeviceAllocator { + +struct Stat { + void increase(size_t amount) { + current += static_cast(amount); + peak = std::max(current, peak); + allocated += static_cast(amount); + } + + void decrease(size_t amount) { + current -= static_cast(amount); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + current >= 0, + "Negative tracked stat in device allocator (likely logic error)."); + freed += static_cast(amount); + } + + void reset_accumulated() { + allocated = 0; + freed = 0; + } + + void reset_peak() { + peak = current; + } + + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +using StatArray = std::array(StatType::NUM_TYPES)>; +using StatTypes = std::array(StatType::NUM_TYPES)>; + +template +void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { + for (const auto stat_type : c10::irange(stat_types.size())) { + if (stat_types[stat_type]) { + f(stat_type); + } + } +} + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from device memory allocation. + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via device memory deallocation) + StatArray inactive_split; + + // SUM: bytes allocated by this memory alocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to device malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to device memory allocation + // after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of device memory allocation calls. This includes both + // mapped and malloced memory. + int64_t num_device_alloc = 0; + + // COUNT: total number of device memory deallocation calls. This includes both + // un-mapped and free memory. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + +// Size pretty-printer +inline std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (static_cast(size) / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << static_cast(size) / 1048576.0; + os << " MiB"; + } else { + os << static_cast(size) / 1073741824.0; + os << " GiB"; + } + return os.str(); +} + +} // namespace c10::CachingDeviceAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 63f709596d43bc..3da3c6d4f5d05a 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include #include @@ -28,7 +27,6 @@ #include #include #include -#include #include #include #include @@ -45,6 +43,8 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace cuda::CUDACachingAllocator { +using namespace c10::CachingDeviceAllocator; + // Included here as this is externally used in CUDAAllocatorConfig const size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks @@ -135,47 +135,13 @@ namespace { using stream_set = ska::flat_hash_set; -using StatTypes = std::array(StatType::NUM_TYPES)>; - -void increase_stat(Stat& stat, size_t amount) { - stat.current += static_cast(amount); - stat.peak = std::max(stat.current, stat.peak); - stat.allocated += static_cast(amount); -} - -void decrease_stat(Stat& stat, size_t amount) { - stat.current -= static_cast(amount); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - stat.current >= 0, - "Negative tracked stat in CUDA allocator (likely logic error)."); - stat.freed += static_cast(amount); -} - -void reset_accumulated_stat(Stat& stat) { - stat.allocated = 0; - stat.freed = 0; -} - -void reset_peak_stat(Stat& stat) { - stat.peak = stat.current; -} - -template -void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { - for (const auto stat_type : c10::irange(stat_types.size())) { - if (stat_types[stat_type]) { - f(stat_type); - } - } -} - void decrease_stat_array( StatArray& stat_array, size_t amount, const StatTypes& stat_types) { for_each_selected_stat_type( stat_types, [&stat_array, amount](size_t stat_type) { - decrease_stat(stat_array[stat_type], amount); + stat_array[stat_type].decrease(amount); }); } @@ -1425,16 +1391,16 @@ class DeviceCachingAllocator { // A new split inactive block is being created from a previously unsplit // block, size remaining->size bytes. for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - increase_stat(stats.inactive_split_bytes[stat_type], remaining->size); - increase_stat(stats.inactive_split[stat_type], 1); + stats.inactive_split_bytes[stat_type].increase(remaining->size); + stats.inactive_split[stat_type].increase(1); }); } } else if (already_split && !block->expandable_segment_) { // An already-split block is becoming active for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - decrease_stat(stats.inactive_split_bytes[stat_type], block->size); - decrease_stat(stats.inactive_split[stat_type], 1); + stats.inactive_split_bytes[stat_type].decrease(block->size); + stats.inactive_split[stat_type].decrease(1); }); } @@ -1455,14 +1421,14 @@ class DeviceCachingAllocator { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - increase_stat(stats.allocation[stat_type], 1); - increase_stat(stats.allocated_bytes[stat_type], block->size); - increase_stat(stats.active[stat_type], 1); - increase_stat(stats.active_bytes[stat_type], block->size); - increase_stat(stats.requested_bytes[stat_type], block->requested_size); + stats.allocation[stat_type].increase(1); + stats.allocated_bytes[stat_type].increase(block->size); + stats.active[stat_type].increase(1); + stats.active_bytes[stat_type].increase(block->size); + stats.requested_bytes[stat_type].increase(block->requested_size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - increase_stat(stats.oversize_allocations, 1); + stats.oversize_allocations.increase(1); auto allocated_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes); @@ -1494,8 +1460,8 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.allocation[stat_type], 1); - decrease_stat(stats.allocated_bytes[stat_type], block->size); + stats.allocation[stat_type].decrease(1); + stats.allocated_bytes[stat_type].decrease(block->size); }); auto allocated_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes); @@ -1512,7 +1478,7 @@ class DeviceCachingAllocator { context ? context : block->context_when_allocated); if (block->size >= CUDAAllocatorConfig::max_split_size()) - decrease_stat(stats.oversize_allocations, 1); + stats.oversize_allocations.decrease(1); if (!block->stream_uses.empty()) { if (C10_UNLIKELY(!captures_underway.empty())) { @@ -1640,15 +1606,15 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - reset_accumulated_stat(stats.allocation[statType]); - reset_accumulated_stat(stats.segment[statType]); - reset_accumulated_stat(stats.active[statType]); - reset_accumulated_stat(stats.inactive_split[statType]); - reset_accumulated_stat(stats.allocated_bytes[statType]); - reset_accumulated_stat(stats.reserved_bytes[statType]); - reset_accumulated_stat(stats.active_bytes[statType]); - reset_accumulated_stat(stats.inactive_split_bytes[statType]); - reset_accumulated_stat(stats.requested_bytes[statType]); + stats.allocation[statType].reset_accumulated(); + stats.segment[statType].reset_accumulated(); + stats.active[statType].reset_accumulated(); + stats.inactive_split[statType].reset_accumulated(); + stats.allocated_bytes[statType].reset_accumulated(); + stats.reserved_bytes[statType].reset_accumulated(); + stats.active_bytes[statType].reset_accumulated(); + stats.inactive_split_bytes[statType].reset_accumulated(); + stats.requested_bytes[statType].reset_accumulated(); } stats.num_alloc_retries = 0; @@ -1656,8 +1622,8 @@ class DeviceCachingAllocator { stats.num_sync_all_streams = 0; stats.num_device_alloc = 0; stats.num_device_free = 0; - reset_accumulated_stat(stats.oversize_allocations); - reset_accumulated_stat(stats.oversize_segments); + stats.oversize_allocations.reset_accumulated(); + stats.oversize_segments.reset_accumulated(); } /** Resets the historical peak stats for the device **/ @@ -1666,18 +1632,18 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - reset_peak_stat(stats.allocation[statType]); - reset_peak_stat(stats.segment[statType]); - reset_peak_stat(stats.active[statType]); - reset_peak_stat(stats.inactive_split[statType]); - reset_peak_stat(stats.allocated_bytes[statType]); - reset_peak_stat(stats.reserved_bytes[statType]); - reset_peak_stat(stats.active_bytes[statType]); - reset_peak_stat(stats.inactive_split_bytes[statType]); - reset_peak_stat(stats.requested_bytes[statType]); + stats.allocation[statType].reset_peak(); + stats.segment[statType].reset_peak(); + stats.active[statType].reset_peak(); + stats.inactive_split[statType].reset_peak(); + stats.allocated_bytes[statType].reset_peak(); + stats.reserved_bytes[statType].reset_peak(); + stats.active_bytes[statType].reset_peak(); + stats.inactive_split_bytes[statType].reset_peak(); + stats.requested_bytes[statType].reset_peak(); } - reset_peak_stat(stats.oversize_allocations); - reset_peak_stat(stats.oversize_segments); + stats.oversize_allocations.reset_peak(); + stats.oversize_segments.reset_peak(); } /* Checkpoint the state of a private pool necessary to return it to its @@ -2289,7 +2255,7 @@ class DeviceCachingAllocator { total_allocated_memory += mapped_range.size; StatTypes stat_types = get_stat_types_for_pool(*to_map->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - increase_stat(stats.reserved_bytes[stat_type], mapped_range.size); + stats.reserved_bytes[stat_type].increase(mapped_range.size); }); auto reserved_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); @@ -2401,27 +2367,23 @@ class DeviceCachingAllocator { // inactive_split if (!block->expandable_segment_) { if (net_change_inactive_split_blocks > 0) { - increase_stat( - stats.inactive_split[stat_type], + stats.inactive_split[stat_type].increase( static_cast(net_change_inactive_split_blocks)); } else if (net_change_inactive_split_blocks < 0) { - decrease_stat( - stats.inactive_split[stat_type], + stats.inactive_split[stat_type].decrease( static_cast(-net_change_inactive_split_blocks)); } if (net_change_inactive_split_size > 0) { - increase_stat( - stats.inactive_split_bytes[stat_type], + stats.inactive_split_bytes[stat_type].increase( static_cast(net_change_inactive_split_size)); } else if (net_change_inactive_split_size < 0) { - decrease_stat( - stats.inactive_split_bytes[stat_type], + stats.inactive_split_bytes[stat_type].decrease( static_cast(-net_change_inactive_split_size)); } } - decrease_stat(stats.active[stat_type], 1); - decrease_stat(stats.active_bytes[stat_type], original_block_size); - decrease_stat(stats.requested_bytes[stat_type], requested_size); + stats.active[stat_type].decrease(1); + stats.active_bytes[stat_type].decrease(original_block_size); + stats.requested_bytes[stat_type].decrease(requested_size); }); } @@ -2733,11 +2695,11 @@ class DeviceCachingAllocator { total_allocated_memory += size; p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { - increase_stat(stats.segment[stat_type], 1); - increase_stat(stats.reserved_bytes[stat_type], size); + stats.segment[stat_type].increase(1); + stats.reserved_bytes[stat_type].increase(size); }); if (size >= CUDAAllocatorConfig::max_split_size()) - increase_stat(stats.oversize_segments, 1); + stats.oversize_segments.increase(1); auto reserved_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); reserved_bytes_gauge.record( @@ -2877,8 +2839,8 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.segment[stat_type], 1); - decrease_stat(stats.reserved_bytes[stat_type], block->size); + stats.segment[stat_type].decrease(1); + stats.reserved_bytes[stat_type].decrease(block->size); }); auto reserved_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); @@ -2887,7 +2849,7 @@ class DeviceCachingAllocator { .current); if (block->size >= CUDAAllocatorConfig::max_split_size()) - decrease_stat(stats.oversize_segments, 1); + stats.oversize_segments.decrease(1); pool->blocks.erase(block); delete block; } @@ -2939,7 +2901,7 @@ class DeviceCachingAllocator { total_allocated_memory -= unmapped.size; StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.reserved_bytes[stat_type], unmapped.size); + stats.reserved_bytes[stat_type].decrease(unmapped.size); }); auto reserved_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); @@ -3773,25 +3735,6 @@ void local_raw_delete(void* ptr) { } } // namespace Native -// Size pretty-printer -std::string format_size(uint64_t size) { - std::ostringstream os; - os.precision(2); - os << std::fixed; - if (size <= 1024) { - os << size << " bytes"; - } else if (size <= 1048576) { - os << (static_cast(size) / 1024.0); - os << " KiB"; - } else if (size <= 1073741824ULL) { - os << static_cast(size) / 1048576.0; - os << " MiB"; - } else { - os << static_cast(size) / 1073741824.0; - os << " GiB"; - } - return os.str(); -} namespace CudaMallocAsync { // If this is put in its own header file, it gets incorrectly renamed in HIPify. diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 72617bcaf3a944..70385654201f50 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -48,74 +48,11 @@ C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace c10::cuda::CUDACachingAllocator { -extern const size_t kLargeBuffer; - -struct Stat { - int64_t current = 0; - int64_t peak = 0; - int64_t allocated = 0; - int64_t freed = 0; -}; - -enum struct StatType : uint64_t { - AGGREGATE = 0, - SMALL_POOL = 1, - LARGE_POOL = 2, - NUM_TYPES = 3 // remember to update this whenever a new stat type is added -}; +// Preserved only for BC reasons +// NOLINTNEXTLINE(misc-unused-using-decls) +using c10::CachingDeviceAllocator::DeviceStats; -typedef std::array(StatType::NUM_TYPES)> StatArray; - -// Struct containing memory allocator summary statistics for a device. -struct DeviceStats { - // COUNT: allocations requested by client code - StatArray allocation; - // COUNT: number of allocated segments from cudaMalloc(). - StatArray segment; - // COUNT: number of active memory blocks (allocated or used by stream) - StatArray active; - // COUNT: number of inactive, split memory blocks (unallocated but can't be - // released via cudaFree) - StatArray inactive_split; - - // SUM: bytes allocated by this memory alocator - StatArray allocated_bytes; - // SUM: bytes reserved by this memory allocator (both free and used) - StatArray reserved_bytes; - // SUM: bytes within active memory blocks - StatArray active_bytes; - // SUM: bytes within inactive, split memory blocks - StatArray inactive_split_bytes; - // SUM: bytes requested by client code - StatArray requested_bytes; - - // COUNT: total number of failed calls to CUDA malloc necessitating cache - // flushes. - int64_t num_alloc_retries = 0; - - // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) - int64_t num_ooms = 0; - - // COUNT: total number of oversize blocks allocated from pool - Stat oversize_allocations; - - // COUNT: total number of oversize blocks requiring malloc - Stat oversize_segments; - - // COUNT: total number of synchronize_and_free_events() calls - int64_t num_sync_all_streams = 0; - - // COUNT: total number of CUDA allocation calls. This includes both cuMemMap - // and cudaMalloc. - int64_t num_device_alloc = 0; - - // COUNT: total number of CUDA free calls. This includes both cuMemUnmap - // and cudaFree. - int64_t num_device_free = 0; - - // SIZE: maximum block size that is allowed to be split. - int64_t max_split_size = 0; -}; +extern const size_t kLargeBuffer; typedef std::shared_ptr (*CreateContextFn)(); @@ -247,9 +184,6 @@ enum struct RecordContext { ALL = 3, // additionally record stacks for when something is freed }; -// Size pretty-printer -std::string format_size(uint64_t size); - using OutOfMemoryObserver = std::functionrecordStream(dataPtr, stream); } -inline DeviceStats getDeviceStats(c10::DeviceIndex device) { +inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) { return get()->getDeviceStats(device); } diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 3fe414b55e30a2..3a7b485ebce223 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -11,6 +11,8 @@ namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync { +using namespace c10::CachingDeviceAllocator; + #if CUDA_VERSION >= 11040 // CUDA device allocator that uses cudaMallocAsync to implement // the same interface as CUDACachingAllocator.cpp. diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index c6af163481481e..5220e86233bd67 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -210,8 +210,8 @@ void CUDAPluggableAllocator::recordStream( } } -c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: - getDeviceStats(c10::DeviceIndex device) { +c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats( + c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getDeviceStats. " diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 70014b2fa0b37b..8652ef0f2bfde8 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -115,7 +115,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator void recordStream(const c10::DataPtr&, streamType stream) override; - c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 134758b2c06814..22cbadc42fc09f 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -565,10 +565,10 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const auto device_index = THPUtils_unpackDeviceIndex(arg); - using c10::cuda::CUDACachingAllocator::DeviceStats; - using c10::cuda::CUDACachingAllocator::Stat; - using c10::cuda::CUDACachingAllocator::StatArray; - using c10::cuda::CUDACachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + using c10::CachingDeviceAllocator::Stat; + using c10::CachingDeviceAllocator::StatArray; + using c10::CachingDeviceAllocator::StatType; const auto statToDict = [](const Stat& stat) { py::dict dict; From 2554b8ec15d6767fd5ad80c1e72c5517eb9c9f38 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Sat, 7 Sep 2024 10:07:02 +0000 Subject: [PATCH 0598/1018] [Intel GPU] Add XPU memory-related APIs (#129919) # Motivation According to https://github.com/pytorch/pytorch/issues/116322, we will help unify the device allocator. So we introduce a simple xpu device allocator only with the key functionality first. And expect to add some memory statistics-related functionality after the unification. But now, some memory statistic-related APIs listed in https://github.com/pytorch/pytorch/issues/127929 are requested. We need more time to unify the device allocator. In order to facilitate the user experience, we expect to support these memory statistic-related APIs before the unification. # Additional Context Fixes: #127929 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129919 Approved by: https://github.com/dvrogozh, https://github.com/abhilash1910, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD ghstack dependencies: #130923 --- c10/xpu/XPUCachingAllocator.cpp | 133 +++++++++++++++++++++- c10/xpu/XPUCachingAllocator.h | 9 +- docs/source/xpu.rst | 21 +++- test/test_xpu.py | 38 +++++++ torch/_C/__init__.pyi.in | 3 + torch/csrc/xpu/Module.cpp | 75 +++++++++++++ torch/xpu/__init__.py | 45 +++++--- torch/xpu/memory.py | 191 ++++++++++++++++++++++++++++++++ 8 files changed, 495 insertions(+), 20 deletions(-) create mode 100644 torch/xpu/memory.py diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index da57191fa1601c..42db9c89c98ab6 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -9,6 +9,8 @@ namespace c10::xpu::XPUCachingAllocator { +using namespace c10::CachingDeviceAllocator; + // newly allocated memory with 512-byte alignment. constexpr size_t kDeviceAlignment = 512; // all sizes are rounded to at least 512 bytes @@ -117,6 +119,7 @@ struct AllocParams { BlockPool* pool; size_t alloc_size; Block* block; + StatTypes stat_types = {}; }; } // anonymous namespace @@ -124,6 +127,7 @@ struct AllocParams { class DeviceCachingAllocator { private: mutable std::recursive_mutex mutex; + DeviceStats stats; BlockPool large_blocks; // unallocated cached blocks larger than 1 MB BlockPool small_blocks; // unallocated cached blocks 1 MB or smaller ska::flat_hash_set active_blocks; // allocated or in use by a stream @@ -173,6 +177,12 @@ class DeviceCachingAllocator { active_blocks.erase(block); bool inserted = pool.blocks.insert(block).second; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); + + StatTypes stat_types = get_stat_types_for_pool(pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + stats.active_bytes[stat_type].decrease(block->size); + stats.requested_bytes[stat_type].decrease(block->requested_size); + }); } void process_events() { @@ -250,6 +260,9 @@ class DeviceCachingAllocator { return false; } p.block = new Block(device, p.queue(), size, p.pool, ptr); + for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { + stats.reserved_bytes[stat_type].increase(size); + }); return true; } @@ -281,6 +294,12 @@ class DeviceCachingAllocator { sycl::free(block->ptr, xpu::get_device_context()); auto* pool = block->pool; pool->blocks.erase(block); + + StatTypes stat_types = get_stat_types_for_pool(*pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + stats.reserved_bytes[stat_type].decrease(block->size); + }); + delete block; } @@ -314,6 +333,14 @@ class DeviceCachingAllocator { } } + StatTypes get_stat_types_for_pool(const BlockPool& pool) { + StatTypes stat_types = {}; + stat_types[static_cast(StatType::AGGREGATE)] = true; + stat_types[static_cast( + pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true; + return stat_types; + } + Block* alloc_found_block( AllocParams params, size_t orig_size, @@ -350,6 +377,12 @@ class DeviceCachingAllocator { bool inserted = active_blocks.insert(block).second; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted) + for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { + stats.allocated_bytes[stat_type].increase(block->size); + stats.active_bytes[stat_type].increase(block->size); + stats.requested_bytes[stat_type].increase(block->requested_size); + }); + return block; } @@ -376,6 +409,7 @@ class DeviceCachingAllocator { auto& pool = get_pool(size); const size_t alloc_size = get_allocation_size(size); AllocParams params(device, size, &queue, &pool, alloc_size); + params.stat_types = get_stat_types_for_pool(pool); // First, try to get a block from the existing pool. bool block_found = get_free_block(params); @@ -384,9 +418,32 @@ class DeviceCachingAllocator { block_found = alloc_block(params) || (release_cached_blocks() && alloc_block(params)); } - TORCH_CHECK( - block_found, - "XPU out of memory, please use `empty_cache` to release all unoccupied cached memory."); + if (!block_found) { + c10::xpu::DeviceProp device_prop; + c10::xpu::get_device_properties(&device_prop, device); + auto device_total = device_prop.global_mem_size; + auto allocated_bytes = + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current; + auto reserved_bytes = + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current; + TORCH_CHECK_WITH( + OutOfMemoryError, + false, + "XPU out of memory. Tried to allocate ", + format_size(alloc_size), + ". GPU ", + static_cast(device), + " has a total capacity of ", + format_size(device_total), + ". Of the allocated memory ", + format_size(allocated_bytes), + " is allocated by PyTorch, and ", + format_size(reserved_bytes - allocated_bytes), + " is reserved by PyTorch but unallocated.", + " Please use `empty_cache` to release all unoccupied cached memory."); + } bool split_remainder = should_split(params.block, params.size()); return alloc_found_block(std::move(params), orig_size, split_remainder); } @@ -395,6 +452,11 @@ class DeviceCachingAllocator { std::scoped_lock lock(mutex); block->allocated = false; + StatTypes stat_types = get_stat_types_for_pool(*block->pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + stats.allocated_bytes[stat_type].decrease(block->size); + }); + if (!block->stream_uses.empty()) { insert_events(block); } else { @@ -414,6 +476,35 @@ class DeviceCachingAllocator { std::scoped_lock lock(mutex); release_cached_blocks(); } + + DeviceStats getStats() { + std::scoped_lock lock(mutex); + return stats; + } + + void resetAccumulatedStats() { + std::scoped_lock lock(mutex); + + for (const auto statType : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats.allocated_bytes[statType].reset_accumulated(); + stats.reserved_bytes[statType].reset_accumulated(); + stats.active_bytes[statType].reset_accumulated(); + stats.requested_bytes[statType].reset_accumulated(); + } + } + + void resetPeakStats() { + std::scoped_lock lock(mutex); + + for (const auto statType : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats.allocated_bytes[statType].reset_peak(); + stats.reserved_bytes[statType].reset_peak(); + stats.active_bytes[statType].reset_peak(); + stats.requested_bytes[statType].reset_peak(); + } + } }; void local_raw_delete(void* ptr); @@ -547,6 +638,30 @@ class XPUAllocator : public Allocator { void copy_data(void* dest, const void* src, std::size_t count) const final { xpu::getCurrentXPUStream().queue().memcpy(dest, src, count); } + + void assertValidDevice(DeviceIndex device) { + const auto device_num = device_allocators.size(); + TORCH_CHECK( + 0 <= device && device < static_cast(device_num), + "Invalid device argument ", + device, + ": did you call init?"); + } + + DeviceStats getDeviceStats(DeviceIndex device) { + assertValidDevice(device); + return device_allocators[device]->getStats(); + } + + void resetPeakStats(DeviceIndex device) { + assertValidDevice(device); + device_allocators[device]->resetPeakStats(); + } + + void resetAccumulatedStats(DeviceIndex device) { + assertValidDevice(device); + device_allocators[device]->resetAccumulatedStats(); + } }; static XPUAllocator allocator; @@ -567,6 +682,18 @@ void emptyCache() { return allocator.emptyCache(); } +void resetPeakStats(DeviceIndex device) { + return allocator.resetPeakStats(device); +} + +void resetAccumulatedStats(DeviceIndex device) { + return allocator.resetAccumulatedStats(device); +} + +DeviceStats getDeviceStats(DeviceIndex device) { + return allocator.getDeviceStats(device); +} + void* raw_alloc(size_t size) { return allocator.raw_alloc(size); } diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index 683654263a473d..6cdc8c8c71a6c9 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace c10::xpu::XPUCachingAllocator { @@ -11,6 +11,13 @@ C10_XPU_API void init(DeviceIndex device_count); C10_XPU_API void emptyCache(); +C10_XPU_API void resetPeakStats(DeviceIndex device); + +C10_XPU_API void resetAccumulatedStats(DeviceIndex device); + +C10_XPU_API c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + DeviceIndex device); + C10_XPU_API void* raw_alloc(size_t size); C10_XPU_API void raw_delete(void* ptr); diff --git a/docs/source/xpu.rst b/docs/source/xpu.rst index d4085cf4e6267c..a83bea4d1b3f86 100644 --- a/docs/source/xpu.rst +++ b/docs/source/xpu.rst @@ -13,7 +13,6 @@ torch.xpu device device_count device_of - empty_cache get_device_capability get_device_name get_device_properties @@ -51,7 +50,25 @@ Streams and events Stream +Memory management +----------------- +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + max_memory_allocated + max_memory_reserved + memory_allocated + memory_reserved + memory_stats + memory_stats_as_nested_dict + reset_accumulated_memory_stats + reset_peak_memory_stats + + .. This module needs to be documented. Adding here in the meantime .. for tracking purposes +.. py:module:: torch.xpu.memory .. py:module:: torch.xpu.random -.. py:module:: torch.xpu.streams \ No newline at end of file +.. py:module:: torch.xpu.streams diff --git a/test/test_xpu.py b/test/test_xpu.py index e77a1e758f0c9f..8fa2d5da21b620 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -366,6 +366,44 @@ def test_serialization_array_with_empty(self): self.assertIs(type(copy), type(original)) self.assertEqual(copy.get_device(), original.get_device()) + def test_out_of_memory(self): + tensor = torch.zeros(1024, device="xpu") + + with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"): + torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="xpu") + + with self.assertRaisesRegex(RuntimeError, "XPU out of memory."): + torch.empty(1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="xpu") + + def test_raises_oom(self): + torch.xpu.memory.empty_cache() + with self.assertRaises(torch.OutOfMemoryError): + torch.empty(1024 * 1024 * 1024 * 1024, device="xpu") + + def test_memory_allocation(self): + torch.xpu.empty_cache() + prev = torch.xpu.memory_allocated() + a = torch.ones(10, device="xpu") + self.assertGreater(torch.xpu.memory_allocated(), prev) + self.assertGreater(torch.xpu.memory_reserved(), 0) + del a + self.assertEqual(torch.xpu.memory_allocated(), prev) + torch.xpu.empty_cache() + self.assertEqual(torch.xpu.memory_reserved(), 0) + + @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected") + def test_device_memory_allocated(self): + device_count = torch.xpu.device_count() + current_alloc = [torch.xpu.memory_allocated(idx) for idx in range(device_count)] + x = torch.ones(10, device="xpu:0") + self.assertGreater(torch.xpu.memory_allocated(0), current_alloc[0]) + self.assertTrue( + all( + torch.xpu.memory_allocated(idx) == current_alloc[idx] + for idx in range(1, device_count) + ) + ) + instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 23f3bae774fec6..338f616bff60aa 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2108,6 +2108,9 @@ def _xpu_getCurrentStream(device: _int) -> Tuple: ... def _xpu_getCurrentRawStream(device: _int) -> _int: ... def _xpu_synchronize(device: _int) -> None: ... def _xpu_emptyCache() -> None: ... +def _xpu_memoryStats(device: _int) -> Dict[str, Any]: ... +def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ... +def _xpu_resetPeakMemoryStats(device: _int) -> None: ... class _XpuDeviceProperties: name: str diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index d8e49d9d56b9a3..6e6c9a4564b65e 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -197,6 +197,72 @@ PyObject* THXPModule_emptyCache(PyObject* self, PyObject* noargs) { Py_RETURN_NONE; } +PyObject* THXPModule_memoryStats(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + + using c10::CachingDeviceAllocator::DeviceStats; + using c10::CachingDeviceAllocator::Stat; + using c10::CachingDeviceAllocator::StatArray; + using c10::CachingDeviceAllocator::StatType; + + const auto statToDict = [](const Stat& stat) { + py::dict dict; + + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto statArrayToDict = [=](const StatArray& statArray) { + const std::array(StatType::NUM_TYPES)> + statTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(statTypeNames.size())) { + dict[statTypeNames[i]] = statToDict(statArray[i]); + } + return dict; + }; + + const DeviceStats stats = + c10::xpu::XPUCachingAllocator::getDeviceStats(device_index); + + py::dict result; + result["allocated_bytes"] = statArrayToDict(stats.allocated_bytes); + result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes); + result["active_bytes"] = statArrayToDict(stats.active_bytes); + result["requested_bytes"] = statArrayToDict(stats.requested_bytes); + + return result.release().ptr(); + END_HANDLE_TH_ERRORS +} + +PyObject* THXPModule_resetPeakMemoryStats(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + c10::xpu::XPUCachingAllocator::resetPeakStats(device_index); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + +PyObject* THXPModule_resetAccumulatedMemoryStats( + PyObject* self, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "invalid argument to reset_accumulated_memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + c10::xpu::XPUCachingAllocator::resetAccumulatedStats(device_index); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + // XPU module initialization static void registerXpuDeviceProperties(PyObject* module) { @@ -353,6 +419,15 @@ static struct PyMethodDef _THXPModule_methods[] = { nullptr}, {"_xpu_synchronize", THXPModule_xpuSynchronize, METH_O, nullptr}, {"_xpu_emptyCache", THXPModule_emptyCache, METH_NOARGS, nullptr}, + {"_xpu_memoryStats", THXPModule_memoryStats, METH_O, nullptr}, + {"_xpu_resetAccumulatedMemoryStats", + THXPModule_resetAccumulatedMemoryStats, + METH_O, + nullptr}, + {"_xpu_resetPeakMemoryStats", + THXPModule_resetPeakMemoryStats, + METH_O, + nullptr}, {nullptr}}; PyMethodDef* THXPModule_methods() { diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index b6eafc38ff1da2..a51924f89d21fe 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -388,19 +388,6 @@ def synchronize(device: _device_t = None) -> None: return torch._C._xpu_synchronize(device) -def empty_cache() -> None: - r"""Release all unoccupied cached memory currently held by the caching - allocator so that those can be used in other XPU application. - - .. note:: - :func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU - memory available for PyTorch. However, it may help reduce fragmentation - of XPU memory in certain cases. - """ - if is_initialized(): - torch._C._xpu_emptyCache() - - def _get_generator(device: torch.device) -> torch._C.Generator: r"""Return the XPU Generator object for the given device. @@ -448,7 +435,29 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: return default_generator.get_offset() -from .random import * # noqa: F403 +# import here to avoid circular import +from .memory import ( + empty_cache, + max_memory_allocated, + max_memory_reserved, + memory_allocated, + memory_reserved, + memory_stats, + memory_stats_as_nested_dict, + reset_accumulated_memory_stats, + reset_peak_memory_stats, +) +from .random import ( + get_rng_state, + get_rng_state_all, + initial_seed, + manual_seed, + manual_seed_all, + seed, + seed_all, + set_rng_state, + set_rng_state_all, +) __all__ = [ @@ -475,6 +484,14 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: "is_initialized", "manual_seed", "manual_seed_all", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", "seed", "seed_all", "set_device", diff --git a/torch/xpu/memory.py b/torch/xpu/memory.py new file mode 100644 index 00000000000000..32e7f6c2ce315f --- /dev/null +++ b/torch/xpu/memory.py @@ -0,0 +1,191 @@ +import collections +from typing import Any, Dict, Union + +import torch +from torch.types import Device + +from . import _get_device_index, is_initialized + + +_device_t = Union[Device, str, int, None] + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other XPU application. + + .. note:: + :func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU + memory available for PyTorch. However, it may help reduce fragmentation + of XPU memory in certain cases. + """ + if is_initialized(): + torch._C._xpu_emptyCache() + + +def reset_peak_memory_stats(device: _device_t = None) -> None: + r"""Reset the "peak" stats tracked by the XPU memory allocator. + + See :func:`~torch.xpu.memory_stats` for details. Peak stats correspond to the + `"peak"` key in each individual stat dict. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + device = _get_device_index(device, optional=True) + return torch._C._xpu_resetPeakMemoryStats(device) + + +def reset_accumulated_memory_stats(device: _device_t = None) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the XPU memory allocator. + + See :func:`~torch.xpu.memory_stats` for details. Accumulated stats correspond to + the `"allocated"` and `"freed"` keys in each individual stat dict. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + device = _get_device_index(device, optional=True) + return torch._C._xpu_resetAccumulatedMemoryStats(device) + + +def memory_stats_as_nested_dict(device: _device_t = None) -> Dict[str, Any]: + r"""Return the result of :func:`~torch.xpu.memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} + device = _get_device_index(device, optional=True) + return torch._C._xpu_memoryStats(device) + + +def memory_stats(device: _device_t = None) -> Dict[str, Any]: + r"""Return a dictionary of XPU memory allocator statistics for a given device. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + memory requested by client code, compare this with allocated_bytes to check if + allocation rounding adds too much overhead. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool (for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool (for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistics for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + result = [] + + def _recurse_add_to_result(prefix: str, obj: Any) -> None: + if isinstance(obj, dict): + if len(prefix) > 0: + prefix += "." + for k, v in obj.items(): + _recurse_add_to_result(prefix + k, v) + else: + result.append((prefix, obj)) + + stats = memory_stats_as_nested_dict(device=device) + _recurse_add_to_result("", stats) + result.sort() + + return collections.OrderedDict(result) + + +def memory_allocated(device: _device_t = None) -> int: + r"""Return the current GPU memory occupied by tensors in bytes for a given device. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + This is likely less than the amount shown in `xpu-smi` since some + unused memory can be held by the caching allocator and some context + needs to be created on GPU. + """ + return memory_stats(device=device).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device: _device_t = None) -> int: + r"""Return the maximum GPU memory occupied by tensors in bytes for a given device. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. For example, these two + functions can measure the peak allocated memory usage of each iteration in a + training loop. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device: _device_t = None) -> int: + r"""Return the current GPU memory managed by the caching allocator in bytes for a given device. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device: _device_t = None) -> int: + r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. For example, these two functions + can measure the peak cached memory amount of each iteration in a training + loop. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("reserved_bytes.all.peak", 0) + + +__all__ = [ + "empty_cache", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", +] From 9eaf66408254d6de368e348dcea407bc8d73b1e2 Mon Sep 17 00:00:00 2001 From: CaoE Date: Fri, 6 Sep 2024 21:42:54 -0700 Subject: [PATCH 0599/1018] Add oneDNN BRGEMM support on CPU (#131878) Pull Request resolved: https://github.com/pytorch/pytorch/pull/131878 Approved by: https://github.com/jgong5, https://github.com/peterbell10 --- BUILD.bazel | 1 + aten/src/ATen/native/CPUBlas.cpp | 375 ++++++++++++++++++++++++++++++- aten/src/ATen/native/CPUBlas.h | 57 ++++- cmake/Modules/FindMKLDNN.cmake | 9 +- third_party/mkl-dnn.BUILD | 3 +- 5 files changed, 440 insertions(+), 5 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index c71781606b7098..1018f7907adecd 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -332,6 +332,7 @@ intern_build_aten_ops( "@fbgemm", "@mkl", "@sleef", + "@mkl_dnn//:mkl-dnn", ], ) diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 2e0dadfbabca65..689d3bacd5a571 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -41,6 +41,17 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int * #include #endif // USE_FBGEMM +#if AT_MKLDNN_ENABLED() +#include +#endif // oneDNN + +#define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5) + +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) +#include +#include +#endif // oneDNN BRGEMM + namespace at::native::cpublas { namespace internal { @@ -822,4 +833,366 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex +struct UnsafeUkernelKeyHasher { + std::size_t operator()(const key_t& key) const; +}; + +template<> +std::size_t UnsafeUkernelKeyHasher::operator()(const BrgemmKey& key) const { + // Use beta, M, N, and K to compute hash to reduce the overhead as + // batch size, alpha, and data types are unlikely to change within the same kernel and + // leading dimensions are likely to be related to M, K, N or use fixed values. + std::size_t h = std::hash()(key.beta + 1); + h = std::hash()(key.M) ^ (h << 1); + h = std::hash()(key.N) ^ (h << 1); + h = std::hash()(key.K) ^ (h << 1); + h = std::hash()(key.ldc) ^ (h << 1); + return h; +} + +template<> +std::size_t UnsafeUkernelKeyHasher::operator()(const PackKey& key) const { + // Use K and N to compute hash to reduce the overhead as + // data types are unlikely to change and + // ld_in/ld_out is likely to be related to K, N or use fixed values + std::size_t h = std::hash()(key.K); + h = std::hash()(key.N) ^ (h << 1); + return h; +} + +template +struct KernelCache { + using kstore_t = std::unordered_map, UnsafeUkernelKeyHasher>; + static inline std::shared_ptr&& fetch_or_create( + const key_t& key, + const std::function()>& callback) { + auto&& search = get_store().find(key); + if (search != get_store().end()) { + return std::move(search->second); + } else { + get_store().insert({key, callback()}); + return std::move(get_store()[key]); + } + } + + static inline kstore_t& get_store() { + static thread_local kstore_t cache_kernels; + return cache_kernels; + } +}; + +// Helper struct for convenient brgemm configuration +struct GemmHelper { + GemmHelper( + int64_t M, + int64_t N, + int64_t K, + int64_t bs, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + ScalarType dt_a, + ScalarType dt_b, + ScalarType dt_c, + const float alpha, + const float beta) { + // Create brgemm + brg = dnnl::ukernel::brgemm( + M, + N, + K, + bs, + ld_a, + ld_b, + ld_c, + get_dnnl_dtype(dt_a), + get_dnnl_dtype(dt_b), + get_dnnl_dtype(dt_c), + alpha, + beta); + // Create a scratchpad buffer for the brgemm execution + scratchpad = std::vector(brg.get_scratchpad_size()); + // Prepare default vector of pairs of tensors A and B offsets for each batch. + A_B_offsets.reserve(1); + A_B_offsets[0] = std::make_pair(0, 0); + } + dnnl::ukernel::brgemm brg; + std::vector scratchpad; + std::vector> A_B_offsets; +}; + +struct Brgemm : public KernelCache { + // Fetch/create GemmHelper object and execute brgemm with batch size = 1 + template + static inline void call( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const scalar_t_a* A, + const scalar_t_b* B, + scalar_t_c* C) { + auto&& key = BrgemmKey( + M, + N, + K, + int64_t(1), + ld_a, + ld_b, + ld_c, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + alpha, + beta); + // Fetch/create GemmHelper object + auto&& value = fetch_or_create(key, [&]() { + auto&& v = std::make_shared( + M, + N, + K, + 1, + ld_a, + ld_b, + ld_c, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + alpha, + beta); + (*v).brg.generate(); + return std::move(v); + }); + if (get_current() != value) { + dnnl::ukernel::brgemm::release_hw_context(); + ((*value).brg).set_hw_context(); + get_current() = value; + } + ((*value).brg) + .execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data()); + } + + static inline std::shared_ptr& get_current() { + static thread_local std::shared_ptr current; + return current; + } + + static inline bool device_check(ScalarType dtype) { + if (!at::globalContext().userEnabledMkldnn()) { + return false; + } + if (dtype == ScalarType::Half) { + static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16; + return fp16_support; + } + return false; + } +}; + +using pack_t = dnnl::ukernel::brgemm_pack_B; +struct Pack : public KernelCache { + static inline void call( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out) { + auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out); + auto&& pack = fetch_or_create(key, [&]() { + auto&& p = std::make_shared( + K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out)); + if (need_pack(dt_in)) { + (*p).generate(); + } + return std::move(p); + }); + if (need_pack(dt_in)) { + (*pack).execute(in, out); + } else { + TORCH_CHECK(false, "No need to pack"); + } + } + + static inline bool need_pack(ScalarType dtype) { + if (!at::globalContext().userEnabledMkldnn()) { + return false; + } + if (dtype == ScalarType::Half) { + static bool fp16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx_fp16; + return fp16_pack; + } + return false; + } +}; +#endif + +void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::Half* A, + const at::Half* B, + float* C) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + if (Brgemm::device_check(ScalarType::Half)) { + Brgemm::call( + M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C); + return; + } +#endif + TORCH_CHECK(false, + "Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported"); +} + +void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::BFloat16* A, + const at::BFloat16* B, + float* C) { + TORCH_CHECK(false, + "BFloat16 Brgemm is currently not supported"); +} + +void brgemm_release() { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + dnnl::ukernel::brgemm::release_hw_context(); +#endif +} + +void pack( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out); +#else + TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled"); +#endif +} + +bool need_pack(ScalarType dt_in) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + return Pack::need_pack(dt_in); +#else + return false; +#endif +} + +} // namespace at::native::cpublas diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 3b30df1c21fad9..01ec49e4b54e00 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -7,6 +7,7 @@ #include #include + namespace at::native::cpublas { namespace internal { @@ -186,4 +187,58 @@ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy); void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); -} // namespace at::native::cpublas +// Batch-reduce GEMM +// Operates by the following formula: +// C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size +// A Base pointer to a tensor A. +// B Base pointer to a tensor B. +// Byte offsets vector of pairs of tensors A and B offsets for +// each batch. The number of batches must coincide with the +// `batch_size` value passed at object construction stage. +// C Pointer to a tensor C (accumulation buffer). +// scratchpad Pointer to a scratchpad buffer. +// Currently, only brgemm with batch size = 1 will be used +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::Half* A, + const at::Half* B, + float* C); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::BFloat16* A, + const at::BFloat16* B, + float* C); + +// Release brgemm hardware context +void brgemm_release(); + +// Pack B matrix to get better performance if needed +void pack( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out); + +// Whether pack is needed in the platform. +bool need_pack(ScalarType dt_in); + +} // namespace at::native::cpublas diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index a8d42b6ee8da6a..234d361d7f5c27 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -87,13 +87,18 @@ IF(NOT MKLDNN_FOUND) SET(ONEDNN_BUILD_GRAPH ON CACHE BOOL "" FORCE) ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER) + IF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") + MESSAGE("-- Will build oneDNN UKERNEL") + SET(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE) + ENDIF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") + FIND_PACKAGE(BLAS) FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include) - FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl) + FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl) IF(NOT MKLDNN_INCLUDE_DIR) MESSAGE("MKLDNN_INCLUDE_DIR not found") EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT}) - FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) + FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) ENDIF(NOT MKLDNN_INCLUDE_DIR) IF(BUILD_ONEDNN_GRAPH) FIND_PATH(LLGA_INCLUDE_DIR dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include/oneapi/dnnl) diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index 64154894d65e69..aa55943bc913c9 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -11,9 +11,9 @@ _DNNL_RUNTIME_OMP = { "#cmakedefine DNNL_SYCL_CUDA": "/* #undef DNNL_SYCL_CUDA */", "#cmakedefine DNNL_SYCL_HIP": "/* #undef DNNL_SYCL_HIP */", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", + "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "/* undef DNNL_EXPERIMENTAL_UKERNEL */", "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE", - "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#undef DNNL_EXPERIMENTAL_UKERNEL", "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", "#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", @@ -138,6 +138,7 @@ cc_library( "DNNL_ENABLE_CONCURRENT_EXEC", "DNNL_ENABLE_PRIMITIVE_CACHE", "DNNL_ENABLE_CPU_ISA_HINTS", + "DNNL_EXPERIMENTAL_UKERNEL", "ONEDNN_BUILD_GRAPH", ], ) From a95624c484cfc8fa2900144f33420064f93a00f4 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Sat, 7 Sep 2024 03:56:27 -0700 Subject: [PATCH 0600/1018] [Doc] update max-autotune for CPU (#134986) The current doc for `max-autotune` is applicable only for GPU. This PR adds the corresponding content for CPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134986 Approved by: https://github.com/jgong5, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- torch/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index d5a4644a85490f..64bfab708388a1 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2378,8 +2378,9 @@ def compile( There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints to debug. - - "max-autotune" is a mode that leverages Triton based matrix multiplications and convolutions - It enables CUDA graphs by default. + - "max-autotune" is a mode that leverages Triton or template based matrix multiplications + on supported devices and Triton based convolutions on GPU. + It enables CUDA graphs by default on GPU. - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs From 3a5f9d0dc3d35b19708008f713f47b6a900750d2 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 6 Sep 2024 21:17:18 -0700 Subject: [PATCH 0601/1018] [inductor] Move LoopBody to its own file (#135257) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135257 Approved by: https://github.com/oulgen --- torch/_inductor/bounds.py | 2 +- torch/_inductor/codegen/common.py | 2 +- torch/_inductor/codegen/cpp.py | 11 +- .../_inductor/codegen/cpp_template_kernel.py | 3 +- torch/_inductor/codegen/cpp_utils.py | 8 +- torch/_inductor/ir.py | 501 +---------------- torch/_inductor/loop_body.py | 519 ++++++++++++++++++ torch/_inductor/optimize_indexing.py | 2 +- torch/_inductor/scheduler.py | 5 +- torch/_inductor/virtualized.py | 2 +- 10 files changed, 539 insertions(+), 516 deletions(-) create mode 100644 torch/_inductor/loop_body.py diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index c2b0fe201fd685..7452f2bb1b62b6 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -9,7 +9,7 @@ import torch from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges -from .ir import InterpreterShim, LoopBody, LoopBodyBlock +from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock from .utils import cache_on_self, dominated_nodes from .virtualized import V diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a703f7d66c8ee0..f80d1f71e380d0 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -439,7 +439,7 @@ def propagate_loopbody(cls, body): @classmethod def propagate_scheduler_node(cls, node): - from ..ir import LoopBody + from ..loop_body import LoopBody from ..scheduler import SchedulerNode assert isinstance(node, SchedulerNode) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e1addfbbebb8c3..6d7f38777b3bed 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -22,6 +22,7 @@ from ..._dynamo.utils import counters from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics +from ..loop_body import LoopBody from ..scheduler import ( BaseSchedulerNode, BaseScheduling, @@ -3186,7 +3187,7 @@ def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: ) -def get_loop_body_lowp_fp(_body: ir.LoopBody) -> Tuple[Optional[torch.dtype], bool]: +def get_loop_body_lowp_fp(_body: LoopBody) -> Tuple[Optional[torch.dtype], bool]: """ Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes and if all the nodes can codegen with this data type without converting to float. @@ -3471,7 +3472,7 @@ def data_type_propagation(self, nodes): # Check if all the nodes of a given fx graph can support BF16/FP16 def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): - if not isinstance(scheduler_node._body, ir.LoopBody): + if not isinstance(scheduler_node._body, LoopBody): return True # Propagate the dtype to check if all the fx node is bf16/fp16 DataTypePropagation.propagate_scheduler_node(scheduler_node) @@ -3480,7 +3481,7 @@ def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): and not get_loop_body_lowp_fp(scheduler_node._body)[1] ) - def legalize_lowp_fp_dtype_loopbody(self, loop_body: ir.LoopBody): + def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody): def add_to_dtype(sub_graph: torch.fx.Graph): def is_lowp_fp_load(node: torch.fx.Node): if node.target not in ["load"]: @@ -3647,7 +3648,7 @@ def legalize_lowp_fp_dtype(self, nodes): for _node in nodes: assert isinstance(_node, SchedulerNode) - assert isinstance(_node._body, ir.LoopBody) + assert isinstance(_node._body, LoopBody) node: SchedulerNode = _node def is_memory_copy_scheduler_node(node: SchedulerNode): @@ -3658,7 +3659,7 @@ def is_memory_copy_scheduler_node(node: SchedulerNode): should_legalize = not is_memory_copy_scheduler_node(node) if should_legalize: - body: ir.LoopBody = node._body + body: LoopBody = node._body self.legalize_lowp_fp_dtype_loopbody(body) def codegen_functions(self, fn_list, var_sizes_list): diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index e9993e22cf63fe..0333720bbdc689 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -10,6 +10,7 @@ from .. import config, cpp_builder, ir, lowering as L from ..autotune_process import CppBenchmarkRequest +from ..loop_body import LoopBody from ..select_algorithm import PartialRender from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix from ..virtualized import V @@ -239,7 +240,7 @@ def fn(*args): node.make_loader()(new_args).value, ) - body = ir.LoopBody( + body = LoopBody( fn, (list(var_ranges.keys()), ()), var_ranges, diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 46ec1ac9644891..e69bb3637f7249 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -16,6 +16,7 @@ from torch.utils._sympy.value_ranges import ValueRanges from .. import ir +from ..loop_body import LoopBody from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs from ..virtualized import ops, OpsValue, V from .common import ( @@ -883,20 +884,19 @@ def inner_fn(index): def _get_loop_body(fn_list): - loop_bodies = None - if all(isinstance(fn, ir.LoopBody) for fn in fn_list): + if all(isinstance(fn, LoopBody) for fn in fn_list): loop_bodies = fn_list else: if hasattr(fn_list[0], "original_fn"): # For the case of local buffer, we wrap the fn with localize_function assert all(hasattr(fn, "original_fn") for fn in fn_list) assert all( - isinstance(fn.original_fn.args[0]._body, ir.LoopBody) for fn in fn_list + isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list ) loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] else: assert all(isinstance(fn, functools.partial) for fn in fn_list) - assert all(isinstance(fn.args[0]._body, ir.LoopBody) for fn in fn_list) + assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list) loop_bodies = [fn.args[0]._body for fn in fn_list] assert loop_bodies is not None return loop_bodies diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 96c171814056de..6d09d6592a0e13 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6,7 +6,6 @@ import functools import itertools import logging -import re import textwrap import traceback from contextlib import nullcontext @@ -72,6 +71,7 @@ extract_read_writes, var_builder, ) +from .loop_body import LoopBody from .ops_handler import OpCounterCSE from .runtime.benchmarking import benchmarker from .runtime.hints import ReductionHint @@ -6770,505 +6770,6 @@ def codegen_reference(self, writer=None): return self.name -class InterpreterShim(torch.fx.Interpreter): - @staticmethod - @functools.lru_cache(None) - def _dummy_gm(): - return torch.fx.symbolic_trace(identity) - - def __init__(self, graph, submodules): - # call super() with a placeholder to avoid constructing a - # GraphModule which is very expensive (it does codegen). - super().__init__(self._dummy_gm(), garbage_collect_values=False) - self.module = self # type: ignore[assignment] - self.graph = graph - self.submodules = submodules - self.extra_traceback = False - self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] - self.current_node = None - - def run_node(self, n: torch.fx.Node) -> Any: - self.current_node = n - return super().run_node(n) - - def run(self, *args, **kwargs): - with V.set_interpreter_handler(self): - return super().run(*args, **kwargs) - - -class LoopBody: - """ - Captures the body of a Loops subclass into an FX graph. Persists any - indexing simplifications and makes it easier to analyze loop bodies. - """ - - indexing_exprs: Dict[str, sympy.Expr] - indexing_exprs_name: Dict[sympy.Expr, str] - reads_name2expr: Dict[str, sympy.Expr] - writes_name2expr: Dict[str, sympy.Expr] - submodules: Dict[str, Any] - subblocks: Dict[str, LoopBodyBlock] - indirect_vars: List[str] - indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] - root_block: LoopBodyBlock - - def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): - super().__init__() - - _flat_sizes = tuple(var_ranges.values()) - self.sizes = ( - _flat_sizes[: len(iter_vars)], - _flat_sizes[len(iter_vars) :], - ) - - self.iter_vars = iter_vars - self.reduce_vars = reduce_vars - self.var_ranges = var_ranges - - if isinstance(fn, LoopBody): - self._init_with_copy(fn, args) - else: - self._init_with_tracing(fn, args) - - self.indexing = None - - def _init_with_tracing(self, fn, args): - """Do an FX trace of an arbitrary callable to construct self""" - self.indexing_exprs = {} - self.indexing_exprs_name = {} - self.reads_name2expr = {} - self.writes_name2expr = {} - self.submodules = {"get_index": self.get_index} - self.subblocks = {} - self.indirect_vars = [] - self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} - self.root_block = LoopBodyBlock(self, fn, args) # traces - del self.indexing_exprs_name # not used after _init_with_tracing - - def _init_with_copy(self, other: LoopBody, args): - """ - _init_with_tracing() is slow, so this is a fast path in the case - where we are just reordering/merging/splitting the args of an - existing LoopBody. - """ - indexing_exprs = other.indexing_from_args(args) - update_index = { - expr: indexing_exprs[name] for name, expr in other.indexing_exprs.items() - } - assert indexing_exprs.keys() == other.indexing_exprs.keys() and len( - update_index - ) == len(indexing_exprs) - - self.indexing_exprs = indexing_exprs - self.reads_name2expr = { - k: update_index[v] for k, v in other.reads_name2expr.items() - } - self.writes_name2expr = { - k: update_index[v] for k, v in other.writes_name2expr.items() - } - self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} - self.indirect_vars = [*other.indirect_vars] - self.indirect_var_ranges = {**other.indirect_var_ranges} - self.root_block = other.root_block.clone(self) - - submodules = {**other.submodules} - submodules.pop("get_index") - self.submodules = { - "get_index": self.get_index, - **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] - } - - def merge_loops(self) -> LoopBody: - """ - Merge both iteration and reduction loops and return a new LoopBody. - """ - old_body = self - old_sizes = self.sizes - old_iter_vars, old_reduce_vars = old_body.vars - old_iter_sizes, old_reduce_sizes = old_sizes - - index_exprs = [*old_body.indexing_exprs.values()] - - iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( - old_iter_vars, - old_iter_sizes, - index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), - ) - - reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( - old_reduce_vars, - old_reduce_sizes, - index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), - ) - - # if iter_sizes == old_iter_sizes: - # # no dimensions get merged. - # return old_sizes, old_body - - # Note: if no dimension get merges, the symbol prefix will - # remain 'y'. But if we merge dimensions, we change prefix to - # 'z'. If this is an issue, we can always retrace the LoopBody - # to change symbol prefix to 'z'. - # - # There is indeed an issue due to symbol name conflicting. - # y0 maybe reused for the y dimension later. - ( - iter_vars, - reduce_vars, - ), var_ranges = dependencies.index_vars_no_squeeze( - iter_sizes, reduce_sizes, prefix="t" - ) - new_body = LoopBody( - old_body, - [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], - var_ranges, - iter_vars, - reduce_vars, - ) - - # use the original symbol prefix - # Can try to optimize if this is a bottleneck for compilation time - (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( - iter_sizes, reduce_sizes, prefix="z" - ) - new_body2 = LoopBody( - new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 - ) - return new_body2 - - def reorder_iter_loops(self, new_order) -> LoopBody: - """ - Reorder iteration loops and return a new LoopBody. - """ - old_body = self - old_sizes = self.sizes - assert len(old_sizes[0]) == len(new_order) - reorder_fn = same_reorder(new_order) - - iter_size, reduce_size = old_sizes - new_iter_size = reorder_fn(iter_size) - - new_sizes = (new_iter_size, reduce_size) - - (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( - *new_sizes, prefix="t" # type: ignore[arg-type] - ) - - inverse_order = {b: a for a, b in enumerate(new_order)} - inverse_order = [inverse_order[i] for i in range(len(new_order))] - - def new_body(*indices: Sequence[sympy.Expr]) -> Any: - index = list(itertools.chain(*indices)) - assert len(index) == len(iter_size) + len(reduce_size) - iter_idx = index[: len(iter_size)] - reduce_idx = index[len(iter_size) :] - iter_idx = [iter_idx[i] for i in inverse_order] - return old_body(iter_idx, reduce_idx) - - loop_body = LoopBody( - new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars - ) - - # use the original symbol prefix so we can do multiple round of reordering - (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( - *new_sizes, prefix="z" # type: ignore[arg-type] - ) - new_body = LoopBody( - loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 - ) - return new_body - - @property - def vars(self): - assert self.iter_vars is not None - assert self.reduce_vars is not None - return self.iter_vars, self.reduce_vars - - @cache_on_self - def get_nodes(self): - all_graphs = itertools.chain( - (self.root_block.graph,), - (block.graph for block in self.subblocks.values()), - ) - return [node for graph in all_graphs for node in graph.nodes] - - @cache_on_self - def bounds(self): - # Doing a local import to avoid dumping all the code here - from .bounds import BoundVars - - return BoundVars(self) - - def debug_str(self): - lines = [f"var_ranges = {dict(self.var_ranges)}"] - lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) - lines.extend( - [ - block.debug_str(name) - for name, block in itertools.chain( - [("body", self.root_block)], self.subblocks.items() - ) - ] - ) - return "\n".join(lines) - - __repr__ = debug_str - - def add_index_expr(self, expr: sympy.Expr, category, buf_name): - if buf_name is not None: - getattr(self, f"{category}_name2expr")[buf_name] = expr - if expr not in self.indexing_exprs_name: - name = f"index{len(self.indexing_exprs)}" - self.indexing_exprs_name[expr] = name - self.indexing_exprs[name] = expr - return self.indexing_exprs_name[expr] - - def add_submodule(self, block, prefix): - """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" - if prefix[-1].isnumeric() and prefix not in self.submodules: - name = prefix - else: - name = f"{prefix}{len(self.submodules)}" - self.submodules[name] = block - return name - - def add_indirect(self, size): - var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) - assert var not in self.indirect_var_ranges - self.indirect_vars.append(var) - self.indirect_var_ranges[var] = size - return var - - def replace_indirect(self, old, new): - """Swap in a variable used in indirect indexing""" - if str(old) == str(new): - return - assert self.indexing is not None - self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} - - def get_index(self, name): - assert self.indexing is not None - return self.indexing[name] - - def indexing_from_args(self, indices): - index = [*itertools.chain.from_iterable(indices)] - assert len(index) == len(self.var_ranges), (index, self.var_ranges) - assert all( - v not in self.var_ranges for v in index - ), f"{self.var_ranges=}, {indices=}" - replacements = dict(zip(self.var_ranges.keys(), index)) - return { - name: sympy_subs(expr, replacements) - for name, expr in self.indexing_exprs.items() - } - - def __call__(self, *indices): - self.indexing = self.indexing_from_args(indices) - result = self.root_block() - self.indexing = None - return result - - def bind_set_indirect_shim(self, var, size, check, wrap_neg): - def set_indirect(new_var): - self.replace_indirect( - var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) - ) - - set_indirect.clone = functools.partial( # type: ignore[attr-defined] - LoopBody.bind_set_indirect_shim, - var=var, - size=size, - check=check, - wrap_neg=wrap_neg, - ) - return set_indirect - - def bind_scan_shim(self, combine_fn): - def shim(dtypes, values): - return V.ops.scan(dtypes, combine_fn, values) - - shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] - return shim - - def bind_masked_shim(self, name): - def shim(mask, other): - return V.ops.masked(mask, self.subblocks[name], other) - - shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] - return shim - - -class LoopBodyBlock: - """ - Captures the body of a Loops subclass into an FX graph. - In normal cases there will be a 1:1 mapping between LoopBody and - LoopBodyBlock, hower in the case of ops.masked() the masked out - operations will manifest as an extra LoopBodyBlock. - """ - - def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): - self.body = body - - def add_index(expr, category, buf_name=None): - return tracer.create_proxy( - "call_module", - "get_index", - (self.body.add_index_expr(expr, category, buf_name),), - {}, - ) - - class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] - self.name = "CaptureIndexing" - - def load(self, name: str, index: sympy.Expr): - index = add_index(index, "reads", name) - return self._inner.load(name, index) - - def store(self, name, index, value, mode=None): - index = add_index(index, "writes", name) - return self._inner.store(name, index, value, mode) - - def store_reduction(self, name, index, value): - index = add_index(index, "writes", name) - return self._inner.store_reduction(name, index, value) - - def reduction(self, dtype, src_dtype, reduction_type, value): - result = self._inner.reduction(dtype, src_dtype, reduction_type, value) - if "welford" in reduction_type: - return tuple(result[i] for i in range(3)) - return result - - def index_expr(self, index, dtype): - if isinstance(index, (int, sympy.Integer)): - return self._inner.constant(int(index), dtype) - index = add_index(index, "other") - return self._inner.index_expr(index, dtype) - - def check_bounds(self, index, size, lower, upper): - index = add_index(index, "other") - size = add_index(size, "other") - return self._inner.check_bounds(index, size, lower, upper) - - def bucketize( - self, - values, - offsets_name: str, - offsets_size: sympy.Expr, - indexing_dtype: torch.dtype, - right: bool, - ): - offsets_size = add_index(offsets_size, "other") - return self._inner.bucketize( - values, offsets_name, offsets_size, indexing_dtype, right - ) - - @staticmethod - def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): - """ - Recursively capture the masked out body in another LoopBodyBlock - """ - name = self.body.add_submodule(None, "masked_subblock") - self.body.submodules[name] = self.body.bind_masked_shim(name) - self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) - return tracer.create_proxy( - "call_module", name, (mask_proxy, other_proxy), {} - ) - - @staticmethod - def scan( - dtype_proxy, - combine_fn: Callable[ - [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...] - ], - value_proxy, - ): - shim = self.body.bind_scan_shim(combine_fn) - name = self.body.add_submodule(shim, "scan") - result = tracer.create_proxy( - "call_module", - name, - (dtype_proxy, value_proxy), - {}, - ) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(value_proxy))) - - def sort(self, dtypes, values, stable, descending): - result = self._inner.sort(dtypes, values, stable, descending) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(values))) - - def frexp(self, value_proxy): - result = self._inner.frexp(value_proxy) - # Proxies are iterable, but some methods expect tuples/lists - return (result[0], result[1]) - - @staticmethod - def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): - """ - Flow data from tensors into indexing formulas. - Introduce a call_module to update the indexing. - """ - - var = self.body.add_indirect(size) - set_indirect = self.body.bind_set_indirect_shim( - var, size, check, wrap_neg - ) - tracer.create_proxy( - "call_module", - self.body.add_submodule(set_indirect, f"set_{var}"), - (index_proxy,), - {}, - ) - return var - - @staticmethod - def output(result): - tracer.create_proxy("output", "output", (result,), {}) - - tracer = torch.fx.Tracer() - tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) - proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) - - from .index_propagation import IndexPropagation - from .sizevars import SimplifyIndexing - - handler: Any = SimplifyIndexing( - CaptureIndexing(proxy_ops), self.body.var_ranges - ) - if config.constant_and_index_propagation: - handler = IndexPropagation( - handler, self.body.var_ranges, self.body.indirect_var_ranges - ) - - with V.set_ops_handler(handler): - # This indirection is just a cute way to get IndexPropagation to - # unwrap the return value. - ops.output(fn(*args)) - self.graph = tracer.graph - - def __call__(self): - graph = self.graph - submodules = self.body.submodules - - return InterpreterShim(graph, submodules).run(V.get_ops_handler()) - - def debug_str(self, name="block"): - code = torch.fx.GraphModule(self.body.submodules, self.graph).code - return re.sub( - # strip `; del var0` suffixes to make output prettier - r";[^\n]*", - "", - code.strip().replace("def forward(", f"def {name}("), - ) - - def clone(self, body: LoopBody): - """Shallow copy with a new parent LoopBody""" - copy = LoopBodyBlock.__new__(LoopBodyBlock) - copy.__dict__.update({**self.__dict__, "body": body}) - return copy - - class _CollectiveKernel(FallbackKernel): def should_allocate(self): return False diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py new file mode 100644 index 00000000000000..78f043d14ac1d1 --- /dev/null +++ b/torch/_inductor/loop_body.py @@ -0,0 +1,519 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import itertools +import re +from typing import Any, Callable, Dict, List, Sequence, Tuple + +import sympy + +import torch.fx +from torch._dynamo.utils import identity +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import index_prevent_reordering +from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs +from .virtualized import ops, V + + +class InterpreterShim(torch.fx.Interpreter): + @staticmethod + @functools.lru_cache(None) + def _dummy_gm(): + return torch.fx.symbolic_trace(identity) + + def __init__(self, graph, submodules): + # call super() with a placeholder to avoid constructing a + # GraphModule which is very expensive (it does codegen). + super().__init__(self._dummy_gm(), garbage_collect_values=False) + self.module = self # type: ignore[assignment] + self.graph = graph + self.submodules = submodules + self.extra_traceback = False + self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] + self.current_node = None + + def run_node(self, n: torch.fx.Node) -> Any: + self.current_node = n + return super().run_node(n) + + def run(self, *args, **kwargs): + with V.set_interpreter_handler(self): + return super().run(*args, **kwargs) + + +class LoopBody: + """ + Captures the body of a Loops subclass into an FX graph. Persists any + indexing simplifications and makes it easier to analyze loop bodies. + """ + + indexing_exprs: Dict[str, sympy.Expr] + indexing_exprs_name: Dict[sympy.Expr, str] + reads_name2expr: Dict[str, sympy.Expr] + writes_name2expr: Dict[str, sympy.Expr] + submodules: Dict[str, Any] + subblocks: Dict[str, LoopBodyBlock] + indirect_vars: List[str] + indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] + root_block: LoopBodyBlock + + def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): + super().__init__() + + _flat_sizes = tuple(var_ranges.values()) + self.sizes = ( + _flat_sizes[: len(iter_vars)], + _flat_sizes[len(iter_vars) :], + ) + + self.iter_vars = iter_vars + self.reduce_vars = reduce_vars + self.var_ranges = var_ranges + + if isinstance(fn, LoopBody): + self._init_with_copy(fn, args) + else: + self._init_with_tracing(fn, args) + + self.indexing = None + + def _init_with_tracing(self, fn, args): + """Do an FX trace of an arbitrary callable to construct self""" + self.indexing_exprs = {} + self.indexing_exprs_name = {} + self.reads_name2expr = {} + self.writes_name2expr = {} + self.submodules = {"get_index": self.get_index} + self.subblocks = {} + self.indirect_vars = [] + self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} + self.root_block = LoopBodyBlock(self, fn, args) # traces + del self.indexing_exprs_name # not used after _init_with_tracing + + def _init_with_copy(self, other: LoopBody, args): + """ + _init_with_tracing() is slow, so this is a fast path in the case + where we are just reordering/merging/splitting the args of an + existing LoopBody. + """ + indexing_exprs = other.indexing_from_args(args) + update_index = { + expr: indexing_exprs[name] for name, expr in other.indexing_exprs.items() + } + assert indexing_exprs.keys() == other.indexing_exprs.keys() and len( + update_index + ) == len(indexing_exprs) + + self.indexing_exprs = indexing_exprs + self.reads_name2expr = { + k: update_index[v] for k, v in other.reads_name2expr.items() + } + self.writes_name2expr = { + k: update_index[v] for k, v in other.writes_name2expr.items() + } + self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} + self.indirect_vars = [*other.indirect_vars] + self.indirect_var_ranges = {**other.indirect_var_ranges} + self.root_block = other.root_block.clone(self) + + submodules = {**other.submodules} + submodules.pop("get_index") + self.submodules = { + "get_index": self.get_index, + **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] + } + + def merge_loops(self) -> LoopBody: + """ + Merge both iteration and reduction loops and return a new LoopBody. + """ + old_body = self + old_sizes = self.sizes + old_iter_vars, old_reduce_vars = old_body.vars + old_iter_sizes, old_reduce_sizes = old_sizes + + index_exprs = [*old_body.indexing_exprs.values()] + + iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( + old_iter_vars, + old_iter_sizes, + index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), + ) + + reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( + old_reduce_vars, + old_reduce_sizes, + index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), + ) + + # if iter_sizes == old_iter_sizes: + # # no dimensions get merged. + # return old_sizes, old_body + + # Note: if no dimension get merges, the symbol prefix will + # remain 'y'. But if we merge dimensions, we change prefix to + # 'z'. If this is an issue, we can always retrace the LoopBody + # to change symbol prefix to 'z'. + # + # There is indeed an issue due to symbol name conflicting. + # y0 maybe reused for the y dimension later. + ( + iter_vars, + reduce_vars, + ), var_ranges = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="t" + ) + new_body = LoopBody( + old_body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + + # use the original symbol prefix + # Can try to optimize if this is a bottleneck for compilation time + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="z" + ) + new_body2 = LoopBody( + new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body2 + + def reorder_iter_loops(self, new_order) -> LoopBody: + """ + Reorder iteration loops and return a new LoopBody. + """ + from .ir import same_reorder + + old_body = self + old_sizes = self.sizes + assert len(old_sizes[0]) == len(new_order) + reorder_fn = same_reorder(new_order) + + iter_size, reduce_size = old_sizes + new_iter_size = reorder_fn(iter_size) + + new_sizes = (new_iter_size, reduce_size) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="t" # type: ignore[arg-type] + ) + + inverse_order = {b: a for a, b in enumerate(new_order)} + inverse_order = [inverse_order[i] for i in range(len(new_order))] + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = list(itertools.chain(*indices)) + assert len(index) == len(iter_size) + len(reduce_size) + iter_idx = index[: len(iter_size)] + reduce_idx = index[len(iter_size) :] + iter_idx = [iter_idx[i] for i in inverse_order] + return old_body(iter_idx, reduce_idx) + + loop_body = LoopBody( + new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars + ) + + # use the original symbol prefix so we can do multiple round of reordering + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="z" # type: ignore[arg-type] + ) + new_body = LoopBody( + loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body + + @property + def vars(self): + assert self.iter_vars is not None + assert self.reduce_vars is not None + return self.iter_vars, self.reduce_vars + + @cache_on_self + def get_nodes(self): + all_graphs = itertools.chain( + (self.root_block.graph,), + (block.graph for block in self.subblocks.values()), + ) + return [node for graph in all_graphs for node in graph.nodes] + + @cache_on_self + def bounds(self): + # Doing a local import to avoid dumping all the code here + from .bounds import BoundVars + + return BoundVars(self) + + def debug_str(self): + lines = [f"var_ranges = {dict(self.var_ranges)}"] + lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) + lines.extend( + [ + block.debug_str(name) + for name, block in itertools.chain( + [("body", self.root_block)], self.subblocks.items() + ) + ] + ) + return "\n".join(lines) + + __repr__ = debug_str + + def add_index_expr(self, expr: sympy.Expr, category, buf_name): + if buf_name is not None: + getattr(self, f"{category}_name2expr")[buf_name] = expr + if expr not in self.indexing_exprs_name: + name = f"index{len(self.indexing_exprs)}" + self.indexing_exprs_name[expr] = name + self.indexing_exprs[name] = expr + return self.indexing_exprs_name[expr] + + def add_submodule(self, block, prefix): + """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" + if prefix[-1].isnumeric() and prefix not in self.submodules: + name = prefix + else: + name = f"{prefix}{len(self.submodules)}" + self.submodules[name] = block + return name + + def add_indirect(self, size): + var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) + assert var not in self.indirect_var_ranges + self.indirect_vars.append(var) + self.indirect_var_ranges[var] = size + return var + + def replace_indirect(self, old, new): + """Swap in a variable used in indirect indexing""" + if str(old) == str(new): + return + assert self.indexing is not None + self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} + + def get_index(self, name): + assert self.indexing is not None + return self.indexing[name] + + def indexing_from_args(self, indices): + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(self.var_ranges), (index, self.var_ranges) + assert all( + v not in self.var_ranges for v in index + ), f"{self.var_ranges=}, {indices=}" + replacements = dict(zip(self.var_ranges.keys(), index)) + return { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + + def __call__(self, *indices): + self.indexing = self.indexing_from_args(indices) + result = self.root_block() + self.indexing = None + return result + + def bind_set_indirect_shim(self, var, size, check, wrap_neg): + def set_indirect(new_var): + self.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) + ) + + set_indirect.clone = functools.partial( # type: ignore[attr-defined] + LoopBody.bind_set_indirect_shim, + var=var, + size=size, + check=check, + wrap_neg=wrap_neg, + ) + return set_indirect + + def bind_scan_shim(self, combine_fn): + def shim(dtypes, values): + return V.ops.scan(dtypes, combine_fn, values) + + shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] + return shim + + def bind_masked_shim(self, name): + def shim(mask, other): + return V.ops.masked(mask, self.subblocks[name], other) + + shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] + return shim + + +class LoopBodyBlock: + """ + Captures the body of a Loops subclass into an FX graph. + In normal cases there will be a 1:1 mapping between LoopBody and + LoopBodyBlock, hower in the case of ops.masked() the masked out + operations will manifest as an extra LoopBodyBlock. + """ + + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): + self.body = body + + def add_index(expr, category, buf_name=None): + return tracer.create_proxy( + "call_module", + "get_index", + (self.body.add_index_expr(expr, category, buf_name),), + {}, + ) + + class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] + self.name = "CaptureIndexing" + + def load(self, name: str, index: sympy.Expr): + index = add_index(index, "reads", name) + return self._inner.load(name, index) + + def store(self, name, index, value, mode=None): + index = add_index(index, "writes", name) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = add_index(index, "writes", name) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = add_index(index, "other") + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = add_index(index, "other") + size = add_index(size, "other") + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + offsets_size = add_index(offsets_size, "other") + return self._inner.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + @staticmethod + def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + @staticmethod + def scan( + dtype_proxy, + combine_fn: Callable[ + [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...] + ], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + @staticmethod + def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim( + var, size, check, wrap_neg + ) + tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + @staticmethod + def output(result): + tracer.create_proxy("output", "output", (result,), {}) + + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) + + from .index_propagation import IndexPropagation + from .sizevars import SimplifyIndexing + + handler: Any = SimplifyIndexing( + CaptureIndexing(proxy_ops), self.body.var_ranges + ) + if config.constant_and_index_propagation: + handler = IndexPropagation( + handler, self.body.var_ranges, self.body.indirect_var_ranges + ) + + with V.set_ops_handler(handler): + # This indirection is just a cute way to get IndexPropagation to + # unwrap the return value. + ops.output(fn(*args)) + self.graph = tracer.graph + + def __call__(self): + graph = self.graph + submodules = self.body.submodules + + return InterpreterShim(graph, submodules).run(V.get_ops_handler()) + + def debug_str(self, name="block"): + code = torch.fx.GraphModule(self.body.submodules, self.graph).code + return re.sub( + # strip `; del var0` suffixes to make output prettier + r";[^\n]*", + "", + code.strip().replace("def forward(", f"def {name}("), + ) + + def clone(self, body: LoopBody): + """Shallow copy with a new parent LoopBody""" + copy = LoopBodyBlock.__new__(LoopBodyBlock) + copy.__dict__.update({**self.__dict__, "body": body}) + return copy diff --git a/torch/_inductor/optimize_indexing.py b/torch/_inductor/optimize_indexing.py index c333bd882942e5..96bf8641f3c9a6 100644 --- a/torch/_inductor/optimize_indexing.py +++ b/torch/_inductor/optimize_indexing.py @@ -6,7 +6,7 @@ import torch from torch.utils._sympy.value_ranges import ValueRanges -from .ir import LoopBody +from .loop_body import LoopBody from .utils import dominated_nodes diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5af3978c41b934..9e991fc26be605 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -46,6 +46,7 @@ from .comm_analysis import estimate_nccl_collective_runtime from .dependencies import Dep, MemoryDep, StarDep, WeakDep from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout +from .loop_body import LoopBody from .runtime.runtime_utils import green_text, red_text from .sizevars import SimplifyIndexing from .utils import ( @@ -922,7 +923,7 @@ def debug_str_extra(self) -> str: buf_name = dep.name buf = V.graph.get_buffer(buf_name) lines.append(f"{buf_name}_layout = {pformat(buf.layout)}") - if isinstance(self._body, ir.LoopBody): + if isinstance(self._body, LoopBody): lines.append(f"class {name}_loop_body:") lines.append(textwrap.indent(self._body.debug_str(), " ")) @@ -1011,7 +1012,7 @@ def can_inplace(self, read_dep: dependencies.Dep) -> bool: @cache_on_self def _get_atomic_add_buffers(self) -> OrderedSet[str]: buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet() - if isinstance(self._body, ir.LoopBody): + if isinstance(self._body, LoopBody): for node in self._body.get_nodes(): if ( node.op == "call_method" diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index b47cd97c1667c0..b00a94c5f2cee3 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -76,7 +76,7 @@ from torch._inductor.codegen.cpp_utils import LocalBufferContext from torch._inductor.debug import DebugContext from torch._inductor.graph import GraphLowering - from torch._inductor.ir import InterpreterShim + from torch._inductor.loop_body import InterpreterShim from torch._subclasses import FakeTensorMode threadlocal = local() From c96c0b2ede79fd1faba4eb65da742f04adea3eb6 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 6 Sep 2024 11:40:31 -0700 Subject: [PATCH 0602/1018] [inductor] Catch BrokenProcessPool and print a more helpful message. (#135120) Summary: BrokenProcessPool means a parallel-compile subprocess exited, which we never expect. It's likely due to a crash, so print a more meaningful error message and instructions that it's probably easier to debug by turning off parallel compile. Output looks like: ``` ... File "/data/users/slarsen/pytorch/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module exec(code, mod.__dict__, mod.__dict__) File "/tmp/torchinductor_slarsen/4q/c4qw7xk5lbb7whg5txnk4hwbc7z6kepak3o666tr3d64gcad5r5b.py", line 815, in async_compile.wait(globals()) File "/data/users/slarsen/pytorch/torch/_inductor/async_compile.py", line 265, in wait raise RuntimeError( RuntimeError: A compilation subprocess exited unexpectedly. This is likely due to a crash. To facilitate debugging, you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 to cause compilation to occur in the main process. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135120 Approved by: https://github.com/Chillee --- torch/_inductor/async_compile.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index ce265460958439..0794ceb3eed571 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -7,6 +7,7 @@ import os import sys from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures.process import BrokenProcessPool from functools import partial from time import time from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING @@ -271,7 +272,15 @@ def wait(self, scope: Dict[str, Any]) -> None: if config.verbose_progress and not isinstance(pbar, _Faketqdm): pbar.set_postfix_str(key) if isinstance(result, (Future, CodeCacheFuture)): - scope[key] = result.result() + try: + scope[key] = result.result() + except BrokenProcessPool as e: + raise RuntimeError( + "A compilation subprocess exited unexpectedly. This " + "is likely due to a crash. To facilitate debugging, " + "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 " + "to cause compilation to occur in the main process." + ) from e pbar.update(1) _compile_end() From a14af6b69f678f51a4636e5f460edce043710b16 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 6 Sep 2024 15:35:07 -0700 Subject: [PATCH 0603/1018] [FlexAttention] Align the matmul tensorcore usage (#135168) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135168 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 49 +++++++++++++++++++++++- torch/_inductor/kernel/flex_attention.py | 27 ++++++++----- 2 files changed, 66 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 9b52048f1dbf37..06b012348d9f88 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -4,7 +4,7 @@ import functools import string from collections import namedtuple -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from typing import Callable, Optional from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch @@ -57,6 +57,22 @@ Tensor = torch.Tensor +@contextmanager +def temp_float32_matmul_precision(precision: str): + """ + Temporarily set the float32 matmul precision and restore it after the context is exited. + + Args: + precision (str): The precision to set ('highest', 'high', or 'medium'). + """ + original_precision = torch.get_float32_matmul_precision() + try: + torch.set_float32_matmul_precision(precision) + yield + finally: + torch.set_float32_matmul_precision(original_precision) + + def rmse(ref, res): """ Calculate root mean squared error @@ -1470,6 +1486,37 @@ def test_differentiable_logsumexp_compiled(self): v_grad, v_grad2, atol=tolerance.atol, rtol=tolerance.rtol ) + @supported_platform + def test_float32_matmul_precision(self): + make_tensor = functools.partial( + torch.zeros, + (2, 2, 128, 32), + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + query.fill_(0.2) + key.fill_(0.3) + value.fill_(0.4) + + query.requires_grad = True + key.requires_grad = True + value.requires_grad = True + + def score_mod(score, b, h, q, kv): + return score * 2 + + with temp_float32_matmul_precision("highest"): + out_eager = flex_attention(query, key, value, score_mod) + flex_compiled = torch.compile(flex_attention, fullgraph=True) + out_compiled = flex_compiled(query, key, value, score_mod) + + grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value)) + grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value)) + + torch.testing.assert_close(grads_eager, grads_compile) + @supported_platform @common_utils.parametrize("score_mod_name", ["_head_offset"]) @common_utils.parametrize("mode", ["eager", "aot_eager"]) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 24cd489c162dbd..4c1c4de3338cf0 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -55,6 +55,13 @@ def maybe_realize(args: List[Optional[IRNode]]): return tree_map(lambda x: realize_inputs(x) if x is not None else None, args) +def get_float32_precision(): + if torch.get_float32_matmul_precision() == "highest": + return "'ieee'" + else: + return "'tf32'" + + def build_subgraph_buffer( args: List[TensorBox], subgraph: Subgraph, @@ -429,7 +436,7 @@ def forward_block_mn( else: k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = "zero") # -- compute qk --- - qk = tl.dot(q, k) # TODO: use cuda matmul when q_len <= 2. + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. if not PRESCALE_QK: qk *= SM_SCALE # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ @@ -501,7 +508,7 @@ def forward_block_mn( v = tl.load(V_block_ptr) else: v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = "zero") - acc = tl.dot(p.to(MATMUL_PRECISION), v, acc) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) # -- update m_i m_i = m_ij @@ -689,6 +696,7 @@ def flex_attention( mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) kernel_options = dict(kernel_options) + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) if _use_flex_decoding(query, kernel_options): return create_flex_decoding_kernel( query, @@ -1270,7 +1278,7 @@ def bwd_dq_block_mn( kT = tl.load(kT_ptrs) else: kT = tl.load(kT_ptrs, mask=offs_n2[None, :] < KV_LEN) - qk = tl.dot(q, kT) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) if not PRESCALE_QK: qk *= SM_SCALE # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ @@ -1320,7 +1328,7 @@ def bwd_dq_block_mn( vT = tl.load(vT_ptrs) else: vT = tl.load(vT_ptrs, mask=offs_n2[None, :] < KV_LEN) - dp = tl.dot(do, vT) + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) ds = p * (dp - Di[:, None]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ {{ modification( @@ -1346,7 +1354,7 @@ def bwd_dq_block_mn( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ds = ds.to(MATMUL_PRECISION) # Compute dQ. - dq += tl.dot(ds, tl.trans(kT)) + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) return dq @@ -1454,7 +1462,7 @@ def bwd_dkdv_block_mn( qT = tl.load(qT_ptrs, mask=offs_m1[None, :] < Q_LEN) lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) lse = tl.where(lse == -float("inf"), 0.0, lse) - qkT = tl.dot(k, qT) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) if not PRESCALE_QK: qkT *= SM_SCALE # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ @@ -1504,13 +1512,13 @@ def bwd_dkdv_block_mn( do = tl.load(do_ptrs, mask=offs_m1[:, None] < Q_LEN) # Compute dV. ppT = pT - dv += tl.dot(ppT.to(MATMUL_PRECISION), do) + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) if IS_DIVISIBLE: Di = tl.load(DELTA + offs_m1) else: Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)) + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) dsT = pT * (dpT - Di[None, :]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ {{ modification( @@ -1533,7 +1541,7 @@ def bwd_dkdv_block_mn( # (grads) apply mask for partially unmasked block dsT = tl.where(mask_mod_output, dsT, 0.0) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) return dk, dv """ @@ -1614,6 +1622,7 @@ def flex_attention_backward(*args, **kwargs): B = Bq kernel_options = dict(kernel_options) + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: kernel_options.setdefault("IS_DIVISIBLE", False) else: From f5f6cdeefe71295ebf816f97364747889cf4ddd8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 6 Sep 2024 22:21:46 -0700 Subject: [PATCH 0604/1018] [Dynamo] Fix Huggingface PretrainedConfig get non const attr (#135413) Fixes #135329 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135413 Approved by: https://github.com/anijain2305 --- test/dynamo/test_model_output.py | 15 +++++++++++++++ torch/_dynamo/variables/dicts.py | 10 ++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index 4dfa93bdc14911..e9eb78254ce067 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -47,6 +47,21 @@ def fn(a, tmp): res = opt_fn(x, tmp) self.assertTrue(same(ref, res)) + @maybe_skip + def test_pretrained_non_const_attr(self): + def fn(a, tmp): + if tmp.pruned_heads: + return a + 1 + else: + return a - 1 + + x = torch.randn(2) + tmp = PretrainedConfig() + ref = fn(x, tmp) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + res = opt_fn(x, tmp) + self.assertTrue(same(ref, res)) + class TestModelOutput(torch._dynamo.test_case.TestCase): @maybe_skip diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 10b55c1b7e96eb..e3ca0b43c0a6a7 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -952,9 +952,15 @@ def __init__(self, obj, **kwargs) -> None: assert self.is_matching_cls(type(obj)) def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - from . import ConstantVariable + from .builder import VariableBuilder + + try: + attr_value = getattr(self.obj, name) + attr_source = AttrSource(self.source, name) + return VariableBuilder(tx, attr_source)(attr_value) - return ConstantVariable.create(getattr(self.obj, name)) + except AttributeError: + unimplemented(f"getattr({self.value}, {name})") def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": return variables.ConstantVariable.create(hasattr(self.obj, name)) From 97a42884a437be043c3a792d8d4c0e01cf312413 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Sat, 7 Sep 2024 16:53:27 -0700 Subject: [PATCH 0605/1018] remove commented out breakpoints (#135363) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135363 Approved by: https://github.com/oulgen --- test/dynamo/test_sources.py | 1 - test/torch_np/check_tests_conform.py | 1 - 2 files changed, 2 deletions(-) diff --git a/test/dynamo/test_sources.py b/test/dynamo/test_sources.py index 48646ac44c5ef3..0f2f7ded33fea5 100644 --- a/test/dynamo/test_sources.py +++ b/test/dynamo/test_sources.py @@ -72,7 +72,6 @@ def forward(self): lambda x, _: CausalLMOutputWithPast(), ) - # breakpoint() torch.export.export(Model(), ()) diff --git a/test/torch_np/check_tests_conform.py b/test/torch_np/check_tests_conform.py index 05ff5357b7c7c2..c1795bfdc0f8d9 100644 --- a/test/torch_np/check_tests_conform.py +++ b/test/torch_np/check_tests_conform.py @@ -43,7 +43,6 @@ def check(path): else: report_violation(line, num, "off-class parametrize") if not src[nn - 1].startswith("@instantiate_parametrized_tests"): - # breakpoint() report_violation( line, num, f"missing instantiation of parametrized tests in {ln}?" ) From e245b0a7aa0853621aec1e71aed98da51cc949bd Mon Sep 17 00:00:00 2001 From: Huamin Li Date: Sun, 8 Sep 2024 04:16:24 +0000 Subject: [PATCH 0606/1018] Change wrapped_linear_prepack and wrapped_quantized_linear_prepacked to private by adding _ as prefix (#135401) Summary: In https://github.com/pytorch/pytorch/pull/134232, we added two new ops wrapped_linear_prepack and wrapped_quantized_linear_prepacked. From the review comments and offline discussion, we are changing them to private by adding `_` as prefix Differential Revision: D62325142 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135401 Approved by: https://github.com/houseroad --- aten/src/ATen/native/native_functions.yaml | 4 +-- .../native/quantized/cpu/qlinear_prepack.cpp | 32 +++++++++---------- aten/src/ATen/native/quantized/library.cpp | 4 +-- .../check_forward_backward_compatibility.py | 5 +++ test/quantization/core/test_quantized_op.py | 8 ++--- torch/_inductor/decomposition.py | 4 +-- torch/csrc/inductor/aoti_torch/c/shim.h | 4 +-- .../csrc/inductor/aoti_torch/shim_common.cpp | 12 +++---- torch/overrides.py | 4 +-- 9 files changed, 41 insertions(+), 36 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 683b38e5f20ab4..9342da626beb03 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3400,9 +3400,9 @@ - func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor -- func: wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor +- func: _wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor -- func: wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor +- func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor - func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index 28fcab13ff3e06..f4c55b2a3cfe4f 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -436,12 +436,12 @@ at::Tensor wrapped_quantized_linear_meta( #endif // USE_FBGEMM } -at::Tensor wrapped_linear_prepack(const at::Tensor& weight, +at::Tensor _wrapped_linear_prepack(const at::Tensor& weight, const at::Tensor& weight_scale, const at::Tensor& weight_zero_point, const at::Tensor& bias); -at::Tensor wrapped_linear_prepack(const at::Tensor& weight, +at::Tensor _wrapped_linear_prepack(const at::Tensor& weight, const at::Tensor& weight_scale, const at::Tensor& weight_zero_point, const at::Tensor& bias) { @@ -474,14 +474,14 @@ at::Tensor wrapped_linear_prepack(const at::Tensor& weight, #endif // USE_FBGEMM } -at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale, +at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale, const at::Tensor& input_zero_point, const at::Tensor& packed_weight, const at::Tensor& output_scale, const at::Tensor& output_zero_point, [[maybe_unused]] const int64_t out_channel); -at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale, +at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale, const at::Tensor& input_zero_point, const at::Tensor& packed_weight, const at::Tensor& output_scale, @@ -507,12 +507,12 @@ at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at: #endif // USE_FBGEMM } -at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight, +at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight, [[maybe_unused]] const at::Tensor& weight_scale, [[maybe_unused]] const at::Tensor& weight_zero_point, [[maybe_unused]] const at::Tensor& bias); -at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight, +at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight, [[maybe_unused]] const at::Tensor& weight_scale, [[maybe_unused]] const at::Tensor& weight_zero_point, [[maybe_unused]] const at::Tensor& bias) { @@ -530,7 +530,7 @@ at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight, #endif // USE_FBGEMM } -at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, +at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, [[maybe_unused]] const at::Tensor& input_scale, [[maybe_unused]] const at::Tensor& input_zero_point, [[maybe_unused]] const at::Tensor& packed_weight, @@ -538,7 +538,7 @@ at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, [[maybe_unused]] const at::Tensor& output_zero_point, const int64_t out_channel); -at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, +at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, [[maybe_unused]] const at::Tensor& input_scale, [[maybe_unused]] const at::Tensor& input_zero_point, [[maybe_unused]] const at::Tensor& packed_weight, @@ -695,21 +695,21 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run)); m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear)); m.impl( - TORCH_SELECTIVE_NAME("_quantized::wrapped_linear_prepack"), - wrapped_linear_prepack); + TORCH_SELECTIVE_NAME("_quantized::_wrapped_linear_prepack"), + _wrapped_linear_prepack); m.impl( - TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear_prepacked"), - wrapped_quantized_linear_prepacked); + TORCH_SELECTIVE_NAME("_quantized::_wrapped_quantized_linear_prepacked"), + _wrapped_quantized_linear_prepacked); } TORCH_LIBRARY_IMPL(_quantized, Meta, m) { m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear_meta)); m.impl( - TORCH_SELECTIVE_NAME("_quantized::wrapped_linear_prepack"), - wrapped_linear_prepack_meta); + TORCH_SELECTIVE_NAME("_quantized::_wrapped_linear_prepack"), + _wrapped_linear_prepack_meta); m.impl( - TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear_prepacked"), - wrapped_quantized_linear_prepacked_meta); + TORCH_SELECTIVE_NAME("_quantized::_wrapped_quantized_linear_prepacked"), + _wrapped_quantized_linear_prepacked_meta); } TORCH_LIBRARY_IMPL(onednn, CPU, m) { diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index be2dc9eb44f485..7780953a94dc3f 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -251,8 +251,8 @@ TORCH_LIBRARY(_quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y")); - m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y")); } TORCH_LIBRARY(onednn, m) { diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 002fd7691b6cc4..8c438bc2e4fc72 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -145,6 +145,11 @@ ("onednn::qlinear_pointwise.binary_tensor", datetime.date(2024, 12, 31)), ("aten::_scaled_mm.out", datetime.date(2024, 12, 31)), ("aten::_scaled_mm", datetime.date(2024, 12, 31)), + ("aten::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)), + ("aten::wrapped_linear_prepack", datetime.date(2024, 12, 31)), + ("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)), + ("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)), + ("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)), # BC-breaking change in can_cast signature: 'from' -> 'from_' ("aten::can_cast", datetime.date(2024, 5, 31)), ] diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index b432f770a0abb7..999b63dbbb0a65 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4223,8 +4223,8 @@ def test_wrapped_quantized_linear(self, m, n, k): ret_ref = qlinear.dequantize() self.assertEqual(ret, ret_ref) - """Tests the correctness of the _quantized::wrapped_linear_prepack and - _quantized::wrapped_quantized_linear_prepacked ops.""" + """Tests the correctness of the _quantized::_wrapped_linear_prepack and + _quantized::_wrapped_quantized_linear_prepacked ops.""" @skipIfNoFBGEMM @given( m=st.integers(2, 6), @@ -4243,13 +4243,13 @@ def test_wrapped_quantized_linear_prepacked(self, m, n, k): output_zero_point = torch.tensor(0) out_channel = n - ret_1 = torch.ops._quantized.wrapped_linear_prepack( + ret_1 = torch.ops._quantized._wrapped_linear_prepack( weight, weight_scale, weight_zero_point, bias ) - ret_2 = torch.ops._quantized.wrapped_quantized_linear_prepacked( + ret_2 = torch.ops._quantized._wrapped_quantized_linear_prepacked( input, input_scale, input_zero_point, diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 105a1e96f3da59..0e067395f80712 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -651,10 +651,10 @@ def wrapped_quantized_linear( out_zero_point: torch.Tensor, out_channel: int, ) -> torch.Tensor: - packed_weight = torch.ops._quantized.wrapped_linear_prepack( + packed_weight = torch.ops._quantized._wrapped_linear_prepack( weight, weight_scale, weight_zero_point, bias ) - return torch.ops._quantized.wrapped_quantized_linear_prepacked( + return torch.ops._quantized._wrapped_quantized_linear_prepacked( input, input_scale, input_zero_point, diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index e6e1c6aa770890..49b16504e6ce2e 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -493,7 +493,7 @@ aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16( // This will soon be deprecated after ao_quantization is complete. // Please refrain from using this or increasing callsites. -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_wrapped_linear_prepack( +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( AtenTensorHandle weight, AtenTensorHandle weight_scale, AtenTensorHandle weight_zero_point, @@ -513,7 +513,7 @@ aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( // This will soon be deprecated after ao_quantization is complete. // Please refrain from using this or increasing callsites. AOTI_TORCH_EXPORT AOTITorchError -aoti_torch_cpu_wrapped_quantized_linear_prepacked( +aoti_torch_cpu__wrapped_quantized_linear_prepacked( AtenTensorHandle input, AtenTensorHandle input_scale, AtenTensorHandle input_zero_point, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 4cf5dd3ae228a6..434da9ad7d9a32 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include #include @@ -42,8 +44,6 @@ #include #include #include -#include -#include #endif @@ -814,7 +814,7 @@ AOTITorchError aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16( }); } -AOTITorchError aoti_torch_cpu_wrapped_linear_prepack( +AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( AtenTensorHandle weight, AtenTensorHandle weight_scale, AtenTensorHandle weight_zero_point, @@ -828,7 +828,7 @@ AOTITorchError aoti_torch_cpu_wrapped_linear_prepack( tensor_handle_to_tensor_pointer(weight_zero_point); at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias); - *out = new_tensor_handle(at::wrapped_linear_prepack( + *out = new_tensor_handle(at::_wrapped_linear_prepack( *weight_tensor, *weight_scale_tensor, *weight_zero_point_tensor, @@ -852,7 +852,7 @@ AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( }); } -AOTITorchError aoti_torch_cpu_wrapped_quantized_linear_prepacked( +AOTITorchError aoti_torch_cpu__wrapped_quantized_linear_prepacked( AtenTensorHandle input, AtenTensorHandle input_scale, AtenTensorHandle input_zero_point, @@ -871,7 +871,7 @@ AOTITorchError aoti_torch_cpu_wrapped_quantized_linear_prepacked( at::Tensor* out_scale_tensor = tensor_handle_to_tensor_pointer(out_scale); at::Tensor* out_zeropoint_tensor = tensor_handle_to_tensor_pointer(out_zeropoint); - *out = new_tensor_handle(at::wrapped_quantized_linear_prepacked( + *out = new_tensor_handle(at::_wrapped_quantized_linear_prepacked( *input_tensor, *input_scale_tensor, *input_zero_point_tensor, diff --git a/torch/overrides.py b/torch/overrides.py index 91b383f32483a6..7a568d7e22c1cd 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1252,8 +1252,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.vsplit: lambda input, indices_or_sections: -1, torch.vstack: lambda tensors, out=None: -1, torch.where: lambda condition, x=None, y=None: -1, - torch.wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1, - torch.wrapped_quantized_linear_prepacked: ( + torch._wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1, + torch._wrapped_quantized_linear_prepacked: ( lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950 ), torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, From c6d33a15852292933c49da1c5e87ef0917588873 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 8 Sep 2024 05:30:34 +0000 Subject: [PATCH 0607/1018] [ONNX] Re-raise the exception if the dynamic shapes cannot be refined (#135418) Improve error reporting. Otherwise users will just see not being able to refine shapes most of the time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135418 Approved by: https://github.com/titaiwangms --- .../_internal/exporter/_capture_strategies.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index 8b908bc35e9346..4cec92854ea858 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -126,11 +126,13 @@ def _capture( ) except torch._dynamo.exc.UserError as exc: # Refine the dynamic shapes based on the suggested fixes. - new_shapes = ( - torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + try: + new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( exc.msg, dynamic_shapes ) - ) + except Exception: + # If the dynamic shapes cannot be refined, re-raise the exception. + raise exc from None return torch.export.export( model, args, kwargs=kwargs, dynamic_shapes=new_shapes ) @@ -165,11 +167,13 @@ def _capture( ) except torch._dynamo.exc.UserError as exc: # Refine the dynamic shapes based on the suggested fixes. - new_shapes = ( - torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + try: + new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( exc.msg, dynamic_shapes ) - ) + except Exception: + # If the dynamic shapes cannot be refined, re-raise the exception. + raise exc from None return torch.export.export( model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False ) From ca82d634fa65b461ce8afb96939695b390fd4145 Mon Sep 17 00:00:00 2001 From: CaoE Date: Fri, 6 Sep 2024 21:42:54 -0700 Subject: [PATCH 0608/1018] Use BRGEMM for Half flash attention forward kernel (#131879) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use oneDNN BRGEMM on packed data to get better performance on the 5th generation of Xeon where Intel® Advanced Matrix Extensions (AMX) will have fp16 support, e.g. amx-fp16. Multiple models have achieved acceleration, for instance, FP16 stable diffusion v2.1 has achieved over 50% improvement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131879 Approved by: https://github.com/jgong5, https://github.com/peterbell10 ghstack dependencies: #131878 --- aten/src/ATen/native/CPUBlas.cpp | 16 - aten/src/ATen/native/CPUBlas.h | 18 - .../ATen/native/cpu/FlashAttentionKernel.cpp | 395 ++++++++++++++++-- aten/src/ATen/native/cpu/utils.h | 6 + test/test_transformers.py | 4 +- 5 files changed, 371 insertions(+), 68 deletions(-) diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 689d3bacd5a571..0a2789f4caff8b 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -1149,22 +1149,6 @@ void brgemm( "Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported"); } -void brgemm( - int64_t M, - int64_t N, - int64_t K, - int64_t ld_a, - int64_t ld_b, - int64_t ld_c, - const float alpha, - const float beta, - const at::BFloat16* A, - const at::BFloat16* B, - float* C) { - TORCH_CHECK(false, - "BFloat16 Brgemm is currently not supported"); -} - void brgemm_release() { #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) dnnl::ukernel::brgemm::release_hw_context(); diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 01ec49e4b54e00..ad209329c95ecc 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -192,12 +192,7 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex #endif - namespace at::native { namespace { @@ -202,7 +201,97 @@ void reshape_attn_mask_to_4d( .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); } -template +template +inline void copy_value_with_pad( + const scalar_t* value_ptr, + scalar_t* dst_ptr, + int64_t rows, + int64_t cols, + int64_t prows, + int64_t pcols, + int64_t ldi) { + auto vec_size = at::vec::Vectorized::size(); + int64_t i = 0; + for (; i < rows; i++) { + int64_t j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + auto zero_vec = at::vec::Vectorized(0); + int64_t pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + zero_vec.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + // row padding + for (; i < prows; i++) { + auto zero_vec = at::vec::Vectorized(0); + int64_t j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + zero_vec.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + zero_vec.store(dst_ptr + i * pcols + j, pcols - j); + } + + } +} + +template +inline void pad_remain_row_col_zero( + scalar_t* value_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi) { + auto psize = pcols - cols; + if (psize == 0 && prows == rows) { + return; + } + auto vec_size = at::vec::Vectorized::size(); + auto zero = at::vec::Vectorized(0); + if (psize > 0) { + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < psize - (psize % vec_size); j += vec_size) { + zero.store(value_ptr + i * ldi + cols + j); + } + if (j < psize) { + zero.store(value_ptr + i * ldi + cols + j, psize - j); + } + } + } + + for (int i = rows; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + zero.store(value_ptr + i * ldi + j); + } + if (j < pcols) { + zero.store(value_ptr + i * ldi + j, pcols - j); + } + } + +} + +template void cpu_flash_attention( const Tensor& output, const Tensor& logsumexp, @@ -278,21 +367,70 @@ void cpu_flash_attention( int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; - int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize; + int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; int64_t num_thread = at::get_num_threads(); const auto dtype = query.scalar_type(); const auto accumulate_dtype = toOpMathType(dtype); + // Whether pack is needed + bool need_pack = false; + // Block size of packing B matrix + int64_t packb_size = 64; + // Use packb_size due to the limitation: + // oneDNN pack only supports output leading dimention being one of (16, 32, 48, 64) + // For instance, + // for q @ k.T [qSplitSize, headSize] * [headSize, kvSplitSize] = [qSplitSize, kvSplitSize], + // we need to split kvSplitSize with packb_size for packing k.T, + // for (q @ k.T) @ v [qSplitSize, kvSplitSize] x [kvSplitSize, headSize] -> [qSplitSize, headSize], + // we need to split headSize with packb_size for packing v + // TODO Simplify the check when oneDNN supports fused pack with transpose and has better performance + if (with_pack) { + need_pack = num_head >= 4 && headSize % packb_size == 0 && kvSize >= packb_size; + if (need_pack) { + float pack_size = batchSize * num_head * kvSize * headSize / 1024; + float gemm_size_per_thread = + (batchSize * num_head * qSlice + num_thread - 1) / num_thread * + qSplitSize * (is_causal ? qSize : kvSize) * headSize / 1024; + float gsize = gemm_size_per_thread / pack_size; + // When the number of gemm is much greater than the number of pack, + // the pack and padding overhead can be overlaped. + if (pack_size < 2688) { + need_pack = gsize >= 36 || (gsize >= 24 && headSize > packb_size); + } else if (pack_size < 16384) { + need_pack = gsize >= (is_causal ? 54 : 52); + } else { + need_pack = gsize >= (is_causal ? 54 : 40); + } + } + } + + int64_t rHeadSize = need_pack ? (headSize + packb_size - 1) / packb_size * packb_size : headSize; + int64_t rkvSplitSize = need_pack ? (kvSplitSize + packb_size - 1) / packb_size * packb_size : kvSplitSize; + int64_t rkvTail = need_pack ? (kvTail + packb_size - 1) / packb_size * packb_size : kvTail; + int64_t rkvSize = kv_split_size > kvSize ? rkvTail : rkvSplitSize * kvSlice + rkvTail; + + // oneDNN pack does not support odd K now, we need also pad odd K + bool headSize_even = headSize % 2 == 0; + int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize; + int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize; + int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail; + // allocate per thread temp buf (accumulate type) int64_t size_per_thread = - /* qk */ qSplitSize * kvSplitSize + + /* qk */ qSplitSize * rkvSplitSize + /* qk_max */ qSplitSize + /* qk_sum */ qSplitSize + - /* dst */ qSplitSize * headSize; + /* dst */ qSplitSize * rHeadSize; at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype)); - at::Tensor buf_reduced = at::empty({num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, query.options()); + at::Tensor buf_reduced = at::empty( + {num_thread, + qSplitSize, + is_reduced_type ? ekvSplitSize : 0}, + query.options()); // Data ptrs const scalar_t* q_data = query.const_data_ptr(); @@ -306,16 +444,128 @@ void cpu_flash_attention( accum_t* buf_data = buf.data_ptr(); scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr() : nullptr; + // Buffer to store padding query + scalar_t* query_padding_ptr = nullptr; + std::unique_ptr query_padding_data; + if (!headSize_even && need_pack) { + query_padding_data = std::make_unique(num_thread * qSplitSize * eheadSize); + query_padding_ptr = query_padding_data.get(); + } + // Buffer to store Key and Value after transforms + scalar_t* key_reorder_ptr = nullptr; + std::unique_ptr key_reorder_data; + scalar_t* value_reorder_ptr = nullptr; + std::unique_ptr value_reorder_data; + int kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail; + if (need_pack) { + key_reorder_data = std::make_unique(batchSize * num_head * eheadSize * rkvSize); + key_reorder_ptr = key_reorder_data.get(); + value_reorder_data = std::make_unique(batchSize * num_head * kv_padding_size * rHeadSize); + value_reorder_ptr = value_reorder_data.get(); + } + + // Reorder K, V + if (need_pack) { + at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, l = 0, n = 0; + at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice); + std::unique_ptr transpose_buffer = std::make_unique(eheadSize * packb_size); + scalar_t* transpose_buffer_ptr = transpose_buffer.get(); + std::unique_ptr v_copy_buffer = std::make_unique(ekvSplitSize * packb_size); + scalar_t* v_copy_buffer_ptr = v_copy_buffer.get(); + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + int64_t ekvBlockSize = kvBlockSize % 2 == 0 ? kvBlockSize : kvBlockSize + 1; + + // Split kvSplitSize with packb_size + // [kvSplitSize, headSize] -> [div_up(kvSplitSize, packb_size), packb_size, headSize] + // Transpose [packb_size, headSize] -> [headSize, packb_size] + // Pack transposed buffer + + for (int64_t b = 0; b < kvBlockSize; b += packb_size) { + bool tail = kvBlockSize - b < packb_size; + // TODO Use fused pack with transpose support when oneDNN supports such usage + utils::transpose( + tail ? kvBlockSize - b : packb_size, + headSize, + /* src_ptr */ + reinterpret_cast( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + + b * kStrideN), + /* ld_src */ kStrideN, + /* dst */ reinterpret_cast(transpose_buffer_ptr), + /* ld_dst */ packb_size); + // Pad [headSize, x] -> [eheadSize, x] + if (!headSize_even) { + pad_remain_row_col_zero( + transpose_buffer_ptr, + headSize, + packb_size, + eheadSize, + packb_size, + packb_size); + } + // Pack + cpublas::pack( + /* K */ eheadSize, + /* N */ packb_size, + /* ld_in */ packb_size, + /* ld_out */ packb_size, + /* dt_in */ dtype, + /* dt_out */ dtype, + transpose_buffer_ptr, + key_reorder_ptr + i * num_head * eheadSize * rkvSize + + j * eheadSize * rkvSize + n * eheadSize + b * eheadSize); + } + + // Split headSize with packb_size + // [kvSplitSize, headSize] -> [kvSplitSize, div_up(headSize, packb_size), packb_size] + for (int64_t b = 0; b < headSize; b += packb_size) { + // Do copy due to the limitation of input_ld of oneDNN pack: + // Regarding packing [K, N], only input_ld == N is supported + // TODO: remove the copy when pack supports input_ld >= N + copy_value_with_pad( + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + v_copy_buffer_ptr, + kvBlockSize, + (headSize - b < packb_size) ? headSize - b : packb_size, + ekvBlockSize, + packb_size, + vStrideN); + cpublas::pack( + ekvBlockSize, + packb_size, + packb_size, + packb_size, + dtype, + dtype, + v_copy_buffer_ptr, + value_reorder_ptr + + i * num_head * kv_padding_size * rHeadSize + + j * kv_padding_size * rHeadSize + n * rHeadSize + + ekvBlockSize * b); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); + } + }); + } + at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { int64_t i = 0, j = 0, k = 0; data_index_init(begin, i, batchSize, j, num_head, k, qSlice); int ompIdx = at::get_thread_num(); accum_t* buf_ptr = buf_data + ompIdx * size_per_thread; accum_t* qk_data = buf_ptr; - accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize; + accum_t* qk_max_data = qk_data + qSplitSize * rkvSplitSize; accum_t* qk_sum_data = qk_max_data + qSplitSize; accum_t* dst_data = qk_sum_data + qSplitSize; - scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize : nullptr; + scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize : nullptr; + scalar_t* query_t_padding_ptr = (!headSize_even && need_pack) + ? query_padding_ptr + ompIdx * qSplitSize * eheadSize + : nullptr; for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable @@ -327,10 +577,46 @@ void cpu_flash_attention( fill_stub(qk_sum_data, static_cast(0), qBlockSize); int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + if (!headSize_even && need_pack) { + // Pad query if headSize is not even + // [qBlockSize, headSize] -> [qBlockSize, eheadSize] + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + eheadSize, + qStrideM + ); + } for (int64_t n = 0; n < num_keys; n += kvSplitSize) { int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + int64_t ekvBlockSize = (need_pack && kvBlockSize % 2 != 0) ? kvBlockSize + 1 : kvBlockSize; + int64_t rkvBlockSize = kvBlockSize == kvSplitSize ? rkvSplitSize : rkvTail; // Calculate scale * q @ k.T - cpublas::gemm( + if (need_pack) { + if constexpr (std::is_same_v) { + for (int64_t b = 0; b < kvBlockSize; b += packb_size) { + cpublas::brgemm( + qBlockSize, + packb_size, + eheadSize, + headSize_even ? qStrideM : eheadSize, + packb_size, + rkvBlockSize, + 1.f, + 0.f, + !headSize_even + ? query_t_padding_ptr + : q_data + i * qStrideB + j * qStrideH + m * qStrideM, + key_reorder_ptr + i * num_head * eheadSize * rkvSize + + j * eheadSize * rkvSize + n * eheadSize + b * eheadSize, + qk_data + b); + } + } + } else { + cpublas::gemm( TransposeType::Transpose, TransposeType::NoTranspose, kvBlockSize, @@ -346,11 +632,12 @@ void cpu_flash_attention( static_cast(0), qk_data, kvBlockSize); + } // Apply causal mask, fill unused with -inf if (is_causal && num_keys - n <= kvSplitSize) { for (const auto row : c10::irange(qBlockSize)) { int64_t last_col = m + row - n; - accum_t* row_ptr = qk_data + row * kvBlockSize; + accum_t* row_ptr = qk_data + row * rkvBlockSize; fill_stub(row_ptr + last_col + 1, -std::numeric_limits::infinity(), kvBlockSize - last_col - 1); @@ -363,29 +650,29 @@ void cpu_flash_attention( for (int64_t row = 0; row < qBlockSize; ++row) { #if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE) _scale_attn_mask_fusion_kernel( - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM + (mStrideN == 0 ? 0 : n), kvBlockSize, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, scaling_factor, mStrideN == 0); #else if (mStrideN == 0) { _scale_attn_mask_fusion_kernel( - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM, kvBlockSize, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, scaling_factor); } else { _scale_attn_mask_fusion_kernel( - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM + n, kvBlockSize, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, scaling_factor); } #endif @@ -398,28 +685,28 @@ void cpu_flash_attention( // max per row tmp_max = at::vec::reduce_all( [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, kvBlockSize); } else { // apply scaling factor and max per row in fusion _mul_reduce_max_fusion_kernel( - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, scaling_factor, kvBlockSize, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, tmp_max); } tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max; if (tmp_max == -std::numeric_limits::infinity()) { // to avoid `nan = exp2f(-inf - (-inf))` - fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize, + fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize, static_cast(0), kvBlockSize); } else { tmp_sum = tmp_max; // qk <- exp(qk - max) and sum per row _exp_reduce_sum_fusion_kernel( - qk_data + row * kvBlockSize, kvBlockSize, - conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize, + qk_data + row * rkvBlockSize, kvBlockSize, + conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize, tmp_sum); // exp_tmp <- exp(max[row] - max) exp_tmp = std::exp(qk_max_data[row] - tmp_max); @@ -431,12 +718,40 @@ void cpu_flash_attention( if (n > 0) { vec::map( [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, - dst_data + row * headSize, dst_data + row * headSize, headSize); + dst_data + row * rHeadSize, + dst_data + row * rHeadSize, + headSize); } } + if (need_pack && kvBlockSize % 2 != 0) { + // Pad: [qSplitSize,kvSplitSize] -> [qSplitSize,kvSplitSize + 1] + *(qk_reduced_data + row * (1 + kvBlockSize) + kvBlockSize) = scalar_t(0); + } } // Calculate Softmax(q @ k.T) @ v - cpublas::gemm( + if (need_pack) { + int64_t psize = n / kvSplitSize * ekvSplitSize; + if constexpr (std::is_same_v) { + for (int64_t b = 0; b < headSize; b += packb_size) { + cpublas::brgemm( + qBlockSize, + packb_size, + ekvBlockSize, + ekvBlockSize, + packb_size, + rHeadSize, + 1.0, + n == 0 ? 0.f : 1.f, + qk_reduced_data, + value_reorder_ptr + + i * num_head * kv_padding_size * rHeadSize + + j * kv_padding_size * rHeadSize + psize * rHeadSize + + b * ekvBlockSize, + dst_data + b); + } + } + } else { + cpublas::gemm( TransposeType::NoTranspose, TransposeType::NoTranspose, headSize, @@ -451,6 +766,7 @@ void cpu_flash_attention( n == 0 ? static_cast(0) : static_cast(1), dst_data, headSize); + } } // dst <- dst / sum[row] @@ -465,7 +781,7 @@ void cpu_flash_attention( vec::map( [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); }, out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, - dst_data + row * headSize, + dst_data + row * rHeadSize, headSize); } // Store logsumexp for backward @@ -478,7 +794,9 @@ void cpu_flash_attention( data_index_step(i, batchSize, j, num_head, k, qSlice); } }); - + if (need_pack) { + cpublas::brgemm_release(); + } } template @@ -826,6 +1144,13 @@ void cpu_flash_attention_backward( AT_PRIVATE_CASE_TYPE_USING_HINT( \ at::ScalarType::Half, mask_t, __VA_ARGS__)) +#define FLASH_ATTENTION_KERNEL(FNAME, PACK, TYPE1, TYPE2, SEQ1, SEQ2, ...) \ + if (PACK) { \ + FNAME(__VA_ARGS__); \ + } else { \ + FNAME(__VA_ARGS__); \ + } + void flash_attention_kernel_impl( const Tensor& output, const Tensor& logsumexp, @@ -838,33 +1163,37 @@ void flash_attention_kernel_impl( std::optional scale) { auto q_seq_len = query.size(2); + // When q_seq_len and k_seq_len are long enough, + // cpu_flash_attention with pack has better performance. + bool could_pack = (query.scalar_type() == kHalf && cpublas::need_pack(kHalf)); + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention", [&] { if (!attn_mask.has_value()) { if (q_seq_len >= 768) { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 256, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } else if (q_seq_len >= 192) { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 64, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } else { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 32, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } } else { AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask", [&]() { if (q_seq_len >= 768) { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 256, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } else if (q_seq_len >= 192) { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 64, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } else { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 32, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } @@ -873,6 +1202,8 @@ void flash_attention_kernel_impl( }); } +#undef FLASH_ATTENTION_KERNEL + void flash_attention_backward_kernel_impl( const at::Tensor& grad_q, const at::Tensor& grad_k, diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h index e74621bef67506..a558c1bf13139a 100644 --- a/aten/src/ATen/native/cpu/utils.h +++ b/aten/src/ATen/native/cpu/utils.h @@ -159,6 +159,12 @@ inline void transpose(int64_t M, int64_t N, const float* src, int64_t ld_ TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); } + +template <> +inline void transpose(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} #endif template diff --git a/test/test_transformers.py b/test/test_transformers.py index 5ccce245bf965e..ae70b7dc09e9d8 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1981,8 +1981,8 @@ def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: to @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16]) @parametrize("batch_size", [2, 12]) - @parametrize("q_seq_len", [514, 1030]) - @parametrize("kv_seq_len", [514]) + @parametrize("q_seq_len", [11, 514, 1030]) + @parametrize("kv_seq_len", [17, 514]) @parametrize("n_head", [1, 3]) @parametrize("head_dim", [8]) @parametrize("mask_dim", [2, 4]) From ecdc28108a1d9f7b94eb3c2d26e149bc6752e8b3 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 5 Sep 2024 16:27:22 -0700 Subject: [PATCH 0609/1018] [dynamo] Remove skip from jit freeze tests (#135281) Fixes https://github.com/pytorch/pytorch/issues/119781 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135281 Approved by: https://github.com/zou3519 --- test/jit/test_freezing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index ac3d8e40b05529..9cbbacb087103f 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -46,7 +46,6 @@ def removeExceptions(graph): n.destroy() -@skipIfTorchDynamo("somehow causing hanging during python shutdown") class TestFreezing(JitTestCase): def test_freeze_module(self): class M(nn.Module): From 017e3b5ca4357a8b9ab3c91cb3249ed327b2d29e Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Sun, 8 Sep 2024 17:08:40 +0000 Subject: [PATCH 0610/1018] [reland][dtensor] move DTensor to public namespace (#134203) reland of https://github.com/pytorch/pytorch/pull/133113 I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :( ---- Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203 Approved by: https://github.com/tianyu-l --- docs/source/distributed.rst | 1 - docs/source/distributed.tensor.rst | 191 +++ docs/source/index.rst | 3 +- test/allowlist_for_publicAPI.json | 3 +- .../fsdp/test_fully_shard_clip_grad_norm_.py | 2 +- .../_composable/fsdp/test_fully_shard_comm.py | 2 +- .../_composable/fsdp/test_fully_shard_init.py | 2 +- .../fsdp/test_fully_shard_training.py | 2 +- .../test_2d_composability.py | 2 +- .../_tensor/debug/test_comm_mode.py | 2 +- .../_tensor/debug/test_comm_mode_features.py | 2 +- .../_tensor/debug/test_op_coverage.py | 2 +- .../_tensor/experimental/test_local_map.py | 2 +- .../_tensor/experimental/test_tp_transform.py | 2 +- test/distributed/_tensor/test_attention.py | 4 +- test/distributed/_tensor/test_common_rules.py | 4 +- test/distributed/_tensor/test_dtensor.py | 4 +- .../distributed/_tensor/test_embedding_ops.py | 6 +- test/distributed/_tensor/test_math_ops.py | 4 +- test/distributed/_tensor/test_matrix_ops.py | 2 +- test/distributed/_tensor/test_op_strategy.py | 18 +- test/distributed/_tensor/test_random_ops.py | 4 +- test/distributed/_tensor/test_redistribute.py | 6 +- test/distributed/_tensor/test_tensor_ops.py | 6 +- test/distributed/_tensor/test_utils.py | 11 +- test/distributed/_tensor/test_view_ops.py | 6 +- .../checkpoint/test_state_dict_utils.py | 3 +- .../fsdp/test_fsdp_tp_integration.py | 2 +- .../tensor/parallel/test_tp_examples.py | 4 +- .../tensor/parallel/test_tp_random_state.py | 2 +- .../tensor/parallel/test_tp_style.py | 4 +- test/distributed/test_device_mesh.py | 12 +- torch/_dynamo/guards.py | 4 +- torch/_dynamo/trace_rules.py | 6 +- torch/_dynamo/variables/distributed.py | 6 +- torch/_dynamo/variables/torch.py | 2 +- .../_composable/fsdp/_fsdp_collectives.py | 2 +- .../_composable/fsdp/_fsdp_common.py | 4 +- .../_composable/fsdp/_fsdp_init.py | 2 +- .../_composable/fsdp/_fsdp_param.py | 12 +- .../_composable/fsdp/fully_shard.py | 2 +- torch/distributed/_functional_collectives.py | 2 +- torch/distributed/_state_dict_utils.py | 2 +- torch/distributed/_tensor/__init__.py | 78 +- torch/distributed/_tensor/api.py | 1238 +---------------- .../_tensor/experimental/__init__.py | 19 - torch/distributed/_tensor/placement_types.py | 916 +----------- torch/distributed/checkpoint/_traverse.py | 2 +- .../distributed/checkpoint/default_planner.py | 2 +- .../examples/async_checkpointing_example.py | 2 +- .../checkpoint/examples/stateful_example.py | 2 +- torch/distributed/checkpoint/optimizer.py | 2 +- .../distributed/checkpoint/planner_helpers.py | 4 +- torch/distributed/checkpoint/state_dict.py | 2 +- torch/distributed/fsdp/_flat_param.py | 4 +- torch/distributed/fsdp/_fsdp_extensions.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 2 +- torch/distributed/fsdp/_shard_utils.py | 2 +- torch/distributed/fsdp/_state_dict_utils.py | 2 +- .../fsdp/fully_sharded_data_parallel.py | 2 +- .../distributed/{_tensor => tensor}/README.md | 0 torch/distributed/tensor/__init__.py | 67 + torch/distributed/tensor/_api.py | 1231 ++++++++++++++++ .../{_tensor => tensor}/_collective_utils.py | 10 +- .../{_tensor => tensor}/_dispatch.py | 25 +- torch/distributed/tensor/_dtensor_spec.py | 276 ++++ .../{_tensor => tensor}/_op_schema.py | 3 +- .../{_tensor/ops => tensor/_ops}/__init__.py | 0 .../ops => tensor/_ops}/_common_rules.py | 8 +- .../{_tensor/ops => tensor/_ops}/_conv_ops.py | 6 +- .../ops => tensor/_ops}/_einsum_strategy.py | 8 +- .../ops => tensor/_ops}/_embedding_ops.py | 8 +- .../ops => tensor/_ops}/_experimental_ops.py | 9 +- .../{_tensor/ops => tensor/_ops}/_math_ops.py | 12 +- .../ops => tensor/_ops}/_matrix_ops.py | 16 +- .../ops => tensor/_ops}/_pointwise_ops.py | 10 +- .../ops => tensor/_ops}/_random_ops.py | 6 +- .../ops => tensor/_ops}/_tensor_ops.py | 14 +- .../{_tensor/ops => tensor/_ops}/_view_ops.py | 10 +- .../{_tensor/ops => tensor/_ops}/utils.py | 12 +- .../{_tensor/random.py => tensor/_random.py} | 7 +- .../{_tensor => tensor}/_redistribute.py | 9 +- .../{_tensor => tensor}/_sharding_prop.py | 8 +- .../{_tensor => tensor}/_shards_wrapper.py | 0 .../{_tensor => tensor}/_tp_conv.py | 2 +- .../distributed/{_tensor => tensor}/_utils.py | 8 +- .../{_tensor => tensor}/debug/__init__.py | 11 +- .../debug/_comm_mode.py} | 41 +- .../{_tensor => tensor}/debug/_op_coverage.py | 2 +- .../debug/_visualize_sharding.py} | 9 +- .../debug/comm_mode_broswer_visual.js | 0 .../{_tensor => tensor}/device_mesh.py | 0 .../examples/comm_mode_features_example.py | 8 +- .../examples/convnext_example.py | 2 +- .../examples/torchrec_sharding_example.py | 16 +- .../examples/visualize_sharding_example.py | 4 +- .../tensor/experimental/__init__.py | 32 + .../experimental/_attention.py} | 2 +- .../experimental/_func_map.py} | 53 +- .../experimental/_register_sharding.py} | 12 +- .../experimental/_tp_transform.py} | 15 +- .../tensor/parallel/_data_parallel_utils.py | 4 +- torch/distributed/tensor/parallel/_utils.py | 4 +- torch/distributed/tensor/parallel/api.py | 6 +- torch/distributed/tensor/parallel/fsdp.py | 2 +- .../tensor/parallel/input_reshard.py | 2 +- torch/distributed/tensor/parallel/loss.py | 11 +- torch/distributed/tensor/parallel/style.py | 4 +- torch/distributed/tensor/placement_types.py | 652 +++++++++ torch/testing/_internal/common_fsdp.py | 2 +- 110 files changed, 2786 insertions(+), 2518 deletions(-) create mode 100644 docs/source/distributed.tensor.rst delete mode 100644 torch/distributed/_tensor/experimental/__init__.py rename torch/distributed/{_tensor => tensor}/README.md (100%) create mode 100644 torch/distributed/tensor/_api.py rename torch/distributed/{_tensor => tensor}/_collective_utils.py (97%) rename torch/distributed/{_tensor => tensor}/_dispatch.py (97%) create mode 100644 torch/distributed/tensor/_dtensor_spec.py rename torch/distributed/{_tensor => tensor}/_op_schema.py (99%) rename torch/distributed/{_tensor/ops => tensor/_ops}/__init__.py (100%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_common_rules.py (97%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_conv_ops.py (93%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_einsum_strategy.py (97%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_embedding_ops.py (98%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_experimental_ops.py (69%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_math_ops.py (99%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_matrix_ops.py (98%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_pointwise_ops.py (98%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_random_ops.py (90%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_tensor_ops.py (98%) rename torch/distributed/{_tensor/ops => tensor/_ops}/_view_ops.py (98%) rename torch/distributed/{_tensor/ops => tensor/_ops}/utils.py (96%) rename torch/distributed/{_tensor/random.py => tensor/_random.py} (98%) rename torch/distributed/{_tensor => tensor}/_redistribute.py (98%) rename torch/distributed/{_tensor => tensor}/_sharding_prop.py (99%) rename torch/distributed/{_tensor => tensor}/_shards_wrapper.py (100%) rename torch/distributed/{_tensor => tensor}/_tp_conv.py (99%) rename torch/distributed/{_tensor => tensor}/_utils.py (98%) rename torch/distributed/{_tensor => tensor}/debug/__init__.py (58%) rename torch/distributed/{_tensor/debug/comm_mode.py => tensor/debug/_comm_mode.py} (96%) rename torch/distributed/{_tensor => tensor}/debug/_op_coverage.py (98%) rename torch/distributed/{_tensor/debug/visualize_sharding.py => tensor/debug/_visualize_sharding.py} (95%) rename torch/distributed/{_tensor => tensor}/debug/comm_mode_broswer_visual.js (100%) rename torch/distributed/{_tensor => tensor}/device_mesh.py (100%) rename torch/distributed/{_tensor => tensor}/examples/comm_mode_features_example.py (99%) rename torch/distributed/{_tensor => tensor}/examples/convnext_example.py (99%) rename torch/distributed/{_tensor => tensor}/examples/torchrec_sharding_example.py (98%) rename torch/distributed/{_tensor => tensor}/examples/visualize_sharding_example.py (92%) create mode 100644 torch/distributed/tensor/experimental/__init__.py rename torch/distributed/{_tensor/experimental/attention.py => tensor/experimental/_attention.py} (99%) rename torch/distributed/{_tensor/experimental/func_map.py => tensor/experimental/_func_map.py} (82%) rename torch/distributed/{_tensor/experimental/register_sharding.py => tensor/experimental/_register_sharding.py} (93%) rename torch/distributed/{_tensor/experimental/tp_transform.py => tensor/experimental/_tp_transform.py} (98%) create mode 100644 torch/distributed/tensor/placement_types.py diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index f4c73b9381e594..b0661f867c961c 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -876,7 +876,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.nn.api .. py:module:: torch.distributed.nn.jit .. py:module:: torch.distributed.nn.jit.templates -.. py:module:: torch.distributed.tensor .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks diff --git a/docs/source/distributed.tensor.rst b/docs/source/distributed.tensor.rst new file mode 100644 index 00000000000000..1df4f2e43d5d9a --- /dev/null +++ b/docs/source/distributed.tensor.rst @@ -0,0 +1,191 @@ +.. currentmodule:: torch.distributed.tensor + +torch.distributed.tensor +=========================== + +.. note:: + ``torch.distributed.tensor`` is currently in alpha state and under + development, we are committing backward compatibility for the most APIs listed + in the doc, but there might be API changes if necessary. + + +PyTorch DTensor (Distributed Tensor) +--------------------------------------- + +PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed +logic, including sharded storage, operator computation and collective communications across devices/hosts. +``DTensor`` could be used to build different paralleism solutions and support sharded state_dict representation +when working with multi-dimensional sharding. + +Please see examples from the PyTorch native parallelism solutions that are built on top of ``DTensor``: + +* `Tensor Parallel `__ +* `FSDP2 `__ + +.. automodule:: torch.distributed.tensor + +:class:`DTensor` follows the SPMD (single program, multiple data) programming model to empower users to +write distributed program as if it's a **single-device program with the same convergence property**. It +provides a uniform tensor sharding layout (DTensor Layout) through specifying the :class:`DeviceMesh` +and :class:`Placement`: + +- :class:`DeviceMesh` represents the device topology and the communicators of the cluster using + an n-dimensional array. + +- :class:`Placement` describes the sharding layout of the logical tensor on the :class:`DeviceMesh`. + DTensor supports three types of placements: :class:`Shard`, :class:`Replicate` and :class:`Partial`. + + +DTensor Class APIs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. currentmodule:: torch.distributed.tensor + +:class:`DTensor` is a ``torch.Tensor`` subclass. This means once a :class:`DTensor` is created, it could be +used in very similar way to ``torch.Tensor``, including running different types of PyTorch operators as if +running them in a single device, allowing proper distributed computation for PyTorch operators. + +In addition to existing ``torch.Tensor`` methods, it also offers a set of additional methods to interact with +``torch.Tensor``, ``redistribute`` the DTensor Layout to a new DTensor, get the full tensor content +on all devices, etc. + +.. autoclass:: DTensor + :members: + :member-order: bysource + + +DeviceMesh as the distributed communicator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. currentmodule:: torch.distributed.device_mesh + +:class:`DeviceMesh` was built from DTensor as the abstraction to describe cluster's device topology and represent +multi-dimensional communicators (on top of ``ProcessGroup``). To see the details of how to create/use a DeviceMesh, +please refer to the `DeviceMesh recipe `__. + + +DTensor Placement Types +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. automodule:: torch.distributed.tensor.placement_types +.. currentmodule:: torch.distributed.tensor.placement_types + +DTensor supports the following types of :class:`Placement` on each :class:`DeviceMesh` dimension: + +.. autoclass:: Shard + :members: + :undoc-members: + +.. autoclass:: Replicate + :members: + :undoc-members: + +.. autoclass:: Partial + :members: + :undoc-members: + +.. autoclass:: Placement + :members: + :undoc-members: + + +Different ways to create a DTensor +--------------------------------------- + +.. currentmodule:: torch.distributed.tensor + +There're three ways to construct a :class:`DTensor`: + * :meth:`distribute_tensor` creates a :class:`DTensor` from a logical or "global" ``torch.Tensor`` on + each rank. This could be used to shard the leaf ``torch.Tensor`` s (i.e. model parameters/buffers + and inputs). + * :meth:`DTensor.from_local` creates a :class:`DTensor` from a local ``torch.Tensor`` on each rank, which can + be used to create :class:`DTensor` from a non-leaf ``torch.Tensor`` s (i.e. intermediate activation + tensors during forward/backward). + * DTensor provides dedicated tensor factory functions (e.g. :meth:`empty`, :meth:`ones`, :meth:`randn`, etc.) + to allow different :class:`DTensor` creations by directly specifying the :class:`DeviceMesh` and + :class:`Placement`. Compare to :meth:`distribute_tensor`, this could directly materializing the sharded memory + on device, instead of performing sharding after initializing the logical Tensor memory. + +Create DTensor from a logical torch.Tensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The SPMD (single program, multiple data) programming model in ``torch.distributed`` launches multiple processes +(i.e. via ``torchrun``) to execute the same program, this means that the model inside the program would be +initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly +on GPU if enough memory). + +``DTensor`` offers a :meth:`distribute_tensor` API that could shard the model weights or Tensors to ``DTensor`` s, +where it would create a DTensor from the "logical" Tensor on each process. This would empower the created +``DTensor`` s to comply with the single device semantic, which is critical for **numerical correctness**. + +.. autofunction:: distribute_tensor + +Along with :meth:`distribute_tensor`, DTensor also offers a :meth:`distribute_module` API to allow easier +sharding on the :class:`nn.Module` level + +.. autofunction:: distribute_module + + +DTensor Factory Functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +DTensor also provides dedicated tensor factory functions to allow creating :class:`DTensor` directly +using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally +specifying the :class:`DeviceMesh` and :class:`Placement` for the :class:`DTensor` created: + +.. autofunction:: zeros + +.. autofunction:: ones + +.. autofunction:: empty + +.. autofunction:: full + +.. autofunction:: rand + +.. autofunction:: randn + + +Debugging +--------------------------------------- + +.. automodule:: torch.distributed.tensor.debug +.. currentmodule:: torch.distributed.tensor.debug + +Logging +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +When launching the program, you can turn on additional logging using the `TORCH_LOGS` environment variable from +`torch._logging `__ : + +* `TORCH_LOGS=+dtensor` will display `logging.DEBUG` messages and all levels above it. +* `TORCH_LOGS=dtensor` will display `logging.INFO` messages and above. +* `TORCH_LOGS=-dtensor` will display `logging.WARNING` messages and above. + +Debugging Tools +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To debug the program that applied DTensor, and understand more details about what collectives happened under the +hood, DTensor provides a :class:`CommDebugMode`: + +.. autoclass:: CommDebugMode + :members: + :undoc-members: + +To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides :meth:`visualize_sharding`: + +.. autofunction:: visualize_sharding + + +Experimental Features +--------------------------------------- + +``DTensor`` also provides a set of experimental features. These features are either in prototyping stage, or the basic +functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to +these features. + +.. automodule:: torch.distributed.tensor.experimental +.. currentmodule:: torch.distributed.tensor.experimental + +.. autofunction:: local_map +.. autofunction:: register_sharding + + +.. modules that are missing docs, add the doc later when necessary +.. py:module:: torch.distributed.tensor.device_mesh diff --git a/docs/source/index.rst b/docs/source/index.rst index dcaadcbb63edc1..773e64204293b7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -74,12 +74,13 @@ Features described in this documentation are classified by release status: torch.backends torch.export torch.distributed + torch.distributed.tensor torch.distributed.algorithms.join torch.distributed.elastic torch.distributed.fsdp + torch.distributed.tensor.parallel torch.distributed.optim torch.distributed.pipelining - torch.distributed.tensor.parallel torch.distributed.checkpoint torch.distributions torch.compiler diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 05b1b8223cc038..9107da9a37cfe0 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -33,7 +33,8 @@ "torch.nn.quantizable": "torch.ao.nn.quantizable", "torch.nn.quantizable.modules": "torch.ao.nn.quantizable.modules", "torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation", - "torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn" + "torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn", + "torch.distributed.tensor.device_mesh": "torch.distributed.device_mesh" }, "torch.backends": [ "contextmanager" diff --git a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py index 521b20ca5b0bdf..636c98ec5ba590 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py @@ -8,8 +8,8 @@ import torch.nn as nn from torch.distributed._composable import replicate from torch.distributed._composable.fsdp import fully_shard -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest, MLPStack from torch.testing._internal.common_utils import run_tests diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 6e92ce3b36e62e..5641082a8063cd 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -32,9 +32,9 @@ from torch.distributed._composable.fsdp._fsdp_param import ShardedState from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup from torch.distributed._tensor import DTensor -from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed._tensor.experimental import implicit_replication from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py index 0ce3a029ca4318..c0c585f9d767c5 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_init.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py @@ -23,7 +23,6 @@ Replicate, Shard, ) -from torch.distributed._tensor.placement_types import _StridedShard from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp._init_utils import ( _init_inter_node_process_group, @@ -34,6 +33,7 @@ parallelize_module, RowwiseParallel, ) +from torch.distributed.tensor.placement_types import _StridedShard from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP from torch.testing._internal.common_utils import run_tests diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 66417c5c1edda4..f8bad4c5d5fa3a 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -20,12 +20,12 @@ register_fsdp_forward_method, ) from torch.distributed._tensor import DTensor, init_device_mesh -from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_PREFIX, apply_activation_checkpointing, ) from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 5f2c9dbe1240b7..754705b567a324 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -14,7 +14,6 @@ from torch.distributed._composable import replicate from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard -from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, @@ -27,6 +26,7 @@ clean_tensor_name, ) from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 7d5f56c30118ce..3428bca2c83ba3 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -5,8 +5,8 @@ import torch.distributed._functional_collectives as funcol import torch.nn as nn from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed._tensor.placement_types import Shard +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import requires_nccl from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule diff --git a/test/distributed/_tensor/debug/test_comm_mode_features.py b/test/distributed/_tensor/debug/test_comm_mode_features.py index 05469b714cf57b..fc19cddb58f4a3 100644 --- a/test/distributed/_tensor/debug/test_comm_mode_features.py +++ b/test/distributed/_tensor/debug/test_comm_mode_features.py @@ -6,7 +6,7 @@ import torch from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.api import distribute_tensor, DTensor -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, diff --git a/test/distributed/_tensor/debug/test_op_coverage.py b/test/distributed/_tensor/debug/test_op_coverage.py index 775a9f0d525d98..0392bae03c504a 100644 --- a/test/distributed/_tensor/debug/test_op_coverage.py +++ b/test/distributed/_tensor/debug/test_op_coverage.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from torch.distributed._tensor.debug._op_coverage import get_inductor_decomp_graphs +from torch.distributed.tensor.debug._op_coverage import get_inductor_decomp_graphs from torch.testing._internal.common_utils import run_tests, TestCase diff --git a/test/distributed/_tensor/experimental/test_local_map.py b/test/distributed/_tensor/experimental/test_local_map.py index 391118b73b0897..85d9977310c1d4 100644 --- a/test/distributed/_tensor/experimental/test_local_map.py +++ b/test/distributed/_tensor/experimental/test_local_map.py @@ -11,8 +11,8 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.experimental import local_map +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/_tensor/experimental/test_tp_transform.py b/test/distributed/_tensor/experimental/test_tp_transform.py index 18a322710825ef..719b3a21c08966 100644 --- a/test/distributed/_tensor/experimental/test_tp_transform.py +++ b/test/distributed/_tensor/experimental/test_tp_transform.py @@ -3,7 +3,7 @@ from typing import Dict import torch -from torch.distributed._tensor.experimental.tp_transform import ( +from torch.distributed._tensor.experimental._tp_transform import ( tensor_parallel_transformation, ) from torch.distributed.tensor.parallel.style import ( diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py index 0f18d54697021a..f5ae0d6a1179b6 100644 --- a/test/distributed/_tensor/test_attention.py +++ b/test/distributed/_tensor/test_attention.py @@ -7,14 +7,14 @@ import torch.nn.functional as F from torch import nn from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.experimental.attention import ( +from torch.distributed._tensor.experimental._attention import ( _AttentionContextParallel, _CausalBehavior, _context_parallel_buffers, _is_causal_behavior, context_parallel, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import parallelize_module from torch.nn.attention import sdpa_kernel, SDPBackend from torch.testing._internal.common_cuda import ( diff --git a/test/distributed/_tensor/test_common_rules.py b/test/distributed/_tensor/test_common_rules.py index 77b5d91405a73a..f712ef1cf02bad 100644 --- a/test/distributed/_tensor/test_common_rules.py +++ b/test/distributed/_tensor/test_common_rules.py @@ -3,9 +3,9 @@ import torch from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor._op_schema import OpSchema -from torch.distributed._tensor.ops._common_rules import einop_rule, pointwise_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema +from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index e4570f7dc41bcb..bc94a4c859a68d 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -14,7 +14,6 @@ DTensor, init_device_mesh, ) -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.experimental import implicit_replication from torch.distributed._tensor.placement_types import ( DTensorSpec, @@ -23,6 +22,7 @@ Shard, TensorMeta, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -943,7 +943,7 @@ def test_split_tensor_1D(self) -> None: ] assert_array_equal(expected_pad_sizes, pad_sizes) - from torch.distributed._tensor._collective_utils import unpad_tensor + from torch.distributed.tensor._collective_utils import unpad_tensor unpadded_list = [ unpad_tensor(tensor, shard_placement.dim, pad_sizes[i]) diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py index 3c366d570b37cc..889f62a8f22309 100644 --- a/test/distributed/_tensor/test_embedding_ops.py +++ b/test/distributed/_tensor/test_embedding_ops.py @@ -10,7 +10,7 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -167,7 +167,7 @@ def test_sharded_embedding_rowwise(self): self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22) self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10) - from torch.distributed._tensor.ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops._embedding_ops import _MaskPartial # test collectives embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type) @@ -191,7 +191,7 @@ def test_multiple_embeddings_rowwise(self): inp = torch.randint(0, 10, (4, 4), device=self.device_type) replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False) - from torch.distributed._tensor.ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops._embedding_ops import _MaskPartial # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial # and MaskBuffer, because of cache hit from sharding propagation diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index f6fa076b5b9069..d40b0571eb6947 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -13,9 +13,9 @@ distribute_tensor, DTensor, ) -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.ops.utils import is_tensor_partial, normalize_dim from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, diff --git a/test/distributed/_tensor/test_matrix_ops.py b/test/distributed/_tensor/test_matrix_ops.py index 45988e233248f0..40241917bd7c52 100644 --- a/test/distributed/_tensor/test_matrix_ops.py +++ b/test/distributed/_tensor/test_matrix_ops.py @@ -8,13 +8,13 @@ import torch.nn.functional as F from torch.distributed._tensor import DeviceMesh, distribute_tensor from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/_tensor/test_op_strategy.py b/test/distributed/_tensor/test_op_strategy.py index 302e2675cc8995..86d5b92075eab9 100644 --- a/test/distributed/_tensor/test_op_strategy.py +++ b/test/distributed/_tensor/test_op_strategy.py @@ -4,12 +4,6 @@ import torch from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy -from torch.distributed._tensor.ops._einsum_strategy import ( - EinsumDims, - gen_einsum_strategies, -) from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, @@ -17,6 +11,12 @@ Shard, TensorMeta, ) +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy +from torch.distributed.tensor._ops._einsum_strategy import ( + EinsumDims, + gen_einsum_strategies, +) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase @@ -169,7 +169,7 @@ def test_redistribute_cost_mesh_1d(self): def test_redistribute_cost_latency(self): # test cost model on addmm op - from torch.distributed._tensor.ops._matrix_ops import addmm_strategy + from torch.distributed.tensor._ops._matrix_ops import addmm_strategy mesh = self.build_device_mesh() shard0_placement = (Shard(0),) @@ -246,7 +246,7 @@ def test_redistribute_cost_mesh_2d(self): self.assertTrue(allreduce_cost > reduce_scatter_cost) def test_mm_strategies(self): - from torch.distributed._tensor.ops._matrix_ops import mm_strategy + from torch.distributed.tensor._ops._matrix_ops import mm_strategy mesh = self.build_device_mesh() lhs_tensor = torch.randn(6, 8) @@ -292,7 +292,7 @@ def test_mm_strategies(self): self.assertFalse(output_sharding.needs_redistribute) def test_bmm_strategies(self): - from torch.distributed._tensor.ops._matrix_ops import bmm_strategy + from torch.distributed.tensor._ops._matrix_ops import bmm_strategy mesh = self.build_device_mesh() lhs_tensor = torch.randn(8, 6, 8) diff --git a/test/distributed/_tensor/test_random_ops.py b/test/distributed/_tensor/test_random_ops.py index c149075eddd989..6964b412537a23 100644 --- a/test/distributed/_tensor/test_random_ops.py +++ b/test/distributed/_tensor/test_random_ops.py @@ -5,13 +5,13 @@ import torch import torch.distributed._functional_collectives as funcol -import torch.distributed._tensor.random as random +import torch.distributed.tensor._random as random from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed._tensor.api import distribute_tensor from torch.distributed._tensor.placement_types import Replicate, Shard -from torch.distributed._tensor.random import is_rng_supported_mesh, manual_seed from torch.distributed.distributed_c10d import broadcast_object_list +from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py index 5634869155052c..6a940d3c11e8cf 100644 --- a/test/distributed/_tensor/test_redistribute.py +++ b/test/distributed/_tensor/test_redistribute.py @@ -5,10 +5,10 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor -from torch.distributed._tensor._collective_utils import shard_dim_alltoall -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor._collective_utils import shard_dim_alltoall +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -207,7 +207,7 @@ def test_replicate_to_partial(self): with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"): partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec]) - from torch.distributed._tensor._redistribute import Redistribute + from torch.distributed.tensor._redistribute import Redistribute comm_mode = CommDebugMode() diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py index dde1a515b5da1e..f9153c126bc8e9 100644 --- a/test/distributed/_tensor/test_tensor_ops.py +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -3,8 +3,8 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.placement_types import Partial, Replicate, Shard +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -445,7 +445,7 @@ def test_gather(self): # case 2 input sharding: input sharded, index replicated, output mask partial # only works when index has size 1 on the gather dimension and # input is sharded on the gather dimension - from torch.distributed._tensor.ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops._embedding_ops import _MaskPartial gather_dim = 1 global_input = torch.randn(12, 8, 16) @@ -614,7 +614,7 @@ def test_dtensor_dtype_conversion(self): self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16) self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16) - from torch.distributed._tensor.debug import _get_sharding_prop_cache_info + from torch.distributed.tensor.debug import _get_sharding_prop_cache_info # by this point we only have cache misses hits, misses, _, _ = _get_sharding_prop_cache_info() diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index 3bad008fb806eb..ea05d2027b5829 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -8,15 +8,10 @@ compute_local_shape, compute_local_shape_and_global_offset, ) -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import ( - _StridedShard, - DTensorSpec, - Replicate, - Shard, - TensorMeta, -) from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.debug import CommDebugMode +from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index 3af39815997fc4..630c7f8511d825 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -14,8 +14,8 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.ops._view_ops import ( +from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor._ops._view_ops import ( Broadcast, dim_maps, Flatten, @@ -25,7 +25,7 @@ Split, view_groups, ) -from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/checkpoint/test_state_dict_utils.py b/test/distributed/checkpoint/test_state_dict_utils.py index 4bd4f9cf21f214..757cf77c067395 100644 --- a/test/distributed/checkpoint/test_state_dict_utils.py +++ b/test/distributed/checkpoint/test_state_dict_utils.py @@ -12,8 +12,7 @@ _gather_state_dict, _offload_state_dict_to_cpu, ) -from torch.distributed._tensor import DTensor -from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor import DTensor, Shard from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/fsdp/test_fsdp_tp_integration.py b/test/distributed/fsdp/test_fsdp_tp_integration.py index c0a2710e9c9023..5f04bc8045dbbd 100644 --- a/test/distributed/fsdp/test_fsdp_tp_integration.py +++ b/test/distributed/fsdp/test_fsdp_tp_integration.py @@ -14,12 +14,12 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, FullyShardedDataParallel as FSDP, ShardingStrategy, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index 0c470890241214..43662c4d6cf2d4 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -15,11 +15,11 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, loss_parallel, @@ -118,8 +118,6 @@ def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=Fal output = model(inp) output.sum().backward() - from torch.distributed._tensor.debug import CommDebugMode - comm_mode = CommDebugMode() with comm_mode: output_tp = model_tp(inp) diff --git a/test/distributed/tensor/parallel/test_tp_random_state.py b/test/distributed/tensor/parallel/test_tp_random_state.py index 41ca79121d4b48..b9f73a70430d46 100644 --- a/test/distributed/tensor/parallel/test_tp_random_state.py +++ b/test/distributed/tensor/parallel/test_tp_random_state.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: distributed"] import torch import torch.distributed._functional_collectives as funcol -import torch.distributed._tensor.random as random +import torch.distributed.tensor._random as random from torch.distributed._tensor import init_device_mesh, Replicate from torch.distributed.tensor.parallel.api import parallelize_module from torch.distributed.tensor.parallel.style import ColwiseParallel diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py index ebd6d82a8c772e..28ff10bab099ca 100644 --- a/test/distributed/tensor/parallel/test_tp_style.py +++ b/test/distributed/tensor/parallel/test_tp_style.py @@ -12,8 +12,7 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import _Partial +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import parallelize_module from torch.distributed.tensor.parallel.style import ( ColwiseParallel, @@ -22,6 +21,7 @@ RowwiseParallel, SequenceParallel, ) +from torch.distributed.tensor.placement_types import _Partial from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index b1c5b4d5883de6..8431bac6f6f913 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -5,12 +5,6 @@ import torch import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import DTensor -from torch.distributed._tensor._collective_utils import ( - mesh_broadcast, - mesh_scatter, - unpad_tensor, -) -from torch.distributed._tensor.placement_types import _Partial, Shard from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.distributed_c10d import ( _get_default_group, @@ -22,6 +16,12 @@ is_nccl_available, ProcessGroup, ) +from torch.distributed.tensor._collective_utils import ( + mesh_broadcast, + mesh_scatter, + unpad_tensor, +) +from torch.distributed.tensor.placement_types import _Partial, Shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index ed0c9893d808b7..1923426bf060f8 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1474,12 +1474,12 @@ def EQUALS_MATCH(self, guard: Guard): ) if torch.distributed.is_available(): - from torch.distributed._tensor.placement_types import ( + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.tensor.placement_types import ( Partial, Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh ok_types = ok_types + ( Shard, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 7c70ed0761f839..6609b84885ecd8 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -144,7 +144,7 @@ "torch.distributed.is_initialized": TorchInGraphFunctionVariable, "torch.distributed.get_rank": TorchInGraphFunctionVariable, "torch.distributed.get_world_size": TorchInGraphFunctionVariable, - "torch.distributed._tensor.api.DTensor#from_local": TorchInGraphFunctionVariable, + "torch.distributed.tensor._api.DTensor#from_local": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable, @@ -3200,8 +3200,8 @@ def _module_dir(m: types.ModuleType): if torch.distributed.is_available(): LEGACY_MOD_INLINELIST |= { - "torch.distributed._tensor.api", - "torch.distributed._tensor.device_mesh", + "torch.distributed.tensor._api", + "torch.distributed.tensor.device_mesh", "torch.distributed.device_mesh", "torch.distributed.algorithms._checkpoint.checkpoint_wrapper", "torch.distributed.tensor.parallel._data_parallel_utils", diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 1c79ce094822b0..c14b8794cba5f7 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -50,7 +50,7 @@ def is_available(): def is_from_local(value): if not DistributedVariable.is_available(): return False - from torch.distributed._tensor import DTensor + from torch.distributed.tensor import DTensor return inspect.isfunction(value) and value is DTensor.from_local @@ -108,7 +108,7 @@ def is_placement_type(value): if not DistributedVariable.is_available(): return False - from torch.distributed._tensor.placement_types import Placement + from torch.distributed.tensor.placement_types import Placement return type(value) is type and issubclass(value, Placement) @@ -143,7 +143,7 @@ def is_placement(value): if not DistributedVariable.is_available(): return False - from torch.distributed._tensor.placement_types import Placement + from torch.distributed.tensor.placement_types import Placement return isinstance(value, Placement) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 34f8166c7bbcde..3dfda1cfec83da 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -621,7 +621,6 @@ def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs): ) if DistributedVariable.is_available(): - from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ( _get_group_size_by_name, _get_group_tag, @@ -629,6 +628,7 @@ def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs): _resolve_group_name_by_ranks_and_tag, get_process_group_ranks, ) + from torch.distributed.tensor import DTensor @register( _get_group_size_by_name, diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index 1351d781ec91fd..4e10f4594c1529 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -4,8 +4,8 @@ import torch import torch._dynamo.compiled_autograd as ca import torch.distributed as dist -from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ReduceOp +from torch.distributed.tensor import DTensor from ._fsdp_common import ( _get_dim0_padded_size, diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 36b181250f28da..31b74079aaa8be 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -10,8 +10,8 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.contract import _get_registry -from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec @dataclass diff --git a/torch/distributed/_composable/fsdp/_fsdp_init.py b/torch/distributed/_composable/fsdp/_fsdp_init.py index 1a137d24b86723..c07e323449f30c 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_init.py +++ b/torch/distributed/_composable/fsdp/_fsdp_init.py @@ -4,8 +4,8 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh from torch.utils._python_dispatch import is_traceable_wrapper_subclass from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index 7bd5dd72badabd..f6fc631ec8ac46 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -9,14 +9,10 @@ import torch.nn as nn from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor import DTensor, Replicate, Shard -from torch.distributed._tensor.device_mesh import _mesh_resources -from torch.distributed._tensor.placement_types import ( - _StridedShard, - DTensorSpec, - Placement, - TensorMeta, -) +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.device_mesh import _mesh_resources +from torch.distributed.tensor.placement_types import _StridedShard, Placement from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ( diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index b06e1a6b777a1a..5da4f83f9e5594 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from torch.distributed._composable import contract -from torch.distributed._tensor import DeviceMesh +from torch.distributed.tensor import DeviceMesh from torch.distributed.utils import _get_root_modules from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 2df9435c369a7a..77b962bf3df47c 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -97,7 +97,7 @@ def is_torchdynamo_compiling(): List[List[int]], dist.ProcessGroup, DeviceMesh, - Tuple["dist._tensor.DeviceMesh", int], + Tuple["dist.tensor.DeviceMesh", int], str, ] diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index b2694f69d720d0..f22a13cb6ec0f4 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -27,7 +27,7 @@ if dist.is_available() or TYPE_CHECKING: from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ShardedTensor - from torch.distributed._tensor import distribute_tensor, DTensor, Replicate + from torch.distributed.tensor import distribute_tensor, DTensor, Replicate def _identity_func( diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index c741ea5a0c46e9..40f9727015d764 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -1,58 +1,44 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates +""" +NOTICE: DTensor has moved to torch.distributed.tensor -import torch -import torch.distributed._tensor.ops as _ops # force import all built-in dtensor ops -from torch.distributed._tensor.api import ( +This file is a shim to redirect to the new location, and +we keep the old import path starts with `_tensor` for +backward compatibility. We will remove this folder once +we resolve all the BC issues. +""" +import sys +from importlib import import_module + + +submodules = [ + # TODO: _shards_wrapper/_utils here mainly for checkpoint BC, remove them + "_shards_wrapper", + "_utils", + "experimental", + "device_mesh", +] + +# Redirect imports +for submodule in submodules: + full_module_name = f"torch.distributed.tensor.{submodule}" + sys.modules[f"torch.distributed._tensor.{submodule}"] = import_module( + full_module_name + ) + +from torch.distributed.tensor import ( # noqa: F401 + DeviceMesh, distribute_module, distribute_tensor, DTensor, empty, full, + init_device_mesh, ones, - rand, - randn, - zeros, -) -from torch.distributed._tensor.placement_types import ( Partial, Placement, + rand, + randn, Replicate, Shard, + zeros, ) -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.optim.optimizer import ( - _foreach_supported_types as _optim_foreach_supported_types, -) -from torch.utils._foreach_utils import ( - _foreach_supported_types as _util_foreach_supported_types, -) - - -# All public APIs from dtensor package -__all__ = [ - "DTensor", - "DeviceMesh", - "distribute_tensor", - "distribute_module", - "init_device_mesh,", - "Shard", - "Replicate", - "Partial", - "Placement", - "ones", - "empty", - "full", - "rand", - "randn", - "zeros", -] - - -# Append DTensor to the list of supported types for foreach implementation for optimizer -# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. -if DTensor not in _optim_foreach_supported_types: - _optim_foreach_supported_types.append(DTensor) - -if DTensor not in _util_foreach_supported_types: - _util_foreach_supported_types.append(DTensor) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 7970cd7fb4bd51..9e5742156a86ca 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -1,1233 +1,9 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -import warnings -from typing import Any, Callable, cast, Optional, Sequence, Tuple +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases -import torch -import torch.distributed._tensor._dispatch as op_dispatch -import torch.distributed._tensor.random as random -import torch.nn as nn -from torch.distributed._tensor._collective_utils import ( - check_tensor_meta, - mesh_broadcast, -) -from torch.distributed._tensor._redistribute import ( - Redistribute, - redistribute_local_tensor, -) -from torch.distributed._tensor._utils import ( - compute_global_tensor_info, - compute_local_shape, - normalize_to_torch_size, -) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Partial, - Placement, - Replicate, - Shard, - TensorMeta, -) -from torch.distributed._tensor.random import ( - is_rng_supported_mesh, - OffsetBasedRNGTracker, -) -from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +TODO: throw warnings when this module imported +""" - -__all__ = [ - "DTensor", - "distribute_tensor", - "distribute_module", - "ones", - "empty", - "full", - "rand", - "randn", - "zeros", -] - -aten = torch.ops.aten - - -# NOTE [Autograd interaction between torch.Tensor] -# -# The autograd functions defined below are being used by the public -# facing APIs (i.e. from_local, to_local) to ensure DTensor to work -# together with torch.Tensor within the autograd engine. This -# allows DTensor to only exist on part of the module hierarchy. -# -# As an example, we have the a module that consists of submodules -# A, B, and C, the execution flow would be like: -# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) -# -# Suppose I only want to make Module B be a sharded module with -# DTensor params, the following forward/backward should work: -# -# input(torch.Tensor) -> Module A -# -> DTensor input (from_local) -> Sharded Module B -> DTensor output -# -> torch.Tensor output (to_local) -> Module C -# -# So from_local/to_local must be Autograd functions. -# -class _ToTorchTensor(torch.autograd.Function): - @staticmethod - def forward( # type: ignore[override] - ctx, - input: "DTensor", - grad_placements: Optional[Sequence[Placement]], - ): - ctx.dtensor_spec = input._spec - ctx.grad_placements = grad_placements - local_tensor = input._local_tensor - - # We need to return a fresh Tensor object there as autograd metadata - # will be inplaced into it. So we don't want to pollute the Tensor - # object stored in the _local_tensor of this DTensor. - return local_tensor.view_as(local_tensor) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] - dtensor_spec = ctx.dtensor_spec - mesh = dtensor_spec.mesh - grad_placements = ctx.grad_placements - dtensor_meta = dtensor_spec.tensor_meta - - _, tensor_stride = compute_global_tensor_info( - grad_output, mesh, dtensor_spec.placements - ) - tensor_stride = tuple(tensor_stride) - grad_placements = grad_placements or dtensor_spec.placements - grad_spec = DTensorSpec( - mesh, - grad_placements, - tensor_meta=TensorMeta( - shape=dtensor_meta.shape, - stride=tensor_stride, - dtype=dtensor_meta.dtype, - ), - ) - - return ( - DTensor( - grad_output, - grad_spec, - requires_grad=grad_output.requires_grad, - ), - None, - ) - - -class _FromTorchTensor(torch.autograd.Function): - @staticmethod - def forward( # type: ignore[override] - ctx, # pyre-ignore[2]: Parameter must be annotated. - input: torch.Tensor, - device_mesh: DeviceMesh, - placements: Tuple[Placement, ...], - run_check: bool, - shape: Optional[torch.Size] = None, - stride: Optional[Tuple[int, ...]] = None, - ) -> "DTensor": - ctx.previous_placement = placements - ctx.previous_device_mesh = device_mesh - - if shape and stride: - tensor_shape, tensor_stride = shape, stride - elif not shape and not stride: - # if it's not by default run_check, we assume user is certain that each - # rank has the same tensor shape, and we just use that to calculate the - # global shape - global_shape, global_stride = compute_global_tensor_info( - input, device_mesh, placements - ) - tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) - else: - raise RuntimeError( - f"Found shape:{shape}, stride:{stride}.", - "Please pass both shape and stride at the same time.", - ) - - if device_mesh.get_coordinate() is None: - # if the global rank is not participating in the device mesh, we - # simply set the local tensor to an empty tensor - input = input.new_empty(0, requires_grad=input.requires_grad) - elif run_check: - # TODO: support uneven sharding when global shape/stride not passed, by - # building the global TensorMeta during check_tensor_meta - check_shape_stride = not shape and not stride - check_tensor_meta(input, check_shape_stride=check_shape_stride) - # TODO: See if we need to make this run_check logic - # have a corresponding backward. - for idx, placement in enumerate(placements): - if placement.is_replicate(): - # broadcast rank 0 tensor to all ranks - # only broadcast if run_check is True - input = input.contiguous() - mesh_broadcast(input, device_mesh, mesh_dim=idx) - - dist_spec = DTensorSpec( - device_mesh, - placements, - tensor_meta=TensorMeta( - tensor_shape, - tensor_stride, - input.dtype, - ), - ) - - # We want a fresh Tensor object that shares memory with the input tensor - dist_tensor = DTensor( - input.view_as(input), - dist_spec, - # requires_grad of the dist tensor depends on if input - # requires_grad or not - requires_grad=input.requires_grad, - ) - return dist_tensor - - @staticmethod - def backward(ctx, grad_output: "DTensor"): # type: ignore[override] - previous_placement = ctx.previous_placement - previous_device_mesh = ctx.previous_device_mesh - - # reshard to the placement when creating DistributedTensor - # so that the gradient layout matches, and we could return - # local gradients directly - if grad_output.placements != previous_placement: - current_spec = grad_output._spec - target_spec = DTensorSpec( - previous_device_mesh, - previous_placement, - tensor_meta=grad_output._spec.tensor_meta, - ) - local_tensor = grad_output._local_tensor - output = redistribute_local_tensor( - local_tensor, current_spec, target_spec, is_backward=True - ) - # TODO: return the redistributed local tensor directly without - # differentiable backward. see if this make sense for all cases. - return output, None, None, None, None, None - - # TODO: backward is also differentiable now, add a test - # to test higher level gradients. - return grad_output.to_local(), None, None, None, None, None - - -class DTensor(torch.Tensor): - """ - ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like - abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding - layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: - - * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension - * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension - * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension - - When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue - communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the - placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. - - To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` - requires every Tensor argument of the operator be DTensor. - - """ - - _local_tensor: torch.Tensor - _spec: DTensorSpec - __slots__ = ["_local_tensor", "_spec"] - - # _op_dispatcher instance as a class attribute to handle runtime dispatching logic - _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() - - @staticmethod - @torch._disable_dynamo - def __new__( - cls, - local_tensor: torch.Tensor, - spec: DTensorSpec, - *, - requires_grad: bool, - ) -> "DTensor": - """ - Construct a DTensor from a local tensor, device mesh, and placement and - other tensor properties (i.e. shape, requires_grad, strides, etc). - Note: This is not a public API and it's only supposed to be used by the - operator implementations and internals. If you want to construct a - DTensor from a local tensor, consider using ``DTensor.from_local``, if - you want to construct a DTensor from a "global" tensor (where you - already have tensor initialized and want to shard this tensor), - consider using ``distribute_tensor``. - """ - if local_tensor.requires_grad and not requires_grad: - warnings.warn( - "To construct DTensor from torch.Tensor, it's recommended to " - "use local_tensor.detach() and make requires_grad consistent." - ) - - # new method instruct wrapper tensor from local_tensor and add - # placement spec, it does not do actual distribution - assert spec.tensor_meta is not None, "TensorMeta should not be None!" - r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, - spec.tensor_meta.shape, - strides=spec.tensor_meta.stride, - dtype=local_tensor.dtype, - device=local_tensor.device, - layout=local_tensor.layout, - requires_grad=requires_grad, - ) - - r._spec = spec - r._local_tensor = local_tensor - return r - - # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. - # pyre-fixme[3]: Return type must be annotated. - def __repr__(self): - # TODO: consider all_gather the local tensors for better debugging - return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" - - def __tensor_flatten__(self): - """ - protocol to inform how to flatten a DTensor to local tensor - for PT2 tracing - """ - return ["_local_tensor"], (self._spec, self.requires_grad) - - @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - assert ( - flatten_spec is not None - ), "Expecting spec to be not None from `__tensor_flatten__` return value!" - local_tensor = inner_tensors["_local_tensor"] - spec, requires_grad = flatten_spec - unflatten_tensor_meta = TensorMeta( - shape=outer_size, - stride=outer_stride, - dtype=spec.tensor_meta.dtype, - ) - unflatten_spec = DTensorSpec( - spec.mesh, - spec.placements, - tensor_meta=unflatten_tensor_meta, - ) - return DTensor( - local_tensor, - unflatten_spec, - requires_grad=requires_grad, - ) - - def __coerce_tangent_metadata__(self): - if not any(isinstance(p, Partial) for p in self.placements): - return self - placements = [ - Replicate() if isinstance(p, Partial) else p for p in self.placements - ] - return self.redistribute(device_mesh=self.device_mesh, placements=placements) - - def __coerce_same_metadata_as_tangent__(self, flatten_spec): - (spec, _) = flatten_spec # Result of tensor_flatten() - return self.redistribute( - device_mesh=self.device_mesh, - placements=spec.placements, - ) - - @classmethod - @torch._disable_dynamo - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - return DTensor._op_dispatcher.dispatch( - func, - args, - kwargs or {}, - ) - - @staticmethod - def from_local( - local_tensor: torch.Tensor, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, - *, - run_check: bool = False, - shape: Optional[torch.Size] = None, - stride: Optional[Tuple[int, ...]] = None, - ) -> "DTensor": - """ - Create a :class:`DTensor` from a local torch.Tensor on each rank - according to the ``device_mesh`` and ``placements`` specified. - - Args: - local_tensor (torch.Tensor): local torch.Tensor on each rank. - device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the - tensor, if not specified, must be called under a DeviceMesh - context manager, default: None - placements (List[:class:`Placement`], optional): the placements that - describes how to place the local torch.Tensor on DeviceMesh, must - have the same number of elements as ``device_mesh.ndim``. - - Keyword args: - run_check (bool, optional): at a cost of extra communications, perform - sanity check across ranks to check each local tensor's meta information - to ensure correctness. If have :class:`Replicate` in ``placements``, the - data on first rank of the device mesh dimension will be broadcasted - to other ranks. default: False - shape (torch.Size, optional): A List of int which specifies the size of - DTensor which build on top of `local_tensor`. Note this needs to be - provided if the shape of ``local_tensor`` are different across the ranks. - If not provided, ``shape`` will be computed assuming the given distributed - tensor is evenly sharded across ranks. default: None - stride (tuple, optional): A List of int which specifies the stride of DTensor. - If not provided, ``stride`` will be computed assuming the given distributed - tensor is evenly sharded across ranks. default: None - - Returns: - A :class:`DTensor` object - - .. note:: When ``run_check=False``, it is the user's responsibility to ensure the - local tensor passed in is correct across ranks (i.e. the tensor is sharded for - the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). - If not, the behavior of the created DTensor is undefined. - - .. note:: ``from_local`` is differentiable, the `requires_grad` of the created - `DTensor` object will depend on if `local_tensor` requires_grad or not. - """ - # if same shape/dtype, no need to run_check, if not, must allgather - # the metadatas to check the size/dtype across ranks - # There should be no data communication unless there's replication - # strategy, where we broadcast the replication from the first rank - # in the mesh dimension - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - device_type = device_mesh.device_type - - # convert the local tensor to desired device base on device mesh's device_type - if device_type != local_tensor.device.type and not local_tensor.is_meta: - local_tensor = local_tensor.to(device_type) - - # set default placements to replicated if not specified - if placements is None: - placements = [Replicate() for _ in range(device_mesh.ndim)] - else: - placements = list(placements) - for idx, placement in enumerate(placements): - # normalize shard dim to be positive - if placement.is_shard(): - placement = cast(Shard, placement) - if placement.dim < 0: - placements[idx] = Shard(placement.dim + local_tensor.ndim) - - # `from_local` is differentiable, and the gradient of the dist tensor this function - # created should flow back the gradients to the local_tensor, so we call an autograd - # function to construct the dist tensor instead. - return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func - local_tensor, - device_mesh, - tuple(placements), - run_check, - shape, - stride, - ) - - def to_local( - self, *, grad_placements: Optional[Sequence[Placement]] = None - ) -> torch.Tensor: - """ - Get the local tensor of this DTensor on its current rank. For sharding it returns - a local shard of the logical tensor view, for replication it returns the replica on - its current rank. - - Keyword args: - grad_placements (List[:class:`Placement`], optional): the placements describes - the future layout of any gradient layout of the Tensor returned from this - function. - `to_local` converts DTensor to local tensor and the returned local tensor - might not be used as the original DTensor layout later in the code. This - argument is the hint that user can give to autograd in case the gradient - layout of the returned tensor does not match the original DTensor layout. - If not specified, we will assume the gradient layout remains the same - as the original DTensor and use that for gradient computation. - - Returns: - A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the - local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, - it means the local tensor is not ready yet (i.e. communication is not finished). In this - case, user needs to call ``wait`` to wait the local tensor to be ready. - - .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned - will depend on if the `DTensor` requires_grad or not. - """ - if not torch.is_grad_enabled(): - return self._local_tensor - - if grad_placements is not None and not isinstance(grad_placements, tuple): - grad_placements = tuple(grad_placements) - return _ToTorchTensor.apply( - self, grad_placements - ) # pyre-ignore[16]: autograd func - - def redistribute( - self, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, - *, - async_op: bool = False, - ) -> "DTensor": - """ - ``redistribute`` performs necessary collective operations that redistribute the current - DTensor from its current placements to a new placements, or from is current DeviceMesh - to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by - specifying a Replicate placement for each dimension of the DeviceMesh. - - When redistributing from current to the new placements on one device mesh dimension, we - will perform the following operations including communication collective or local operation: - - 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` - 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` - 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) - 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` - 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` - - - ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors - that are created either on 1-D or N-D DeviceMesh. - - Args: - device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the - DTensor. If not specified, it would use the current DTensor's DeviceMesh. - default: None - placements (List[:class:`Placement`], optional): the new placements that - describes how to place the DTensor into the DeviceMesh, must - have the same number of elements as ``device_mesh.ndim``. - default: replicate on all mesh dimensions - - Keyword args: - async_op (bool, optional): whether to perform the DTensor redistribute operation - asynchronously or not. Default: False - - Returns: - A :class:`DTensor` object - - .. note:: ``redistribute`` is differentiable, which means user do not need to worry about - the backward formula of the redistribute operation. - - .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, - Please file an issue if you need to redistribute DTensor to different DeviceMesh. - """ - # NOTE: This redistribute API currently only supports out - # of place redistribution, i.e. it always create a new - # DTensor object and leave the original one unchanged. - - # if device_mesh is not specified, use the current device_mesh - device_mesh = device_mesh or self.device_mesh - # raise error if new placements not specified - if placements is None: - raise RuntimeError("placements is needed for redistribute!") - - placements = list(placements) - for i, placement in enumerate(placements): - if placement.is_partial(): - raise RuntimeError( - "Can not redistribute to Partial, redistributing to Partial is for internal use only!" - ) - elif isinstance(placement, Shard) and placement.dim < 0: - # normalize shard dim to be positive - placements[i] = Shard(placement.dim + self.ndim) - placements = tuple(placements) - - # pyre-fixme[16]: `Redistribute` has no attribute `apply`. - return Redistribute.apply(self, device_mesh, placements, async_op) - - def full_tensor( - self, *, grad_placements: Optional[Sequence[Placement]] = None - ) -> torch.Tensor: - """ - Return the full tensor of this DTensor. It will perform necessary collectives - to gather the local tensors from other ranks in its DeviceMesh and concatenate - them together. It's a syntatic sugar of the following code: - - ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` - - Keyword args: - grad_placements (List[:class:`Placement`], optional): the placements describes - the future layout of any gradient layout of the full Tensor returned from this - function. - `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor - might not be used as the original replicated DTensor layout later in the code. This - argument is the hint that user can give to autograd in case the gradient - layout of the returned tensor does not match the original replicated DTensor layout. - If not specified, we will assume the gradient layout of the full tensor be replicated. - - Returns: - A :class:`torch.Tensor` object that represents the full tensor of this DTensor. - - .. note:: ``full_tensor`` is differentiable. - """ - - redist_res = self.redistribute( - placements=[Replicate()] * self.device_mesh.ndim, async_op=False - ) - return _ToTorchTensor.apply(redist_res, grad_placements) - - @property - def device_mesh(self) -> DeviceMesh: - """ - The :class:`DeviceMesh` attribute that associates with this DTensor object. - - .. note:: ``device_mesh`` is a read-only property, it can not be set. - """ - return self._spec.mesh - - @property - def placements(self) -> Tuple[Placement, ...]: - """ - The placements attribute of this DTensor that describes the layout of this - DTensor on the its DeviceMesh. - - .. note:: ``placements`` is a read-only property, it can not be set. - """ - return self._spec.placements - - def __create_write_items__(self, fqn: str, object: Any): - from torch.distributed.checkpoint.planner_helpers import ( - _create_write_items_for_dtensor, - ) - - if hasattr(self._local_tensor, "__create_write_items__"): - return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] - elif isinstance(self._local_tensor, torch.Tensor): - return [_create_write_items_for_dtensor(fqn, object)] - else: - raise RuntimeError("Unsupported tensor type!") - - def __create_chunk_list__(self): - from torch.distributed.checkpoint.planner_helpers import ( - _create_chunk_from_dtensor, - ) - - if hasattr(self._local_tensor, "__create_chunk_list__"): - return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] - elif isinstance(self._local_tensor, torch.Tensor): - return [_create_chunk_from_dtensor(self)] - else: - raise RuntimeError("Unsupported tensor type!") - - def __get_tensor_shard__(self, index): - if hasattr(self._local_tensor, "__get_tensor_shard__"): - return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] - elif isinstance(self._local_tensor, torch.Tensor): - return self.to_local() - else: - raise RuntimeError("Unsupported tensor type!") - - -def distribute_tensor( - tensor: torch.Tensor, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according - to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the - same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use - the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve - the single-device semantic. If you want to construct a DTensor in the middle of the Autograd - computation, please use ``DTensor.from_local`` instead. - - Args: - tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you - want to shard a tensor on a dimension that is not evenly divisible by - the number of devices in that mesh dimension, we use ``torch.chunk`` - semantic to shard the tensor and scatter the shards. The uneven sharding - behavior is experimental and subject to change. - device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the - tensor, if not specified, must be called under a DeviceMesh context - manager, default: None - placements (List[:class:`Placement`], optional): the placements that - describes how to place the tensor on DeviceMesh, must have the same - number of elements as ``device_mesh.ndim``. If not specified, we will - by default replicate the tensor across the ``device_mesh`` from the - first rank of each dimension of the `device_mesh`. - - Returns: - A :class:`DTensor` or ``XLAShardedTensor`` object. - - .. note:: - When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` - return `XLAShardedTensor` instead. see [link](https://github.com/pytorch/pytorch/issues/92909) - for more details. The XLA integration is experimental and subject to change. - """ - - torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") - - # get default device mesh if there's nothing specified - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - device_type = device_mesh.device_type - if device_type == "xla": - try: - # call PyTorch/XLA SPMD for `xla` backend type device mesh. - # This returns XLAShardedTensor - from torch_xla.distributed.spmd import ( # type:ignore[import] - xla_distribute_tensor, - ) - - return xla_distribute_tensor( - tensor, device_mesh, placements - ) # type:ignore[return-value] - except ImportError as e: - msg = "To use DTensor API with xla, you must install the torch_xla package!" - raise ImportError(msg) from e - - # instantiate a RNG tracker if haven't. By default DTensor uses an - # OffsetBasedRNGTracker to perform random operators. - # TODO: the value assignment to global variable is not the ideal solution - # we can replace it in future. - if not random._rng_tracker and is_rng_supported_mesh(device_mesh): - random._rng_tracker = OffsetBasedRNGTracker(device_type) - - if not tensor.is_leaf: - raise RuntimeError( - "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" - ) - - # convert tensor to the corresponding device type if it's not in that device type - if device_type != tensor.device.type and not tensor.is_meta: - tensor = tensor.to(device_type) - - # set default placements to replicated if not specified - if placements is None: - placements = [Replicate() for _ in range(device_mesh.ndim)] - - if len(placements) != device_mesh.ndim: - raise ValueError( - f"`placements` must have the same length as `device_mesh.ndim`! " - f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." - ) - if isinstance(tensor, DTensor): - # if the tensor is already a DTensor, we need to check: - # 1. if the we can further shard this DTensor if the two device mesh belong to - # the same parenet mesh and further sharding is possible. - # 2. check if device mesh and placements are the same - if tensor.device_mesh != device_mesh: - raise ValueError( - f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " - f"to a different device mesh {device_mesh}." - ) - if tensor.placements != tuple(placements): - raise ValueError( - f"Cannot distribute a DTensor with placements {tensor.placements} " - f"to a different placements {placements}. do you want to call " - f"`redistribute` instead?" - ) - return tensor - - local_tensor = tensor.detach() - - # TODO(xilun): address sharding order - # distribute the tensor according to the placements. - placements = list(placements) - for idx, placement in enumerate(placements): - if placement.is_shard(): - placement = cast(Shard, placement) - if placement.dim < 0: - # normalize shard placement dim - placement = Shard(placement.dim + tensor.ndim) - placements[idx] = placement - local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx) - elif placement.is_replicate(): - placement = cast(Replicate, placement) - local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx) - else: - raise RuntimeError( - f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" - ) - placements = tuple(placements) - - assert local_tensor is not None, "distributing a tensor should not be None" - # detach the local tensor passed to DTensor since after the construction - # of DTensor, autograd would work on top of DTensor instead of local tensor - spec = DTensorSpec( - mesh=device_mesh, - placements=placements, - tensor_meta=TensorMeta( - shape=tensor.size(), - stride=tensor.stride(), - dtype=tensor.dtype, - ), - ) - return DTensor( - local_tensor.requires_grad_(tensor.requires_grad), - spec, - requires_grad=tensor.requires_grad, - ) - - -def distribute_module( - module: nn.Module, - device_mesh: Optional[DeviceMesh] = None, - partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, - input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, - output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, -) -> nn.Module: - """ - This function expose three functions to control the parameters/inputs/outputs of the module: - - 1. To perform sharding on the module before runtime execution by specifying the - ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` - parameters according to the `partition_fn` specified). - 2. To control the inputs or outputs of the module during runtime execution by - specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to - :class:`DTensor`, convert the output back to ``torch.Tensor``) - - Args: - module (:class:`nn.Module`): user module to be partitioned. - device_mesh (:class:`DeviceMesh`): the device mesh to place the module. - partition_fn (Callable): the function to partition parameters (i.e. shard certain - parameters across the ``device_mesh``). If ``partition_fn`` is not specified, - by default we replicate all module parameters of ``module`` across the mesh. - input_fn (Callable): specify the input distribution, i.e. could control how the - input of the module is sharded. ``input_fn`` will be installed as a module - ``forward_pre_hook`` (pre forward hook). - output_fn (Callable): specify the output distribution, i.e. could control how the - output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be - installed as a module ``forward_hook`` (post forward hook). - - Returns: - A module that contains parameters/buffers that are all ``DTensor`` s. - - .. note:: - When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` - return nn.Module with PyTorch/XLA SPMD annotated parameters. See [link](https://github.com/pytorch/pytorch/issues/92909) - for more details. The XLA integration is experimental and subject to change. - - """ - - torch._C._log_api_usage_once("torch.dtensor.distribute_module") - - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - device_type = device_mesh.device_type - if device_type == "xla": - try: - # This function annotates all module parameters for auto-partitioning with - # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters - # according to the `partition_fn` specified. - from torch_xla.distributed.spmd import ( # type:ignore[import] - xla_distribute_module, - ) - - return xla_distribute_module( - module, device_mesh, partition_fn, input_fn, output_fn - ) # type:ignore[return-value] - except ImportError as e: - msg = "To use DTensor API with xla, you must install the torch_xla package!" - raise ImportError(msg) from e - - def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: - # This function loop over the immediate module parameters and - # buffers, replicate all non DTensor params/buffers to DTensor - # parameters/buffers, if they have not been partitioned in the - # partition_fn, we can't easily use `module._apply` here - # because we don't know what happened inside partition_fn as - # user could do anything, i.e. install hooks, and we want to - # preserve those. - full_replicate = [Replicate()] * mesh.ndim - for key, param in m._parameters.items(): - if param is not None and not isinstance(param, DTensor): - m.register_parameter( - key, - nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), - ) - for key, buffer in m._buffers.items(): - if buffer is not None and not isinstance(buffer, DTensor): - m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) - - if partition_fn is None: - # if partition_fn not specified, we by default replicate - # all module params/buffers - for name, submod in module.named_modules(): - replicate_module_params_buffers(submod, device_mesh) - else: - # apply partition_fun to submodules - for name, submod in module.named_modules(): - partition_fn(name, submod, device_mesh) - replicate_module_params_buffers(submod, device_mesh) - - # register input_fn as module forward pre hook - if input_fn is not None: - # check the input_fn signature - num_args = len(inspect.signature(input_fn).parameters) - if num_args == 2: - # input_fn only takes in inputs and device mesh - warnings.warn( - "Deprecating input_fn that takes two arguments (inputs, device_mesh), " - "please use input_fn that takes in (module, inputs, device_mesh) instead!", - FutureWarning, - stacklevel=2, - ) - module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] - elif num_args == 3: - # input_fn takes in module, inputs, device mesh - module.register_forward_pre_hook( - lambda mod, inputs: input_fn(mod, inputs, device_mesh) - ) - else: - raise ValueError( - f"input_fn should take in 3 arguments, but got {num_args} arguments!" - ) - # register output_fn as module forward hook - if output_fn is not None: - num_args = len(inspect.signature(output_fn).parameters) - if num_args == 2: - # output_fn only takes in outputs and device mesh - warnings.warn( - "Deprecating output_fn that takes two arguments (inputs, device_mesh), " - "please use output_fn that takes in (module, inputs, device_mesh) instead!", - FutureWarning, - stacklevel=2, - ) - module.register_forward_hook( - lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] - ) - elif num_args == 3: - module.register_forward_hook( - lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) - ) - else: - raise ValueError( - f"output_fn should take in 3 arguments, but got {num_args} arguments!" - ) - - return module - - -# Below are tensor factory function APIs, which are used to create a DTensor directly. We need -# to make separate factory function APIs because tensor subclass could not override the tensor -# factory methods, and we need user to call the factory functions with user intended device_mesh -# and placements to create a proper DTensor. - - -def _dtensor_init_helper( # type: ignore[no-untyped-def] - init_op, - size: torch.Size, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, - **kwargs, -) -> DTensor: - # from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta - - # if device_mesh is None, use the one from mesh resources - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - kwargs["device"] = device_mesh.device_type - - # set default placements to replicated if not specified - placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) - - # check device_mesh againts placements - assert device_mesh.ndim == len( - placements - ), "mesh dimension does not match the length of placements" - - assert kwargs["layout"] == torch.strided, "layout value not supported!" - torch_stride = torch._prims_common.make_contiguous_strides_for(size) - - # get local tensor shape - local_shape = compute_local_shape(size, device_mesh, placements) - # initialize the local tensor - if init_op == torch.full: - fill_value = kwargs.pop("fill_value", 0) - local_tensor = init_op(local_shape, fill_value, **kwargs) - elif init_op == torch.rand or init_op == torch.randn: - # this tensor meta is not used except `shape` - dtype = kwargs.get("dtype", torch.get_default_dtype()) - - tensor_meta = TensorMeta(size, (0,), dtype) - spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) - - if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: - random._rng_tracker = random.OffsetBasedRNGTracker() - - assert random._rng_tracker is not None - with random._rng_tracker._distribute_region(spec): - local_tensor = init_op(local_shape, **kwargs) - else: - local_tensor = init_op(local_shape, **kwargs) - - spec = DTensorSpec( - device_mesh, - tuple(placements), - tensor_meta=TensorMeta( - size, - torch_stride, - local_tensor.dtype, - ), - ) - - return DTensor( - local_tensor, - spec, - requires_grad=kwargs["requires_grad"], - ) - - -def ones( # type: ignore[no-untyped-def] - *size, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined - by the variable argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.ones, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def empty( # type: ignore[no-untyped-def] - *size, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` - is defined by the variable argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ - layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.empty, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def full( # type: ignore[no-untyped-def] - size, - fill_value, - *, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match - ``device_mesh.device_type``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - fill_value(Scalar): the value to fill the output tensor with. - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.full, - torch_size, - fill_value=fill_value, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def rand( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with random numbers from a uniform distribution - on the interval ``[0, 1)``. The shape of the tensor is defined by the variable - argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.rand, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def randn( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with random numbers from a normal distribution - with mean 0 and variance 1. The shape of the tensor is defined by the variable - argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.randn, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def zeros( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with the scalar value 0. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) - Keyword args: - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. - Default: ``torch.strided``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.zeros, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) +from torch.distributed.tensor._api import * # noqa: F401, F403 diff --git a/torch/distributed/_tensor/experimental/__init__.py b/torch/distributed/_tensor/experimental/__init__.py deleted file mode 100644 index 0b699eaf878d2c..00000000000000 --- a/torch/distributed/_tensor/experimental/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -from contextlib import contextmanager - -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.experimental.func_map import local_map -from torch.distributed._tensor.experimental.register_sharding import register_sharding - - -__all__ = ["implicit_replication", "local_map", "register_sharding"] - - -@contextmanager -def implicit_replication(): - try: - DTensor._op_dispatcher._allow_implicit_replication = True - yield - finally: - DTensor._op_dispatcher._allow_implicit_replication = False diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 135dd8449a702a..6a4e70dbba4554 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -1,910 +1,10 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases -from dataclasses import dataclass -from typing import Any, cast, List, NamedTuple, Optional, Tuple +TODO: throw warnings when this module imported +""" -import torch -import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._collective_utils import ( - fill_empty_tensor_to_shards, - mesh_broadcast, - mesh_scatter, - pad_tensor, - shard_dim_alltoall, - unpad_tensor, -) -from torch.distributed.device_mesh import DeviceMesh - - -__all__ = ["Placement", "Shard", "Replicate", "Partial", "DTensorSpec", "TensorMeta"] - - -class Placement: - # base class Placement type - - # convenient utils to check for placement types - def is_shard(self, dim: Optional[int] = None) -> bool: - is_shard_instance = isinstance(self, Shard) - if dim is not None and is_shard_instance: - return cast(Shard, self).dim == dim - else: - return is_shard_instance - - def is_replicate(self) -> bool: - return isinstance(self, Replicate) - - def is_partial(self) -> bool: - return isinstance(self, Partial) - - -@dataclass(frozen=True) -class Shard(Placement): - """ - The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension - ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the - DeviceMesh dimension only holds a shard/piece of the global Tensor. The - ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the - last few shards on the DeviceMesh dimension might be empty when the tensor dimension - is not evenly divisble on the DeviceMesh dimension. The ``Shard`` placement can be - used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) - - Args: - dim (int): The tensor dimension that describes the DTensor is sharded over its - corresponding DeviceMesh dimension. - - .. warning:: sharding on a tensor dimension where the tensor dimension size is not - evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. - """ - - dim: int - - def _split_tensor( - self, - tensor: torch.Tensor, - num_chunks: int, - *, - with_padding: bool = True, - contiguous: bool = True, - ) -> Tuple[List[torch.Tensor], List[int]]: - """ - This function uses torch.chunk to split a tensor into num_chunks shards along - the Shard placement dimension, and return a list of shards with their pad sizes. - - Keyword args: - with_padding (bool, optional): when True, we pad the tensor on the last - few ranks before calling the collectives (i.e. scatter/all_gather, etc.). - This is because collectives usually require equal size tensor inputs - """ - assert ( - self.dim <= tensor.ndim - ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" - - # chunk tensor over dimension `dim` into n slices - tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) - num_empty_tensors = num_chunks - len(tensor_list) - - # if no need to have padding or tensor dim size is evenly sharded already - # we can return early. - if not with_padding or tensor.size(self.dim) % num_chunks == 0: - if contiguous: - tensor_list = [t.contiguous() for t in tensor_list] - return ( - fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors), - [], - ) - - # compute the chunk size inline with ``torch.chunk`` to calculate padding - full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks - - # Compute chunk size for each chunk for ``self.dim`` - chunk_sizes = [ - tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0 - for idx in range(num_chunks) - ] - # Compute pad size on each chunk - pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] - - # Reuse tensor to fill empty chunk with empty tensor - tensor_list = fill_empty_tensor_to_shards( - tensor_list, self.dim, num_empty_tensors - ) - shard_list = [] - for shard, pad_size in zip(tensor_list, pad_sizes): - # Fill the empty tensor with zeroes with padding. - if with_padding and pad_size > 0: - shard = pad_tensor(shard, self.dim, pad_size) - shard = shard.contiguous() if contiguous else shard - shard_list.append(shard) - return shard_list, pad_sizes - - @staticmethod - def _local_shard_size_on_dim( - size_on_dim: int, - num_chunks: int, - rank: int, - return_offset: bool = False, - ) -> Tuple[int, int]: - """ - returns the local shard size and offset on a given tensor dim - """ - # Compute the chunk size inline with ``torch.chunk`` - if size_on_dim % num_chunks == 0: - full_chunk_size = size_on_dim // num_chunks - return full_chunk_size, full_chunk_size * rank if return_offset else -1 - - # uneven sharding case - full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks - shard_starting_idx = full_chunk_size * rank - - if size_on_dim < shard_starting_idx: - return 0, size_on_dim if return_offset else -1 - else: - local_shard_size = ( - min(size_on_dim, shard_starting_idx + full_chunk_size) - - shard_starting_idx - ) - return local_shard_size, shard_starting_idx if return_offset else -1 - - def _shard_tensor( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - """ - shard and scatter a tensor on a mesh dimension (use coordinate - 0 on the mesh dimension as source of truth) - """ - my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(mesh_dim=mesh_dim) - - if my_coordinate is None: - # if rank is not part of mesh, we simply return an empty tensor - return tensor.new_empty(0, requires_grad=tensor.requires_grad) - - scatter_list, pad_sizes = self._split_tensor( - tensor, num_chunks, with_padding=True, contiguous=True - ) - - mesh_dim_local_rank = my_coordinate[mesh_dim] - output = torch.empty_like(scatter_list[mesh_dim_local_rank]) - mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim) - - # Only unpad if the local_tensor was padded on the dimension. - if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0: - output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank]) - return output - - def _reduce_shard_tensor( - self, - tensor: torch.Tensor, - mesh: DeviceMesh, - reduce_op: str, - mesh_dim: int, - ) -> torch.Tensor: - """ - reduce and scatter a tensor on a mesh dimension - """ - my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(mesh_dim=mesh_dim) - - if my_coordinate is None: - # if rank is not part of mesh, we simply return local_tensor, - # which should be an empty tensor - return tensor - - is_padded = tensor.size(self.dim) % num_chunks != 0 - if is_padded: - scattered_list, pad_sizes = self._split_tensor( - tensor, num_chunks, with_padding=True, contiguous=True - ) - tensor = torch.cat(scattered_list, dim=self.dim) - elif not tensor.is_contiguous(): - tensor = tensor.contiguous() - - output = funcol.reduce_scatter_tensor( - tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) - ) - - if is_padded: - output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined] - return output - - def _to_replicate_tensor( - self, - local_tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - current_logical_shape: List[int], - ) -> torch.Tensor: - """ - This function all_gather all shards and return a tensor that - is replicated on the previously sharded mesh dimension - """ - num_chunks = mesh.size(mesh_dim=mesh_dim) - # check if it's uneven, so we need to pad input tensor before all_gather - local_shape = list(local_tensor.size()) - - logical_dim_size = current_logical_shape[self.dim] - is_padded = logical_dim_size % num_chunks != 0 - - if is_padded: - full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks - pad_size = full_chunk_size - local_shape[self.dim] - local_tensor = pad_tensor(local_tensor, self.dim, pad_size) - - if not local_tensor.is_contiguous(): - local_tensor = local_tensor.contiguous() - - result = funcol.all_gather_tensor( - local_tensor, - gather_dim=self.dim, - group=(mesh, mesh_dim), - ) - if is_padded: - unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] - result = unpad_tensor(result, self.dim, unpad_size) - return result - - def _replicate_to_shard( - self, - local_tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - shard_index: int, - ) -> torch.Tensor: - """ - transform from replicated tensor to a sharded tensor on - the current rank, which would perform a local chunk - """ - num_chunks = mesh.size(mesh_dim=mesh_dim) - shards, _ = self._split_tensor( - local_tensor, - num_chunks, - with_padding=False, - contiguous=False, - ) - return shards[shard_index].clone() - - def _to_new_shard_dim( - self, - local_tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - current_logical_shape: List[int], - new_shard_dim: int, - ) -> torch.Tensor: - """ - transform from existing sharded tensor to a new sharded tensor on - that shard on a new dimension, which performs an alltoall - """ - my_coordinate = mesh.get_coordinate() - if my_coordinate is None: - # if rank is not part of mesh, we simply return local_tensor, - # which should be an empty tensor - return local_tensor - - num_chunks = mesh.size(mesh_dim=mesh_dim) - - old_dim_logical_size = current_logical_shape[self.dim] - new_dim_logical_size = current_logical_shape[new_shard_dim] - old_dim_padding = old_dim_logical_size % num_chunks != 0 - new_dim_padding = new_dim_logical_size % num_chunks != 0 - if old_dim_padding: - old_dim_full_chunk_size = ( - old_dim_logical_size + num_chunks - 1 - ) // num_chunks - old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim) - local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size) - if new_dim_padding: - new_dim_full_chunk_size = ( - new_dim_logical_size + num_chunks - 1 - ) // num_chunks - new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size( - new_shard_dim - ) - local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) - - if not local_tensor.is_contiguous(): - local_tensor = local_tensor.contiguous() - - new_tensor = shard_dim_alltoall( - local_tensor, self.dim, new_shard_dim, mesh, mesh_dim - ) - - if old_dim_padding: - old_dim_unpad_size = ( - old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim] # type: ignore[possibly-undefined] - ) - new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined] - - if new_dim_padding: - local_shard_size_on_new_dim = self._local_shard_size_on_dim( - new_dim_logical_size, num_chunks, my_coordinate[mesh_dim] - )[0] - new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] - new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] - - return new_tensor - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Shard): - return False - return self.dim == other.dim - - def __hash__(self) -> int: - return hash(self.dim) - - def __repr__(self) -> str: - """ - machine readable representation of the Shard placement - """ - return f"Shard(dim={self.dim})" - - def __str__(self) -> str: - """human readable representation of the Shard placement""" - return f"S({self.dim})" - - -# kw_only is only available in python >= 3.10 -kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {} - - -@dataclass(frozen=True, **kw_only_dataclass) -class _StridedShard(Shard): - """ - _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor - is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. - We call this right-to-left sharding which is the opposite of the default - left-to-right sharding. See the example below: - tensor shape: [8, 8] - mesh: [[0, 1], [2, 3]], names=("dp", "tp") - placements: [Shard(0), Shard(0)] - - The default sharding behavior shards the tensor on "dp" mesh dimension first then - "tp" dimension. The sharding result will be: - Rank | Mesh Coordinate | Shard Index - ------------------------------------------------ - 0 | (0, 0) | 0 (row 0-1) - 1 | (0, 1) | 1 (row 2-3) - 2 | (1, 0) | 2 (row 4-5) - 3 | (1, 1) | 3 (row 6-7) - - While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on - "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the - result: - Rank | Mesh Coordinate | Shard Index - ------------------------------------------------ - 0 | (0, 0) | 0 (row 0-1) - 1 | (0, 1) | 2 (row 4-5) - 2 | (1, 0) | 1 (row 2-3) - 3 | (1, 1) | 3 (row 6-7) - - The consequence is, any attempt to redistribute this DTensor to a full replica will - produce a wrong result because the shard-to-replicate redistribution always happens - right-to-left, regardless it's left-to-right sharding or right-to-left. To address - this, we use _StridedShard placement to make this right-to-left sharding compatible - with our left-to-right convention on both tensor distribution and redistribution. - - Now with _StridedShard, the right-to-left sharding above can be represented as: - tensor shape: [8, 8] - mesh: [[0, 1], [2, 3]], names=("dp", "tp") - placements: [_StridedShard(0, split_factor=2), Shard(0)] - - And a left-to-right processing of `placements` will produce the same result, which is - different from using the `Shard` placement: - Rank | Mesh Coordinate | Shard Index - ------------------------------------------------ - 0 | (0, 0) | 0 (row 0-1) - 1 | (0, 1) | 2 (row 4-5) - 2 | (1, 0) | 1 (row 2-3) - 3 | (1, 1) | 3 (row 6-7) - - The argument `split_factor` is the number of existing shards over the tensor sharding - dimension before processing the _StridedShard placement, as if the sharding happened - right-to-left. In the example above, the tensor should first be sharded on the "tp" - dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the - `split_factor` of the _StridedShard placement on "dp" dim is 2. - - TODO: strided sharding needs to work fine with uneven sharding. Now it forbids - resharding if the tensor is unevenly sharded. - TODO: we should remove _StridedShard placement once we can unify it with Shard - """ - - split_factor: int - - def __eq__(self, other: object) -> bool: - if isinstance(other, _StridedShard): - return self.dim == other.dim and self.split_factor == other.split_factor - elif isinstance(other, Shard): - # TODO: this is to avoid extra all-gather in dtensor op dispatch - # note that sharding prop would not produce _StridedShard and an - # placement inequality would introduce an all-gather for resharding - return self.dim == other.dim - return False - - def __hash__(self) -> int: - return hash((self.dim, self.split_factor)) - - def __repr__(self) -> str: - """ - machine readable representation of the _StridedShard placement - """ - return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" - - def __str__(self) -> str: - """human readable representation of the _StridedShard placement""" - return f"_S({self.dim}, {self.split_factor})" - - def _split_tensor( - self, - tensor: torch.Tensor, - num_chunks: int, - *, - with_padding: bool = True, - contiguous: bool = True, - ) -> Tuple[List[torch.Tensor], List[int]]: - """ - TODO: currently _StridedShard does not support padding - """ - assert ( - self.dim <= tensor.ndim - ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" - - total_split = num_chunks * self.split_factor - assert tensor.size(self.dim) % total_split == 0, ( - "_StridedShard currently only allows even sharding but got tensor size" - f" {tensor.size(self.dim)} on dim {self.dim} and total split" - f" {total_split}={num_chunks} * {self.split_factor}" - ) - - group_size = self.split_factor - total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim)) - tensor_list = [ - torch.cat( - [ - total_split_tensor_list[i + j * num_chunks] # stride is num_chunks - for j in range(group_size) - ], - dim=self.dim, - ) - for i in range(num_chunks) - ] - - if contiguous: - tensor_list = [t.contiguous() for t in tensor_list] - - return tensor_list, [] - - def _to_replicate_tensor( - self, - local_tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - current_logical_shape: List[int], - ) -> torch.Tensor: - """ - Note: currently _StridedShard does not support padding - """ - num_chunks = mesh.size(mesh_dim=mesh_dim) - total_split = num_chunks * self.split_factor - # NOTE: we require Strided Sharding to be even for now - assert current_logical_shape[self.dim] % total_split == 0, ( - "_StridedShard requires even sharding but got tensor size " - f"{current_logical_shape[self.dim]} on dim {self.dim} and " - f"total split {total_split}=num_chunks {num_chunks} " - f"* split_factor {self.split_factor}" - ) - - result = funcol.all_gather_tensor( - local_tensor, - gather_dim=self.dim, - group=(mesh, mesh_dim), - ) - if isinstance(result, funcol.AsyncCollectiveTensor): - result = result.wait() - - tensor_shard_list = torch.chunk(result, total_split, dim=self.dim) - # rearrange the order - new_tensor_shard_list = [] - for idx in range(len(tensor_shard_list)): - # the shard split of index `idx` is assigned a new index within - # _StridedShard._split_tensor: - # the original tensor was split into `total_split` chunks, - # all chunks with the same `idx % num_chunks` are merged into one - # new shard and placed on mesh's local rank `idx % num_chunks` - idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks - new_tensor_shard_list.append(tensor_shard_list[idx_after_split]) - - return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous() - - -@dataclass(frozen=True) -class Replicate(Placement): - """ - The ``Replicate()`` placement describes the DTensor replicating on a corresponding - ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a - replica of the global Tensor. The ``Replicate`` placement can be used by all - DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) - """ - - def __eq__(self, other: object) -> bool: - return isinstance(other, Replicate) - - def __hash__(self) -> int: - # every replicate placement is the same - return -1 - - def __repr__(self) -> str: - """ - machine readable representation of the Replicate placement - """ - return "Replicate()" - - def __str__(self) -> str: - """ - human readable representation of the Replicate placement - """ - return "R" - - def _replicate_tensor( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - """ - Replicate (broadcast) a torch.Tensor on a mesh dimension (use - the first coordinate on the mesh dimension as source of truth) - """ - my_coordinate = mesh.get_coordinate() - if my_coordinate is None: - # if rank is not part of mesh, we simply return an empty tensor - return tensor.new_empty(0, requires_grad=tensor.requires_grad) - - tensor = tensor.contiguous() - mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim) - return tensor - - -@dataclass(frozen=True) -class Partial(Placement): - """ - The ``Partial(reduce_op)`` placement describes the DTensor that is pending - reduction on a specified ``DeviceMesh`` dimension, where each rank on the - DeviceMesh dimension holds the partial value of the global Tensor. User can - redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` - placement on the specified ``DeviceMesh`` dimension using ``redistribute``, - which would trigger necessary communication operations under the hood (i.e. - ``allreduce``, ``reduce_scatter``). - - Args: - reduce_op (str, optional): The reduction op to be used for the partial DTensor - to produce Replicated/Sharded DTensor. Only element-wise reduction operations - are supported, including: "sum", "avg", "product", "max", "min", default: "sum". - - .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, - and can only be used by the ``DTensor.from_local`` API. - """ - - reduce_op: str = "sum" - - def _reduce_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - # Partial placement contract #1: - # _reduce_value: reduce the value of the tensor on the mesh dimension - return funcol.all_reduce( - tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) - ) - - def _reduce_shard_value( - self, - tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - shard_spec: Placement, - ) -> torch.Tensor: - # Partial placement contract #2: - # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension - shard_spec = cast(Shard, shard_spec) - return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) - - def _partition_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - # Partial placement contract #3: - # _partition_value: partition the value of a replicated tensor on the mesh dimension - - # _partition_value is the conjugate operation of _reduce_value - # - i.e. _partition_value on a sum reduce op is just a divison operation - # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation - # TODO: if the reduce_op is min/max, etc. the _partition_value should be a - # different operation - assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" - num_chunks = mesh.size(mesh_dim=mesh_dim) - return tensor / num_chunks - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Partial): - return False - return self.reduce_op == other.reduce_op - - def __hash__(self) -> int: - return 1 + hash(self.reduce_op) - - def __repr__(self) -> str: - """ - machine readable representation of the Partial placement - """ - return f"Partial({self.reduce_op})" - - def __str__(self) -> str: - """ - human readable representation of the Partial placement - """ - return "P" - - -# We keep the old _Partial name for a while for BC reason -_Partial = Partial - - -class TensorMeta(NamedTuple): - # simple named tuple to represent tensor metadata - # intentionally to stay simple only for sharding - # propagation purposes. - shape: torch.Size - stride: Tuple[int, ...] - dtype: torch.dtype - - -# used internally to propagate the placements -@dataclass -class DTensorSpec: - mesh: DeviceMesh - placements: Tuple[Placement, ...] - - # tensor meta will only be set during sharding propagation - tensor_meta: Optional[TensorMeta] = None - - def __post_init__(self): - if not isinstance(self.placements, tuple): - self.placements = tuple(self.placements) - self._hash: Optional[int] = None - - def __setattr__(self, attr: str, value: Any): - super().__setattr__(attr, value) - # Make sure to recompute the hash in case any of the hashed attributes - # change (though we do not expect `mesh` or `placements` to change) - if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): - self._hash = None - - def _hash_impl(self) -> int: - # hashing and equality check for DTensorSpec are used to cache the sharding - # propagation results. We only need to consider the mesh, placements, shape - # dtype and stride. - # Caveat: we need to keep this in mind and sync hash and eq if we add more - # fields to them. - if self.tensor_meta is not None: - return hash( - ( - self.mesh, - self.placements, - self.tensor_meta.shape, - self.tensor_meta.stride, - self.tensor_meta.dtype, - ) - ) - return hash((self.mesh, self.placements)) - - def __hash__(self) -> int: - # We lazily cache the spec to avoid recomputing the hash upon each - # use, where we make sure to update the hash when the `tensor_meta` - # changes by overriding `__setattr__`. This must be lazy so that Dynamo - # does not try to hash non-singleton `SymInt`s for the stride. - if self._hash is None: - self._hash = self._hash_impl() - return self._hash - - def __eq__(self, __o: object) -> bool: - if not ( - isinstance(__o, DTensorSpec) - and self.mesh == __o.mesh - and self.placements == __o.placements - ): - return False - if self.tensor_meta is None or __o.tensor_meta is None: - return self.tensor_meta == __o.tensor_meta - - return ( - self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr] - and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr] - and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr] - ) - - def __str__(self) -> str: - """ - human readable representation of the DTensorSpec - """ - if len(self.placements) == 1: - placement_str = str(self.placements[0]) - else: - placement_str = str(self.placements) - - if self.tensor_meta is not None: - tensor_shape = str(tuple(self.tensor_meta.shape)) - else: - tensor_shape = "unknown shape" - - return f"Spec({placement_str} on {tensor_shape})" - - @property - def shape(self) -> torch.Size: - if self.tensor_meta is None: - raise ValueError("tensor_meta is not set") - return self.tensor_meta.shape - - @property - def stride(self) -> Tuple[int, ...]: - if self.tensor_meta is None: - raise ValueError("tensor_meta is not set") - return self.tensor_meta.stride - - @property - def ndim(self) -> int: - if self.tensor_meta is None: - raise ValueError("tensor_meta is not set") - return len(self.tensor_meta.shape) - - @property - def num_shards(self) -> int: - num_shards = 1 - for i, placement in enumerate(self.placements): - if placement.is_shard(): - num_shards *= self.mesh.size(i) - return num_shards - - @property - def device_mesh(self) -> DeviceMesh: - # simple aliasing for the mesh field, make some - # checks that mixes DTensor/DTensorSpec easier - return self.mesh - - @property - def dim_map(self) -> List[int]: - """ - dim_map is a property we derive from `placements` of - the distributed tensor. It simply return a list of ints - where dim_map[i] denotes the sharding mapping to the mesh - dimension, and len(dim_map) == dist_tensor.ndim - dim_map[i] = -1: means tensor dim i replicate on mesh - dim_map[i] = j: means tensor dim i shard on mesh dim j - - For example, we have a dist tensor that have the shape of - [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: - [Shard(1)], the dim_map of this placement would be: - [-1, 0, -1]. This representation is pretty helpful during - sharding propagation where we could know exactly each - tensor dimension is sharded or not. - - Note that if placements contains `_Partial`, we have to - explicitly deal with it, so that when we create a DTensorSpec - with dim_map, we could properly record the pending sums. - """ - # dims mapping of dist tensor sharding - # return size of tensor ndim, -1 represent replicate - # and int >=0 represent shard on that device mesh dim - r = [-1] * self.ndim - for i, placement in enumerate(self.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim - if r[shard_dim] > -1: - raise ValueError( - f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," - " DTensor operator implementation does not support things like hybrid" - " sharding strategies yet (i.e. [Shard(0), Shard(0)])" - ) - r[shard_dim] = i - return r - - @property - def num_shards_map(self) -> List[int]: - """ - dim_map is a property we derive from `placements` of - the distributed tensor. Unlike `dim_map`, `num_shards_map` - denotes how many shards each tensor dim has. Like `dim_map`: - len(num_shards_map) == dist_tensor.ndim - num_shards_map[i] = 1: means tensor dim i is not sharded - num_shards_map[i] = j: means tensor dim i has j shards in total - - For example, we have a dist tensor of shape [18, 20, 30], - a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements - ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor - would be: [4, 2, 1]. - """ - r = [1] * self.ndim - for i, placement in enumerate(self.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim - r[shard_dim] *= self.mesh.size(i) - - return r - - @property - def sums(self) -> List[int]: - """ - sums is a property we derive from `placements` of the - distributed tensor. It simply return a list of ints where - sums[i] denotes the pending sum (partial) on mesh dim i - """ - return [ - idx - for idx, placement in enumerate(self.placements) - if placement.is_partial() - ] - - @classmethod - def from_dim_map( - cls, - mesh: DeviceMesh, - dim_map: List[int], - sums: List[int], - tensor_meta: Optional[TensorMeta] = None, - ) -> "DTensorSpec": - """ - Construct a DTensorSpec from dim_map list and pending sum. - - Args: - mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec - dim_map (List[int]): a list of integer that represents sharding on each - tensor dimension, see `dim_map` property doc for details - sums (List[int]): a list of integer that represents the dist tensor have - pending sum on which device mesh dimension. - tensor meta (TensorMeta): DTensor metadata - - Return: - a class:`DTensorSpec` object - """ - # by default replicate on device mesh dims - placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] - - # find all mesh dims that need pending reductions - for s in sums: - placements[s] = Partial() - - for i, m in enumerate(dim_map): - if m >= 0: - placement = placements[m] - if placement.is_shard(): - placement = cast(Shard, placement) - raise RuntimeError( - f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" - ) - elif placement.is_partial(): - raise RuntimeError( - f"DeviceMesh dimension {m} cannot be both shard and partial!" - ) - placements[m] = Shard(i) - - return cls(mesh, tuple(placements), tensor_meta=tensor_meta) - - def is_replicated(self): - """ - return True if the current DTensorSpec replicates on all mesh dims (devices) - """ - return all(placement.is_replicate() for placement in self.placements) - - def is_sharded(self): - """ - return True if the current DTensorSpec is sharded on any mesh dims (devices) - """ - return any(placement.is_shard() for placement in self.placements) - - def shallow_copy_with_tensor_meta( - self, tensor_meta: Optional[TensorMeta] - ) -> "DTensorSpec": - """ - Shallow copy the DTensorSpec with a new tensor_meta. - """ - assert tensor_meta is not None, "shallow copy with no tensor_meta!" - return DTensorSpec( - self.mesh, - self.placements, - tensor_meta=tensor_meta, - ) +from torch.distributed.tensor._dtensor_spec import * # noqa: F401, F403 +from torch.distributed.tensor.placement_types import * # noqa: F401, F403 diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index db7957e814604c..2c79013abeb269 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -14,8 +14,8 @@ import torch from torch.distributed._shard.sharded_tensor.api import ShardedTensor -from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.tensor import DTensor PATH_ITEM = Union[str, int] diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index b0325b195c82d4..67f3c73bca51da 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -11,7 +11,6 @@ import torch from torch.distributed._shard._utils import narrow_tensor_by_index -from torch.distributed._tensor import DTensor from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans from torch.distributed.checkpoint._nested_dict import ( FLATTEN_MAPPING, @@ -45,6 +44,7 @@ _init_state_dict, ) from torch.distributed.checkpoint.utils import find_state_dict_object +from torch.distributed.tensor import DTensor from . import _version diff --git a/torch/distributed/checkpoint/examples/async_checkpointing_example.py b/torch/distributed/checkpoint/examples/async_checkpointing_example.py index 589f9b93544289..0f8e392f4e9c36 100644 --- a/torch/distributed/checkpoint/examples/async_checkpointing_example.py +++ b/torch/distributed/checkpoint/examples/async_checkpointing_example.py @@ -11,12 +11,12 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.checkpoint.state_dict import ( _patch_model_state_dict, _patch_optimizer_state_dict, ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.tensor.device_mesh import init_device_mesh DEVICE = "cuda" diff --git a/torch/distributed/checkpoint/examples/stateful_example.py b/torch/distributed/checkpoint/examples/stateful_example.py index 9624229b53f79b..1e5e2f7c967b63 100644 --- a/torch/distributed/checkpoint/examples/stateful_example.py +++ b/torch/distributed/checkpoint/examples/stateful_example.py @@ -12,11 +12,11 @@ import torch.distributed.checkpoint as dcp import torch.multiprocessing as mp import torch.nn as nn -from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.checkpoint.state_dict import ( _patch_model_state_dict, _patch_optimizer_state_dict, ) +from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index f37fae5c0a87e7..5ad170768db929 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -12,7 +12,6 @@ ) from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec -from torch.distributed._tensor import DTensor from torch.distributed.checkpoint._nested_dict import unflatten_state_dict from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from torch.distributed.checkpoint.metadata import ( @@ -39,6 +38,7 @@ from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DTensor STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index 17ab4b027d3a37..d715c651e024ba 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -7,8 +7,8 @@ from torch._utils import _get_device_module from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharded_tensor import ShardedTensor -from torch.distributed._tensor import DTensor -from torch.distributed._tensor._utils import compute_local_shape_and_global_offset +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset from .metadata import ( BytesStorageMetadata, diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index cf20ecbfb213a3..bd0bff271a5974 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -32,7 +32,6 @@ _offload_state_dict_to_cpu, _unflatten_state_dict, ) -from torch.distributed._tensor import DTensor from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_PREFIX, ) @@ -50,6 +49,7 @@ _get_module_fsdp_state_if_fully_sharded_module, FSDP_WRAPPED_MODULE, ) +from torch.distributed.tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils._pytree import tree_map_only diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 8bc975dc72fd5a..ede7d06ec9a1de 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -1887,7 +1887,7 @@ def _use_unsharded_views(self, as_params: bool) -> None: flat_param = self.flat_param self._check_unsharded(flat_param) views = self._get_unflat_views() - from torch.distributed._tensor import DTensor + from torch.distributed.tensor import DTensor for i, (view, (param_name, module, _)) in enumerate( zip(views, flat_param._param_infos) @@ -2717,7 +2717,7 @@ def _warn_use_fake_reduce(log: logging.Logger, warning: str): def _same_storage(a, b): # Params are DTensors in backward # with SHARD_GRAD_OP + TP - from torch.distributed._tensor import DTensor + from torch.distributed.tensor import DTensor if isinstance(a, DTensor): a = a._local_tensor diff --git a/torch/distributed/fsdp/_fsdp_extensions.py b/torch/distributed/fsdp/_fsdp_extensions.py index 7d5d1e5f98a342..d96fa99a236571 100644 --- a/torch/distributed/fsdp/_fsdp_extensions.py +++ b/torch/distributed/fsdp/_fsdp_extensions.py @@ -5,12 +5,12 @@ import torch.distributed as dist from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._shard.sharded_tensor.shard import Shard -from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.fsdp._shard_utils import ( _all_gather_dtensor, _create_chunk_dtensor, _create_chunk_sharded_tensor, ) +from torch.distributed.tensor import DeviceMesh, DTensor class FSDPExtensions(ABC): diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 4cfe761769a3b9..35beee36ef5839 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -27,7 +27,6 @@ import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.nn as nn from torch.distributed._state_dict_utils import _gather_state_dict -from torch.distributed._tensor import DTensor, Replicate from torch.distributed.distributed_c10d import _get_pg_default_device from torch.distributed.fsdp._common_utils import ( _apply_to_modules, @@ -53,6 +52,7 @@ StateDictSettings, StateDictType, ) +from torch.distributed.tensor import DTensor, Replicate from torch.utils._pytree import tree_map_only diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py index b6cd1e108fb767..9769d5296d4efb 100644 --- a/torch/distributed/fsdp/_shard_utils.py +++ b/torch/distributed/fsdp/_shard_utils.py @@ -15,7 +15,7 @@ TensorProperties, ) from torch.distributed._shard.sharding_spec import ShardMetadata -from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard def _get_remote_device_str(rank, device_type, num_devices_per_node): diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index a64c6413053ecc..06dfb00d71e898 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -25,7 +25,6 @@ Shard, ShardedTensor, ) -from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import _mesh_resources from torch.distributed.fsdp._common_utils import ( _FSDPState, @@ -49,6 +48,7 @@ ShardingStrategy, StateDictType, ) +from torch.distributed.tensor import DTensor from torch.distributed.utils import _replace_by_prefix from ._fsdp_extensions import ( diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 27f375da6f0669..3ff9e80ca12e3c 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -25,7 +25,6 @@ import torch.distributed as dist import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.nn as nn -from torch.distributed._tensor import DeviceMesh from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_WRAPPED_MODULE, ActivationWrapper, @@ -84,6 +83,7 @@ StateDictSettings, StateDictType, ) +from torch.distributed.tensor import DeviceMesh from torch.distributed.utils import _p_assert from ._flat_param import FlatParameter, FlatParamHandle diff --git a/torch/distributed/_tensor/README.md b/torch/distributed/tensor/README.md similarity index 100% rename from torch/distributed/_tensor/README.md rename to torch/distributed/tensor/README.md diff --git a/torch/distributed/tensor/__init__.py b/torch/distributed/tensor/__init__.py index e69de29bb2d1d6..f2746f60ba8720 100644 --- a/torch/distributed/tensor/__init__.py +++ b/torch/distributed/tensor/__init__.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +import torch.distributed.tensor._ops # force import all built-in dtensor ops +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 +from torch.distributed.tensor._api import ( + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + ones, + rand, + randn, + zeros, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.optim.optimizer import ( + _foreach_supported_types as _optim_foreach_supported_types, +) +from torch.utils._foreach_utils import ( + _foreach_supported_types as _util_foreach_supported_types, +) + + +# All public APIs from dtensor package +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "Shard", + "Replicate", + "Partial", + "Placement", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + + +# Append DTensor to the list of supported types for foreach implementation for optimizer +# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. +if DTensor not in _optim_foreach_supported_types: + _optim_foreach_supported_types.append(DTensor) + +if DTensor not in _util_foreach_supported_types: + _util_foreach_supported_types.append(DTensor) + + +# Set namespace for exposed private names +DTensor.__module__ = "torch.distributed.tensor" +distribute_tensor.__module__ = "torch.distributed.tensor" +distribute_module.__module__ = "torch.distributed.tensor" +ones.__module__ = "torch.distributed.tensor" +empty.__module__ = "torch.distributed.tensor" +full.__module__ = "torch.distributed.tensor" +rand.__module__ = "torch.distributed.tensor" +randn.__module__ = "torch.distributed.tensor" +zeros.__module__ = "torch.distributed.tensor" diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py new file mode 100644 index 00000000000000..ec2993690e74e2 --- /dev/null +++ b/torch/distributed/tensor/_api.py @@ -0,0 +1,1231 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import inspect +import warnings +from typing import Any, Callable, cast, Optional, Sequence, Tuple + +import torch +import torch.distributed.tensor._dispatch as op_dispatch +import torch.distributed.tensor._random as random +import torch.nn as nn +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._random import ( + is_rng_supported_mesh, + OffsetBasedRNGTracker, +) +from torch.distributed.tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, +) +from torch.distributed.tensor._utils import ( + compute_global_tensor_info, + compute_local_shape, + normalize_to_torch_size, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +aten = torch.ops.aten + + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure DTensor to work +# together with torch.Tensor within the autograd engine. This +# allows DTensor to only exist on part of the module hierarchy. +# +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DTensor params, the following forward/backward should work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input (from_local) -> Sharded Module B -> DTensor output +# -> torch.Tensor output (to_local) -> Module C +# +# So from_local/to_local must be Autograd functions. +# +class _ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + input: "DTensor", + grad_placements: Optional[Sequence[Placement]], + ): + ctx.dtensor_spec = input._spec + ctx.grad_placements = grad_placements + local_tensor = input._local_tensor + + # We need to return a fresh Tensor object there as autograd metadata + # will be inplaced into it. So we don't want to pollute the Tensor + # object stored in the _local_tensor of this DTensor. + return local_tensor.view_as(local_tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + + return ( + DTensor( + grad_output, + grad_spec, + requires_grad=grad_output.requires_grad, + ), + None, + ) + + +class _FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: Tuple[Placement, ...], + run_check: bool, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + + if shape and stride: + tensor_shape, tensor_stride = shape, stride + elif not shape and not stride: + # if it's not by default run_check, we assume user is certain that each + # rank has the same tensor shape, and we just use that to calculate the + # global shape + global_shape, global_stride = compute_global_tensor_info( + input, device_mesh, placements + ) + tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) + else: + raise RuntimeError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time.", + ) + + if device_mesh.get_coordinate() is None: + # if the global rank is not participating in the device mesh, we + # simply set the local tensor to an empty tensor + input = input.new_empty(0, requires_grad=input.requires_grad) + elif run_check: + # TODO: support uneven sharding when global shape/stride not passed, by + # building the global TensorMeta during check_tensor_meta + check_shape_stride = not shape and not stride + check_tensor_meta(input, check_shape_stride=check_shape_stride) + # TODO: See if we need to make this run_check logic + # have a corresponding backward. + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + # only broadcast if run_check is True + input = input.contiguous() + mesh_broadcast(input, device_mesh, mesh_dim=idx) + + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + + # We want a fresh Tensor object that shares memory with the input tensor + dist_tensor = DTensor( + input.view_as(input), + dist_spec, + # requires_grad of the dist tensor depends on if input + # requires_grad or not + requires_grad=input.requires_grad, + ) + return dist_tensor + + @staticmethod + def backward(ctx, grad_output: "DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + current_spec = grad_output._spec + target_spec = DTensorSpec( + previous_device_mesh, + previous_placement, + tensor_meta=grad_output._spec.tensor_meta, + ) + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, is_backward=True + ) + # TODO: return the redistributed local tensor directly without + # differentiable backward. see if this make sense for all cases. + return output, None, None, None, None, None + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return grad_output.to_local(), None, None, None, None, None + + +class DTensor(torch.Tensor): + """ + ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like + abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding + layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: + + * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension + * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension + * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension + + When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue + communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the + placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. + + To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` + requires every Tensor argument of the operator be DTensor. + + """ + + _local_tensor: torch.Tensor + _spec: DTensorSpec + __slots__ = ["_local_tensor", "_spec"] + + # _op_dispatcher instance as a class attribute to handle runtime dispatching logic + _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() + + @staticmethod + @torch._disable_dynamo + def __new__( + cls, + local_tensor: torch.Tensor, + spec: DTensorSpec, + *, + requires_grad: bool, + ) -> "DTensor": + """ + Construct a DTensor from a local tensor, device mesh, and placement and + other tensor properties (i.e. shape, requires_grad, strides, etc). + + .. note:: This is not a public API and it's only supposed to be used by the + operator implementations and internals. If you want to construct a + DTensor from a local tensor, consider using ``DTensor.from_local``, if + you want to construct a DTensor from a "global" tensor (where you + already have tensor initialized and want to shard this tensor), + consider using ``distribute_tensor``. + """ + if local_tensor.requires_grad and not requires_grad: + warnings.warn( + "To construct DTensor from torch.Tensor, it's recommended to " + "use local_tensor.detach() and make requires_grad consistent." + ) + + # new method instruct wrapper tensor from local_tensor and add + # placement spec, it does not do actual distribution + assert spec.tensor_meta is not None, "TensorMeta should not be None!" + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + spec.tensor_meta.shape, + strides=spec.tensor_meta.stride, + dtype=local_tensor.dtype, + device=local_tensor.device, + layout=local_tensor.layout, + requires_grad=requires_grad, + ) + + r._spec = spec + r._local_tensor = local_tensor + return r + + # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): + # TODO: consider all_gather the local tensors for better debugging + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + + def __tensor_flatten__(self): + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + return ["_local_tensor"], (self._spec, self.requires_grad) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + assert ( + flatten_spec is not None + ), "Expecting spec to be not None from `__tensor_flatten__` return value!" + local_tensor = inner_tensors["_local_tensor"] + spec, requires_grad = flatten_spec + unflatten_tensor_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + return DTensor( + local_tensor, + unflatten_spec, + requires_grad=requires_grad, + ) + + def __coerce_tangent_metadata__(self): + if not any(isinstance(p, Partial) for p in self.placements): + return self + placements = [ + Replicate() if isinstance(p, Partial) else p for p in self.placements + ] + return self.redistribute(device_mesh=self.device_mesh, placements=placements) + + def __coerce_same_metadata_as_tangent__(self, flatten_spec): + (spec, _) = flatten_spec # Result of tensor_flatten() + return self.redistribute( + device_mesh=self.device_mesh, + placements=spec.placements, + ) + + @classmethod + @torch._disable_dynamo + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + return DTensor._op_dispatcher.dispatch( + func, + args, + kwargs or {}, + ) + + @staticmethod + def from_local( + local_tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + run_check: bool = False, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + ) -> "DTensor": + """ + Create a :class:`DTensor` from a local torch.Tensor on each rank + according to the ``device_mesh`` and ``placements`` specified. + + Args: + local_tensor (torch.Tensor): local torch.Tensor on each rank. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + tensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the local torch.Tensor on DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + + Keyword args: + run_check (bool, optional): at a cost of extra communications, perform + sanity check across ranks to check each local tensor's meta information + to ensure correctness. If have :class:`Replicate` in ``placements``, the + data on first rank of the device mesh dimension will be broadcasted + to other ranks. default: False + shape (torch.Size, optional): A List of int which specifies the size of + DTensor which build on top of `local_tensor`. Note this needs to be + provided if the shape of ``local_tensor`` are different across the ranks. + If not provided, ``shape`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + stride (tuple, optional): A List of int which specifies the stride of DTensor. + If not provided, ``stride`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + + Returns: + A :class:`DTensor` object + + .. note:: When ``run_check=False``, it is the user's responsibility to ensure the + local tensor passed in is correct across ranks (i.e. the tensor is sharded for + the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). + If not, the behavior of the created DTensor is undefined. + + .. note:: ``from_local`` is differentiable, the `requires_grad` of the created + `DTensor` object will depend on if `local_tensor` requires_grad or not. + """ + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + else: + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + placements[idx] = Shard(placement.dim + local_tensor.ndim) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, + device_mesh, + tuple(placements), + run_check, + shape, + stride, + ) + + def to_local( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Get the local tensor of this DTensor on its current rank. For sharding it returns + a local shard of the logical tensor view, for replication it returns the replica on + its current rank. + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the Tensor returned from this + function. + `to_local` converts DTensor to local tensor and the returned local tensor + might not be used as the original DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original DTensor layout. + If not specified, we will assume the gradient layout remains the same + as the original DTensor and use that for gradient computation. + + Returns: + A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the + local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, + it means the local tensor is not ready yet (i.e. communication is not finished). In this + case, user needs to call ``wait`` to wait the local tensor to be ready. + + .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned + will depend on if the `DTensor` requires_grad or not. + """ + if not torch.is_grad_enabled(): + return self._local_tensor + + if grad_placements is not None and not isinstance(grad_placements, tuple): + grad_placements = tuple(grad_placements) + return _ToTorchTensor.apply( + self, grad_placements + ) # pyre-ignore[16]: autograd func + + def redistribute( + self, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + async_op: bool = False, + ) -> "DTensor": + """ + ``redistribute`` performs necessary collective operations that redistribute the current + DTensor from its current placements to a new placements, or from is current DeviceMesh + to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by + specifying a Replicate placement for each dimension of the DeviceMesh. + + When redistributing from current to the new placements on one device mesh dimension, we + will perform the following operations including communication collective or local operation: + + 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` + 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` + 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) + 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` + 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` + + + ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors + that are created either on 1-D or N-D DeviceMesh. + + Args: + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. If not specified, it would use the current DTensor's DeviceMesh. + default: None + placements (List[:class:`Placement`], optional): the new placements that + describes how to place the DTensor into the DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + default: replicate on all mesh dimensions + + Keyword args: + async_op (bool, optional): whether to perform the DTensor redistribute operation + asynchronously or not. Default: False + + Returns: + A :class:`DTensor` object + + .. note:: ``redistribute`` is differentiable, which means user do not need to worry about + the backward formula of the redistribute operation. + + .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, + Please file an issue if you need to redistribute DTensor to different DeviceMesh. + """ + # NOTE: This redistribute API currently only supports out + # of place redistribution, i.e. it always create a new + # DTensor object and leave the original one unchanged. + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self.device_mesh + # raise error if new placements not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + + placements = list(placements) + for i, placement in enumerate(placements): + if placement.is_partial(): + raise RuntimeError( + "Can not redistribute to Partial, redistributing to Partial is for internal use only!" + ) + elif isinstance(placement, Shard) and placement.dim < 0: + # normalize shard dim to be positive + placements[i] = Shard(placement.dim + self.ndim) + placements = tuple(placements) + + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + return Redistribute.apply(self, device_mesh, placements, async_op) + + def full_tensor( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Return the full tensor of this DTensor. It will perform necessary collectives + to gather the local tensors from other ranks in its DeviceMesh and concatenate + them together. It's a syntatic sugar of the following code: + + ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the full Tensor returned from this + function. + `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor + might not be used as the original replicated DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original replicated DTensor layout. + If not specified, we will assume the gradient layout of the full tensor be replicated. + + Returns: + A :class:`torch.Tensor` object that represents the full tensor of this DTensor. + + .. note:: ``full_tensor`` is differentiable. + """ + + redist_res = self.redistribute( + placements=[Replicate()] * self.device_mesh.ndim, async_op=False + ) + return _ToTorchTensor.apply(redist_res, grad_placements) + + @property + def device_mesh(self) -> DeviceMesh: + """ + The :class:`DeviceMesh` attribute that associates with this DTensor object. + + .. note:: ``device_mesh`` is a read-only property, it can not be set. + """ + return self._spec.mesh + + @property + def placements(self) -> Tuple[Placement, ...]: + """ + The placements attribute of this DTensor that describes the layout of this + DTensor on the its DeviceMesh. + + .. note:: ``placements`` is a read-only property, it can not be set. + """ + return self._spec.placements + + def __create_write_items__(self, fqn: str, object: Any): + from torch.distributed.checkpoint.planner_helpers import ( + _create_write_items_for_dtensor, + ) + + if hasattr(self._local_tensor, "__create_write_items__"): + return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_write_items_for_dtensor(fqn, object)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __create_chunk_list__(self): + from torch.distributed.checkpoint.planner_helpers import ( + _create_chunk_from_dtensor, + ) + + if hasattr(self._local_tensor, "__create_chunk_list__"): + return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_chunk_from_dtensor(self)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __get_tensor_shard__(self, index): + if hasattr(self._local_tensor, "__get_tensor_shard__"): + return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return self.to_local() + else: + raise RuntimeError("Unsupported tensor type!") + + +def distribute_tensor( + tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according + to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the + same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use + the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve + the single-device semantic. If you want to construct a DTensor in the middle of the Autograd + computation, please use :meth:`DTensor.from_local` instead. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use ``torch.chunk`` + semantic to shard the tensor and scatter the shards. The uneven sharding + behavior is experimental and subject to change. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as ``device_mesh.ndim``. If not specified, we will + by default replicate the tensor across the ``device_mesh`` from the + first rank of each dimension of the `device_mesh`. + + Returns: + A :class:`DTensor` or ``XLAShardedTensor`` object. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` + return `XLAShardedTensor` instead. see `this issue `__ + for more details. The XLA integration is experimental and subject to change. + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") + + # get default device mesh if there's nothing specified + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # call PyTorch/XLA SPMD for `xla` backend type device mesh. + # This returns XLAShardedTensor + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_tensor, + ) + + return xla_distribute_tensor( + tensor, device_mesh, placements + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + # instantiate a RNG tracker if haven't. By default DTensor uses an + # OffsetBasedRNGTracker to perform random operators. + # TODO: the value assignment to global variable is not the ideal solution + # we can replace it in future. + if not random._rng_tracker and is_rng_supported_mesh(device_mesh): + random._rng_tracker = OffsetBasedRNGTracker(device_type) + + if not tensor.is_leaf: + raise RuntimeError( + "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" + ) + + # convert tensor to the corresponding device type if it's not in that device type + if device_type != tensor.device.type and not tensor.is_meta: + tensor = tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + if len(placements) != device_mesh.ndim: + raise ValueError( + f"`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." + ) + if isinstance(tensor, DTensor): + # if the tensor is already a DTensor, we need to check: + # 1. if the we can further shard this DTensor if the two device mesh belong to + # the same parenet mesh and further sharding is possible. + # 2. check if device mesh and placements are the same + if tensor.device_mesh != device_mesh: + raise ValueError( + f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " + f"to a different device mesh {device_mesh}." + ) + if tensor.placements != tuple(placements): + raise ValueError( + f"Cannot distribute a DTensor with placements {tensor.placements} " + f"to a different placements {placements}. do you want to call " + f"`redistribute` instead?" + ) + return tensor + + local_tensor = tensor.detach() + + # TODO(xilun): address sharding order + # distribute the tensor according to the placements. + placements = list(placements) + for idx, placement in enumerate(placements): + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + # normalize shard placement dim + placement = Shard(placement.dim + tensor.ndim) + placements[idx] = placement + local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx) + elif placement.is_replicate(): + placement = cast(Replicate, placement) + local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) + placements = tuple(placements) + + assert local_tensor is not None, "distributing a tensor should not be None" + # detach the local tensor passed to DTensor since after the construction + # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + return DTensor( + local_tensor.requires_grad_(tensor.requires_grad), + spec, + requires_grad=tensor.requires_grad, + ) + + +def distribute_module( + module: nn.Module, + device_mesh: Optional[DeviceMesh] = None, + partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, + input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, + output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, +) -> nn.Module: + """ + This function expose three functions to control the parameters/inputs/outputs of the module: + + 1. To perform sharding on the module before runtime execution by specifying the + ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` + parameters according to the `partition_fn` specified). + 2. To control the inputs or outputs of the module during runtime execution by + specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to + :class:`DTensor`, convert the output back to ``torch.Tensor``) + + Args: + module (:class:`nn.Module`): user module to be partitioned. + device_mesh (:class:`DeviceMesh`): the device mesh to place the module. + partition_fn (Callable): the function to partition parameters (i.e. shard certain + parameters across the ``device_mesh``). If ``partition_fn`` is not specified, + by default we replicate all module parameters of ``module`` across the mesh. + input_fn (Callable): specify the input distribution, i.e. could control how the + input of the module is sharded. ``input_fn`` will be installed as a module + ``forward_pre_hook`` (pre forward hook). + output_fn (Callable): specify the output distribution, i.e. could control how the + output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be + installed as a module ``forward_hook`` (post forward hook). + + Returns: + A module that contains parameters/buffers that are all ``DTensor`` s. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` + return nn.Module with PyTorch/XLA SPMD annotated parameters. See + `this issue `__ + for more details. The XLA integration is experimental and subject to change. + + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_module") + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # This function annotates all module parameters for auto-partitioning with + # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters + # according to the `partition_fn` specified. + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_module, + ) + + return xla_distribute_module( + module, device_mesh, partition_fn, input_fn, output_fn + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: + # This function loop over the immediate module parameters and + # buffers, replicate all non DTensor params/buffers to DTensor + # parameters/buffers, if they have not been partitioned in the + # partition_fn, we can't easily use `module._apply` here + # because we don't know what happened inside partition_fn as + # user could do anything, i.e. install hooks, and we want to + # preserve those. + full_replicate = [Replicate()] * mesh.ndim + for key, param in m._parameters.items(): + if param is not None and not isinstance(param, DTensor): + m.register_parameter( + key, + nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), + ) + for key, buffer in m._buffers.items(): + if buffer is not None and not isinstance(buffer, DTensor): + m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) + + if partition_fn is None: + # if partition_fn not specified, we by default replicate + # all module params/buffers + for name, submod in module.named_modules(): + replicate_module_params_buffers(submod, device_mesh) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) + replicate_module_params_buffers(submod, device_mesh) + + # register input_fn as module forward pre hook + if input_fn is not None: + # check the input_fn signature + num_args = len(inspect.signature(input_fn).parameters) + if num_args == 2: + # input_fn only takes in inputs and device mesh + warnings.warn( + "Deprecating input_fn that takes two arguments (inputs, device_mesh), " + "please use input_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] + elif num_args == 3: + # input_fn takes in module, inputs, device mesh + module.register_forward_pre_hook( + lambda mod, inputs: input_fn(mod, inputs, device_mesh) + ) + else: + raise ValueError( + f"input_fn should take in 3 arguments, but got {num_args} arguments!" + ) + # register output_fn as module forward hook + if output_fn is not None: + num_args = len(inspect.signature(output_fn).parameters) + if num_args == 2: + # output_fn only takes in outputs and device mesh + warnings.warn( + "Deprecating output_fn that takes two arguments (inputs, device_mesh), " + "please use output_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) + ) + else: + raise ValueError( + f"output_fn should take in 3 arguments, but got {num_args} arguments!" + ) + + return module + + +# Below are tensor factory function APIs, which are used to create a DTensor directly. We need +# to make separate factory function APIs because tensor subclass could not override the tensor +# factory methods, and we need user to call the factory functions with user intended device_mesh +# and placements to create a proper DTensor. + + +def _dtensor_init_helper( # type: ignore[no-untyped-def] + init_op, + size: torch.Size, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + **kwargs, +) -> DTensor: + # from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + + # if device_mesh is None, use the one from mesh resources + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + kwargs["device"] = device_mesh.device_type + + # set default placements to replicated if not specified + placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) + + # check device_mesh againts placements + assert device_mesh.ndim == len( + placements + ), "mesh dimension does not match the length of placements" + + assert kwargs["layout"] == torch.strided, "layout value not supported!" + torch_stride = torch._prims_common.make_contiguous_strides_for(size) + + # get local tensor shape + local_shape = compute_local_shape(size, device_mesh, placements) + # initialize the local tensor + if init_op == torch.full: + fill_value = kwargs.pop("fill_value", 0) + local_tensor = init_op(local_shape, fill_value, **kwargs) + elif init_op == torch.rand or init_op == torch.randn: + # this tensor meta is not used except `shape` + dtype = kwargs.get("dtype", torch.get_default_dtype()) + + tensor_meta = TensorMeta(size, (0,), dtype) + spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) + + if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: + random._rng_tracker = random.OffsetBasedRNGTracker() + + assert random._rng_tracker is not None + with random._rng_tracker._distribute_region(spec): + local_tensor = init_op(local_shape, **kwargs) + else: + local_tensor = init_op(local_shape, **kwargs) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + + return DTensor( + local_tensor, + spec, + requires_grad=kwargs["requires_grad"], + ) + + +def ones( # type: ignore[no-untyped-def] + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined + by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.ones, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def empty( # type: ignore[no-untyped-def] + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` + is defined by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.empty, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def full( # type: ignore[no-untyped-def] + size, + fill_value, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and + ``placements``, with the shape defined by the argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + fill_value(Scalar): the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.full, + torch_size, + fill_value=fill_value, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def rand( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a uniform distribution + on the interval ``[0, 1)``. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.rand, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def randn( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a normal distribution + with mean 0 and variance 1. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.randn, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def zeros( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 0. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) + Keyword args: + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.zeros, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py similarity index 97% rename from torch/distributed/_tensor/_collective_utils.py rename to torch/distributed/tensor/_collective_utils.py index a59aa44492fa77..34795858356ab8 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -7,7 +7,7 @@ import torch import torch.distributed._functional_collectives as funcol -import torch.distributed._tensor.placement_types as placement_types +import torch.distributed.tensor._dtensor_spec as dtensor_spec from torch._C._distributed_c10d import _resolve_process_group from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.distributed_c10d import ( @@ -202,7 +202,7 @@ def fill_empty_tensor_to_shards( def check_tensor_meta( local_tensor, check_shape_stride=False -) -> Optional["placement_types.TensorMeta"]: +) -> Optional["dtensor_spec.TensorMeta"]: local_metadata = { "dtype": local_tensor.dtype, "requires_grad": local_tensor.requires_grad, @@ -224,7 +224,7 @@ def check_tensor_meta( return None -def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int: +def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: assert spec.tensor_meta is not None, "spec should have tensor meta defined!" return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) @@ -311,8 +311,8 @@ def reduce_scatter_cost( def redistribute_cost( - current_spec: "placement_types.DTensorSpec", - target_spec: "placement_types.DTensorSpec", + current_spec: "dtensor_spec.DTensorSpec", + target_spec: "dtensor_spec.DTensorSpec", ) -> float: """ This function returns the cost of redistribute from current to target DTensorSpec. diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py similarity index 97% rename from torch/distributed/_tensor/_dispatch.py rename to torch/distributed/tensor/_dispatch.py index 9faae22cd37353..4579a16826d0fb 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -8,30 +8,25 @@ import torch import torch.distributed as dist -import torch.distributed._tensor.api as dtensor -import torch.distributed._tensor.random as random -from torch.distributed._tensor._op_schema import ( +import torch.distributed.tensor._api as dtensor +import torch.distributed.tensor._random as random +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpInfo, OpSchema, OutputSpecType, ) -from torch.distributed._tensor._redistribute import redistribute_local_tensor -from torch.distributed._tensor._sharding_prop import ShardingPropagator -from torch.distributed._tensor._tp_conv import ( +from torch.distributed.tensor._random import is_rng_supported_mesh +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor._sharding_prop import ShardingPropagator +from torch.distributed.tensor._tp_conv import ( convolution_backward_handler, convolution_handler, ) -from torch.distributed._tensor._utils import try_find_mesh_from_args -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Partial, - Placement, - Replicate, - TensorMeta, -) -from torch.distributed._tensor.random import is_rng_supported_mesh +from torch.distributed.tensor._utils import try_find_mesh_from_args +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate if TYPE_CHECKING: diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py new file mode 100644 index 00000000000000..e80729c7b62869 --- /dev/null +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -0,0 +1,276 @@ +from dataclasses import dataclass +from typing import Any, cast, List, NamedTuple, Optional, Tuple + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +class TensorMeta(NamedTuple): + # simple named tuple to represent tensor metadata + # intentionally to stay simple only for sharding + # propagation purposes. + shape: torch.Size + stride: Tuple[int, ...] + dtype: torch.dtype + + +# used internally to propagate the placements +@dataclass +class DTensorSpec: + mesh: DeviceMesh + placements: Tuple[Placement, ...] + + # tensor meta will only be set during sharding propagation + tensor_meta: Optional[TensorMeta] = None + + def __post_init__(self) -> None: + if not isinstance(self.placements, tuple): + self.placements = tuple(self.placements) + self._hash: Optional[int] = None + + def __setattr__(self, attr: str, value: Any) -> None: + super().__setattr__(attr, value) + # Make sure to recompute the hash in case any of the hashed attributes + # change (though we do not expect `mesh` or `placements` to change) + if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): + self._hash = None + + def _hash_impl(self) -> int: + # hashing and equality check for DTensorSpec are used to cache the sharding + # propagation results. We only need to consider the mesh, placements, shape + # dtype and stride. + # Caveat: we need to keep this in mind and sync hash and eq if we add more + # fields to them. + if self.tensor_meta is not None: + return hash( + ( + self.mesh, + self.placements, + self.tensor_meta.shape, + self.tensor_meta.stride, + self.tensor_meta.dtype, + ) + ) + return hash((self.mesh, self.placements)) + + def __hash__(self) -> int: + # We lazily cache the spec to avoid recomputing the hash upon each + # use, where we make sure to update the hash when the `tensor_meta` + # changes by overriding `__setattr__`. This must be lazy so that Dynamo + # does not try to hash non-singleton `SymInt`s for the stride. + if self._hash is None: + self._hash = self._hash_impl() + return self._hash + + def __eq__(self, __o: object) -> bool: + if not ( + isinstance(__o, DTensorSpec) + and self.mesh == __o.mesh + and self.placements == __o.placements + ): + return False + if self.tensor_meta is None or __o.tensor_meta is None: + return self.tensor_meta == __o.tensor_meta + + return ( + self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr] + and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr] + and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr] + ) + + def __str__(self) -> str: + """ + human readable representation of the DTensorSpec + """ + if len(self.placements) == 1: + placement_str = str(self.placements[0]) + else: + placement_str = str(self.placements) + + if self.tensor_meta is not None: + tensor_shape = str(tuple(self.tensor_meta.shape)) + else: + tensor_shape = "unknown shape" + + return f"Spec({placement_str} on {tensor_shape})" + + @property + def shape(self) -> torch.Size: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.shape + + @property + def stride(self) -> Tuple[int, ...]: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.stride + + @property + def ndim(self) -> int: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return len(self.tensor_meta.shape) + + @property + def num_shards(self) -> int: + num_shards = 1 + for i, placement in enumerate(self.placements): + if placement.is_shard(): + num_shards *= self.mesh.size(i) + return num_shards + + @property + def device_mesh(self) -> DeviceMesh: + # simple aliasing for the mesh field, make some + # checks that mixes DTensor/DTensorSpec easier + return self.mesh + + @property + def dim_map(self) -> List[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. It simply return a list of ints + where dim_map[i] denotes the sharding mapping to the mesh + dimension, and len(dim_map) == dist_tensor.ndim + dim_map[i] = -1: means tensor dim i replicate on mesh + dim_map[i] = j: means tensor dim i shard on mesh dim j + + For example, we have a dist tensor that have the shape of + [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: + [Shard(1)], the dim_map of this placement would be: + [-1, 0, -1]. This representation is pretty helpful during + sharding propagation where we could know exactly each + tensor dimension is sharded or not. + + Note that if placements contains `_Partial`, we have to + explicitly deal with it, so that when we create a DTensorSpec + with dim_map, we could properly record the pending sums. + """ + # dims mapping of dist tensor sharding + # return size of tensor ndim, -1 represent replicate + # and int >=0 represent shard on that device mesh dim + r = [-1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + if r[shard_dim] > -1: + raise ValueError( + f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + " DTensor operator implementation does not support things like hybrid" + " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + ) + r[shard_dim] = i + return r + + @property + def num_shards_map(self) -> List[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. Unlike `dim_map`, `num_shards_map` + denotes how many shards each tensor dim has. Like `dim_map`: + len(num_shards_map) == dist_tensor.ndim + num_shards_map[i] = 1: means tensor dim i is not sharded + num_shards_map[i] = j: means tensor dim i has j shards in total + + For example, we have a dist tensor of shape [18, 20, 30], + a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements + ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor + would be: [4, 2, 1]. + """ + r = [1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + r[shard_dim] *= self.mesh.size(i) + + return r + + @property + def sums(self) -> List[int]: + """ + sums is a property we derive from `placements` of the + distributed tensor. It simply return a list of ints where + sums[i] denotes the pending sum (partial) on mesh dim i + """ + return [ + idx + for idx, placement in enumerate(self.placements) + if placement.is_partial() + ] + + @classmethod + def from_dim_map( + cls, + mesh: DeviceMesh, + dim_map: List[int], + sums: List[int], + tensor_meta: Optional[TensorMeta] = None, + ) -> "DTensorSpec": + """ + Construct a DTensorSpec from dim_map list and pending sum. + + Args: + mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec + dim_map (List[int]): a list of integer that represents sharding on each + tensor dimension, see `dim_map` property doc for details + sums (List[int]): a list of integer that represents the dist tensor have + pending sum on which device mesh dimension. + tensor meta (TensorMeta): DTensor metadata + + Return: + a class:`DTensorSpec` object + """ + # by default replicate on device mesh dims + placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] + + # find all mesh dims that need pending reductions + for s in sums: + placements[s] = Partial() + + for i, m in enumerate(dim_map): + if m >= 0: + placement = placements[m] + if placement.is_shard(): + placement = cast(Shard, placement) + raise RuntimeError( + f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + ) + elif placement.is_partial(): + raise RuntimeError( + f"DeviceMesh dimension {m} cannot be both shard and partial!" + ) + placements[m] = Shard(i) + + return cls(mesh, tuple(placements), tensor_meta=tensor_meta) + + def is_replicated(self) -> bool: + """ + return True if the current DTensorSpec replicates on all mesh dims (devices) + """ + return all(placement.is_replicate() for placement in self.placements) + + def is_sharded(self) -> bool: + """ + return True if the current DTensorSpec is sharded on any mesh dims (devices) + """ + return any(placement.is_shard() for placement in self.placements) + + def shallow_copy_with_tensor_meta( + self, tensor_meta: Optional[TensorMeta] + ) -> "DTensorSpec": + """ + Shallow copy the DTensorSpec with a new tensor_meta. + """ + assert tensor_meta is not None, "shallow copy with no tensor_meta!" + return DTensorSpec( + self.mesh, + self.placements, + tensor_meta=tensor_meta, + ) diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py similarity index 99% rename from torch/distributed/_tensor/_op_schema.py rename to torch/distributed/tensor/_op_schema.py index b2461086149929..190886da21fd22 100644 --- a/torch/distributed/_tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -5,8 +5,9 @@ import torch from torch._ops import OpOverload -from torch.distributed._tensor.placement_types import DTensorSpec, Placement from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Placement try: diff --git a/torch/distributed/_tensor/ops/__init__.py b/torch/distributed/tensor/_ops/__init__.py similarity index 100% rename from torch/distributed/_tensor/ops/__init__.py rename to torch/distributed/tensor/_ops/__init__.py diff --git a/torch/distributed/_tensor/ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py similarity index 97% rename from torch/distributed/_tensor/ops/_common_rules.py rename to torch/distributed/tensor/_ops/_common_rules.py index f70b27076de7c4..2a41252be40e76 100644 --- a/torch/distributed/_tensor/ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -2,15 +2,15 @@ from typing import cast, Dict, List, Optional, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpSchema, OutputSharding, ) -from torch.distributed._tensor._utils import compute_local_shape -from torch.distributed._tensor.ops.utils import prod -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor._ops.utils import prod +from torch.distributed.tensor._utils import compute_local_shape def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: diff --git a/torch/distributed/_tensor/ops/_conv_ops.py b/torch/distributed/tensor/_ops/_conv_ops.py similarity index 93% rename from torch/distributed/_tensor/ops/_conv_ops.py rename to torch/distributed/tensor/_ops/_conv_ops.py index 3a3b743dc4710e..db2a8136e14da0 100644 --- a/torch/distributed/_tensor/ops/_conv_ops.py +++ b/torch/distributed/tensor/_ops/_conv_ops.py @@ -4,9 +4,9 @@ from typing import List import torch -from torch.distributed._tensor._op_schema import OpSchema, OutputSharding -from torch.distributed._tensor.ops.utils import register_prop_rule -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema, OutputSharding +from torch.distributed.tensor._ops.utils import register_prop_rule aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_einsum_strategy.py b/torch/distributed/tensor/_ops/_einsum_strategy.py similarity index 97% rename from torch/distributed/_tensor/ops/_einsum_strategy.py rename to torch/distributed/tensor/_ops/_einsum_strategy.py index 97dd43b1524dc8..fc3227600b35d2 100644 --- a/torch/distributed/_tensor/ops/_einsum_strategy.py +++ b/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -2,15 +2,15 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh @dataclass diff --git a/torch/distributed/_tensor/ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_embedding_ops.py rename to torch/distributed/tensor/_ops/_embedding_ops.py index 15b2af2a01c17b..ae333b800ffcb8 100644 --- a/torch/distributed/_tensor/ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -7,23 +7,23 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, StrategyType, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, register_op_strategy, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_experimental_ops.py b/torch/distributed/tensor/_ops/_experimental_ops.py similarity index 69% rename from torch/distributed/_tensor/ops/_experimental_ops.py rename to torch/distributed/tensor/_ops/_experimental_ops.py index 1edfb8d294ffea..d4ab5cc3aeeb02 100644 --- a/torch/distributed/_tensor/ops/_experimental_ops.py +++ b/torch/distributed/tensor/_ops/_experimental_ops.py @@ -3,15 +3,16 @@ # implement matrix related ops for distributed tensor import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.ops.utils import register_op_strategy -from torch.distributed._tensor.placement_types import DTensorSpec, Replicate +from torch.distributed.tensor._ops.utils import register_op_strategy +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py similarity index 99% rename from torch/distributed/_tensor/ops/_math_ops.py rename to torch/distributed/tensor/_ops/_math_ops.py index bfa2622c96cade..4905c338918595 100644 --- a/torch/distributed/_tensor/ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -7,7 +7,9 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, @@ -15,8 +17,7 @@ RuntimeSchemaInfo, TupleStrategy, ) -from torch.distributed._tensor._utils import normalize_to_torch_size -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( as_list, expand_to_full_mesh_op_strategy, generate_redistribute_costs, @@ -25,14 +26,13 @@ normalize_dims, register_op_strategy, ) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.tensor._utils import normalize_to_torch_size +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_matrix_ops.py rename to torch/distributed/tensor/_ops/_matrix_ops.py index 9d63dd53389987..fd9a7a430a70eb 100644 --- a/torch/distributed/_tensor/ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -5,15 +5,17 @@ from typing import List import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, PlacementStrategy, RuntimeSchemaInfo, ) -from torch.distributed._tensor.ops._einsum_strategy import gen_einsum_strategies -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies +from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, generate_redistribute_costs, infer_broadcast_dims_map, @@ -21,13 +23,7 @@ map_placements_after_broadcast, register_op_strategy, ) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Placement, - Replicate, - Shard, -) -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_pointwise_ops.py rename to torch/distributed/tensor/_ops/_pointwise_ops.py index 229029e863b16a..bb40865ed9c06a 100644 --- a/torch/distributed/_tensor/ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -2,7 +2,9 @@ from typing import List, Sequence, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpSchema, @@ -12,21 +14,19 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, register_op_strategy, ) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_random_ops.py b/torch/distributed/tensor/_ops/_random_ops.py similarity index 90% rename from torch/distributed/_tensor/ops/_random_ops.py rename to torch/distributed/tensor/_ops/_random_ops.py index f54cfea7b00509..726b25e1eed023 100644 --- a/torch/distributed/_tensor/ops/_random_ops.py +++ b/torch/distributed/tensor/_ops/_random_ops.py @@ -1,14 +1,14 @@ # mypy: allow-untyped-decorators # Copyright (c) Meta Platforms, Inc. and affiliates import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_tensor_ops.py rename to torch/distributed/tensor/_ops/_tensor_ops.py index 335773f4f8a286..e9bcb3b0d12240 100644 --- a/torch/distributed/_tensor/ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -4,7 +4,9 @@ from typing import cast, List, Optional, Sequence, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( _is_inplace_op, OpSchema, OpStrategy, @@ -15,9 +17,9 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor.ops._common_rules import pointwise_rule -from torch.distributed._tensor.ops._embedding_ops import _MaskPartial -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops._common_rules import pointwise_rule +from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, is_tensor_dim_sharded, is_tensor_evenly_shardable, @@ -26,14 +28,12 @@ register_op_strategy, register_prop_rule, ) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_view_ops.py rename to torch/distributed/tensor/_ops/_view_ops.py index c61e21e98b89ba..451b92c80b24b8 100644 --- a/torch/distributed/_tensor/ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -17,23 +17,23 @@ import torch from torch import Tensor -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, RuntimeSchemaInfo, StrategyType, ) -from torch.distributed._tensor.api import Shard -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, normalize_dim, normalize_dims, prod, register_op_strategy, ) -from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/tensor/_ops/utils.py similarity index 96% rename from torch/distributed/_tensor/ops/utils.py rename to torch/distributed/tensor/_ops/utils.py index 31a43b5ad8a80b..334bcc4a37fea3 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -6,18 +6,18 @@ from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, PlacementStrategy, RuntimeSchemaInfo, ) -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, diff --git a/torch/distributed/_tensor/random.py b/torch/distributed/tensor/_random.py similarity index 98% rename from torch/distributed/_tensor/random.py rename to torch/distributed/tensor/_random.py index ac3771be0ba2f0..db4b2832548c07 100644 --- a/torch/distributed/_tensor/random.py +++ b/torch/distributed/tensor/_random.py @@ -7,8 +7,9 @@ import torch import torch.distributed as dist from torch import Tensor -from torch.distributed._tensor.placement_types import DTensorSpec, Shard from torch.distributed.device_mesh import _get_device_handle, DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Shard __all__ = [ @@ -290,7 +291,7 @@ def _set_pre_op_offset(self, spec: DTensorSpec) -> None: return_offset=False, )[0] - from torch.distributed._tensor.ops.utils import prod + from torch.distributed.tensor._ops.utils import prod local_size = prod(local_size_on_rank_0) @@ -317,7 +318,7 @@ def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: """ dtensor_shape = spec.shape - from torch.distributed._tensor.ops.utils import prod + from torch.distributed.tensor._ops.utils import prod numel = prod(dtensor_shape) # pytorch: offset must be multiple of 4 diff --git a/torch/distributed/_tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py similarity index 98% rename from torch/distributed/_tensor/_redistribute.py rename to torch/distributed/tensor/_redistribute.py index 127bf3e9857d74..6b4e37e54744f4 100644 --- a/torch/distributed/_tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -6,15 +6,14 @@ import torch import torch.distributed._functional_collectives as funcol -import torch.distributed._tensor.api as dtensor -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +import torch.distributed.tensor._api as dtensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, - TensorMeta, ) diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py similarity index 99% rename from torch/distributed/_tensor/_sharding_prop.py rename to torch/distributed/tensor/_sharding_prop.py index c8c8ddb9817c1c..c277655ff4bf0f 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -7,7 +7,9 @@ import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( OpInfo, OpSchema, OpStrategy, @@ -18,13 +20,11 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor._utils import ( +from torch.distributed.tensor._utils import ( compute_local_shape, compute_local_stride, try_find_mesh_from_args, ) -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten diff --git a/torch/distributed/_tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py similarity index 100% rename from torch/distributed/_tensor/_shards_wrapper.py rename to torch/distributed/tensor/_shards_wrapper.py diff --git a/torch/distributed/_tensor/_tp_conv.py b/torch/distributed/tensor/_tp_conv.py similarity index 99% rename from torch/distributed/_tensor/_tp_conv.py rename to torch/distributed/tensor/_tp_conv.py index cc6f1968e6ef99..ac11ef2162cbb7 100644 --- a/torch/distributed/_tensor/_tp_conv.py +++ b/torch/distributed/tensor/_tp_conv.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist -import torch.distributed._tensor.api as dtensor +import torch.distributed.tensor._api as dtensor aten = torch.ops.aten diff --git a/torch/distributed/_tensor/_utils.py b/torch/distributed/tensor/_utils.py similarity index 98% rename from torch/distributed/_tensor/_utils.py rename to torch/distributed/tensor/_utils.py index b7c42f094be9e6..55a892063db926 100644 --- a/torch/distributed/_tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,17 +1,17 @@ from typing import cast, List, Sequence, Tuple import torch -import torch.distributed._tensor.api as dtensor +import torch.distributed.tensor._api as dtensor from torch._prims_common import ShapeType -from torch.distributed._tensor.placement_types import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import ( _StridedShard, - DTensorSpec, Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh # TODO: audit existing code base to see if we can safely remove this API. diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/tensor/debug/__init__.py similarity index 58% rename from torch/distributed/_tensor/debug/__init__.py rename to torch/distributed/tensor/debug/__init__.py index 7ad08c376028dd..e5bf3b833fe478 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/tensor/debug/__init__.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs -from torch.distributed._tensor.debug.comm_mode import CommDebugMode -from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding +from torch.distributed.tensor.debug._comm_mode import CommDebugMode +from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding __all__ = ["CommDebugMode", "visualize_sharding"] @@ -12,8 +12,13 @@ def _get_sharding_prop_cache_info(): This would return a named tuple showing hits, misses, maxsize and cursize of the sharding propagator cache. """ - from torch.distributed._tensor.api import DTensor + from torch.distributed.tensor._api import DTensor return ( DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined] ) + + +# Set namespace for exposed private names +CommDebugMode.__module__ = "torch.distributed.tensor.debug" +visualize_sharding.__module__ = "torch.distributed.tensor.debug" diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/tensor/debug/_comm_mode.py similarity index 96% rename from torch/distributed/_tensor/debug/comm_mode.py rename to torch/distributed/tensor/debug/_comm_mode.py index f4632fdf691ed8..6a776c88cf0b82 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/tensor/debug/_comm_mode.py @@ -10,8 +10,8 @@ import torch.nn from torch._guards import detect_fake_mode from torch.autograd.graph import register_multi_grad_hook -from torch.distributed._tensor.api import DTensor from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed.tensor._api import DTensor from torch.nn.modules.module import ( register_module_forward_hook, register_module_forward_pre_hook, @@ -69,7 +69,7 @@ } -class CommModeModuleTracker(ModTracker): +class _CommModeModuleTracker(ModTracker): """ Inherits ModuleTracker and expands on its functionality to track the parameters and sharding information of a model at a module-level @@ -222,12 +222,11 @@ def print_sharding_info(self): class CommDebugMode(TorchDispatchMode): """ - ``CommDebugMode`` is a context manager that counts the number of + :class:`CommDebugMode` is a context manager that counts the number of functional collectives within its context. It does this using a ``TorchDispatchMode``. - NOTE: this mode only works for functional collective atm and the - distributed_c10d collectives are not supported yet. + .. note: Not all collectives are supported yet. Example usage @@ -237,7 +236,7 @@ class CommDebugMode(TorchDispatchMode): comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() - + print(comm_mode.get_comm_counts()) """ def __init__(self): @@ -250,7 +249,7 @@ def __init__(self): self.comm_registry.add(py_op) self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall) - self.advanced_module_tracker = CommModeModuleTracker() + self.advanced_module_tracker = _CommModeModuleTracker() def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3): """ @@ -266,7 +265,7 @@ def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3): include_module_data, include_ops, include_trivial_ops, - ) = self.set_noise_parameters(noise_level) + ) = self._set_noise_parameters(noise_level) # recursively builds json data def add_json_information(json_dict, fqn): @@ -321,7 +320,9 @@ def add_json_information(json_dict, fqn): forward_operations, backward_operations, checkpointing_operations, - ) = self.get_operations_list(self.comm_module_operation_counts[fqn]) + ) = self._get_operations_list( + self.comm_module_operation_counts[fqn] + ) # remove all operations who don't have DTensor inputs if not include_ops: @@ -414,7 +415,7 @@ def generate_comm_debug_tracing_table(self, noise_level=3): include_module_data, include_ops, include_trivial_ops, - ) = self.set_noise_parameters(noise_level) + ) = self._set_noise_parameters(noise_level) table = "" for fqn in self.advanced_module_tracker.module_helper_dict: @@ -469,7 +470,9 @@ def generate_comm_debug_tracing_table(self, noise_level=3): forward_operations, backward_operations, checkpointing_operations, - ) = self.get_operations_list(self.comm_module_operation_counts[fqn]) + ) = self._get_operations_list( + self.comm_module_operation_counts[fqn] + ) def add_tracing_information(table, collectives_dict, operation_list): """ @@ -544,7 +547,7 @@ def add_operations( return table - def get_operations_list(self, module_operation_counts): + def _get_operations_list(self, module_operation_counts): forward_operations = [ op for op in module_operation_counts["operations_list"] if not op["is_bw"] ] @@ -572,12 +575,6 @@ def get_comm_counts(self) -> Dict[Any, int]: """ return self.comm_counts - def get_comm_module_counts(self) -> Dict[str, Dict[Any, int]]: - """ - Returns the communication counts at a module level as a dictionary. - """ - return self.comm_module_counts - def get_parameter_info(self) -> Dict[str, Dict[str, Any]]: return self.advanced_module_tracker.module_parameters_dict @@ -613,13 +610,7 @@ def log_comm_debug_tracing_table_to_file( with open(file_name, "w") as log_file: log_file.write(table) - def print_paramater_info(self): - self.advanced_module_tracker.print_paramater_info() - - def print_sharding_info(self): - self.advanced_module_tracker.print_sharding_info() - - def set_noise_parameters(self, noise_level): + def _set_noise_parameters(self, noise_level): """ sets variables controlling what information displays based on noise level """ diff --git a/torch/distributed/_tensor/debug/_op_coverage.py b/torch/distributed/tensor/debug/_op_coverage.py similarity index 98% rename from torch/distributed/_tensor/debug/_op_coverage.py rename to torch/distributed/tensor/debug/_op_coverage.py index 214c4f003ff2d6..258dc27a6c43db 100644 --- a/torch/distributed/_tensor/debug/_op_coverage.py +++ b/torch/distributed/tensor/debug/_op_coverage.py @@ -8,7 +8,7 @@ from functorch.compile import make_boxed_func from torch._functorch.compilers import aot_module from torch._inductor.decomposition import select_decomp_table -from torch.distributed._tensor import DTensor +from torch.distributed.tensor import DTensor inductor_decomps = select_decomp_table() diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/tensor/debug/_visualize_sharding.py similarity index 95% rename from torch/distributed/_tensor/debug/visualize_sharding.py rename to torch/distributed/tensor/debug/_visualize_sharding.py index c080fa0020e265..ade935d5133cf1 100644 --- a/torch/distributed/_tensor/debug/visualize_sharding.py +++ b/torch/distributed/tensor/debug/_visualize_sharding.py @@ -4,8 +4,8 @@ import numpy as np from torch._prims_common import ShapeType -from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.placement_types import Placement, Shard +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.placement_types import Placement, Shard __all__ = ["visualize_sharding"] @@ -135,10 +135,9 @@ def _compute_local_shape_and_global_offset( def visualize_sharding(dtensor, header=""): """ - Visualizes sharding in 1D-2D dtensors - Requires tabulate, install with `pip install tabulate` + Visualizes sharding in the terminal for :class:`DTensor` that are 1D or 2D. - note: no sharding info will be printed for empty tensors + .. note:: This requires the ``tabulate`` package. No sharding info will be printed for empty tensors """ if dtensor.numel() == 0: # we do not print for empty dtensors return diff --git a/torch/distributed/_tensor/debug/comm_mode_broswer_visual.js b/torch/distributed/tensor/debug/comm_mode_broswer_visual.js similarity index 100% rename from torch/distributed/_tensor/debug/comm_mode_broswer_visual.js rename to torch/distributed/tensor/debug/comm_mode_broswer_visual.js diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/tensor/device_mesh.py similarity index 100% rename from torch/distributed/_tensor/device_mesh.py rename to torch/distributed/tensor/device_mesh.py diff --git a/torch/distributed/_tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py similarity index 99% rename from torch/distributed/_tensor/examples/comm_mode_features_example.py rename to torch/distributed/tensor/examples/comm_mode_features_example.py index b98a8a3962c4a2..98143973145330 100644 --- a/torch/distributed/_tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn -from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -115,7 +115,7 @@ def example_MLP_distributed_sharding_display(self) -> None: output_tp = model(inp) output_tp.sum().backward() - comm_mode.print_sharding_info() + print(comm_mode.get_sharding_info()) def example_MLPStacked_distributed_sharding_display(self) -> None: """ @@ -152,7 +152,7 @@ def example_MLPStacked_distributed_sharding_display(self) -> None: output_tp = model(inp) output_tp.sum().backward() - comm_mode.print_sharding_info() + print(comm_mode.get_sharding_info()) def example_MLP_module_tracing(self) -> None: """ diff --git a/torch/distributed/_tensor/examples/convnext_example.py b/torch/distributed/tensor/examples/convnext_example.py similarity index 99% rename from torch/distributed/_tensor/examples/convnext_example.py rename to torch/distributed/tensor/examples/convnext_example.py index 14abad0033460f..57d7bca8cc08bc 100644 --- a/torch/distributed/_tensor/examples/convnext_example.py +++ b/torch/distributed/tensor/examples/convnext_example.py @@ -12,7 +12,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed._tensor import ( +from torch.distributed.tensor import ( DeviceMesh, distribute_module, distribute_tensor, diff --git a/torch/distributed/_tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py similarity index 98% rename from torch/distributed/_tensor/examples/torchrec_sharding_example.py rename to torch/distributed/tensor/examples/torchrec_sharding_example.py index 33f8c7017f5be4..9e6f4054e292b4 100644 --- a/torch/distributed/_tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -9,23 +9,23 @@ from typing import List, TYPE_CHECKING import torch -from torch.distributed._tensor import ( +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.tensor import ( DeviceMesh, DTensor, init_device_mesh, Replicate, Shard, ) -from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding -from torch.distributed.checkpoint.metadata import ( - ChunkStorageMetadata, - TensorProperties, - TensorStorageMetadata, -) +from torch.distributed.tensor.debug import visualize_sharding if TYPE_CHECKING: - from torch.distributed._tensor.placement_types import Placement + from torch.distributed.tensor.placement_types import Placement def get_device_type(): diff --git a/torch/distributed/_tensor/examples/visualize_sharding_example.py b/torch/distributed/tensor/examples/visualize_sharding_example.py similarity index 92% rename from torch/distributed/_tensor/examples/visualize_sharding_example.py rename to torch/distributed/tensor/examples/visualize_sharding_example.py index 5494316428b594..71cad75ef95f81 100644 --- a/torch/distributed/_tensor/examples/visualize_sharding_example.py +++ b/torch/distributed/tensor/examples/visualize_sharding_example.py @@ -6,8 +6,8 @@ import os import torch -from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard -from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding +from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate, Shard +from torch.distributed.tensor.debug import visualize_sharding world_size = int(os.environ["WORLD_SIZE"]) diff --git a/torch/distributed/tensor/experimental/__init__.py b/torch/distributed/tensor/experimental/__init__.py new file mode 100644 index 00000000000000..5193034770af13 --- /dev/null +++ b/torch/distributed/tensor/experimental/__init__.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from contextlib import contextmanager + +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor.experimental._func_map import local_map +from torch.distributed.tensor.experimental._register_sharding import register_sharding + + +__all__ = ["implicit_replication", "local_map", "register_sharding"] + + +@contextmanager +def implicit_replication(): + """ + This context manager allows :class:`DTensor` to implicitly treat all non-DTensors (``torch.Tensor``) + in the program be replicate :class:`DTensor` s during the operator computation. + + .. warning:: This might possible lead to incorrect results if ``torch.Tensor`` s are not replicated + in practice, please use it at your discretion. + """ + try: + DTensor._op_dispatcher._allow_implicit_replication = True + yield + finally: + DTensor._op_dispatcher._allow_implicit_replication = False + + +# Set namespace for exposed private names +implicit_replication.__module__ = "torch.distributed.tensor.experimental" +local_map.__module__ = "torch.distributed.tensor.experimental" +register_sharding.__module__ = "torch.distributed.tensor.experimental" diff --git a/torch/distributed/_tensor/experimental/attention.py b/torch/distributed/tensor/experimental/_attention.py similarity index 99% rename from torch/distributed/_tensor/experimental/attention.py rename to torch/distributed/tensor/experimental/_attention.py index 7ed404872835aa..b5c05232a57349 100644 --- a/torch/distributed/_tensor/experimental/attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -24,8 +24,8 @@ import torch.distributed._functional_collectives as ft_c import torch.nn.functional as F from torch import nn -from torch.distributed._tensor import distribute_module, DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard from torch.distributed.tensor.parallel.style import ParallelStyle diff --git a/torch/distributed/_tensor/experimental/func_map.py b/torch/distributed/tensor/experimental/_func_map.py similarity index 82% rename from torch/distributed/_tensor/experimental/func_map.py rename to torch/distributed/tensor/experimental/_func_map.py index de417f2d8fbfe6..23b69373e74e57 100644 --- a/torch/distributed/_tensor/experimental/func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -4,8 +4,8 @@ import torch from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor.placement_types import Placement try: @@ -16,7 +16,6 @@ __all__ = ["local_map"] - PlacementType = Optional[Sequence[Placement]] InputPlacements = Optional[Tuple[PlacementType, ...]] OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] @@ -31,32 +30,34 @@ def local_map( redistribute_inputs: bool = False, ): """ - ``local_map`` is an experimental API that allows users to apply on :class:`DTensors` - a function that is written to be applied on :class:`~torch.Tensors`. + :meth:`local_map` is an experimental API that allows users to pass :class:`DTensor` s + to a function that is written to be applied on ``torch.Tensor`` s. It is done by extracting + the local components of :class:`DTensor`, call the function, and wrap the outputs to + :class:`DTensor` according to the ``out_placements``. Args: func (Callable): the function to be applied on each local shard of - :class:`DTensor`s. + :class:`DTensor` s. out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]): - the desired placements of the :class:`DTensor`s in ``func``'s flattened output. + the desired placements of the :class:`DTensor` s in ``func``'s flattened output. If the flattened ``output`` is a single value, the ``out_placements`` should be of type `PlacementType`. Otherwise if the flattened ``output`` has multiple values, the ``out_placements`` should be a tuple of `PlacementType` values 1:1 mapping to the flattened ``output``. Besides, for :class:`Tensor` output, we use `PlacementType` as its - placements (a `Tuple[Placement]` value). For non-:class:`Tensor` output, - the `PlacementType` should be `None`. + placements (a `Tuple[Placement]` value). For non-Tensor output, the `PlacementType` + should be `None`. Note that the only exception is when no :class:`DTensor` argument is passed in. In this case, even if `out_placements` is not `None`, the result function - should ignore the desired placements because the application is not on - :class:`DTensors`. + should ignore the desired placements because the function is not running with + :class:`DTensor` s. in_placements (Tuple[`PlacementType`, ...], optional): - the required placements of the :class:`DTensor`s in ``func``'s flattened input. - If ``in_placements`` is specified, ``local_map`` would examine whether the + the required placements of the :class:`DTensor` s in the flattened inputs of ``func``. + If ``in_placements`` is specified, :meth:`local_map` would examine whether the placements of each :class:`DTensor` argument is the same as the required placements or not. If the placements are not the same and ``redistribute_inputs`` is ``False``, an exception will be raised. Otherwise if - ``redistribute_inputs`` is `True`, the argument will be first redistributed to + ``redistribute_inputs`` is ``True``, the argument will be first redistributed to the required sharding placements before passing its local tensor to ``func``. The only exception is when required placements are not ``None`` and the argument is a :class:`torch.Tensor`. In this case, the placements examination @@ -64,12 +65,12 @@ def local_map( If ``in_placements`` is ``None``, no placements examination will be performed. Default: None device_mesh (:class:`DeviceMesh`, optional): - the device mesh that all the :class:`DTensor`s are placed on. If not - specified, this will be inferred from the input :class:`DTensor`s' device - mesh. `local_map` requires every :class:`DTensor`s to be placed on the same + the device mesh that all the :class:`DTensor` s are placed on. If not + specified, this will be inferred from the input :class:`DTensor` s' device + mesh. `local_map` requires every :class:`DTensor` s to be placed on the same device mesh. Default: None. redistribute_inputs (bool, optional): - the bool value indicating whether to reshard the input :class:`DTensor`s when + the bool value indicating whether to reshard the input :class:`DTensor` s when their placements are different from the required input placements. If this value is ``False`` and some :class:`DTensor` input has a different placement, an exception will be raised. Default: False. @@ -79,16 +80,16 @@ def local_map( and returns a :class:`DTensor` constructed from the return value of ``func``. Raises: - AssertionError: If the input :class:`DTensor`s are not placed on the same device - mesh, or if they are placed on a different device mesh than the ``device_mesh`` - argument passed in. + AssertionError: If the input :class:`DTensor` is not placed on the same device + mesh, or if they are placed on a different device mesh than the ``device_mesh`` + argument passed in. - AssertionError: For any non-:class:`DTensor` output, we require its corresponding - output placement in ``out_placements`` be None. An AssertionError will be raised - if this is not the case. + AssertionError: For any non-DTensor output, we require its corresponding + output placement in ``out_placements`` be None. An AssertionError will be raised + if this is not the case. ValueError: If ``redistribute_inputs=False`` but the input :class:`DTensor` needs - a redistribution according to ``in_placements``. + a redistribution according to ``in_placements``. Example: >>> # xdoctest: +SKIP("distributed") @@ -115,7 +116,7 @@ def local_map( >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors - NOTE: This API is currently experimental and subject to change + .. note:: This API is currently experimental and subject to change """ def wrapped(*args, **kwargs): diff --git a/torch/distributed/_tensor/experimental/register_sharding.py b/torch/distributed/tensor/experimental/_register_sharding.py similarity index 93% rename from torch/distributed/_tensor/experimental/register_sharding.py rename to torch/distributed/tensor/experimental/_register_sharding.py index a6a70c162e97fc..c526e2e0a44090 100644 --- a/torch/distributed/_tensor/experimental/register_sharding.py +++ b/torch/distributed/tensor/experimental/_register_sharding.py @@ -5,8 +5,8 @@ import torch from torch._ops import OpOverload -from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor._op_schema import ( _is_inplace_op, OpSchema, OpStrategy, @@ -15,7 +15,7 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor.ops.utils import expand_to_full_mesh_op_strategy +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy __all__ = ["register_sharding"] @@ -23,8 +23,8 @@ def register_sharding(op: Union[OpOverload, List[OpOverload]]): """ - ``register_sharding`` is an experimental API that allows users to register sharding - strategies for an operator when the tensor inputs and outputs are :class:`DTensor`s. + :meth:`register_sharding` is an experimental API that allows users to register sharding + strategies for an operator when the tensor inputs and outputs are DTensor. It can be useful when: (1) there doesn't exist a default sharding strategy for ``op``, e.g. when ``op`` is a custom operator that is not supported by :class:`DTensor`; (2) when users would like to overwrite default sharding strategies of existing operators. @@ -62,6 +62,8 @@ def register_sharding(op: Union[OpOverload, List[OpOverload]]): >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings + + .. note:: This API is currently experimental and subject to change """ def custom_strategy( diff --git a/torch/distributed/_tensor/experimental/tp_transform.py b/torch/distributed/tensor/experimental/_tp_transform.py similarity index 98% rename from torch/distributed/_tensor/experimental/tp_transform.py rename to torch/distributed/tensor/experimental/_tp_transform.py index 114192c9880494..81b47c9b5a7d3b 100644 --- a/torch/distributed/_tensor/experimental/tp_transform.py +++ b/torch/distributed/tensor/experimental/_tp_transform.py @@ -5,22 +5,17 @@ import torch from torch._subclasses.fake_tensor import FakeTensor -from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor -from torch.distributed._tensor._op_schema import ( - DTensorSpec, +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( OpSchema, OutputSharding, OutputSpecType, PlacementStrategy, ) -from torch.distributed._tensor._redistribute import redistribute_local_tensor -from torch.distributed._tensor.placement_types import ( - Placement, - Replicate, - Shard, - TensorMeta, -) +from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard from torch.export import ExportedProgram from torch.export.exported_program import ExportGraphSignature from torch.fx import GraphModule diff --git a/torch/distributed/tensor/parallel/_data_parallel_utils.py b/torch/distributed/tensor/parallel/_data_parallel_utils.py index 2e1ebfd53ab7b3..0d1097f4328c3e 100644 --- a/torch/distributed/tensor/parallel/_data_parallel_utils.py +++ b/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -3,8 +3,8 @@ import torch from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor import DTensor -from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec @no_type_check diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 0eace5f6d78061..f50b5dd64768d0 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -2,9 +2,9 @@ import warnings from typing import Tuple, Union -from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import _mesh_resources +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.placement_types import Placement try: diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index e0fc4d2ef2b725..db4db018cce2d3 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -3,10 +3,10 @@ from typing import Dict, Union import torch -import torch.distributed._tensor.random as random +import torch.distributed.tensor._random as random import torch.nn as nn -from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.random import ( +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor._random import ( is_rng_supported_mesh, TensorParallelRNGTracker, ) diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 8c4a7a655a8ffd..d2faa9ed32dc5d 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -14,12 +14,12 @@ ) from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec -from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.device_mesh import _mesh_resources from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.tensor.parallel._data_parallel_utils import ( _flatten_tensor, _unflatten_tensor, diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 4e7af55d32c356..630c287cae88f5 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -3,7 +3,7 @@ from typing import Any, Optional, Tuple import torch -from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard __all__ = [ diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index 79f6f08e595393..99f1e3ad6ef9ad 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -8,15 +8,16 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch import Tensor -from torch.distributed._tensor import DTensor, Replicate, Shard -from torch.distributed._tensor.ops._embedding_ops import _MaskPartial -from torch.distributed._tensor.ops._math_ops import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +from torch.distributed.tensor._ops._math_ops import ( _skip_dim, Reduction, replicate_reduction_dims, ) -from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Placement aten = torch.ops.aten diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 46b19a918296c4..d1e8dd7e236cd1 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from torch.distributed._tensor import ( +from torch.distributed.tensor import ( DeviceMesh, distribute_module, distribute_tensor, @@ -14,7 +14,7 @@ Replicate, Shard, ) -from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor.placement_types import Placement __all__ = [ diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py new file mode 100644 index 00000000000000..0d9834ab8b81cd --- /dev/null +++ b/torch/distributed/tensor/placement_types.py @@ -0,0 +1,652 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +from dataclasses import dataclass +from typing import cast, List, Optional, Tuple + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._collective_utils import ( + fill_empty_tensor_to_shards, + mesh_broadcast, + mesh_scatter, + pad_tensor, + shard_dim_alltoall, + unpad_tensor, +) + + +__all__ = ["Placement", "Shard", "Replicate", "Partial"] + + +class Placement: + """ + The base class for the Placement type, where it describes how a DTensor is placed onto the + ``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout. + It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``, + and ``Partial``. + + This class is not meant to be used directly, mainly served as a typing stub. + """ + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + is_shard_instance = isinstance(self, Shard) + if dim is not None and is_shard_instance: + return cast(Shard, self).dim == dim + else: + return is_shard_instance + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, Partial) + + +@dataclass(frozen=True) +class Shard(Placement): + """ + The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension + ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension only holds a shard/piece of the global Tensor. The + ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the + last few shards on the DeviceMesh dimension might be empty when the tensor dimension + is not evenly divisble on the DeviceMesh dimension. The ``Shard`` placement can be + used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) + + Args: + dim (int): The tensor dimension that describes the DTensor is sharded over its + corresponding DeviceMesh dimension. + + .. warning:: sharding on a tensor dimension where the tensor dimension size is not + evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. + """ + + dim: int + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + This function uses torch.chunk to split a tensor into num_chunks shards along + the Shard placement dimension, and return a list of shards with their pad sizes. + + Keyword args: + with_padding (bool, optional): when True, we pad the tensor on the last + few ranks before calling the collectives (i.e. scatter/all_gather, etc.). + This is because collectives usually require equal size tensor inputs + """ + assert ( + self.dim <= tensor.ndim + ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + + # chunk tensor over dimension `dim` into n slices + tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) + num_empty_tensors = num_chunks - len(tensor_list) + + # if no need to have padding or tensor dim size is evenly sharded already + # we can return early. + if not with_padding or tensor.size(self.dim) % num_chunks == 0: + if contiguous: + tensor_list = [t.contiguous() for t in tensor_list] + return ( + fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors), + [], + ) + + # compute the chunk size inline with ``torch.chunk`` to calculate padding + full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks + + # Compute chunk size for each chunk for ``self.dim`` + chunk_sizes = [ + tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0 + for idx in range(num_chunks) + ] + # Compute pad size on each chunk + pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] + + # Reuse tensor to fill empty chunk with empty tensor + tensor_list = fill_empty_tensor_to_shards( + tensor_list, self.dim, num_empty_tensors + ) + shard_list = [] + for shard, pad_size in zip(tensor_list, pad_sizes): + # Fill the empty tensor with zeroes with padding. + if with_padding and pad_size > 0: + shard = pad_tensor(shard, self.dim, pad_size) + shard = shard.contiguous() if contiguous else shard + shard_list.append(shard) + return shard_list, pad_sizes + + @staticmethod + def _local_shard_size_on_dim( + size_on_dim: int, + num_chunks: int, + rank: int, + return_offset: bool = False, + ) -> Tuple[int, int]: + """ + returns the local shard size and offset on a given tensor dim + """ + # Compute the chunk size inline with ``torch.chunk`` + if size_on_dim % num_chunks == 0: + full_chunk_size = size_on_dim // num_chunks + return full_chunk_size, full_chunk_size * rank if return_offset else -1 + + # uneven sharding case + full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks + shard_starting_idx = full_chunk_size * rank + + if size_on_dim < shard_starting_idx: + return 0, size_on_dim if return_offset else -1 + else: + local_shard_size = ( + min(size_on_dim, shard_starting_idx + full_chunk_size) + - shard_starting_idx + ) + return local_shard_size, shard_starting_idx if return_offset else -1 + + def _shard_tensor( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + shard and scatter a tensor on a mesh dimension (use coordinate + 0 on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + scatter_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + + mesh_dim_local_rank = my_coordinate[mesh_dim] + output = torch.empty_like(scatter_list[mesh_dim_local_rank]) + mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim) + + # Only unpad if the local_tensor was padded on the dimension. + if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0: + output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank]) + return output + + def _reduce_shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + reduce_op: str, + mesh_dim: int, + ) -> torch.Tensor: + """ + reduce and scatter a tensor on a mesh dimension + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return tensor + + is_padded = tensor.size(self.dim) % num_chunks != 0 + if is_padded: + scattered_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + tensor = torch.cat(scattered_list, dim=self.dim) + elif not tensor.is_contiguous(): + tensor = tensor.contiguous() + + output = funcol.reduce_scatter_tensor( + tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) + ) + + if is_padded: + output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined] + return output + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + ) -> torch.Tensor: + """ + This function all_gather all shards and return a tensor that + is replicated on the previously sharded mesh dimension + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + # check if it's uneven, so we need to pad input tensor before all_gather + local_shape = list(local_tensor.size()) + + logical_dim_size = current_logical_shape[self.dim] + is_padded = logical_dim_size % num_chunks != 0 + + if is_padded: + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + pad_size = full_chunk_size - local_shape[self.dim] + local_tensor = pad_tensor(local_tensor, self.dim, pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if is_padded: + unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] + result = unpad_tensor(result, self.dim, unpad_size) + return result + + def _replicate_to_shard( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_index: int, + ) -> torch.Tensor: + """ + transform from replicated tensor to a sharded tensor on + the current rank, which would perform a local chunk + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + shards, _ = self._split_tensor( + local_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + return shards[shard_index].clone() + + def _to_new_shard_dim( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + new_shard_dim: int, + ) -> torch.Tensor: + """ + transform from existing sharded tensor to a new sharded tensor on + that shard on a new dimension, which performs an alltoall + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return local_tensor + + num_chunks = mesh.size(mesh_dim=mesh_dim) + + old_dim_logical_size = current_logical_shape[self.dim] + new_dim_logical_size = current_logical_shape[new_shard_dim] + old_dim_padding = old_dim_logical_size % num_chunks != 0 + new_dim_padding = new_dim_logical_size % num_chunks != 0 + if old_dim_padding: + old_dim_full_chunk_size = ( + old_dim_logical_size + num_chunks - 1 + ) // num_chunks + old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim) + local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size) + if new_dim_padding: + new_dim_full_chunk_size = ( + new_dim_logical_size + num_chunks - 1 + ) // num_chunks + new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size( + new_shard_dim + ) + local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + new_tensor = shard_dim_alltoall( + local_tensor, self.dim, new_shard_dim, mesh, mesh_dim + ) + + if old_dim_padding: + old_dim_unpad_size = ( + old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim] # type: ignore[possibly-undefined] + ) + new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined] + + if new_dim_padding: + local_shard_size_on_new_dim = self._local_shard_size_on_dim( + new_dim_logical_size, num_chunks, my_coordinate[mesh_dim] + )[0] + new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] + new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] + + return new_tensor + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Shard): + return False + return self.dim == other.dim + + def __hash__(self) -> int: + return hash(self.dim) + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + def __str__(self) -> str: + """human readable representation of the Shard placement""" + return f"S({self.dim})" + + +# kw_only is only available in python >= 3.10 +kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {} + + +@dataclass(frozen=True, **kw_only_dataclass) +class _StridedShard(Shard): + """ + _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor + is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. + We call this right-to-left sharding which is the opposite of the default + left-to-right sharding. See the example below: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [Shard(0), Shard(0)] + + The default sharding behavior shards the tensor on "dp" mesh dimension first then + "tp" dimension. The sharding result will be: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 1 (row 2-3) + 2 | (1, 0) | 2 (row 4-5) + 3 | (1, 1) | 3 (row 6-7) + + While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on + "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the + result: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The consequence is, any attempt to redistribute this DTensor to a full replica will + produce a wrong result because the shard-to-replicate redistribution always happens + right-to-left, regardless it's left-to-right sharding or right-to-left. To address + this, we use _StridedShard placement to make this right-to-left sharding compatible + with our left-to-right convention on both tensor distribution and redistribution. + + Now with _StridedShard, the right-to-left sharding above can be represented as: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [_StridedShard(0, split_factor=2), Shard(0)] + + And a left-to-right processing of `placements` will produce the same result, which is + different from using the `Shard` placement: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The argument `split_factor` is the number of existing shards over the tensor sharding + dimension before processing the _StridedShard placement, as if the sharding happened + right-to-left. In the example above, the tensor should first be sharded on the "tp" + dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the + `split_factor` of the _StridedShard placement on "dp" dim is 2. + + TODO: strided sharding needs to work fine with uneven sharding. Now it forbids + resharding if the tensor is unevenly sharded. + TODO: we should remove _StridedShard placement once we can unify it with Shard + """ + + split_factor: int + + def __eq__(self, other: object) -> bool: + if isinstance(other, _StridedShard): + return self.dim == other.dim and self.split_factor == other.split_factor + elif isinstance(other, Shard): + # TODO: this is to avoid extra all-gather in dtensor op dispatch + # note that sharding prop would not produce _StridedShard and an + # placement inequality would introduce an all-gather for resharding + return self.dim == other.dim + return False + + def __hash__(self) -> int: + return hash((self.dim, self.split_factor)) + + def __repr__(self) -> str: + """ + machine readable representation of the _StridedShard placement + """ + return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" + + def __str__(self) -> str: + """human readable representation of the _StridedShard placement""" + return f"_S({self.dim}, {self.split_factor})" + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + TODO: currently _StridedShard does not support padding + """ + assert ( + self.dim <= tensor.ndim + ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + + total_split = num_chunks * self.split_factor + assert tensor.size(self.dim) % total_split == 0, ( + "_StridedShard currently only allows even sharding but got tensor size" + f" {tensor.size(self.dim)} on dim {self.dim} and total split" + f" {total_split}={num_chunks} * {self.split_factor}" + ) + + group_size = self.split_factor + total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim)) + tensor_list = [ + torch.cat( + [ + total_split_tensor_list[i + j * num_chunks] # stride is num_chunks + for j in range(group_size) + ], + dim=self.dim, + ) + for i in range(num_chunks) + ] + + if contiguous: + tensor_list = [t.contiguous() for t in tensor_list] + + return tensor_list, [] + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + ) -> torch.Tensor: + """ + Note: currently _StridedShard does not support padding + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + total_split = num_chunks * self.split_factor + # NOTE: we require Strided Sharding to be even for now + assert current_logical_shape[self.dim] % total_split == 0, ( + "_StridedShard requires even sharding but got tensor size " + f"{current_logical_shape[self.dim]} on dim {self.dim} and " + f"total split {total_split}=num_chunks {num_chunks} " + f"* split_factor {self.split_factor}" + ) + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if isinstance(result, funcol.AsyncCollectiveTensor): + result = result.wait() + + tensor_shard_list = torch.chunk(result, total_split, dim=self.dim) + # rearrange the order + new_tensor_shard_list = [] + for idx in range(len(tensor_shard_list)): + # the shard split of index `idx` is assigned a new index within + # _StridedShard._split_tensor: + # the original tensor was split into `total_split` chunks, + # all chunks with the same `idx % num_chunks` are merged into one + # new shard and placed on mesh's local rank `idx % num_chunks` + idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks + new_tensor_shard_list.append(tensor_shard_list[idx_after_split]) + + return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous() + + +@dataclass(frozen=True) +class Replicate(Placement): + """ + The ``Replicate()`` placement describes the DTensor replicating on a corresponding + ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a + replica of the global Tensor. The ``Replicate`` placement can be used by all + DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) + """ + + def __eq__(self, other: object) -> bool: + return isinstance(other, Replicate) + + def __hash__(self) -> int: + # every replicate placement is the same + return -1 + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + def __str__(self) -> str: + """ + human readable representation of the Replicate placement + """ + return "R" + + def _replicate_tensor( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + Replicate (broadcast) a torch.Tensor on a mesh dimension (use + the first coordinate on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + tensor = tensor.contiguous() + mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim) + return tensor + + +@dataclass(frozen=True) +class Partial(Placement): + """ + The ``Partial(reduce_op)`` placement describes the DTensor that is pending + reduction on a specified ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension holds the partial value of the global Tensor. User can + redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` + placement on the specified ``DeviceMesh`` dimension using ``redistribute``, + which would trigger necessary communication operations under the hood (i.e. + ``allreduce``, ``reduce_scatter``). + + Args: + reduce_op (str, optional): The reduction op to be used for the partial DTensor + to produce Replicated/Sharded DTensor. Only element-wise reduction operations + are supported, including: "sum", "avg", "product", "max", "min", default: "sum". + + .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, + and can only be used by the ``DTensor.from_local`` API. + """ + + reduce_op: str = "sum" + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #1: + # _reduce_value: reduce the value of the tensor on the mesh dimension + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # Partial placement contract #2: + # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #3: + # _partition_value: partition the value of a replicated tensor on the mesh dimension + + # _partition_value is the conjugate operation of _reduce_value + # - i.e. _partition_value on a sum reduce op is just a divison operation + # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation + # TODO: if the reduce_op is min/max, etc. the _partition_value should be a + # different operation + assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" + num_chunks = mesh.size(mesh_dim=mesh_dim) + return tensor / num_chunks + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Partial): + return False + return self.reduce_op == other.reduce_op + + def __hash__(self) -> int: + return 1 + hash(self.reduce_op) + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"Partial({self.reduce_op})" + + def __str__(self) -> str: + """ + human readable representation of the Partial placement + """ + return "P" + + +# We keep the old _Partial name for a while for BC reason +_Partial = Partial diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index fe02eeeabb1baf..f9eff69767931c 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -34,7 +34,6 @@ FSDPParamGroup, RegisterPostBackwardFunction, ) -from torch.distributed._tensor import distribute_tensor, DTensor, Shard from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP from torch.distributed.fsdp._common_utils import TrainingState @@ -46,6 +45,7 @@ ) from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap +from torch.distributed.tensor import distribute_tensor, DTensor, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, From 220a68a002d3d51747f0299ce5d3ebb1154336bd Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 8 Sep 2024 17:18:29 +0000 Subject: [PATCH 0611/1018] [22/N] Fix clang-tidy warnings in jit (#135319) Follows #134537 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135319 Approved by: https://github.com/titaiwangms --- torch/csrc/jit/passes/onnx.cpp | 1 - .../onnx/cast_all_constant_to_floating.cpp | 6 +-- .../onnx/cast_all_constant_to_floating.h | 6 +-- torch/csrc/jit/passes/onnx/constant_fold.cpp | 6 +-- torch/csrc/jit/passes/onnx/constant_fold.h | 7 +-- torch/csrc/jit/passes/onnx/constant_map.cpp | 54 +++++++++---------- torch/csrc/jit/passes/onnx/constant_map.h | 8 ++- .../passes/onnx/deduplicate_initializers.cpp | 6 +-- .../passes/onnx/deduplicate_initializers.h | 7 +-- .../passes/onnx/eliminate_unused_items.cpp | 6 +-- .../jit/passes/onnx/eliminate_unused_items.h | 7 +-- torch/csrc/jit/passes/onnx/eval_peephole.cpp | 6 +-- torch/csrc/jit/passes/onnx/eval_peephole.h | 7 +-- .../passes/onnx/fixup_onnx_controlflow.cpp | 22 +++----- .../jit/passes/onnx/fixup_onnx_controlflow.h | 6 +-- .../jit/passes/onnx/function_extraction.cpp | 26 +++++---- .../jit/passes/onnx/function_extraction.h | 10 +--- .../jit/passes/onnx/function_substitution.cpp | 6 +-- .../jit/passes/onnx/function_substitution.h | 4 +- torch/csrc/jit/passes/onnx/helper.cpp | 6 +-- torch/csrc/jit/passes/onnx/helper.h | 6 +-- .../jit/passes/onnx/list_model_parameters.cpp | 6 +-- .../jit/passes/onnx/list_model_parameters.h | 6 +-- torch/csrc/jit/passes/onnx/naming.cpp | 30 +++++------ torch/csrc/jit/passes/onnx/naming.h | 14 ++--- torch/csrc/jit/passes/onnx/onnx_log.cpp | 8 +-- torch/csrc/jit/passes/onnx/onnx_log.h | 8 +-- torch/csrc/jit/passes/onnx/peephole.cpp | 6 +-- torch/csrc/jit/passes/onnx/peephole.h | 6 +-- .../passes/onnx/prepare_division_for_onnx.cpp | 6 +-- .../passes/onnx/prepare_division_for_onnx.h | 6 +-- .../jit/passes/onnx/preprocess_for_onnx.cpp | 6 +-- .../jit/passes/onnx/preprocess_for_onnx.h | 6 +-- .../onnx/remove_inplace_ops_for_onnx.cpp | 20 ++++--- .../passes/onnx/remove_inplace_ops_for_onnx.h | 6 +-- .../jit/passes/onnx/scalar_type_analysis.cpp | 6 +-- .../jit/passes/onnx/scalar_type_analysis.h | 6 +-- .../jit/passes/onnx/shape_type_inference.cpp | 33 ++++++------ .../jit/passes/onnx/shape_type_inference.h | 14 +++-- .../passes/onnx/unpack_quantized_weights.cpp | 7 ++- .../passes/onnx/unpack_quantized_weights.h | 6 +-- 41 files changed, 159 insertions(+), 260 deletions(-) diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 50e034dd40f0b3..238b8e5c236efe 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -164,7 +164,6 @@ void PreprocessCaffe2Ops(std::shared_ptr& graph) { std::shared_ptr ToONNX( std::shared_ptr& graph, ::torch::onnx::OperatorExportTypes operator_export_type) { - auto constant_value_map = ConstantValueMap::getInstance(); ConstantValueMap::ClearMaps(); auto new_graph = std::make_shared(graph->current_scope()); py::dict env; diff --git a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp index 65aefb9a577613..5a62e02b628e53 100644 --- a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp +++ b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp @@ -1,8 +1,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; } @@ -70,5 +69,4 @@ void CastAllConstantToFloating(Block* block) { void CastAllConstantToFloating(const std::shared_ptr& graph) { CastAllConstantToFloating(graph->block()); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h index c9c40f2dd9a293..72321f0c27d706 100644 --- a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h +++ b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h @@ -4,9 +4,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { // see .cpp for docs TORCH_API void CastAllConstantToFloating(const std::shared_ptr& graph); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 61d97057c5b429..75f1ae4c349aaa 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -9,8 +9,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -707,5 +706,4 @@ void ConstantFoldONNX( GRAPH_DUMP("After ConstantFoldONNX:", g); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_fold.h b/torch/csrc/jit/passes/onnx/constant_fold.h index d25ebee32a787e..899ae706ca8a2e 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.h +++ b/torch/csrc/jit/passes/onnx/constant_fold.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { const int ONNX_OPSET_9 = 9; const int ONNX_OPSET_10 = 10; @@ -30,6 +29,4 @@ void ConstantFoldONNX( std::map& paramDict, int opset_version); -} // namespace jit - -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 99c801dcf77367..8b9a6273ef619d 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -7,12 +7,7 @@ #include #include -namespace torch { -namespace jit { - -namespace onnx { -using namespace ::c10::onnx; -} +namespace torch::jit { // Meyer’s Singleton for C++ 14 ConstantValueMap& ConstantValueMap::getInstance() { @@ -290,7 +285,7 @@ void ConstantValueMap::ClearMaps() { // For debug only. void ConstantValueMap::PrintMaps() { - std::cout << "Rank/Shape Map:" << std::endl; + std::cout << "Rank/Shape Map:" << '\n'; for (const auto& x : ConstantValueMap::getInstance().rankMap) { std::stringstream ss; if (ConstantValueMap::getInstance().shapeMap.find(x.first) != @@ -308,45 +303,45 @@ void ConstantValueMap::PrintMaps() { } } ss << " (rank = " << x.second << ")"; - std::cout << "node " << x.first << ": " << ss.str() << std::endl; + std::cout << "node " << x.first << ": " << ss.str() << '\n'; } - std::cout << std::endl; - std::cout << "Value Map:" << std::endl; + std::cout << '\n'; + std::cout << "Value Map:" << '\n'; for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) { - std::cout << "node " << x.first << ": " << x.second << std::endl; + std::cout << "node " << x.first << ": " << x.second << '\n'; } - std::cout << std::endl; - std::cout << "TypeReliable Map:" << std::endl; + std::cout << '\n'; + std::cout << "TypeReliable Map:" << '\n'; size_t count = 0; for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) { std::cout << "(node " << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << std::endl; - std::cout << "UseInferredType Map:" << std::endl; + std::cout << '\n'; + std::cout << "UseInferredType Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) { std::cout << "(node " << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << std::endl; - std::cout << "ShapeValue Map:" << std::endl; + std::cout << '\n'; + std::cout << "ShapeValue Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().shapeValueMap) { std::cout << "(node " << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << std::endl; - std::cout << "InferredShape Map:" << std::endl; + std::cout << '\n'; + std::cout << "InferredShape Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().inferredShapeData) { std::cout << "(node " << x.first << ": "; @@ -360,29 +355,28 @@ void ConstantValueMap::PrintMaps() { std::cout << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << std::endl; - std::cout << "SymbolDim Map:" << std::endl; + std::cout << '\n'; + std::cout << "SymbolDim Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().symbolDimMap) { std::cout << "(" << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << "DimSymbol Map:" << std::endl; + std::cout << "DimSymbol Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().dimSymbolMap) { std::cout << "(" << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_map.h b/torch/csrc/jit/passes/onnx/constant_map.h index 9bea729a37ad54..60d4470c1b12cf 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.h +++ b/torch/csrc/jit/passes/onnx/constant_map.h @@ -13,8 +13,7 @@ C10_DIAGNOSTIC_POP() #include #include -namespace torch { -namespace jit { +namespace torch::jit { using ShapeDataMap = std::unordered_map; @@ -112,8 +111,7 @@ class ConstantValueMap { // Stores if all graph-level inputs have static shape std::optional allGraphInputsStatic; // True if reliable has been computed for all graph inputs - bool allGraphInputsReliableComputed; + bool allGraphInputsReliableComputed{}; }; -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp b/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp index cbb97c238e7618..009bb157737405 100644 --- a/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp +++ b/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp @@ -4,8 +4,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -99,5 +98,4 @@ void DeduplicateInitializers( buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/deduplicate_initializers.h b/torch/csrc/jit/passes/onnx/deduplicate_initializers.h index 60fbd8f089d4f8..f6da1601110994 100644 --- a/torch/csrc/jit/passes/onnx/deduplicate_initializers.h +++ b/torch/csrc/jit/passes/onnx/deduplicate_initializers.h @@ -4,14 +4,11 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { void DeduplicateInitializers( std::shared_ptr& g, std::map& paramsDict, bool is_train); -} // namespace jit - -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp b/torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp index b803ae35cf3ca5..50e5e4363528e0 100644 --- a/torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp +++ b/torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp @@ -1,8 +1,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -16,5 +15,4 @@ void EliminateUnusedItemsONNX(Block* b, ParamMap& paramsDict) { return; } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/eliminate_unused_items.h b/torch/csrc/jit/passes/onnx/eliminate_unused_items.h index 74e658070edbd2..793a8c1041ff69 100644 --- a/torch/csrc/jit/passes/onnx/eliminate_unused_items.h +++ b/torch/csrc/jit/passes/onnx/eliminate_unused_items.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { // EliminateUnusedItemsONNX pass is removing unused // initializers and inputs, this is needed because @@ -12,6 +11,4 @@ void EliminateUnusedItemsONNX( Block* b, std::map& paramDict); -} // namespace jit - -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/eval_peephole.cpp b/torch/csrc/jit/passes/onnx/eval_peephole.cpp index d6be7039f53188..d25a215ecc8af6 100644 --- a/torch/csrc/jit/passes/onnx/eval_peephole.cpp +++ b/torch/csrc/jit/passes/onnx/eval_peephole.cpp @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -152,5 +151,4 @@ void EvalPeepholeONNX(std::shared_ptr& g, ParamMap& paramsDict) { GRAPH_DUMP("After EvalPeepholeONNX:", g); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/eval_peephole.h b/torch/csrc/jit/passes/onnx/eval_peephole.h index 6f8961d08fd5eb..1bd0bd4c24d225 100644 --- a/torch/csrc/jit/passes/onnx/eval_peephole.h +++ b/torch/csrc/jit/passes/onnx/eval_peephole.h @@ -4,13 +4,10 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { void EvalPeepholeONNX( std::shared_ptr& g, std::map& paramDict); -} // namespace jit - -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index 5407d9aa2986b1..75c44aead7caf3 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -8,19 +8,14 @@ #include #include -namespace torch { -namespace jit { - -namespace onnx { -using namespace ::c10::onnx; -} +namespace torch::jit { namespace { const int ONNX_OPSET_13 = 13; const int ONNX_TYPE_BOOL = 9; Node* CreateCastToBoolNode(Value* val, Graph* graph) { - Node* cast_node = graph->create(onnx::Cast); + Node* cast_node = graph->create(c10::onnx::Cast); cast_node->addInput(val); cast_node->i_(attr::to, ONNX_TYPE_BOOL); cast_node->output()->setType(BoolType::get()); @@ -149,7 +144,7 @@ std::vector ConvertSequenceDependencies(Node* node, int opset_version) { // Split the added scan_output back to expected tensor sequence. auto loop_output = loop_node->output(i - 2); Node* split_node = - loop_node->owningGraph()->create(onnx::SplitToSequence); + loop_node->owningGraph()->create(c10::onnx::SplitToSequence); loop_output->replaceAllUsesWith(split_node->output()); split_node->i_(attr::keepdims, 0); split_node->addInput(loop_output); @@ -191,7 +186,7 @@ std::vector ConvertSequenceDependencies(Node* node, int opset_version) { return new_outputs; } -Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) { +Node* ONNXOptionalNode(const OptionalTypePtr& opt_type, Graph* g) { TORCH_INTERNAL_ASSERT(opt_type); TypePtr elem_type = opt_type->getElementType(); Node* opt_node = g->create(::c10::onnx::Optional, 1); @@ -208,7 +203,7 @@ Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) { // 2. Loop Op: insert Optional node before output, if input is Optional type // or output type is None. void ReplaceBlockOutputWithOptional( - OptionalTypePtr opt_type, + const OptionalTypePtr& opt_type, Block* block, size_t i) { Node* opt_node = ONNXOptionalNode(opt_type, block->owningGraph()); @@ -235,9 +230,9 @@ void FixupONNXSubblockOutputs(Node* n) { // Identity(None). Also enables shape inference later on, since // ONNX shape inference doesn't handle None. if (output->type()->cast()) { - id_node = block->owningGraph()->create(onnx::Optional); + id_node = block->owningGraph()->create(c10::onnx::Optional); } else { - id_node = block->owningGraph()->create(onnx::Identity); + id_node = block->owningGraph()->create(c10::onnx::Identity); id_node->addInput(output); } id_node->insertBefore(block->return_node()); @@ -741,5 +736,4 @@ void FixupONNXControlflowNodeOutputs(Node* n) { } } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h index 8d33c2dd1fb5e4..f7fc049de99814 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h @@ -2,11 +2,9 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { std::vector FixupONNXControlflowNode(Node* n, int opset_version); void FixupONNXControlflowNodeOutputs(Node* n); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index febf412e5d1224..c988b9243669cc 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -2,9 +2,7 @@ #include #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { namespace { @@ -75,9 +73,9 @@ struct FunctionExtractor { using FunctionCtxPtr = FunctionContext*; using func_ctx_map = std::unordered_map; - static bool IsValidScope(ScopePtr s); + static bool IsValidScope(const ScopePtr& s); static std::optional InferScope(Node* n); - static bool IsAncestor(ScopePtr parent, ScopePtr child); + static bool IsAncestor(const ScopePtr& parent, ScopePtr child); static std::optional FindCommonAncestor(ScopePtr a, ScopePtr b); static std::optional FindCommonAncestor(const scope_list& scopes); std::shared_ptr ConstructFuncGraph(FunctionContext& ctx); @@ -88,7 +86,9 @@ struct FunctionExtractor { scope_ctx_map& scope_ctxs, const std::shared_ptr& graph); - static void HandleNoScopeNodes(scope_ctx_map&, node_list no_scope_nlist); + static void HandleNoScopeNodes( + scope_ctx_map&, + const node_list& no_scope_nlist); std::tuple PartitionNodesByScope(Block* b); scope_ctx_map PartitionNodesByScope(const std::shared_ptr& graph); static std::unordered_map PartitionIdenticalScopes( @@ -279,11 +279,11 @@ void FunctionExtractor::DebugPrintGraphWithFunction( GRAPH_UPDATE("Main graph: ", g->toString()); } -bool FunctionExtractor::IsValidScope(ScopePtr s) { +bool FunctionExtractor::IsValidScope(const ScopePtr& s) { return !s->isRoot() && !s->isBlank(); } -bool FunctionExtractor::IsAncestor(ScopePtr parent, ScopePtr child) { +bool FunctionExtractor::IsAncestor(const ScopePtr& parent, ScopePtr child) { if (!IsValidScope(parent) || !IsValidScope(child) || parent->getDepth() >= child->getDepth()) { return false; @@ -376,7 +376,7 @@ std::optional FunctionExtractor::InferScope(Node* n) { std::all_of( output_scopes.begin(), output_scopes.end(), - [&output_scopes](ScopePtr scope) -> bool { + [&output_scopes](const ScopePtr& scope) -> bool { return IsValidScope(scope) && scope == output_scopes.at(0); })) { return output_scopes.at(0); @@ -385,7 +385,7 @@ std::optional FunctionExtractor::InferScope(Node* n) { std::all_of( input_scopes.begin(), input_scopes.end(), - [&input_scopes](ScopePtr scope) -> bool { + [&input_scopes](const ScopePtr& scope) -> bool { return IsValidScope(scope) && scope == input_scopes.at(0); })) { return input_scopes.at(0); @@ -822,7 +822,7 @@ void FunctionExtractor::ScopeContext::PopulateInputsOutputs( void FunctionExtractor::HandleNoScopeNodes( scope_ctx_map& scope_ctxs, - node_list no_scope_nlist) { + const node_list& no_scope_nlist) { GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size()); for (auto n : no_scope_nlist) { TORCH_WARN( @@ -1181,6 +1181,4 @@ void ONNXTrackScopeAttributes( } } -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/function_extraction.h b/torch/csrc/jit/passes/onnx/function_extraction.h index 40555f8e3561ca..fea0d23e703010 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.h +++ b/torch/csrc/jit/passes/onnx/function_extraction.h @@ -2,9 +2,6 @@ #include -namespace torch { -namespace jit { - // This api will be used by serialization/export.cpp to extract function // information. It should do conversion on graph to // 1. Extract subgraph pattern of functions and define as local function @@ -15,7 +12,7 @@ namespace jit { // represent these info inside Graph object. // export.cpp will serialize the ONNX model with function_proto with // above information. -namespace onnx { +namespace torch::jit::onnx { // The following return types are used to track information regarding function // attributes, that are unable to be traced through Torch IR. @@ -64,7 +61,4 @@ TORCH_API void ONNXTrackScopeAttributes( std::shared_ptr& graph, std::map& attributes); -} // namespace onnx - -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/function_substitution.cpp b/torch/csrc/jit/passes/onnx/function_substitution.cpp index 81bfa3fd6caf59..5a8e24b015da35 100644 --- a/torch/csrc/jit/passes/onnx/function_substitution.cpp +++ b/torch/csrc/jit/passes/onnx/function_substitution.cpp @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace { @@ -193,5 +192,4 @@ void ONNXFunctionCallSubstitution(Graph& graph) { GRAPH_DUMP("After function call substitution calls: ", &graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/function_substitution.h b/torch/csrc/jit/passes/onnx/function_substitution.h index b07acef2c3fc66..3571bab936e2c4 100644 --- a/torch/csrc/jit/passes/onnx/function_substitution.h +++ b/torch/csrc/jit/passes/onnx/function_substitution.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API void ONNXFunctionCallSubstitution(Graph& graph); } -} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 9d4c5061414c5a..64a4bb9bdfab36 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -12,8 +12,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -296,5 +295,4 @@ void ONNXLintGraph(const std::shared_ptr& graph) { " constants."); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/helper.h b/torch/csrc/jit/passes/onnx/helper.h index 9e09c638779ef5..09b31576998a5d 100644 --- a/torch/csrc/jit/passes/onnx/helper.h +++ b/torch/csrc/jit/passes/onnx/helper.h @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { // Utility functions for PyTorch to ONNX conversion. @@ -73,5 +72,4 @@ class ScalarTypeHashFunction { } }; -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 6a1e3b08f3b9a8..268cc4dc7f9cbb 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -191,5 +190,4 @@ std::pair> list_module_parameters( return std::make_pair(moduleClone, parameterIValues); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.h b/torch/csrc/jit/passes/onnx/list_model_parameters.h index 50d1cea2b8fe00..114b3b2d894139 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.h +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.h @@ -3,11 +3,9 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API std::pair> list_module_parameters( const Module& module); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/naming.cpp b/torch/csrc/jit/passes/onnx/naming.cpp index ed62c79231899d..62fd67e1d2ca47 100644 --- a/torch/csrc/jit/passes/onnx/naming.cpp +++ b/torch/csrc/jit/passes/onnx/naming.cpp @@ -3,9 +3,7 @@ #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { namespace ONNXScopeName { @@ -16,7 +14,7 @@ const std::string name_separator = "::"; namespace { std::string nameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator, NameFunc name_func) { std::string out = (*name_func)(scope); @@ -32,7 +30,7 @@ std::string nameFromRoot( } std::pair parseNameFromScope( - torch::jit::ScopePtr scope) { + const torch::jit::ScopePtr& scope) { std::string full_name = scope->name().toUnqualString(); auto pos = full_name.find(name_separator); TORCH_CHECK( @@ -55,7 +53,7 @@ std::string variableName(torch::jit::ScopePtr scope) { } std::string variableNameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator) { return nameFromRoot(scope, layer_separator, &variableName); } @@ -65,12 +63,12 @@ std::string className(torch::jit::ScopePtr scope) { } std::string classNameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator) { return nameFromRoot(scope, layer_separator, &className); } -bool isCompatibleScope(torch::jit::ScopePtr scope) { +bool isCompatibleScope(const torch::jit::ScopePtr& scope) { return !scope->isRoot() && !scope->isBlank() && (std::string(scope->name().toUnqualString()).find(name_separator) != std::string::npos); @@ -89,7 +87,7 @@ class NodeNameGenerator { virtual void CreateNodeName(Node* n) = 0; void PopulateNodeNames(Block*); void UpdateOutputsNames(Node* n); - bool IsGraphOutput(const Value* v, const std::shared_ptr graph) const; + bool IsGraphOutput(const Value* v, const std::shared_ptr& graph) const; protected: std::string CreateUniqueName( @@ -99,19 +97,21 @@ class NodeNameGenerator { std::unordered_map node_names_; std::unordered_map base_node_name_counts_; std::shared_ptr graph_; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string layer_separator_ = "/"; }; NodeNameGenerator::~NodeNameGenerator() = default; class ScopedNodeNameGenerator : public NodeNameGenerator { public: - ScopedNodeNameGenerator(std::shared_ptr g) : NodeNameGenerator(g){}; + ScopedNodeNameGenerator(std::shared_ptr g) + : NodeNameGenerator(std::move(g)){}; protected: void CreateNodeName(Node* n) override; private: - std::string GetFullScopeName(ScopePtr scope); + std::string GetFullScopeName(const ScopePtr& scope); std::unordered_map full_scope_names_; std::unordered_map base_scope_name_counts_; }; @@ -131,7 +131,7 @@ std::string NodeNameGenerator::CreateUniqueName( bool NodeNameGenerator::IsGraphOutput( const Value* v, - const std::shared_ptr graph) const { + const std::shared_ptr& graph) const { for (const auto* graph_output : graph->outputs()) { if (v == graph_output) { return true; @@ -185,7 +185,7 @@ void ScopedNodeNameGenerator::CreateNodeName(Node* n) { n->s_(Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute), node_names_[n]); } -std::string ScopedNodeNameGenerator::GetFullScopeName(ScopePtr scope) { +std::string ScopedNodeNameGenerator::GetFullScopeName(const ScopePtr& scope) { if (full_scope_names_.find(scope) == full_scope_names_.end()) { auto full_scope_name = ONNXScopeName::variableNameFromRoot(scope, layer_separator_); @@ -202,6 +202,4 @@ void AssignScopedNamesForNodeAndValue(std::shared_ptr& graph) { node_name_generator->PopulateNodeNames(); } -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/naming.h b/torch/csrc/jit/passes/onnx/naming.h index 0472d041a92331..905d47bf541b4e 100644 --- a/torch/csrc/jit/passes/onnx/naming.h +++ b/torch/csrc/jit/passes/onnx/naming.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { namespace ONNXScopeName { @@ -13,18 +11,16 @@ std::string createFullScopeName( const std::string& variable_name); std::string variableName(torch::jit::ScopePtr scope); std::string variableNameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator); std::string className(torch::jit::ScopePtr scope); std::string classNameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator); -bool isCompatibleScope(torch::jit::ScopePtr scope); +bool isCompatibleScope(const torch::jit::ScopePtr& scope); } // namespace ONNXScopeName TORCH_API void AssignScopedNamesForNodeAndValue(std::shared_ptr& graph); -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/onnx_log.cpp b/torch/csrc/jit/passes/onnx/onnx_log.cpp index eff2ae4d5a3231..f749690145601f 100644 --- a/torch/csrc/jit/passes/onnx/onnx_log.cpp +++ b/torch/csrc/jit/passes/onnx/onnx_log.cpp @@ -1,9 +1,7 @@ #include #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { namespace { bool log_enabled = false; @@ -26,6 +24,4 @@ std::ostream& _get_log_output_stream() { return out ? *out : std::cout; } -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/onnx_log.h b/torch/csrc/jit/passes/onnx/onnx_log.h index a659a122342760..b3343df4c6e387 100644 --- a/torch/csrc/jit/passes/onnx/onnx_log.h +++ b/torch/csrc/jit/passes/onnx/onnx_log.h @@ -4,9 +4,7 @@ #include #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { TORCH_API bool is_log_enabled(); @@ -22,6 +20,4 @@ TORCH_API std::ostream& _get_log_output_stream(); << ::c10::str(__VA_ARGS__) << std::endl; \ } -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 9ccea955aee1b7..407a8c79dfb7fe 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -23,8 +23,7 @@ typedef SSIZE_T ssize_t; #endif -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -1068,5 +1067,4 @@ void PeepholeOptimizeONNX( GRAPH_DUMP("After PeepholeOptimizeONNX", graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/peephole.h b/torch/csrc/jit/passes/onnx/peephole.h index 7d23267310ab85..3a3819974d5625 100644 --- a/torch/csrc/jit/passes/onnx/peephole.h +++ b/torch/csrc/jit/passes/onnx/peephole.h @@ -2,13 +2,11 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { void PeepholeOptimizeONNX( std::shared_ptr& graph, int opset_version, bool fixed_batch_size); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp index 63ad0040237e25..f7c0e0a7cac0c0 100644 --- a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { // onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0, // so before converting the ints to tensors we need to cast them to floats. @@ -43,5 +42,4 @@ void PrepareDivisionForONNX(const std::shared_ptr& graph) { GRAPH_DUMP("After PrepareDivisionForONNX: ", graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h index 1bfd8eaef311d2..b9e25861b4778c 100644 --- a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h +++ b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { // Prepare division ops for ONNX export. This is necessary for and only used // by ONNX export. @@ -15,5 +14,4 @@ namespace jit { // TORCH_API void PrepareDivisionForONNX(const std::shared_ptr& graph); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp index 89c3a42a9c0814..5f35a85b2aa891 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -229,5 +228,4 @@ void PreprocessForONNX(std::shared_ptr& graph) { GRAPH_DUMP("After fuseListAndListUnpack: ", graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h index dc180a59c1b6b7..541e4339768e23 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { void PreprocessForONNX(std::shared_ptr& graph); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index 8e0258a3049e02..4fa3068c3d1e46 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -13,8 +13,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { namespace { @@ -368,7 +367,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { } } -static void PrepareForRemoveMutations(std::shared_ptr graph) { +static void PrepareForRemoveMutations(const std::shared_ptr& graph) { MutationRemover mr(graph); PrepareForRemoveMutations(mr, graph->block()); GRAPH_DUMP("After PrepareForRemoveMutations: ", graph); @@ -438,23 +437,23 @@ std::string InplaceConverter::ValueTracker::toString() const { // ss << "Current graph: " << graph_->toString() << std::endl; ss << "Tracking " << value_to_sorted_aliases_.size() << " individual values." - << std::endl; - ss << "value_to_sorted_aliases_: " << std::endl; + << '\n'; + ss << "value_to_sorted_aliases_: " << '\n'; size_t idx = 0; for (const auto& it : value_to_sorted_aliases_) { - ss << "Value[" << idx << "]: " << it.first->debugName() << std::endl; + ss << "Value[" << idx << "]: " << it.first->debugName() << '\n'; ss << " Mapping to "; for (auto v : it.second) { ss << v->debugName() << " "; } - ss << std::endl; + ss << '\n'; idx++; } - ss << "alias_to_value_: " << std::endl; + ss << "alias_to_value_: " << '\n'; for (auto it : alias_to_value_) { ss << " Alias " << it.first->debugName(); - ss << " map to " << it.second->debugName() << std::endl; + ss << " map to " << it.second->debugName() << '\n'; } return ss.str(); @@ -890,5 +889,4 @@ void RemoveInplaceOpsForONNX( ic.convertMutationForONNX(); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h index 6afd28268bdeb9..996c7f80a6c1f6 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h @@ -2,12 +2,10 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API void RemoveInplaceOpsForONNX( const std::shared_ptr& graph, Module* model); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 009566499275b4..5369dedac9762b 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -479,5 +478,4 @@ void ScalarTypeAnalysisNodeForONNX(Node* n) { ImplicitCastNodeForONNX(n); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.h b/torch/csrc/jit/passes/onnx/scalar_type_analysis.h index 90a433a3ce47ed..8c5051220e84f6 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.h +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API void ScalarTypeAnalysisForONNX( const std::shared_ptr& graph, @@ -11,5 +10,4 @@ TORCH_API void ScalarTypeAnalysisForONNX( int opset_version); void ScalarTypeAnalysisNodeForONNX(Node* n); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 294d0682246e9f..9a6dcd9bd761dd 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -22,16 +22,15 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { inline bool PyNone_Check(PyObject* o) { return o == Py_None; } std::pair MergeInferredType( - TypePtr existing_type, - TypePtr inferred_type) { + const TypePtr& existing_type, + const TypePtr& inferred_type) { auto new_list_type = inferred_type->cast(); auto use_inferred_type = false; if (new_list_type) { @@ -75,8 +74,8 @@ std::pair MergeInferredType( void MergeInferredTypeAndSetMap( Value* dest_v, - TypePtr existing_type, - TypePtr inferred_type) { + const TypePtr& existing_type, + const TypePtr& inferred_type) { auto [mergedType, inferred] = MergeInferredType(existing_type, inferred_type); dest_v->setType(mergedType); ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred); @@ -256,7 +255,7 @@ bool CustomSettype(Node* node) { Value* CloneValueFromListConstruct( Value* v, - std::shared_ptr n_graph, + const std::shared_ptr& n_graph, int opset_version) { auto lc_node = v->node(); TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct); @@ -355,7 +354,7 @@ Node* CloneNodeToGraph( return clone_node; } -bool HasValidType(TypePtr type, std::string name) { +bool HasValidType(const TypePtr& type, const std::string& name) { if (auto t_type = type->cast()) { if (!t_type->scalarType().has_value()) { GRAPH_UPDATE("Input ", name, " is missing tensor datatype."); @@ -371,7 +370,7 @@ bool HasValidType(TypePtr type, std::string name) { return true; } -bool IsGraphValidForInference(std::shared_ptr graph) { +bool IsGraphValidForInference(const std::shared_ptr& graph) { // Verify if every input has type (either Tensor, Sequence or Optional) and // scalar type. This is a requirement for ONNX graph inputs. for (auto in : graph->inputs()) { @@ -381,7 +380,7 @@ bool IsGraphValidForInference(std::shared_ptr graph) { } void ConvertGraphToONNXProto( - std::shared_ptr graph, + const std::shared_ptr& graph, std::shared_ptr& model_proto, SymbolDimMap& symbol_dim_map, DimSymbolMap& dim_symbol_map, @@ -1652,7 +1651,8 @@ void SpecialPostProcess(Node* n) { auto seq_node = n->input(0)->node(); auto t_type = n->input(1)->type()->cast(); - auto update_sequence_empty_dtype = [](Node* n, TensorTypePtr t_type) { + auto update_sequence_empty_dtype = [](Node* n, + const TensorTypePtr& t_type) { TORCH_INTERNAL_ASSERT(n && n->kind() == ::c10::onnx::SequenceEmpty); TORCH_INTERNAL_ASSERT(t_type && t_type->scalarType().has_value()); auto scalar_type = t_type->scalarType().value(); @@ -1711,7 +1711,7 @@ void SpecialPostProcess(Node* n) { return nullptr; }; return find_sequence_empty_impl( - input, t_type, find_sequence_empty_impl); + input, std::move(t_type), find_sequence_empty_impl); }; if (seq_node && t_type && t_type->scalarType()) { @@ -2122,9 +2122,9 @@ void ONNXShapeTypeInference( case ::c10::onnx::Gather: { auto* schema_registry = onnx::OpSchemaRegistry::Instance(); onnx::ShapeInferenceOptions options{ - /*check_type=*/false, - /*error_mode=*/false, - /*enable_data_propagation=*/true}; + /*check_type_val=*/false, + /*strict_mode_val=*/0, + /*data_prop_val=*/true}; onnx::shape_inference::InferShapes( *model_proto, schema_registry, options, &inferred_shape_data); break; @@ -2509,5 +2509,4 @@ void UpdateShapeConstantIfReliable(torch::jit::Value* node_output) { } } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index 685ca39c16dec1..bca534654febb8 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -6,8 +6,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { // Merges existing_type and inferred_type. // Returns {merged type, whether or not inferred_type was used}. @@ -28,13 +27,13 @@ namespace jit { // ONNX represents list of scalars by 1-d Tensor. Return inferred type since // it is more compatible with ONNX. std::pair MergeInferredType( - TypePtr existing_type, - TypePtr inferred_type); + const TypePtr& existing_type, + const TypePtr& inferred_type); void MergeInferredTypeAndSetMap( Value* dest_v, - TypePtr existing_type, - TypePtr inferred_type); + const TypePtr& existing_type, + const TypePtr& inferred_type); // Update graph input types with dynamic axes info. // Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol. @@ -96,5 +95,4 @@ void UpdateReliable( void UpdateReliable(torch::jit::Node* n); void UpdateShapeConstantIfReliable(torch::jit::Value* output); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index e34e06fbcbfaf9..7a4f95ec69763a 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -14,8 +14,8 @@ #include using ::c10::Dispatcher; -namespace torch { -namespace jit { + +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -765,5 +765,4 @@ void insertPermutes( GRAPH_DUMP("After insertPermutes: ", graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h index aa9b7af63478c0..70a99b4ef1859b 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h @@ -6,8 +6,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API void UnpackQuantizedWeights( std::shared_ptr& graph, @@ -15,5 +14,4 @@ TORCH_API void UnpackQuantizedWeights( TORCH_API void insertPermutes( std::shared_ptr& graph, std::map& paramsDict); -} // namespace jit -} // namespace torch +} // namespace torch::jit From b0ed5df855d7d81469271c67e065cdf2bc53094d Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 7 Sep 2024 19:54:40 -0700 Subject: [PATCH 0612/1018] [inductor] Refactor LoopBody.memory_usage (#135286) This is preparing for some other changes where I speed up extract_read_writes tracing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135286 Approved by: https://github.com/oulgen --- torch/_inductor/codegen/cpp.py | 8 +-- torch/_inductor/ir.py | 4 +- torch/_inductor/loop_body.py | 125 +++++++++++++++++++++++---------- 3 files changed, 95 insertions(+), 42 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 6d7f38777b3bed..418bb43836f2f4 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4346,9 +4346,9 @@ def is_all_write_read_contiguous(): ): contiguous_index_expr += stride * var stride *= range - write_index_expr = scheduler_node._body.writes_name2expr[ + write_index_expr = scheduler_node._body.get_write_expr( scheduler_buffer.get_name() - ] + ) def is_contiguous_index(x): return x == contiguous_index_expr @@ -4356,9 +4356,9 @@ def is_contiguous_index(x): return is_contiguous_index(write_index_expr) and all( isinstance(user.node, SchedulerNode) and is_contiguous_index( - user.node._body.reads_name2expr[ + user.node._body.get_read_expr( scheduler_buffer.get_name() - ], + ), ) for user in scheduler_buffer.users ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 6d09d6592a0e13..ef8ce7b36bbc30 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3773,9 +3773,9 @@ def simplify_and_reorder( ] index_formulas += extra_indexing_expr - memory_addrs = [*body.writes_name2expr.values()] + memory_addrs = [*body.get_write_exprs()] if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): - memory_addrs.extend(body.reads_name2expr.values()) + memory_addrs.extend(body.get_read_exprs()) def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): sizes, reindex0, reindex1 = self._apply_loop_reordering( diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 78f043d14ac1d1..046ef11b61e307 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -4,7 +4,8 @@ import functools import itertools import re -from typing import Any, Callable, Dict, List, Sequence, Tuple +from enum import auto, Enum +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple import sympy @@ -44,6 +45,23 @@ def run(self, *args, **kwargs): return super().run(*args, **kwargs) +class MemoryEntry(NamedTuple): + index_name: str # LoopBody.indexing_exprs[index_name] + buffer_name: Optional[str] + mode: Optional[str] # V.ops.store(..., mode=mode) + + +class MemoryUsageType(Enum): + # These are 1:1 with the opcode generating the usage + LOAD = auto() + LOAD_SEED = auto() + STORE = auto() + STORE_REDUCTION = auto() + INDEX_EXPR = auto() + CHECK_BOUNDS = auto() + BUCKETIZE = auto() + + class LoopBody: """ Captures the body of a Loops subclass into an FX graph. Persists any @@ -52,13 +70,12 @@ class LoopBody: indexing_exprs: Dict[str, sympy.Expr] indexing_exprs_name: Dict[sympy.Expr, str] - reads_name2expr: Dict[str, sympy.Expr] - writes_name2expr: Dict[str, sympy.Expr] submodules: Dict[str, Any] subblocks: Dict[str, LoopBodyBlock] indirect_vars: List[str] indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] root_block: LoopBodyBlock + memory_usage: Dict[MemoryUsageType, List[MemoryEntry]] def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): super().__init__() @@ -84,12 +101,11 @@ def _init_with_tracing(self, fn, args): """Do an FX trace of an arbitrary callable to construct self""" self.indexing_exprs = {} self.indexing_exprs_name = {} - self.reads_name2expr = {} - self.writes_name2expr = {} self.submodules = {"get_index": self.get_index} self.subblocks = {} self.indirect_vars = [] self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} + self.memory_usage = {t: [] for t in MemoryUsageType} self.root_block = LoopBodyBlock(self, fn, args) # traces del self.indexing_exprs_name # not used after _init_with_tracing @@ -99,24 +115,11 @@ def _init_with_copy(self, other: LoopBody, args): where we are just reordering/merging/splitting the args of an existing LoopBody. """ - indexing_exprs = other.indexing_from_args(args) - update_index = { - expr: indexing_exprs[name] for name, expr in other.indexing_exprs.items() - } - assert indexing_exprs.keys() == other.indexing_exprs.keys() and len( - update_index - ) == len(indexing_exprs) - - self.indexing_exprs = indexing_exprs - self.reads_name2expr = { - k: update_index[v] for k, v in other.reads_name2expr.items() - } - self.writes_name2expr = { - k: update_index[v] for k, v in other.writes_name2expr.items() - } + self.indexing_exprs = other.indexing_from_args(args) self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} - self.indirect_vars = [*other.indirect_vars] - self.indirect_var_ranges = {**other.indirect_var_ranges} + self.indirect_vars = other.indirect_vars + self.indirect_var_ranges = other.indirect_var_ranges + self.memory_usage = other.memory_usage self.root_block = other.root_block.clone(self) submodules = {**other.submodules} @@ -249,6 +252,37 @@ def bounds(self): return BoundVars(self) + def get_read_expr(self, buffer_name): + # reversed to match old behavior + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_write_expr(self, buffer_name): + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_read_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in self.memory_usage[MemoryUsageType.LOAD] + ] + + def get_write_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ) + ] + def debug_str(self): lines = [f"var_ranges = {dict(self.var_ranges)}"] lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) @@ -264,14 +298,20 @@ def debug_str(self): __repr__ = debug_str - def add_index_expr(self, expr: sympy.Expr, category, buf_name): - if buf_name is not None: - getattr(self, f"{category}_name2expr")[buf_name] = expr - if expr not in self.indexing_exprs_name: + def add_index_expr( + self, + expr: sympy.Expr, + mtype: MemoryUsageType, + buffer_name: Optional[str] = None, + mode: Optional[str] = None, + ): + name = self.indexing_exprs_name.get(expr) + if not name: name = f"index{len(self.indexing_exprs)}" self.indexing_exprs_name[expr] = name self.indexing_exprs[name] = expr - return self.indexing_exprs_name[expr] + self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode)) + return name def add_submodule(self, block, prefix): """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" @@ -359,11 +399,11 @@ class LoopBodyBlock: def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): self.body = body - def add_index(expr, category, buf_name=None): + def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): return tracer.create_proxy( "call_module", "get_index", - (self.body.add_index_expr(expr, category, buf_name),), + (body.add_index_expr(expr, mtype, **kwargs),), {}, ) @@ -371,15 +411,26 @@ class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] self.name = "CaptureIndexing" def load(self, name: str, index: sympy.Expr): - index = add_index(index, "reads", name) + index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) return self._inner.load(name, index) + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + def store(self, name, index, value, mode=None): - index = add_index(index, "writes", name) + index = add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) return self._inner.store(name, index, value, mode) def store_reduction(self, name, index, value): - index = add_index(index, "writes", name) + index = add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) return self._inner.store_reduction(name, index, value) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -391,12 +442,12 @@ def reduction(self, dtype, src_dtype, reduction_type, value): def index_expr(self, index, dtype): if isinstance(index, (int, sympy.Integer)): return self._inner.constant(int(index), dtype) - index = add_index(index, "other") + index = add_index(index, MemoryUsageType.INDEX_EXPR) return self._inner.index_expr(index, dtype) def check_bounds(self, index, size, lower, upper): - index = add_index(index, "other") - size = add_index(size, "other") + index = add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = add_index(size, MemoryUsageType.CHECK_BOUNDS) return self._inner.check_bounds(index, size, lower, upper) def bucketize( @@ -407,7 +458,9 @@ def bucketize( indexing_dtype: torch.dtype, right: bool, ): - offsets_size = add_index(offsets_size, "other") + offsets_size = add_index( + offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name + ) return self._inner.bucketize( values, offsets_name, offsets_size, indexing_dtype, right ) From 1fc02a60754dc35ec48570a4730b825cb405b9de Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 7 Sep 2024 19:54:40 -0700 Subject: [PATCH 0613/1018] [inductor] Remove ReadWrites.op_counts (#135306) This was (almost) unused. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135306 Approved by: https://github.com/oulgen ghstack dependencies: #135286 --- torch/_inductor/codegen/cpp.py | 13 ++---------- torch/_inductor/dependencies.py | 36 +++------------------------------ torch/_inductor/loop_body.py | 18 +++++++++++++++++ torch/_inductor/scheduler.py | 10 --------- 4 files changed, 23 insertions(+), 54 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 418bb43836f2f4..e1d54762d9e8fd 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3649,17 +3649,8 @@ def legalize_lowp_fp_dtype(self, nodes): for _node in nodes: assert isinstance(_node, SchedulerNode) assert isinstance(_node._body, LoopBody) - node: SchedulerNode = _node - - def is_memory_copy_scheduler_node(node: SchedulerNode): - op_counts = node.read_writes.op_counts - return ( - len(op_counts) == 2 and "load" in op_counts and "store" in op_counts - ) - - should_legalize = not is_memory_copy_scheduler_node(node) - if should_legalize: - body: LoopBody = node._body + body: LoopBody = _node._body + if not body.is_memory_copy(): self.legalize_lowp_fp_dtype_loopbody(body) def codegen_functions(self, fn_list, var_sizes_list): diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 851ff2d648fe35..adb560ace3a039 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import abc -import collections import dataclasses import itertools import logging @@ -356,9 +355,6 @@ class ReadWrites: index_exprs: OrderedSet[IndexExprDep] range_vars: Optional[List[sympy.Expr]] = None var_ranges: Optional[VarRanges] = None - op_counts: typing.Counter[str] = dataclasses.field( - default_factory=collections.Counter - ) def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": return ReadWrites( @@ -367,7 +363,6 @@ def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": self.index_exprs, self.range_vars, self.var_ranges, - op_counts=self.op_counts, ) def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites": @@ -380,28 +375,20 @@ def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites": self.index_exprs, self.range_vars, self.var_ranges, - op_counts=self.op_counts, ) def merge(self, other: "ReadWrites"): reads = OrderedSet.union(self.reads, other.reads) writes = OrderedSet.union(self.writes, other.writes) index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs) - op_counts = collections.Counter(self.op_counts) - op_counts.update(other.op_counts) - return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts) + return ReadWrites(reads - writes, writes, index_exprs) @staticmethod def merge_list(read_writes: List["ReadWrites"]): all_writes = OrderedSet.union(*[rw.writes for rw in read_writes]) all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes]) - - op_counts: typing.Counter[Any] = collections.Counter() - for rw in read_writes: - op_counts.update(rw.op_counts) - - return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts) + return ReadWrites(all_reads, all_writes, all_index_exprs) def remove_reads(self, rem_reads): return ReadWrites( @@ -410,7 +397,6 @@ def remove_reads(self, rem_reads): self.index_exprs, self.range_vars, self.var_ranges, - op_counts=self.op_counts, ) def reads_and_writes(self): @@ -461,7 +447,6 @@ def _normalize( # Try to further simplify the indexes even if simplify_loops didn't # convert it to the simplest form because of the interference from # different indexing formulas. - free_symbols = index.free_symbols index_vars = [*var_ranges.keys()] sizes = tuple(var_ranges.values()) # type: ignore[assignment] new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( @@ -531,25 +516,11 @@ def bucketize( return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})" -class _OpCounter: - """Shim to count how many times each op is used""" - - def __init__(self, inner) -> None: - super().__init__() - self.parent_handler = inner - self._op_counts: typing.Counter[Any] = collections.Counter() - - def __getattr__(self, name): - self._op_counts[name] += 1 - return getattr(self.parent_handler, name) - - class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: parent_handler = _RecordLoadStoreInner( var_ranges=var_ranges, normalize=normalize ) - parent_handler = _OpCounter(parent_handler) super().__init__(parent_handler=parent_handler) @@ -603,14 +574,13 @@ def extract_read_writes( else: range_vars = list(itertools.chain.from_iterable(args)) - inner = rw.parent_handler.parent_handler + inner = rw.parent_handler return ReadWrites( OrderedSet(inner._reads), OrderedSet(inner._writes), inner._index_exprs, range_vars, var_ranges, - rw.parent_handler._op_counts, ) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 046ef11b61e307..dcf0b61ff5adac 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -296,6 +296,18 @@ def debug_str(self): ) return "\n".join(lines) + def is_memory_copy(self) -> bool: + """ + True of this contains only a single loads and store. + Note, this could involve a layout change. + """ + return ( + len(self.memory_usage[MemoryUsageType.LOAD]) == 1 + and len(self.memory_usage[MemoryUsageType.STORE]) == 1 + and len(self.submodules) == 1 # get_index + and self.root_block.contains_only_ops(("load", "store")) + ) + __repr__ = debug_str def add_index_expr( @@ -565,6 +577,12 @@ def debug_str(self, name="block"): code.strip().replace("def forward(", f"def {name}("), ) + def contains_only_ops(self, allowed_ops) -> bool: + return all( + node.target in allowed_ops + for node in self.graph.find_nodes(op="call_method") + ) + def clone(self, body: LoopBody): """Shallow copy with a new parent LoopBody""" copy = LoopBodyBlock.__new__(LoopBodyBlock) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 9e991fc26be605..6157c02aa5465d 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -275,9 +275,6 @@ def mark_run(self) -> None: for buf in self.outputs: buf.allocate() - def op_counts(self) -> Counter[str]: - return self.read_writes.op_counts - def used_buffer_names(self) -> OrderedSet[str]: return OrderedSet( dep.name @@ -1219,13 +1216,6 @@ def get_device(self) -> torch.device: def has_aliasing_or_mutation(self) -> bool: return any(x.has_aliasing_or_mutation() for x in self.snodes) - @cache_on_self - def op_counts(self) -> Counter[str]: - op_counts: Counter[str] = collections.Counter() - for node in self.snodes: - op_counts.update(node.op_counts()) - return op_counts - # None of these need to be implemented, as a FusedSchedulerNode is just an # abstraction for scheduling purposes def update_mutated_names(self, renames: Dict[str, str]) -> None: From 21a57c79bbb3134f7f2b01447aa9565813d17d10 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 7 Sep 2024 19:54:40 -0700 Subject: [PATCH 0614/1018] [inductor] Fast path for extract_read_writes without tracing (#135377) Before (bottom of stack): ![image](https://github.com/user-attachments/assets/13060ff9-b31d-42a9-8e8f-c50b2bf3dc2f) After (this PR): ![image](https://github.com/user-attachments/assets/7d190821-b614-46b7-9e9e-9087443df654) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135377 Approved by: https://github.com/oulgen ghstack dependencies: #135286, #135306 --- torch/_inductor/dependencies.py | 43 +++++++++++++++++++++++++++++---- torch/_inductor/ir.py | 30 +++++++++-------------- torch/_inductor/scheduler.py | 9 +++---- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index adb560ace3a039..95a48281977c87 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -563,18 +563,51 @@ def extract_read_writes( *argsizes: Tuple[sympy.Expr, ...], normalize: bool = False, prefix: str = "d", + hidden_args=(), ): args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) - rw = RecordLoadStore(var_ranges, normalize=normalize) - with V.set_ops_handler(rw): - fn(*args) + + from .loop_body import LoopBody, MemoryUsageType + + if isinstance(fn, LoopBody): + # Fast path to avoid tracing when we already have a LoopBody + inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize) + name_to_index = fn.indexing_from_args([*args, *hidden_args]) + if fn.indirect_vars: + # mimic the `tmpX` naming tracing gives us + repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)} + name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} + for entry in fn.memory_usage[MemoryUsageType.LOAD]: + inner.load(entry.buffer_name, name_to_index[entry.index_name]) + for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) + for entry in fn.memory_usage[MemoryUsageType.STORE]: + inner.store( + entry.buffer_name, name_to_index[entry.index_name], None, entry.mode + ) + for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: + inner.store_reduction( + entry.buffer_name, name_to_index[entry.index_name], None + ) + for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: + inner.index_expr(name_to_index[entry.index_name], None) + for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: + inner.bucketize( + None, entry.buffer_name, name_to_index[entry.index_name], None, None + ) + # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped + else: + # Slow path tracing the function + rw = RecordLoadStore(var_ranges, normalize=normalize) + with V.set_ops_handler(rw): + fn(*args, *hidden_args) + inner = rw.parent_handler if normalize: range_vars = [] # Number of vars could differ due to normalization else: - range_vars = list(itertools.chain.from_iterable(args)) + range_vars = [*itertools.chain.from_iterable(args)] - inner = rw.parent_handler return ReadWrites( OrderedSet(inner._reads), OrderedSet(inner._writes), diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index ef8ce7b36bbc30..8c0ec1042ead14 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3645,12 +3645,6 @@ def get_fill_order(self): self.data.get_pointwise_size(), self.data.get_reduction_size() ) reads = self.get_read_writes().reads - reads_bufs = [ - V.graph.name_to_buffer[r.name] - if r.name in V.graph.name_to_buffer.keys() - else None - for r in reads - ] # only consider reads to buffer of same size # ignore StarDeps because they don't contribute stride information assert all( @@ -6419,19 +6413,17 @@ def num_reads(self): if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)): return 1 if isinstance(data, ComputedBuffer): - read_writes = data.get_read_writes() - else: - assert isinstance(data, (Pointwise, Reduction)), type(data) - read_writes = ComputedBuffer( - name=None, - layout=FlexibleLayout( - device=data.get_device(), - dtype=data.get_dtype(), - size=data.get_size(), - ), - data=data, - ).get_read_writes() - return len(read_writes.reads) + return data.num_reads() + assert isinstance(data, (Pointwise, Reduction)), type(data) + return ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ).num_reads() @cache_on_self def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 6157c02aa5465d..16952d0b9d55aa 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -982,16 +982,15 @@ def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: log.fatal("Error in codegen for %s", self.node) raise + @cache_on_self def pointwise_read_writes(self) -> dependencies.ReadWrites: """ Get the memory dependencies in the non-reduction axis. """ sizes, reduction_sizes = self._sizes - - def fn(index: Sequence[sympy.Symbol]) -> str: - return self._body(index, [sympy.Integer(0) for _ in reduction_sizes]) - - return dependencies.extract_read_writes(fn, sizes) + return dependencies.extract_read_writes( + self._body, sizes, hidden_args=[[sympy.Integer(0)] * len(reduction_sizes)] + ) def can_inplace(self, read_dep: dependencies.Dep) -> bool: if self.is_template(): From ca3ab31b5f11e378396d4d9db786b3c2e6416acf Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 7 Sep 2024 19:54:41 -0700 Subject: [PATCH 0615/1018] [inductor] Refactor BaseSchedulerNode.__init__ (#135400) Might be a small compile time improvement since we remove a call to extract_read_writes(). Pull Request resolved: https://github.com/pytorch/pytorch/pull/135400 Approved by: https://github.com/oulgen ghstack dependencies: #135286, #135306, #135377 --- torch/_inductor/scheduler.py | 41 ++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 16952d0b9d55aa..564a9b4ccfd833 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -159,27 +159,27 @@ class BaseSchedulerNode: group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] read_writes: dependencies.ReadWrites unmet_dependencies: OrderedSet[Dep] - - def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. + # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node + # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. + # For non-"grouped" nodes (i.e. regular SchedulerNode), + # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. + min_order: int + max_order: int + + def __init__(self, scheduler: Scheduler) -> None: self.scheduler: Scheduler = scheduler + + def _init_from_node(self, node: ir.Operation) -> None: self.node: Optional[ir.Operation] = node - self.set_read_writes(node.get_read_writes()) self.ancestors: OrderedSet[str] = OrderedSet() - # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. - # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node - # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. - # For non-"grouped" nodes (i.e. regular SchedulerNode), - # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. - self.min_order: int - self.max_order: int self.last_usage: OrderedSet[ str ] = OrderedSet() # buffers that won't be used after this kernel self.written = False - self.outputs: List[SchedulerBuffer] = [ SchedulerBuffer( - scheduler=scheduler, + scheduler=self.scheduler, node=output, defining_op=self, ) @@ -799,6 +799,11 @@ def should_prune(dep: Dep) -> bool: class ExternKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + def debug_str_extra(self) -> str: return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" @@ -811,7 +816,10 @@ def has_side_effects(self) -> bool: class NopKernelSchedulerNode(BaseSchedulerNode): - pass + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) class SchedulerNode(BaseSchedulerNode): @@ -820,7 +828,8 @@ def __init__( scheduler: Scheduler, node: Union[ir.ComputedBuffer, ir.TemplateBuffer], ) -> None: - super().__init__(scheduler, node) + super().__init__(scheduler) + self._init_from_node(node) self._compute_attrs() def _compute_attrs( @@ -1121,7 +1130,7 @@ def reorder_loops_by_dep_pair( refresh_group_node_dependencies(self) def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: - # NB: No need to call super().__init__() because we don't need to re-use any of its logic. + super().__init__(scheduler) init_group_node(self, scheduler, snodes) self.users: List[NodeUser] = [] self.group = max(snodes, key=lambda x: int(x.is_reduction())).group @@ -1590,7 +1599,7 @@ def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode: return grouped_snode def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: - # NB: No need to call super().__init__() because we don't need to re-use any of its logic. + super().__init__(scheduler) init_group_node(self, scheduler, snodes) def unpack(self) -> List[BaseSchedulerNode]: From 10d3a5ed801d4651b7d6010b16fd212a5169ea1d Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 7 Sep 2024 19:54:41 -0700 Subject: [PATCH 0616/1018] [inductor] Cleanup analysis done at lowering time (#135412) Before this we would take multiple passes over the body of each IRNode as we did lowering. This combines most analysis into `OpCounterCSE` so it can be done in a single pass. Before: ![image](https://github.com/user-attachments/assets/0047db09-4258-4491-a9a6-b078e183092a) After: ![image](https://github.com/user-attachments/assets/1e03adcb-8303-4bb1-8bbb-cc42dacd44d7) This stack: ![image](https://github.com/user-attachments/assets/d6b50b24-c30c-4d23-8b1a-344b3ba65d7a) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135412 Approved by: https://github.com/oulgen ghstack dependencies: #135286, #135306, #135377, #135400 --- torch/_inductor/ir.py | 93 ++++++++++++++-------------------- torch/_inductor/lowering.py | 2 +- torch/_inductor/ops_handler.py | 78 +++++++++++++++++++++++----- 3 files changed, 103 insertions(+), 70 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8c0ec1042ead14..0929996fe17452 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -72,7 +72,7 @@ var_builder, ) from .loop_body import LoopBody -from .ops_handler import OpCounterCSE +from .ops_handler import OpCounterCSE, OpCountResult from .runtime.benchmarking import benchmarker from .runtime.hints import ReductionHint from .utils import ( @@ -412,6 +412,7 @@ def codegen_reference(self, writer=None): dtype: torch.dtype get_name: Callable[[], str] get_reads: Callable[[], Any] + num_reads: Callable[[], int] get_stride: Callable[[], Any] get_storage_numel: Callable[[], Any] has_exceeded_max_reads: Callable[[], bool] @@ -555,25 +556,25 @@ def _index(ranges, prefix=SymT.INDEX): ] @cache_on_self - def inner_fn_opcount(self): + def inner_fn_opcount(self) -> OpCountResult: opcounter = OpCounterCSE(V.MockHandler()) - with V.set_ops_handler(opcounter), patch.object( FlexibleLayout, "allow_indexing", True ): self.inner_fn(*self.inner_fn_args()) - return opcounter.op_count + return opcounter.getvalue() def inner_fn_args(self): return (self._index(self.ranges),) + @cache_on_self def inner_fn_str(self): return V.KernelFormatterHandler.ir_to_string( self.inner_fn, *self.inner_fn_args() ) def has_large_inner_fn(self): - return self.inner_fn_opcount() > config.realize_opcount_threshold + return self.inner_fn_opcount().num_ops > config.realize_opcount_threshold def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) @@ -594,7 +595,10 @@ def get_reads(self): ).reads def get_read_names(self) -> OrderedSet[str]: - return OrderedSet(dep.name for dep in self.get_reads()) + return OrderedSet(self.inner_fn_opcount().read_buffers) + + def num_reads(self): + return len(self.inner_fn_opcount().read_buffers) def get_reduction_size(self): raise NotImplementedError( @@ -2682,6 +2686,9 @@ def codegen_reference(self, writer=None): dtype=self.layout.dtype, ) + def num_reads(self): + return 1 + @dataclasses.dataclass class DtypeView(BaseView): @@ -3509,7 +3516,8 @@ def __post_init__(self): class InputBuffer(Buffer): - pass + def num_reads(self): + return 1 class ConstantBuffer(InputBuffer): @@ -3570,9 +3578,11 @@ def get_computed_buffer_name(self): return self.data.name return None - @cache_on_self def num_reads(self): - return len(self.get_read_writes().reads) + return self.data.num_reads() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() def get_read_writes(self): with patch.object(FlexibleLayout, "allow_indexing", True): @@ -4154,6 +4164,9 @@ def unwrap_storage(inputs): def is_extern(self): return True + def num_reads(self): + return 1 + class NopKernel(InputsKernel): def is_no_op(self): @@ -6371,8 +6384,7 @@ def realize_hint(self): """ if ( isinstance(self.data, (Pointwise, Reduction)) - and self.num_reads() > 1 - and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one() + and self.data.inner_fn_opcount().nontrivial_read_count > 1 ): self.realize() @@ -6382,61 +6394,30 @@ def has_exceeded_max_reads(self): or self.has_large_inner_fn() ) - def mark_reuse(self, users): + def should_realize_on_reuse(self, users): """ A heuristic to decide if we should realize a tensor that is used multiple times. """ - - def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): - """ - The heuristic for realizing reused result of heavy ops on cpu - """ - heavy_ops = ["exp", "sigmoid"] # a list of heavy ops - fn_str = loops.inner_fn_str() - return any((op + "(") in fn_str for op in heavy_ops) - - if ( - users > 1 - and isinstance(self.data, (Pointwise, Reduction)) - and ( + if users > 1 and isinstance(self.data, (Pointwise, Reduction)): + if is_cpu(self.data): + # Heuristic for realizing reused result of heavy ops on cpu + opcount = self.data.inner_fn_opcount() + heavy_ops = ["exp", "sigmoid"] # a list of heavy ops + if any(x in opcount.used_ops for x in heavy_ops): + return True + return ( self.num_reads() > config.realize_reads_threshold or self.has_large_inner_fn() - or (is_cpu(self.data) and should_realize_on_cpu(self.data)) ) - ): + return False + + def mark_reuse(self, users): + if self.should_realize_on_reuse(users): self.realize() - @cache_on_self def num_reads(self): - data = self.data - if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)): - return 1 - if isinstance(data, ComputedBuffer): - return data.num_reads() - assert isinstance(data, (Pointwise, Reduction)), type(data) - return ComputedBuffer( - name=None, - layout=FlexibleLayout( - device=data.get_device(), - dtype=data.get_dtype(), - size=data.get_size(), - ), - data=data, - ).num_reads() - - @cache_on_self - def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self): - # Skip the check for non Pointwise instances - return ( - (sum(read.index != 0 for read in self.data.get_reads()) > 1) - if isinstance(self.data, Pointwise) - and all( - not isinstance(read, dependencies.StarDep) - for read in self.data.get_reads() - ) - else True - ) + return self.data.num_reads() @dataclasses.dataclass diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 84120d9f9933f0..adf921913edfcd 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1481,7 +1481,7 @@ def op_count(x): if not isinstance(x, ir.Pointwise): return 0 - count = x.inner_fn_opcount() + count = x.inner_fn_opcount().num_ops for read in x.get_read_names(): count += op_count(V.graph.get_buffer(read)) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index bf8f0d3d6370ca..c47ee1026ab919 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -5,7 +5,9 @@ Callable, Dict, Generic, + List, Literal, + NamedTuple, Optional, Tuple, TypeVar, @@ -19,6 +21,7 @@ import torch import torch.utils._pytree as pytree +from ..utils._ordered_set import OrderedSet from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str @@ -955,6 +958,13 @@ def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: return h +class OpCountResult(NamedTuple): + num_ops: int + used_ops: OrderedSet[str] + read_buffers: List[str] + nontrivial_read_count: int + + class OpCounterCSE: """Shim to count how many ops are used""" @@ -963,25 +973,67 @@ def __init__(self, inner): self.parent_handler = inner self.op_count = 0 self.var_names = {} + self._used_ops: OrderedSet[str] = OrderedSet() + self._read_names: List[str] = [] + self._nontrivial_read_count = 0 def __getattr__(self, name): def inner(*args, **kwargs): - val = getattr(self.parent_handler, name)(*args, **kwargs) - if name == "indirect_indexing": - return val + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) - def count(val): - if val not in self.var_names: - varname = f"tmp{self.op_count}" - self.op_count += 1 - self.var_names[val] = varname - return varname - else: - return self.var_names[val] + self._used_ops.add(name) + return inner - return pytree.tree_map(count, val) + def _update_count(self, val): + varname = self.var_names.get(val) + if not varname: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + + def indirect_indexing(self, *args, **kwargs): + self._used_ops.add("indirect_indexing") + return self.parent_handler.indirect_indexing(*args, **kwargs) + + def load(self, name: str, index: sympy.Expr) -> str: + val = self.parent_handler.load(name, index) + if val not in self.var_names: + self._used_ops.add("load") + self._read_names.append(name) + if not isinstance(index, (sympy.Integer, int)): + self._nontrivial_read_count += 1 + return self._update_count(val) - return inner + def load_seed(self, name: str, offset: T): + val = self.parent_handler.load_seed(name, offset) + if val not in self.var_names: + self._used_ops.add("load_seed") + self._read_names.append(name) + return self._update_count(val) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + val = self.parent_handler.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + if val not in self.var_names: + self._used_ops.add("bucketize") + self._read_names.append(offsets_name) + return self._update_count(val) + + def getvalue(self): + return OpCountResult( + self.op_count, self._used_ops, self._read_names, self._nontrivial_read_count + ) def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: From c2ad5dfc691fedd3e49525b115a44ce334f639d1 Mon Sep 17 00:00:00 2001 From: yuqingj Date: Sat, 7 Sep 2024 20:41:46 -0700 Subject: [PATCH 0617/1018] [NJT]Add permute ops support (#135336) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135336 Approved by: https://github.com/davidberard98 --- test/test_nestedtensor.py | 31 ++++++++++++++++++++++++++++ torch/nested/_internal/ops.py | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index d9fcd7caaf0c77..49cb422f8720cc 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5955,6 +5955,37 @@ def check_nt_equality(x, y): self.assertFalse(clone.is_contiguous()) check_nt_equality(detached, transposed) + def test_permute(self, device): + nt = random_nt_from_dims( + [2, None, 3, 5], device, torch.float32, layout=torch.jagged + ) + nt_shape = nt.shape + nt_inner_shape = nt.values().shape + with self.assertRaisesRegex( + ValueError, + r"permute\(\): number of dimensions in the tensor input \(4\) " + + r"does not match the length of the desired ordering of dimensions \(3\).", + ): + nt.permute(0, 2, 1) + with self.assertRaisesRegex( + ValueError, r"permute\(\): duplicate dims are not allowed." + ): + nt.permute(0, 2, -2, 3) + with self.assertRaisesRegex( + ValueError, "Permute is not supported on the batch dimension for jagged NT" + ): + nt.permute(1, 0, 2, 3) + nt_permute = nt.permute(0, 2, 1, -1) + self.assertEqual( + nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3]) + ) + self.assertEqual( + nt_permute.values().shape, + (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]), + ) + self.assertEqual(nt_permute._ragged_idx, 2) + self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt) + def test_to_dtype(self, device): nt = random_nt_from_dims( [2, None, 3], device, torch.float32, layout=torch.jagged diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 0d203270b820b2..ed9a54f9dca932 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1159,6 +1159,44 @@ def transpose_int(func, *args, **kwargs): return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) +@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any") +def permute_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + dims = new_kwargs.pop("dims") + inp_kwargs = extract_kwargs(inp) + inp_dim = len(inp._size) + + # The first two checks are the same as the checks in the normal permute implementation + if inp_dim != len(dims): + raise ValueError( + f"permute(): number of dimensions in the tensor input ({inp_dim}) " + + f"does not match the length of the desired ordering of dimensions ({len(dims)}).", + ) + + from torch._prims_common import canonicalize_dims + + canonicalized_dims = canonicalize_dims(inp_dim, dims) + + if len(canonicalized_dims) != len(set(canonicalized_dims)): + raise ValueError("permute(): duplicate dims are not allowed.") + + if inp._lengths is not None: + raise ValueError( + "permute(): not supported on jagged layout nested tensor with holes" + ) + if canonicalized_dims[0] != 0: + raise ValueError( + "Permute is not supported on the batch dimension for jagged NT" + ) + inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx) + inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]] + new_kwargs["dims"] = inner_dims + return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs) + + @register_jagged_func( [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default], "self: jt_all, size: any", From 0cfe5e782d4d3aa05e1cf0579c2a1411cb0ad3c6 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 9 Sep 2024 03:10:54 +0000 Subject: [PATCH 0618/1018] [inductor] calibration inductor windows uts (18/N) (#135449) skip test_quantized_* UTs of `test/inductor/test_cpu_select_algorithm.py`. Windows inductor don't support quantize so far. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135449 Approved by: https://github.com/ezyang --- test/inductor/test_cpu_select_algorithm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index a3ff2b3590aaae..19f518387a398c 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -24,7 +24,12 @@ from torch.testing._internal.common_quantized import ( _calculate_dynamic_per_channel_qparams, ) -from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL +from torch.testing._internal.common_utils import ( + IS_MACOS, + parametrize, + skipIfWindows, + TEST_MKL, +) log = logging.getLogger(__name__) @@ -1055,6 +1060,7 @@ def forward(self, arg152_1): "gelu", ), ) + @skipIfWindows(msg="Windows don't support quantize.") def test_quantized_linear_with_pointwise( self, batch_size, in_features, out_features, bias, input_3d, dtype, epilogue ): @@ -1172,6 +1178,7 @@ def forward(self, x, scale): "relu", ), ) + @skipIfWindows(msg="Windows don't support quantize.") def test_quantized_linear_with_pointwise_binary( self, batch_size, @@ -1252,6 +1259,7 @@ def forward(self, x, other, other2): @parametrize("in_features", (4, 68, 128)) # k should be a multiple of 4 @parametrize("out_features", (64, 65)) @parametrize("bias", (True, False)) + @skipIfWindows(msg="Windows don't support quantize.") def test_quantized_linear_amx(self, batch_size, in_features, out_features, bias): class M(torch.nn.Module): def __init__(self, bias): From 18e391887f22b8d6da8f23cad6afcf35424f5d41 Mon Sep 17 00:00:00 2001 From: Alexander Kurakin Date: Mon, 9 Sep 2024 03:16:09 +0000 Subject: [PATCH 0619/1018] docs: `torch.nn.utils.rnn.pack_padded_sequence`: docs improve (#135417) docs: `torch.nn.utils.rnn.pack_padded_sequence`: docs improve /cc @mikaylagawarecki Pull Request resolved: https://github.com/pytorch/pytorch/pull/135417 Approved by: https://github.com/ezyang --- torch/nn/utils/rnn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index 09a38e1119de6b..13fa6324833196 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -285,10 +285,10 @@ def pack_padded_sequence( ) -> PackedSequence: r"""Packs a Tensor containing padded sequences of variable length. - :attr:`input` can be of size ``T x B x *`` where ``T`` is the length of the - longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions - (including 0). If :attr:`batch_first` is ``False``, ``T x B x *`` :attr:`input` is expected, - ``B x T x *`` otherwise. + :attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) + or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length + of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions + (including 0). For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is ``True``, the sequences should be sorted by length in a decreasing order, i.e. From 74eb53a359a8bb34f17a0b8902de50f446c9a6ca Mon Sep 17 00:00:00 2001 From: Rafal Litka Date: Mon, 9 Sep 2024 03:32:18 +0000 Subject: [PATCH 0620/1018] enable lazy_init for hpu (#135203) enables lazy_init for hpu device Pull Request resolved: https://github.com/pytorch/pytorch/pull/135203 Approved by: https://github.com/ezyang --- torch/csrc/utils/device_lazy_init.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index 79c05f3c9ada77..c0147977ead29e 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -28,7 +28,8 @@ void set_requires_device_init(at::DeviceType device_type, bool value); inline void maybe_initialize_device(at::Device& device) { // Add more devices here to enable lazy initialization. - if (device.is_cuda() || device.is_xpu() || device.is_privateuseone()) { + if (device.is_cuda() || device.is_xpu() || device.is_privateuseone() || + device.is_hpu()) { device_lazy_init(device.type()); } } From 55aa062b61a1fd2f8f45f31b6e802bc79ae328b4 Mon Sep 17 00:00:00 2001 From: "Zhou, Lingzhi" Date: Mon, 9 Sep 2024 03:34:00 +0000 Subject: [PATCH 0621/1018] [Partitioner] Query whether nodes exist in graph faster (#135316) Find node if exist in graph.nodes (linked list) take too long time. Using graph._find_nodes_lookup_table (hash table) instead to speed up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135316 Approved by: https://github.com/ezyang --- torch/fx/passes/utils/fuser_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 324e8a67801564..11a9cfa34898ab 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -118,7 +118,7 @@ def fuse_as_graphmodule(gm: GraphModule, for node in nodes: assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" assert not node._erased, f"{node} has been removed from owning graph" - assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}" + assert node in gm.graph._find_nodes_lookup_table, f"{node} is not found in graph module {gm._get_name()}" # validates partition doesn't introduce dependency circles in the graph assert validate_partition(nodes), "Invalid partition, found dependency cycles" From dc56ea44ea595ec4f9f8438a7091f85f44f1fe09 Mon Sep 17 00:00:00 2001 From: PhilipMay Date: Mon, 9 Sep 2024 03:34:50 +0000 Subject: [PATCH 0622/1018] Fix return type of `nansum` example. (#135435) One of the examples in the documentation of `torch.nansum` contains a wrong return type. This fixes it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135435 Approved by: https://github.com/ezyang --- torch/_torch_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index bba6dc2e307329..70f7192ee2d73c 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -10665,7 +10665,7 @@ def merge_dicts(*dicts): Example:: >>> torch.nansum(torch.tensor([1., float("nan")])) - 1.0 + tensor(1.) >>> a = torch.tensor([[1, 2], [3., float("nan")]]) >>> torch.nansum(a) tensor(6.) From 1bdf3668e7d10d66d95cc91a1e675793d51d84a3 Mon Sep 17 00:00:00 2001 From: xingyunjohn1 <1298459635@qq.com> Date: Mon, 9 Sep 2024 03:47:32 +0000 Subject: [PATCH 0623/1018] =?UTF-8?q?Fix=20example:=20Address=20broadcasti?= =?UTF-8?q?ng=20error=20in=20the=20addition=20of=20`attn=5Fbias=E2=80=A6?= =?UTF-8?q?=20(#135427)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …` and `attn_mask`, and correct device assignment for newly created variables in the method. Fix example: Address broadcasting error in the addition of `attn_bias` and `attn_mask`, and correct device assignment for newly created variables in the method. 1. Adding `attn_bias += attn_mask` results in a broadcasting error. The expected shape of `attn_bias` is (L, S), so the output should also have the shape (L, S). However, when the input shape is (N, num_heads, L, S), broadcasting occurs, leading to an output shape of (N, num_heads, L, S), which is not desired. 2. `attn_bias` is a newly created variable within the method, but it is not assigned to the correct device. **This is my retry of PR #130209 . The PR has been merged into commit `d4a79d4a7c746068d25fe5cf9333495561f4ce1f`, but the modifications were overwritten by subsequent commits.** Co-authored-by: mikaylagawarecki @mikaylagawarecki provided a more elegant implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135427 Approved by: https://github.com/ezyang --- torch/nn/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 3072ac0fef0f29..9640ca1e76e290 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -5620,7 +5620,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) @@ -5631,7 +5631,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: - attn_bias += attn_mask + attn_bias = attn_mask + attn_bias if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) From 78670d61c8d4c0e110f7ecaaef2c42c863e3b572 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 9 Sep 2024 05:03:29 +0000 Subject: [PATCH 0624/1018] Check function declarations in Caffe2 code (#134925) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134925 Approved by: https://github.com/ezyang --- caffe2/CMakeLists.txt | 2 +- caffe2/perfkernels/embedding_lookup_idx.cc | 4 +++- caffe2/serialize/crc.cc | 1 - caffe2/utils/proto_wrap.h | 23 ++++++++++++++++++++++ 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8552dab106e4cf..2160399a3ea296 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -781,7 +781,7 @@ if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_IOS AND NOT USE_COREML set_source_files_properties(${source_file} PROPERTIES COMPILE_OPTIONS "-Wno-missing-prototypes;-Wno-error=missing-prototypes") continue() endif() - string(FIND "${source_file}" "caffe2" res) + string(FIND "${source_file}" "embedding_lookup_idx_avx2.cc" res) if(res GREATER -1) set_source_files_properties(${source_file} PROPERTIES COMPILE_OPTIONS "-Wno-missing-prototypes;-Wno-error=missing-prototypes") endif() diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index c9b91dc31b880f..5fcf71016aea69 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -6,6 +6,7 @@ #include #include "caffe2/perfkernels/common.h" +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-prototypes") namespace caffe2 { /** @@ -125,7 +126,7 @@ static bool EmbeddingLookupGenericSlowIdx( const float* scale_bias, \ bool normalize_by_lengths, \ OutType* out) { \ - if (std::is_same::value) { \ + if constexpr (std::is_same::value) { \ CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); \ } else { \ CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \ @@ -231,3 +232,4 @@ EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true); #undef EMBEDDING_IDX_SPECIALIZATION } // namespace caffe2 +C10_DIAGNOSTIC_POP() diff --git a/caffe2/serialize/crc.cc b/caffe2/serialize/crc.cc index 7a7173e417fd3b..944d95c2ec3565 100644 --- a/caffe2/serialize/crc.cc +++ b/caffe2/serialize/crc.cc @@ -1,5 +1,4 @@ #include "miniz.h" -#include #include "caffe2/serialize/crc_alt.h" diff --git a/caffe2/utils/proto_wrap.h b/caffe2/utils/proto_wrap.h index 75e2c4be5c190c..4ce0ce5d3b8bfb 100644 --- a/caffe2/utils/proto_wrap.h +++ b/caffe2/utils/proto_wrap.h @@ -9,6 +9,29 @@ namespace caffe2 { // testing and valgrind cases to avoid protobuf appearing to "leak" memory). TORCH_API void ShutdownProtobufLibrary(); +// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() +// function used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); } // namespace caffe2 +namespace ONNX_NAMESPACE { + +// ONNX wrapper functions for protobuf's GetEmptyStringAlreadyInited() function +// used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); + +} // namespace ONNX_NAMESPACE + +namespace torch { + +// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() +// function used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); + +void ShutdownProtobufLibrary(); + +} // namespace torch #endif // CAFFE2_UTILS_PROTO_WRAP_H_ From 640b85c3657fb64fc5a2a9911d8771e7d3130454 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sun, 8 Sep 2024 18:56:51 -0700 Subject: [PATCH 0625/1018] [inductor][cpp][gemm] fix perf regression xcit_large_24_p8_224 (#134686) (#135438) Fix #134686. PR https://github.com/pytorch/pytorch/pull/132729 makes GEMM template faster for one of the GEMMs in xcit_large_24_p8_224: SingleProcess AUTOTUNE benchmarking takes 1.7088 seconds and 1.9207 seconds precompiling AUTOTUNE linear_unary(12544x3072, 768x3072, 768) cpp_packed_gemm_2 2.9371 ms 100.0% _linear_pointwise 3.1584 ms 93.0% But it is slower than Aten in the e2e run due to different cache behavior. The access to the input data (12544x3072) is LLC latency bound and bottlenecks seen due to the memory synchronization (data transfers and coherence updates across processors). This PR tries to mitigate the problem by cooperatively loading different chunks of input data from different processors that share the input data. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135438 Approved by: https://github.com/leslie-fang-intel --- torch/_inductor/codegen/cpp_gemm_template.py | 73 ++++++++++++-------- torch/_inductor/codegen/cpp_prefix.h | 26 ------- 2 files changed, 46 insertions(+), 53 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index b2a2b1c41576d0..9756af21753078 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -86,7 +86,9 @@ ); const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; - const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; + const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; {%- else %} constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; @@ -98,7 +100,9 @@ constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}}; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; - constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; + constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; {%- endif %} // make sure all partitions are assigned @@ -109,8 +113,8 @@ {%- if maybe_k_slicing %} std::unique_ptr[]> local_buf_ptrs; - if (num_k_slices > 1) { - local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]); + if (num_Kt_blocks > 1) { + local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]); } {%- endif %} @@ -118,30 +122,44 @@ #pragma omp parallel num_threads({{num_threads}}) { const int tid = omp_get_thread_num(); - int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; - mm_get_thread_blocks( - tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks, - m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); - {%- if maybe_k_slicing %} - const int64_t k_group_id = tid / num_k_slices; - const int64_t k_slice_id = tid % num_k_slices; - {%- endif %} + const int64_t k_group_id = tid / num_Kt_blocks; + const int64_t k_slice_id = tid % num_Kt_blocks; + const int64_t n_group_id = k_group_id / num_Nt_blocks; + const int64_t n_slice_id = k_group_id % num_Nt_blocks; + const int64_t k_block_start = k_slice_id * Kt_blocks; + const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); + const int64_t n_block_start = n_slice_id * Nt_blocks; + const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); + const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); + const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); + const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; {%- else %} { - const int tid = 0; - const int64_t m_block_start = 0; - const int64_t m_block_end = Mr_blocks; - const int64_t n_block_start = 0; - const int64_t n_block_end = Nr_blocks; - const int64_t k_block_start = 0; - const int64_t k_block_end = Kr_blocks; + constexpr int tid = 0; + constexpr int64_t k_group_id = 0; + constexpr int64_t k_slice_id = 0; + constexpr int64_t n_group_id = 0; + constexpr int64_t n_slice_id = 0; + constexpr int64_t m_block_start = 0; + constexpr int64_t m_block_end = Mr_blocks; + constexpr int64_t n_block_start = 0; + constexpr int64_t n_block_end = Nr_blocks; + constexpr int64_t k_block_start = 0; + constexpr int64_t k_block_end = Kr_blocks; + {%- if is_dynamic_M %} + const int64_t num_Mc_blocks_per_thread = num_Mc_blocks; + {%- else %} + constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; + {%- endif %} {%- endif %} {{ micro_gemm.codegen_init(kernel) }} {%- if use_local_acc %} {%- set acc_buf_name = "local_acc_buf" %} {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} {%- endif %} - for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; + const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; @@ -173,9 +191,10 @@ } } {%- if maybe_k_slicing %} - if (num_k_slices > 1) { + if (num_Kt_blocks > 1) { const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; - local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }}); + local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset( + {{ kernel.release_buffer(acc_buf_name) }}); } else {%- endif %} { @@ -189,14 +208,14 @@ } } {%- if maybe_k_slicing %} - if (num_k_slices > 1) { + if (num_Kt_blocks > 1) { #pragma omp barrier for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { // We slice M-dim and each thread in the k-slicing group works on a slice const int64_t m_start_unsliced = mc * Mr; const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced; - const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices; + const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks; const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced); const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced); const int64_t m_size = m_end - m_start; @@ -206,9 +225,9 @@ const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; - auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get(); - for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) { - auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get(); + auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get(); + for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) { + auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get(); for (int64_t m = m_offset; m < m_offset + m_size; m++) { #pragma omp simd for (int64_t n = 0; n < n_size; n++) { diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index b46e772cd6ce11..8972e5ef522d42 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -884,32 +884,6 @@ void mm_get_cache_blocking( } } -inline void mm_get_thread_blocks( - int thread_id, - int64_t M_blocks, - int64_t N_blocks, - int64_t K_blocks, - int64_t Mt_blocks, - int64_t Nt_blocks, - int64_t Kt_blocks, - int64_t& m_block_start, - int64_t& m_block_end, - int64_t& n_block_start, - int64_t& n_block_end, - int64_t& k_block_start, - int64_t& k_block_end) { - int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks; - k_block_start = (thread_id % num_Kt) * Kt_blocks; - k_block_end = std::min(k_block_start + Kt_blocks, K_blocks); - thread_id /= num_Kt; - int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks; - n_block_start = (thread_id % num_Nt) * Nt_blocks; - n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); - thread_id /= num_Nt; - m_block_start = std::min(thread_id * Mt_blocks, M_blocks); - m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); -} - struct amx_tilecfg { uint8_t palette_id; uint8_t start_row; From fcaed9a0209e94aaa3ca10ddaa871531334f76b1 Mon Sep 17 00:00:00 2001 From: PHLens Date: Mon, 9 Sep 2024 13:27:58 +0000 Subject: [PATCH 0626/1018] Add `__init__.py` to shape inference folder. (#135461) Fixes #135196 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135461 Approved by: https://github.com/ezyang --- torch/fx/experimental/shape_inference/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 torch/fx/experimental/shape_inference/__init__.py diff --git a/torch/fx/experimental/shape_inference/__init__.py b/torch/fx/experimental/shape_inference/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 From 80bf89f0ffec1de03f84fee0584cd96645b0f7e9 Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Mon, 9 Sep 2024 13:28:29 +0000 Subject: [PATCH 0627/1018] Fix wrong error msg (#135423) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135423 Approved by: https://github.com/ezyang --- torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 10a7d20947433a..39717696b0746e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -178,7 +178,7 @@ struct DumpPipe { getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) { return; } - TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty"); + TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_PIPE_FILE is empty"); std::string filename = c10::str(fileStem, rank, ".pipe"); TORCH_CHECK( unlink(filename.c_str()) != -1 || errno == ENOENT, From 7bee08cc1c6ba44f5925cf9e89a492d87b480b07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 9 Sep 2024 15:09:01 +0000 Subject: [PATCH 0628/1018] [ONNX] Improves documentation of ONNX exporter (#135372) The PR updates the documentation to reflect the changes introduced in pytorch 2.5 and related to onnx exporter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135372 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu --- docs/source/onnx.rst | 30 +++++++++++++++++++++++++++++- docs/source/onnx_dynamo.rst | 27 +++++++++++++++++++++++---- torch/onnx/__init__.py | 10 +++++++++- 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index ffaa8ef836b787..cb795cfb11f91a 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -13,7 +13,35 @@ The exported model can be consumed by any of the many `runtimes that support ONNX `_, including Microsoft's `ONNX Runtime `_. -**There are two flavors of ONNX exporter API that you can use, as listed below:** +**There are two flavors of ONNX exporter API that you can use, as listed below.** +Both can be called through function :func:`torch.onnx.export`. +Next example shows how to export a simple model. + +.. code-block:: python + + import torch + + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 128, 5) + + def forward(self, x): + return torch.relu(self.conv1(x)) + + input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32) + + model = MyModel() + + torch.onnx.export( + model, # model to export + (input_tensor,), # inputs of the model, + "my_model.onnx", # filename of the ONNX model + input_names=["input"], # Rename inputs for the ONNX model + dynamo=True # True or False to select the exporter to use + ) + +Next sections introduces the two versions of the exporter. TorchDynamo-based ONNX Exporter ------------------------------- diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst index 89b995687dab8e..9865844f32e6b2 100644 --- a/docs/source/onnx_dynamo.rst +++ b/docs/source/onnx_dynamo.rst @@ -44,6 +44,9 @@ They can be installed through `pip `_: pip install --upgrade onnx onnxscript +`onnxruntime `_ can then be used to execute the model +on a large variety of processors. + A simple example ---------------- @@ -74,9 +77,9 @@ See below a demonstration of exporter API in action with a simple Multilayer Per model = MLPModel() tensor_x = torch.rand((97, 8), dtype=torch.float32) - onnx_program = torch.onnx.dynamo_export(model, tensor_x) + onnx_program = torch.onnx.export(model, (tensor_x,), dynamo=True) -As the code above shows, all you need is to provide :func:`torch.onnx.dynamo_export` with an instance of the model and its input. +As the code above shows, all you need is to provide :func:`torch.onnx.export` with an instance of the model and its input. The exporter will then return an instance of :class:`torch.onnx.ONNXProgram` that contains the exported ONNX graph along with extra information. The in-memory model available through ``onnx_program.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec `_. @@ -86,6 +89,17 @@ The ONNX model may then be serialized into a `Protobuf file `__ to help users debug and improve their model using a GUI, such as diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index d4d5d2c71255d8..27c8e2c6240c18 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -284,7 +284,9 @@ def forward(self, x): This is required for models with large weights that exceed the ONNX file size limit (2GB). When False, the weights are saved in the ONNX file with the model architecture. dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to - :func:`torch.export.export` for more details. + :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True. + Only one parameter `dynamic_axes` or `dynamic_shapes` should be set + at the same time. report: Whether to generate a markdown report for the export process. verify: Whether to verify the exported model using ONNX Runtime. profile: Whether to profile the export process. @@ -364,6 +366,12 @@ def forward(self, x): else: from torch.onnx.utils import export + if dynamic_shapes: + raise ValueError( + "The exporter only supports dynamic shapes " + "through parameter dynamic_axes when dynamo=False." + ) + export( model, args, From f33c7ffd0845141f2bb879a7cad7ad84814a8e6f Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Mon, 9 Sep 2024 15:18:49 +0000 Subject: [PATCH 0629/1018] [MPS] Update decorator comments with issue ref (#135448) Updating the comments with references to better places for context now that the bugs have been identified. xref #135442 #135447 #134184 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135448 Approved by: https://github.com/ezyang --- test/test_nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 813197b4499b67..07500b65255bf7 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8583,7 +8583,7 @@ def test_ReplicationPad_empty(self, device, dtype): with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 6'): torch._C._nn.replication_pad3d(torch.randn([2]), padding=[]) - @expectedFailureMPS # TODO(hvaara): Investigate as possible bug. + @expectedFailureMPS # Correctness issue https://github.com/pytorch/pytorch/issues/135447 def test_ReplicationPad1d_large(self, device): shapes = ([2, 65736, 4], [65736, 2, 4]) pl, pr = 3, 4 @@ -8608,7 +8608,7 @@ def test_ReplicationPad1d_large(self, device): self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1)) self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1)) - @expectedFailureMPS # TODO(hvaara): Investigate as possible bug. + @expectedFailureMPS # Correctness issue https://github.com/pytorch/pytorch/issues/135447 def test_ReplicationPad2d_large(self, device): shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4]) pl, pr, pt, pb = 3, 4, 5, 6 From feae5d8392a85c5d9301fd8dadfeae466ff9bdd4 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 9 Sep 2024 15:31:36 +0000 Subject: [PATCH 0630/1018] Use float data type for Half var_sum in batchnorm stats updating on CPU (#126525) Using float data type for Half `var_sum` in batchnorm stats updating on CPU to avoid `var_sum` overflow since the representation range of Half is small. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126525 Approved by: https://github.com/jgong5, https://github.com/peterbell10 --- aten/src/ATen/native/Normalization.cpp | 14 ++++++++++---- test/test_nn.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index bedf6eda03cedf..796eac362b1246 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -209,7 +209,13 @@ std::tuple batch_norm_cpu_update_stats_template( bool all_contiguous = is_contiguous(input); constexpr bool mixed_type = !std::is_same_v; - const auto dtype = mixed_type ? kFloat : input.scalar_type(); + // Using float data type for Half _var_sum in batchnorm stats updating on CPU + // to avoid _var_sum overflow since the representation range of Half is small. + using opmath_t = std::conditional_t, at::opmath_type, param_t>; + auto dtype = mixed_type ? kFloat : input.scalar_type(); + if (dtype == kHalf) { + dtype = kFloat; + } auto save_mean_a = save_mean.accessor(); auto save_var_transform_a = save_var_transform.accessor(); @@ -220,9 +226,9 @@ std::tuple batch_norm_cpu_update_stats_template( if (all_contiguous) { auto _mean = at::empty({n_input}, input.options().dtype(dtype)); auto _var_sum = at::empty({n_input}, input.options().dtype(dtype)); - auto _mean_a = _mean.accessor(); - auto _var_sum_a = _var_sum.accessor(); - auto momentum_ = static_cast(momentum); + auto _mean_a = _mean.accessor(); + auto _var_sum_a = _var_sum.accessor(); + auto momentum_ = static_cast(momentum); batch_norm_cpu_collect_stats_stub(kCPU, _mean, _var_sum, input); diff --git a/test/test_nn.py b/test/test_nn.py index 07500b65255bf7..df8120eb84e26e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4829,6 +4829,28 @@ def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last mixed_dtype = False helper(self, nn.BatchNorm3d, shape, dtype, mixed_dtype, torch.channels_last_3d, precisons[dtype]) + def test_batchnorm_half_overflow(self): + def helper(self, mod, size, format): + channels = size[1] + input = torch.randn(size, dtype=torch.half, device='cpu', requires_grad=True) + input = input.contiguous(memory_format=format) + bn = mod(channels).cpu().to(torch.half) + out = bn(input) + + ref_bn = mod(channels).cpu().to(torch.float) + ref_bn.load_state_dict(bn.to(torch.float).state_dict()) + ref_out = ref_bn(input) + + self.assertFalse(out.isinf().any()) + self.assertFalse(out.isnan().any()) + self.assertEqual(out, ref_out) + + for format in [torch.contiguous_format, torch.channels_last]: + helper(self, nn.BatchNorm2d, (4, 80, 500, 500), format) + + for format in [torch.contiguous_format, torch.channels_last_3d]: + helper(self, nn.BatchNorm3d, (4, 80, 20, 100, 100), format) + @parametrize_test( 'bn_module', [ From a66ea7d48043dc89e9376ece7666f617c5789156 Mon Sep 17 00:00:00 2001 From: Victor Tao Date: Mon, 9 Sep 2024 16:24:56 +0000 Subject: [PATCH 0631/1018] [Inductor] Make static_input_idxs a set for faster lookup (#135314) `static_input_idxs` is only used for lookups. With large models, this is a large list. This takes over a millisecond in some cases. Profile before change: image Profile after change: gaps are smaller, 1ms speedup before launching the cuda graph image Pull Request resolved: https://github.com/pytorch/pytorch/pull/135314 Approved by: https://github.com/oulgen --- torch/_inductor/compile_fx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index db4c1e0eddfece..f68fb6dbef4a97 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -62,6 +62,7 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.monitor import _WaitCounter +from torch.utils._ordered_set import OrderedSet from .._dynamo.backends.common import aot_autograd from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] @@ -1037,7 +1038,7 @@ def cudagraphify_impl( Assumes inputs[static_input_idxs[i]] are always the same memory address """ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] - static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + static_input_idxs = OrderedSet(remove_unaligned_input_idxs(inputs, static_input_idxs)) # type: ignore[arg-type] copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type] assert isinstance(inputs, list) From 9b850f60d948af834db38de907c9006f90f57b54 Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 9 Sep 2024 16:49:04 +0000 Subject: [PATCH 0632/1018] [Release] Apply Release changes scripts after release 2.4 (#135495) Based on additional changes required for https://github.com/pytorch/pytorch/pull/128347 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135495 Approved by: https://github.com/kit1980 --- scripts/release/apply-release-changes.sh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/scripts/release/apply-release-changes.sh b/scripts/release/apply-release-changes.sh index c1f37a402928bd..49d3d0b5192f76 100755 --- a/scripts/release/apply-release-changes.sh +++ b/scripts/release/apply-release-changes.sh @@ -17,19 +17,25 @@ GIT_TOP_DIR=$(git rev-parse --show-toplevel) RELEASE_VERSION=${RELEASE_VERSION:-$(cut -d'.' -f1-2 "${GIT_TOP_DIR}/version.txt")} DRY_RUN=${DRY_RUN:-enabled} -# Change all GitHub Actions to reference the test-infra release branch -# as opposed to main. echo "Applying to workflows" for i in .github/workflows/*.yml; do sed -i -e s#@main#@"release/${RELEASE_VERSION}"# $i; done -# Change all checkout step in templates to not add ref to checkout echo "Applying to templates" for i in .github/templates/*.yml.j2; do sed -i 's#common.checkout(\(.*\))#common.checkout(\1, checkout_pr_head=False)#' $i; + sed -i -e s#main#"release/${RELEASE_VERSION}"# $i; done +echo "Applying to changes to linux binary builds" +for i in ".github/workflows/_binary-build-linux.yml" ".github/workflows/_binary-test-linux.yml"; do + sed -i "/github.event_name == 'pull_request'/d" $i; + sed -i -e s#main#"release/${RELEASE_VERSION}"# $i; +done + +sed -i -e "/generate_ci_workflows.py/i \\\t\t\t\texport RELEASE_VERSION_TAG=${RELEASE_VERSION}" .github/workflows/lint.yml + # Triton wheel echo "Triton Changes" sed -i -e s#-\ main#"-\ release\/${RELEASE_VERSION}"# .github/workflows/build-triton-wheel.yml From e81fe7ec1b07a5a3b906a7ebbd0ce93ba0cd8eab Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Sat, 7 Sep 2024 06:25:04 -0700 Subject: [PATCH 0633/1018] [AOTI] Add C shim for aten.mkldnn_rnn_layer in cpp wrapper (#134857) Summary: Support aten.mkldnn_rnn_layer in the ABI-compatible mode. Because aten.mkldnn_rnn_layer is an aten op, it is easier to add a C shim function for it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134857 Approved by: https://github.com/angelayi --- build_variables.bzl | 1 + test/inductor/test_cpu_cpp_wrapper.py | 1 - torch/_inductor/codegen/cpp_wrapper_cpu.py | 6 +- torch/_inductor/codegen/wrapper.py | 3 + torch/_inductor/mkldnn_ir.py | 11 +++- torch/csrc/inductor/aoti_torch/c/shim.h | 4 +- .../csrc/inductor/aoti_torch/c/shim_mkldnn.h | 39 ++++++++++++ .../csrc/inductor/aoti_torch/shim_mkldnn.cpp | 62 +++++++++++++++++++ 8 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h create mode 100644 torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp diff --git a/build_variables.bzl b/build_variables.bzl index 306e71b61f1d0f..11f5a5c9108e35 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -469,6 +469,7 @@ inductor_core_resources = [ "torch/csrc/inductor/aoti_runner/model_container_runner.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp", "torch/csrc/inductor/aoti_torch/shim_common.cpp", + "torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp", "torch/csrc/inductor/aoti_torch/tensor_converter.cpp", "torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp", "torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp", diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 4d18f64ffd6255..d02ca75cfa5c44 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -88,7 +88,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ) if config.abi_compatible: xfail_list = [ - "test_lstm_packed_change_input_sizes_cpu", *[ func for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 290787b99b9930..07d84ac249d5eb 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1269,7 +1269,11 @@ def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): def generate_extern_kernel_alloc(self, extern_kernel, args): if config.abi_compatible: - self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + if hasattr(extern_kernel, "outputs"): + # ir.ExternKernelAlloc may have outputs if it returns a tuple + self.generate_c_shim_fallback_kernel(extern_kernel, args) + else: + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) else: super().generate_extern_kernel_alloc(extern_kernel, args) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index ddc422f4c294cb..500532d432b39f 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -581,6 +581,9 @@ def write_header(self) -> None: strip=True, ) + def include_extra_header(self, header: str): + pass + def write_kernel_autotune_defs_header(self) -> None: self.kernel_autotune_defs.splice( f""" diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 12634c632b80a9..df9e475be2796d 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -1787,6 +1787,7 @@ def __init__( None, op_overload=torch.ops.aten.mkldnn_rnn_layer.default, ) + self.outputs: List[MultiOutput] = [] @classmethod def create( @@ -1856,11 +1857,14 @@ def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" return FlexibleLayout.contiguous_strides(output_shape) - output_sizes = [output_shape, hy_shape, cy_shape] + # C shim call requires all the outputs to be passed in, and thus the last + # dummy return value is added. + output_sizes = [output_shape, hy_shape, cy_shape, [1]] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), FlexibleLayout.contiguous_strides(hy_shape), FlexibleLayout.contiguous_strides(cy_shape), + [1], ] output_ir = [ MultiOutput( @@ -1877,5 +1881,10 @@ def get_strides_of_lstm_output(output_shape, batch_first): zip(output_sizes, output_strides) ) ] + packed.outputs = output_ir return output_ir + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + return super().codegen(wrapper) diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 49b16504e6ce2e..b470b5f10061b8 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -51,6 +51,8 @@ #endif // _WIN32 #endif // __GNUC__ +// The following files are implemented in a header-only way and are guarded by +// test/cpp/aoti_abi_check #include #include #include @@ -612,7 +614,7 @@ aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); -#endif +#endif // USE_CUDA // See `ProxyExecutor Design Note` in ir.py for more details AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function( diff --git a/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h b/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h new file mode 100644 index 00000000000000..63b35f62b3dc34 --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h @@ -0,0 +1,39 @@ +#ifndef AOTI_TORCH_SHIM_MKLDNN +#define AOTI_TORCH_SHIM_MKLDNN + +#include +#include + +#if AT_MKLDNN_ENABLED() +#ifdef __cplusplus +extern "C" { +#endif + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mkldnn_rnn_layer( + AtenTensorHandle input, + AtenTensorHandle weight0, + AtenTensorHandle weight1, + AtenTensorHandle weight2, + AtenTensorHandle weight3, + AtenTensorHandle hx_, + AtenTensorHandle cx_, + int32_t reverse, + const int64_t* batch_sizes, + int64_t batch_sizes_len_, + int64_t mode, + int64_t hidden_size, + int64_t num_layers, + int32_t has_biases, + int32_t bidirectional, + int32_t batch_first, + int32_t train, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1, + AtenTensorHandle* ret2, + AtenTensorHandle* ret3); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif // AT_MKLDNN_ENABLED() +#endif // AOTI_TORCH_SHIM_MKLDNN diff --git a/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp b/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp new file mode 100644 index 00000000000000..14f9fbf69459e6 --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp @@ -0,0 +1,62 @@ + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +using namespace torch::aot_inductor; + +#if AT_MKLDNN_ENABLED() + +AOTITorchError aoti_torch_cpu_mkldnn_rnn_layer( + AtenTensorHandle input, + AtenTensorHandle weight0, + AtenTensorHandle weight1, + AtenTensorHandle weight2, + AtenTensorHandle weight3, + AtenTensorHandle hx_, + AtenTensorHandle cx_, + int32_t reverse, + const int64_t* batch_sizes, + int64_t batch_sizes_len_, + int64_t mode, + int64_t hidden_size, + int64_t num_layers, + int32_t has_biases, + int32_t bidirectional, + int32_t batch_first, + int32_t train, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1, + AtenTensorHandle* ret2, + AtenTensorHandle* ret3) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = at::cpu::mkldnn_rnn_layer( + *tensor_handle_to_tensor_pointer(input), + *tensor_handle_to_tensor_pointer(weight0), + *tensor_handle_to_tensor_pointer(weight1), + *tensor_handle_to_tensor_pointer(weight2), + *tensor_handle_to_tensor_pointer(weight3), + *tensor_handle_to_tensor_pointer(hx_), + *tensor_handle_to_tensor_pointer(cx_), + reverse, + pointer_to_list(batch_sizes, batch_sizes_len_), + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train); + *ret0 = new_tensor_handle(std::move(std::get<0>(tmp_result))); + *ret1 = new_tensor_handle(std::move(std::get<1>(tmp_result))); + *ret2 = new_tensor_handle(std::move(std::get<2>(tmp_result))); + *ret3 = new_tensor_handle(std::move(std::get<3>(tmp_result))); + }); +} + +#endif // AT_MKLDNN_ENABLED() From 399931215749a8a0f77ea90db84c8f702b0dbcfd Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Sat, 7 Sep 2024 06:25:05 -0700 Subject: [PATCH 0634/1018] [AOTI] Fix assert_function call in cpu autotune template (#135086) Summary: In the ABI-compatible mode, assert_function should be AOTI_TORCH_CHECK. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135086 Approved by: https://github.com/chenyang78, https://github.com/angelayi ghstack dependencies: #134857 --- test/inductor/test_aot_inductor.py | 4 +++- test/inductor/test_cpu_cpp_wrapper.py | 13 ++++++------- torch/_inductor/codegen/cpp.py | 4 +--- torch/_inductor/codegen/cpp_micro_gemm.py | 10 +++++----- torch/_inductor/codegen/cpp_template.py | 9 ++++----- torch/_inductor/graph.py | 3 +-- 6 files changed, 20 insertions(+), 23 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 60cffb05d63e9f..cfac539c63b879 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3260,7 +3260,9 @@ def forward(self, values, offsets): model, example_inputs_list, dynamic_shapes=dynamic_shapes ) - @common_utils.parametrize("max_autotune", [False, True]) + # max_autotune is disabled due to https://github.com/pytorch/pytorch/issues/135106 + # @common_utils.parametrize("max_autotune", [False, True]) + @common_utils.parametrize("max_autotune", [False]) def test_misc_1(self, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index d02ca75cfa5c44..8ea3e55468019f 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -87,13 +87,7 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): } ) if config.abi_compatible: - xfail_list = [ - *[ - func - for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) - if func.startswith("test_linear_with_pointwise") - ], - ] + xfail_list = [] for test_name in xfail_list: test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( ("cpp_wrapper",), is_skip=False @@ -103,6 +97,11 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False) skip_list = [ "test_multihead_attention_cpu", + *[ + func + for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) + if func.startswith("test_linear_with_pointwise") + ], ] for test_name in skip_list: test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e1d54762d9e8fd..940bc5f73cd838 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2146,9 +2146,7 @@ def codegen_loops(self, code, worksharing): @property def assert_function(self) -> str: - if V.graph.aot_mode: - # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models - # compared with JIT Inductor which uses TORCH_CHECK + if config.abi_compatible: return "AOTI_TORCH_CHECK" else: return "TORCH_CHECK" diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index cb26c48fce53c6..2f0c3e78f56791 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -332,8 +332,8 @@ class CppMicroGemmFP32Vec(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { - TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); @@ -364,7 +364,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm): break; {%- endfor %} default: - {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); } } } @@ -509,8 +509,8 @@ class CppMicroGemmAMX(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { - TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 2bce16b2ad1538..a237924b9182d4 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -111,11 +111,10 @@ def make_kernel_render( def header(self) -> IndentedBuffer: res = IndentedBuffer() res.writeline(codecache.cpp_prefix()) - res.splice( - """ - #include "c10/util/Unroll.h" - """ - ) + # TODO: add c10::ForcedUnroll test to test_aoti_abi_check + res.splice("""#include """) + if config.abi_compatible: + res.splice("""#include """) enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ "linux", "win32", diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 3448ea6eb9620a..6a4f8ce741d6f4 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -854,7 +854,6 @@ def get_original_value_of_constant(self, name: str) -> torch.Tensor: def allocate_non_dup_const_name( self, name: Optional[str], data: Union[Tensor] ) -> str: - orig_name = name if not config.aot_inductor.use_runtime_constant_folding: for constant_name, value in self.constants.items(): if ( @@ -871,7 +870,7 @@ def allocate_non_dup_const_name( if name is None: name = f"constant{len(self.constants)}" - assert name is not None + orig_name = name if name[0].isdigit(): name = f"constant_{name}" name = self.qualify_name(name) From f3fec525efca4f727087f2b8c488b235ffe15d5f Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Mon, 9 Sep 2024 17:12:36 +0000 Subject: [PATCH 0635/1018] [MPS] Add regression test for `fft.fftfreq` (#135440) The issue reported in #135223 was already solved in #128393. This PR adds a regression test for it. Fixes #135223 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135440 Approved by: https://github.com/ezyang --- test/test_mps.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_mps.py b/test/test_mps.py index 55d05e502d3d24..eb99a6afdc6490 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2732,6 +2732,12 @@ def test_ifft(self): # Expecting the inverted to yield the original signal self.assertEqual(ifft_result, signal) + # Regression test for https://github.com/pytorch/pytorch/issues/135223 + def test_fftfreq(self): + freq_cpu = torch.fft.fftfreq(10**4, device='cpu') + freq_mps = torch.fft.fftfreq(10**4, device='mps') + self.assertEqual(freq_cpu, freq_mps) + def test_instance_norm(self): def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False): From db29d69b51ce73c5d2f51b2c4be601ab0d733bd4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 9 Sep 2024 17:33:01 +0000 Subject: [PATCH 0636/1018] Revert "[Inductor] Make static_input_idxs a set for faster lookup (#135314)" This reverts commit 011cae9570fb3c44b7f6f0c8004c470579ed21da. Reverted https://github.com/pytorch/pytorch/pull/135314 on behalf of https://github.com/ZainRizvi due to Lint is failing on this file in trunk. See [GH job link](https://github.com/pytorch/pytorch/actions/runs/10777258770/job/29885960050) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/011cae9570fb3c44b7f6f0c8004c470579ed21da) ([comment](https://github.com/pytorch/pytorch/pull/135314#issuecomment-2338678219)) --- torch/_inductor/compile_fx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index f68fb6dbef4a97..db4c1e0eddfece 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -62,7 +62,6 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.monitor import _WaitCounter -from torch.utils._ordered_set import OrderedSet from .._dynamo.backends.common import aot_autograd from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] @@ -1038,7 +1037,7 @@ def cudagraphify_impl( Assumes inputs[static_input_idxs[i]] are always the same memory address """ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] - static_input_idxs = OrderedSet(remove_unaligned_input_idxs(inputs, static_input_idxs)) # type: ignore[arg-type] + static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type] assert isinstance(inputs, list) From a3187c7d70c21dfaf3a5222968b3b71db4efb2cb Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 6 Sep 2024 16:44:30 -0700 Subject: [PATCH 0637/1018] [CP] Extend CP to support load-balancing shards (#132442) This PR extends the current ring attention to support load-balancing shards -- the context/sequence is divided into `2 * world_size` shards and each rank gets `rank` and `(world_size * 2 - rank - 1)` shards. The data re-shuffling is done in the `context_parallel` API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132442 Approved by: https://github.com/wconstab --- test/distributed/_tensor/test_attention.py | 49 +- .../tensor/experimental/_attention.py | 422 ++++++++++++++++-- 2 files changed, 417 insertions(+), 54 deletions(-) diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py index f5ae0d6a1179b6..238e551ab6a4aa 100644 --- a/test/distributed/_tensor/test_attention.py +++ b/test/distributed/_tensor/test_attention.py @@ -10,9 +10,10 @@ from torch.distributed._tensor.experimental._attention import ( _AttentionContextParallel, _CausalBehavior, - _context_parallel_buffers, + _cp_options, _is_causal_behavior, context_parallel, + context_parallel_unshard, ) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import parallelize_module @@ -24,6 +25,7 @@ ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( + decorateIf, instantiate_parametrized_tests, parametrize, run_tests, @@ -57,11 +59,19 @@ def world_size(self) -> int: "Does not support flash nor efficient attention", ) @with_comms + @decorateIf( + unittest.skip, lambda params: params["load_balance"] and not params["is_causal"] + ) @parametrize("is_causal", [True, False]) @parametrize("compiled", [True, False]) @parametrize("backend", backends) + @parametrize("load_balance", [True, False]) def test_ring_attention_sdpa( - self, is_causal: bool, compiled: bool, backend: SDPBackend + self, + is_causal: bool, + compiled: bool, + backend: SDPBackend, + load_balance: bool, ) -> None: device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size)) dtype = torch.bfloat16 @@ -79,6 +89,8 @@ def test_ring_attention_sdpa( # TODO: Fix this after we move `wait_tensor` to use `with_effect`. return + _cp_options.enable_load_balance = load_balance + q = torch.rand( (bs, nheads, self.world_size * query_tokens, dim), device=self.device_type, @@ -108,12 +120,6 @@ def test_ring_attention_sdpa( out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) out.sum().backward() - local_out, local_dq, local_dk, local_dv = _context_parallel_buffers( - device_mesh, - buffers=(out, q.grad, k.grad, v.grad), - buffer_seq_dims=(2, 2, 2, 2), - ) - cp_q = q.clone().detach() cp_k = k.clone().detach() cp_v = v.clone().detach() @@ -154,21 +160,26 @@ def test_ring_attention_sdpa( # Due to numerical error, we need to choose different atol for different # attention kernels + cp_out, cp_dq, cp_dk, cp_dv = context_parallel_unshard( + device_mesh, + [cp_out, cp_q.grad, cp_k.grad, cp_v.grad], + [2, 2, 2, 2], + ) atol = ( 1e-08 if backend == SDPBackend.EFFICIENT_ATTENTION else 1e-3 * self.world_size ) - self.assertTrue(torch.allclose(local_out, cp_out, atol=atol)) + self.assertTrue(torch.allclose(out, cp_out, atol=atol)) atol = ( 2e-06 if backend == SDPBackend.EFFICIENT_ATTENTION else 8e-3 * self.world_size ) - self.assertTrue(torch.allclose(local_dq, cp_q.grad, atol=atol)) - self.assertTrue(torch.allclose(local_dk, cp_k.grad, atol=atol)) - self.assertTrue(torch.allclose(local_dv, cp_v.grad, atol=atol)) + self.assertTrue(torch.allclose(q.grad, cp_dq, atol=atol)) + self.assertTrue(torch.allclose(k.grad, cp_dk, atol=atol)) + self.assertTrue(torch.allclose(v.grad, cp_dv, atol=atol)) cp_q.grad = None cp_k.grad = None @@ -178,6 +189,7 @@ def test_ring_attention_sdpa( cp_v.requires_grad = False def test_is_causal_behavior(self) -> None: + _cp_options.enable_load_balance = False self.assertEqual( _is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False), _CausalBehavior.NOT_IS_CAUSAL, @@ -194,6 +206,18 @@ def test_is_causal_behavior(self) -> None: behavior, ) + _cp_options.enable_load_balance = True + ranks = [ + [_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL], + [_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL], + ] + for rank, iters in enumerate(ranks): + for i, behavior in enumerate(iters): + self.assertEqual( + _is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True), + behavior, + ) + @skip_if_lt_x_gpu(2) @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" @@ -202,6 +226,7 @@ def test_is_causal_behavior(self) -> None: @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) @parametrize("is_causal", [True, False]) def test_ring_attention_native_transformer(self, is_causal: bool) -> None: + _cp_options.enable_load_balance = is_causal device_mesh = DeviceMesh( self.device_type, torch.arange(0, self.world_size), diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index b5c05232a57349..a00c92d9bba16c 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -5,6 +5,8 @@ import logging import types import weakref +from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from typing import ( Any, @@ -34,10 +36,18 @@ aten = torch.ops.aten logger = logging.getLogger(__name__) -# Whether to upcast parameters and gradients to float32 to avoid accumulation -# errors. It is likely this is always True but we currently keep this variable -# for the experimental purpose. -_convert_to_f32 = True + + +@dataclass +class _ContextParallelOptions: + # Whether to upcast parameters and gradients to float32 to avoid accumulation + # errors. It is likely this is always True but we currently keep this variable + # for the experimental purpose. + convert_to_f32: bool = True + enable_load_balance = True + + +_cp_options = _ContextParallelOptions() class _CausalBehavior(Enum): @@ -60,7 +70,7 @@ def _is_causal_behavior( return _CausalBehavior.IS_CAUSAL source_rank = (rank - i) % world_size - if source_rank < rank: + if source_rank < rank or _cp_options.enable_load_balance: return _CausalBehavior.NOT_IS_CAUSAL else: return _CausalBehavior.SKIP @@ -76,31 +86,91 @@ def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: return tensor +def _partial_update( + original: torch.Tensor, + new: torch.Tensor, + dim: int, + n_chunks: int, + idx: int, + add: bool, +) -> torch.Tensor: + """ + This API partially update a chunk of ``original`` tensor. The ``original`` + tensor will be first chunked along ``dim`` dimension then the ``idx`` chunk + will be updated with ``new``. If ``add`` is True, the chunk will be added + with ``new``, otherwise the chunk with be replaced by ``add``. + + The result is a tensor that is the same size as ``original``. + """ + chunks = list(original.chunk(n_chunks, dim=dim)) + assert chunks[idx].shape == new.shape, (original.shape, new.shape, idx) + if add: + chunks[idx] += new + else: + chunks[idx] = new + return torch.cat(chunks, dim=dim) + + class _SDPAMerger: """A class to help to merge the local SDPA result.""" - def __init__(self, convert_to_f32: bool): + def __init__(self, convert_to_f32: bool, seq_dim: int): + self._seq_dim = seq_dim self._out: Optional[torch.Tensor] = None self._lse: Optional[torch.Tensor] = None self._convert_to_f32 = convert_to_f32 self._out_dtype = torch.float32 self._lse_dtype = torch.float32 - def _merge_one(self, block_out: torch.Tensor, block_lse: torch.Tensor) -> None: + def _merge_one( + self, block_out: torch.Tensor, block_lse: torch.Tensor, partial: bool + ) -> None: block_lse = block_lse.unsqueeze(dim=-1) if self._lse is None: self._lse = block_lse self._out = block_out else: + ROUND_ROBIN_CYCLE = 2 + assert self._lse is not None + assert self._out is not None + lse = ( + self._lse.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._lse + ) + out = ( + self._out.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._out + ) + # The algorithm from # github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 # gives a relatively stable result. - self._out = self._out - F.sigmoid(block_lse - self._lse) * ( - self._out - block_out - ) - self._lse = self._lse - F.logsigmoid(self._lse - block_lse) + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + if partial: + self._lse = _partial_update( + self._lse, + lse, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + self._out = _partial_update( + self._out, + out, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + else: + self._lse = lse + self._out = out - def step(self, out: torch.Tensor, lse: torch.Tensor) -> None: + def step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool) -> None: self._out_dtype = out.dtype self._lse_dtype = lse.dtype @@ -108,7 +178,7 @@ def step(self, out: torch.Tensor, lse: torch.Tensor) -> None: out = out.to(torch.float32) lse = lse.to(torch.float32) - self._merge_one(out, lse) + self._merge_one(out, lse, partial) def results(self) -> Tuple[torch.Tensor, torch.Tensor]: assert self._out is not None @@ -131,8 +201,10 @@ def _scaled_dot_product_ring_flash_attention( if return_debug_mask: raise NotImplementedError("return_debug_mask is not supported yet") + seq_dim = 2 return _templated_ring_attention( mesh, + seq_dim, aten._scaled_dot_product_flash_attention, query=query, key=key, @@ -160,8 +232,10 @@ def _scaled_dot_product_ring_efficient_attention( if not compute_log_sumexp: raise NotImplementedError("compute_log_sumexp must be set") + seq_dim = 2 return _templated_ring_attention( mesh, + seq_dim, aten._scaled_dot_product_efficient_attention, query=query, key=key, @@ -188,6 +262,7 @@ def __call__( def _ring_rotate( block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool ) -> torch.Tensor: + block = block.contiguous() size = dist.get_world_size(pg) dsts = ( list(range(1, size)) + [0] @@ -199,6 +274,7 @@ def _ring_rotate( def _templated_ring_attention( mesh: DeviceMesh, + seq_dim: int, op: _AttentionOp, query: torch.Tensor, key: torch.Tensor, @@ -209,6 +285,61 @@ def _templated_ring_attention( """ This is a generalized ring attention implementation that can support multiple attention ops. + Note [Context parallelism load balance algorithm for causal masking] + ===================== + This explanation uses an example to illustrate the CP algorithm with causal + masking. + + Consider a scenario where the sequence length of q, k, and v is 4 (e.g., + q = (q0, q1, q2, q3)), and there are two ranks. For simplicity, we will discuss + only q and k, as v follows the same pattern as k. + + The diagram below represents a complete QK^T operation without parallelism. + The `****` entries indicate that the result is not required due to causal + masking (e.g., q0k1 is marked as `****`). + + +----+------------------------+ + | | k0 k1 k2 k3 | + +----+------------------------+ + | q0 | q0k0, ****, ****, **** | + | q1 | q1k0, q1k1, ****, **** | + | q2 | q2k0, q2k1, q2k2, **** | + | q3 | q3k0, q3k1, q3k2, q3k3 | + +----+------------------------+ + + ### No Load Balance: + + In this scenario, each rank owns a local chunk of q, k, and v, with each chunk + containing two elements. Rank0 is responsible for managing (q0, q1) and (k0, k1), + while rank1 manages (q2, q3) and (k2, k3). + + First Iteration: Both rank0 and rank1 perform SDPA with their local qkv pairs. + Causal masking is enabled as some results are not required (e.g., q0k1). + + Second Iteration: Local queries remain the same, but local kv pairs are exchanged. + Rank0 now has (q0, q1) and (k2, k3); rank1 has (q2, q3) and (k0, k1). Rank0 performs + no computation, while rank1 computes locally without causal masking since all results + (q2k0, q2k1, q3k0, q3k1) are needed. + + ### Round-robin Load Balance: + + In this setup, each rank owns two local chunks of q, k, and v, with each chunk + containing one element. Rank0 manages (q0, q3) and (k0, k3); Rank1 manages (q1, q2) + and (k1, k2). Although the local chunks are not consecutive, they are concatenated to + enable SDPA to be performed in a single call for each step. Consequently, the chunk() + function may be required to prepare the correct q, k, and v configurations. + + First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the + no-load-balance case. This iteration corresponds to the `if` of the + (`if, `elif`, `else`) in the implemementation. + + Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and + (k0, k3). For rank0, no computation is needed for q0. However, computations for + q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the + `else` of the (`if`, `elif`, `else`) in the implemementation. + For rank1, k0 is not needed for q1 and q2, so only k3 is used for SDPA. This + corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. + Parameters ---------- op: @@ -246,20 +377,21 @@ def _templated_ring_attention( key = key.contiguous() value = value.contiguous() - sdpa_merger = _SDPAMerger(_convert_to_f32) + sdpa_merger = _SDPAMerger(_cp_options.convert_to_f32, seq_dim=seq_dim) rest: List[Any] out: torch.Tensor logsumexp: torch.Tensor for i in range(size): - # overlap communication with compute if next_kv is not None: + # Wait for the kv from the (cp_rank - 1) rank. next_kv = _maybe_wait(next_kv) key = next_kv[: key.numel()].reshape(key.shape) value = next_kv[key.numel() :].reshape(value.shape) if i < (size - 1): + # Send the k, v to the next rank next_kv = torch.cat([key.flatten(), value.flatten()]) next_kv = _ring_rotate(next_kv, pg, send_to_next=True) @@ -267,16 +399,44 @@ def _templated_ring_attention( rank=rank, world_size=size, i=i, is_causal=is_causal ) - if is_causal_behavior != _CausalBehavior.SKIP: - out, logsumexp, *rest = op( + # For a detailed understanding of the load balancing algorithm, see + # Note [Context parallelism load balance algorithm for causal masking] + if is_causal_behavior == _CausalBehavior.SKIP: + # If i > rank and load balancing is not turned on. + continue + + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # When local balance is enabled, we still need to do SDPA with + # the both local chunks of q, k, v for the first iteration. + q, k, v, partial = (query, key, value, False) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SPDA, with only the first local chunk of the k, v. + # Note that q, k, v, each contains two local chunks. + ROUND_ROBIN_CYCLE = 2 + q, k, v, partial = ( query, - key, - value, - is_causal=is_causal_behavior.value, - **kwargs, + key.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + value.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + False, ) - - sdpa_merger.step(out, logsumexp) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SPDA with only the second half of the q, and update + # only the the second part of logsumexp. So partial is True. + # Note that q, k, v, each contains two chunks. + q, k, v, partial = query.chunk(2, dim=2)[1], key, value, True + + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. + out, logsumexp, *rest = op( + q, + k, + v, + is_causal=is_causal_behavior.value, + **kwargs, + ) + sdpa_merger.step(out, logsumexp, partial) return *sdpa_merger.results(), *rest @@ -358,6 +518,7 @@ def _sdpa_backward_handler( def _templated_ring_attention_backward( mesh: DeviceMesh, + seq_dim: int, op: _AttentionOp, grad_out: torch.Tensor, grad_out_name: str, @@ -369,6 +530,7 @@ def _templated_ring_attention_backward( is_causal: bool, **kwargs: Any, ) -> Tuple[torch.Tensor, ...]: + """This API implements the backward of the ring attention.""" pg = mesh.get_group() assert isinstance(pg, dist.ProcessGroup), "must be single dimension" rank = dist.get_rank(pg) @@ -378,7 +540,7 @@ def _templated_ring_attention_backward( rest: List[Any] grad_query_, grad_key_, grad_value_ = None, None, None - accum_dtype = torch.float32 if _convert_to_f32 else query.dtype + accum_dtype = torch.float32 if _cp_options.convert_to_f32 else query.dtype grad_query = torch.zeros_like(query, dtype=accum_dtype) grad_key = torch.zeros_like(key, dtype=accum_dtype) grad_value = torch.zeros_like(value, dtype=accum_dtype) @@ -387,6 +549,7 @@ def _templated_ring_attention_backward( value = value.contiguous() for i in range(size): if next_kv is not None: + # Wait for the kv from the (cp_rank - 1) rank. buffer = _maybe_wait(next_kv) pointer = 0 key = buffer[pointer : pointer + key.numel()].reshape(key.shape) @@ -395,6 +558,7 @@ def _templated_ring_attention_backward( pointer += value.numel() if i != size - 1: + # Send the kv to the next rank. next_kv = torch.cat([key.flatten(), value.flatten()]) next_kv = _ring_rotate(next_kv, pg, send_to_next=True) @@ -403,13 +567,45 @@ def _templated_ring_attention_backward( ) if is_causal_behavior != _CausalBehavior.SKIP: - kwargs[grad_out_name] = grad_out + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # We need to do SDPA with the full local q, k, v. + q, k, v, out_, dout, lse = (query, key, value, out, grad_out, logsumexp) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SPDA with only the first half of the k, v. + # Note that q, k, v, each contains two chunks. + q, k, v, out_, dout, lse = ( + query, + key.chunk(2, dim=seq_dim)[0], + value.chunk(2, dim=seq_dim)[0], + out, + grad_out, + logsumexp, + ) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SPDA with only the second half of the q + # Note that q, k, v, each contains two chunks. + q, k, v, out_, dout, lse = ( + query.chunk(2, dim=seq_dim)[1], + key, + value, + out.chunk(2, dim=seq_dim)[1], + grad_out.chunk(2, dim=seq_dim)[1], + # Need to make logsumexp contiguous, otherwise there will + # be numerical error. + logsumexp.chunk(2, dim=seq_dim)[1].contiguous(), + ) + + kwargs[grad_out_name] = dout + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. grad_query_, grad_key_, grad_value_, *rest = op( - query=query, - key=key, - value=value, - out=out, - logsumexp=logsumexp, + query=q, + key=k, + value=v, + out=out_, + logsumexp=lse, is_causal=is_causal_behavior.value, **kwargs, ) @@ -418,10 +614,14 @@ def _templated_ring_attention_backward( grad_key_ = torch.zeros_like(key, dtype=accum_dtype) grad_value_ = torch.zeros_like(value, dtype=accum_dtype) - # Get the grad key and grad value for the i round. - if i > 0: + ROUND_ROBIN_CYCLE = 2 + if i == 0: + grad_key += grad_key_ + grad_value += grad_value_ + else: pointer = 0 assert next_grad_kv is not None + # Wait for the kv gradient from (cp_rank - 1) rank. next_grad_kv = _maybe_wait(next_grad_kv) grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape( grad_key.shape @@ -431,13 +631,42 @@ def _templated_ring_attention_backward( grad_value.shape ) - grad_key += grad_key_ - grad_value += grad_value_ + if i <= rank and _cp_options.enable_load_balance: + grad_key = _partial_update( + grad_key, + grad_key_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + grad_value = _partial_update( + grad_value, + grad_value_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + else: + grad_key += grad_key_ + grad_value += grad_value_ - # Send the key, value, grad key, and grad value to the next rank. next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()]) + # Send the grad key, and grad value to the next rank. next_grad_kv = _ring_rotate(next_grad_kv, pg, send_to_next=True) - grad_query += grad_query_ + + if i <= rank or not _cp_options.enable_load_balance: + grad_query += grad_query_ + else: + grad_query = _partial_update( + grad_query, + grad_query_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=True, + ) assert next_grad_kv is not None assert grad_key_ is not None @@ -473,8 +702,10 @@ def _scaled_dot_product_ring_flash_attention_backward( *, scale: Optional[float] = None, ) -> Tuple[torch.Tensor, ...]: + seq_dim = 2 return _templated_ring_attention_backward( mesh, + seq_dim, aten._scaled_dot_product_flash_attention_backward.default, grad_out=grad_out, grad_out_name="grad_out", @@ -512,8 +743,10 @@ def _scaled_dot_product_ring_efficient_attention_backward( *, scale: Optional[float] = None, ) -> Tuple[torch.Tensor, ...]: + seq_dim = 2 return _templated_ring_attention_backward( mesh, + seq_dim, aten._scaled_dot_product_efficient_attention_backward.default, grad_out=grad_out, grad_out_name="grad_out_", @@ -779,10 +1012,94 @@ def attention_output_fn(mesh: DeviceMesh, outputs: Any) -> Any: _restore_function(F.scaled_dot_product_attention, F) -def _get_sequence_shard( - buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int -) -> torch.Tensor: - return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()] +class _LoadBalancer(ABC): + @classmethod + @abstractmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + ... + + @classmethod + @abstractmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + ... + + +class _SequentialSharder(_LoadBalancer): + """ + This load balancer chunks the buffer into cp_world_size and rank0 gets + 0th shard, rank1 gets 1st shard, ... + So this doesn't have any load balancing effect when using the causal masking. + """ + + @classmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert buffer.size()[seq_dim] % mesh.size() == 0 + return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()] + + @classmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + buffer = buffer.contiguous() + all_buffers = [torch.empty_like(buffer) for _ in range(mesh.size())] + ft_c.all_gather_inplace(all_buffers, buffer, mesh) + return torch.cat(all_buffers, dim=seq_dim) + + +class _RoundRobinLoadBalancer(_LoadBalancer): + """ + This load balancer chunk the buffer into cp_world_size * ROUND_ROBIN_CYCLE + shards, and uses a round robin approach to achieve load balancing. + Since ROUND_ROBIN_CYCLE being 2 will achieve perfect load balancing for + causal masking, we assume ROUND_ROBIN_CYCLE is always 2 to simplify the + implementation. + """ + + ROUND_ROBIN_CYCLE = 2 + + @classmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert ( + cls.ROUND_ROBIN_CYCLE == 2 + ), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + cp_world_size = mesh.size() + cp_rank = mesh.get_local_rank() + assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0 + chunks = buffer.chunk(cp_world_size * 2, dim=seq_dim) + return torch.cat( + (chunks[cp_rank], chunks[cp_world_size * 2 - cp_rank - 1]), + dim=seq_dim, + ) + + @classmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert ( + cls.ROUND_ROBIN_CYCLE == 2 + ), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + buffer = buffer.contiguous() + cp_world_size = mesh.size() + cp_rank = mesh.get_local_rank() + + all_buffers = [torch.empty_like(buffer) for _ in range(cp_world_size)] + ft_c.all_gather_inplace(all_buffers, buffer, mesh) + sliced_buffers = [sb for b in all_buffers for sb in b.chunk(2, dim=seq_dim)] + ordered_buffers = list(sliced_buffers) + for i, b in enumerate(sliced_buffers): + if i % 2 == 0: + ordered_buffers[i // 2] = b + else: + ordered_buffers[cp_world_size * 2 - (i // 2) - 1] = b + return torch.cat(ordered_buffers, dim=seq_dim) def _context_parallel_buffers( @@ -792,8 +1109,13 @@ def _context_parallel_buffers( ) -> List[torch.Tensor]: """Shard the buffers along the sequence dimensions according to CP rules.""" new_buffers = [] + sharder = ( + _RoundRobinLoadBalancer + if _cp_options.enable_load_balance + else _SequentialSharder + ) for buffer, seq_dim in zip(buffers, buffer_seq_dims): - new_buffers.append(_get_sequence_shard(buffer, mesh, seq_dim)) + new_buffers.append(sharder.shard(buffer, mesh, seq_dim)) return new_buffers @@ -851,7 +1173,6 @@ def context_parallel( raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] - chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims) for buffer, chunk in zip(buffers, chunks): chunk = chunk.clone() @@ -865,3 +1186,20 @@ def context_parallel( if original_buffer is not None: buffer.resize_(original_buffer.shape) buffer.copy_(original_buffer) + + +@torch.no_grad() +def context_parallel_unshard( + mesh: DeviceMesh, + buffers: List[torch.Tensor], + seq_dims: List[int], +) -> List[torch.Tensor]: + """ + Unshard the tensors (e.g., output) that are sharded due to context parallelism. + """ + sharder = ( + _RoundRobinLoadBalancer + if _cp_options.enable_load_balance + else _SequentialSharder + ) + return [sharder.unshard(b, mesh, dim) for b, dim in zip(buffers, seq_dims)] From 0d5ac7c5bf0c7c6e11247ce12685e884e211b0d5 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 9 Sep 2024 18:10:37 +0000 Subject: [PATCH 0638/1018] [ONNX] Update fake mode usage in onnx docs (#135512) Update fake mode usage in onnx docs Pull Request resolved: https://github.com/pytorch/pytorch/pull/135512 Approved by: https://github.com/justinchuby --- torch/onnx/_internal/_exporter_legacy.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 71b828b9bceb7a..f16cb09d2b2506 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -419,10 +419,15 @@ def enable_fake_mode(): ... arg1, ... export_options=export_options ... ) + ... onnx_program.apply_weights(MyModel().state_dict()) >>> # Saving model WITHOUT initializers - >>> onnx_program.save("my_model_without_initializers.onnx") + >>> onnx_program.save( + ... "my_model_without_initializers.onnx", + ... include_initializers=False, + ... keep_initializers_as_inputs=True, + ... ) >>> # Saving model WITH initializers - >>> onnx_program.save("my_model_with_initializers.onnx", model_state=MyModel().state_dict()) + >>> onnx_program.save("my_model_with_initializers.onnx") .. warning:: This API is experimental and is *NOT* backward-compatible. From b3ad5159ee714108748a926435c91c7e251ebcc9 Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 9 Sep 2024 07:30:54 -0700 Subject: [PATCH 0639/1018] Fix bugs blocking flipping the default layout constraint for custom ops (#135391) Fixes two things: - For regular PyTorch ops, the default layout constraint tag is always flexible_layout. This was a bug with #135238 - Mark the new quantized _wrapped_linear_prepack ops as flexible_layout. The metas for these are incorrect, I didn't want to fix them (and changing the default requires the metas actually be correct). Test Plan: - The next PR up in the stack. The PRs are split because the next one is riskier. foo Pull Request resolved: https://github.com/pytorch/pytorch/pull/135391 Approved by: https://github.com/albanD --- aten/src/ATen/native/quantized/library.cpp | 4 ++-- torch/_inductor/lowering.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 7780953a94dc3f..e3b95687306e8b 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -251,8 +251,8 @@ TORCH_LIBRARY(_quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y")); - m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor"), {at::Tag::flexible_layout}); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"), {at::Tag::flexible_layout}); } TORCH_LIBRARY(onednn, m) { diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index adf921913edfcd..a1eade0aab773c 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -124,6 +124,8 @@ def get_layout_constraint_tag(fn): for tag in tags_by_priority: if tag in fn.tags: return tag + if torch._library.utils.is_builtin(fn): + return torch._C.Tag.flexible_layout return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) From 567e33b22e12c5381009cf80bf820df8d8e2f7c1 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 9 Sep 2024 19:07:48 +0000 Subject: [PATCH 0640/1018] Don't try to tag s390x docker images (#135509) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135509 Approved by: https://github.com/atalman --- .github/scripts/tag_docker_images_for_release.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/scripts/tag_docker_images_for_release.py b/.github/scripts/tag_docker_images_for_release.py index 62a4f21ae7dd20..73dfda03b1c82c 100644 --- a/.github/scripts/tag_docker_images_for_release.py +++ b/.github/scripts/tag_docker_images_for_release.py @@ -51,6 +51,8 @@ def main() -> None: for platform_image in platform_images: # type: ignore[attr-defined] for arch in platform_image.keys(): # type: ignore[attr-defined] + if arch == "cpu-s390x": + continue tag_image( platform_image[arch], # type: ignore[index] default_tag, From 5c4f381313465d8c183c78156bc71163e7e818cd Mon Sep 17 00:00:00 2001 From: imShZh Date: Mon, 9 Sep 2024 19:32:16 +0000 Subject: [PATCH 0641/1018] Fix symbolic number's type and tensor's dtype mismatch bug in Tensor ctor (#135433) Fixes #135432 In the current implementation, if we try to store a symbolic number in Tensor's constructor, it assumes that the tensor's dtype and the symbolic number's type are matched, which is not the case. In other words, if we try to store a `SymInt`, current implementation assumes tensor's dtype is `torch.int32`, `torch.int64` or something. And if we try to store a `SymFloat`, it assumes tensor's dtype is `torch.float32` or `torch.float64`. However, the tensor's dtype could also be `torch.float32` or something else when we try to store `SymInt`, which would be wrong. This PR stores symbolic numbers by tensor's scalar type by wrapping `SymInt` and `SymFoat`'s guarded number into a PyObject. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135433 Approved by: https://github.com/ezyang --- test/test_dynamic_shapes.py | 28 ++++++++++++++++++++++++++++ torch/csrc/utils/tensor_new.cpp | 30 +++--------------------------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 8b1c1c109ddba4..66f6fdeb7237a5 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -33,6 +33,7 @@ StatelessSymbolicContext, statically_known_true, ) +from torch.testing._internal.common_dtype import all_types_and from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -1068,6 +1069,33 @@ def test_ephemeral_source_unified_with_non_ephemeral_source(self): self.assertEqual(x.stride(), y.stride()) self.assertEqual(x.storage_offset(), y.storage_offset()) + def test_tensor_factory_with_symint(self): + args = list(range(0, 3)) + expected = torch.tensor(args) + + shape_env = ShapeEnv() + sym_args = [create_symint(shape_env, i) for i in args] + + # test tensor factories + for dt in all_types_and(torch.half, torch.bfloat16): + res = torch.tensor(sym_args, dtype=dt) + self.assertEqual(res, expected, exact_dtype=False) + + # test legacy tensor factories + legacy_ctors = [ + torch.Tensor, + torch.LongTensor, + torch.DoubleTensor, + torch.FloatTensor, + torch.IntTensor, + torch.ShortTensor, + torch.HalfTensor, + torch.ByteTensor, + ] + for Tensor in legacy_ctors: + res = Tensor(sym_args) + self.assertEqual(res, expected, exact_dtype=False) + @skipIfTorchDynamo( "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 10fbcfc41b2d74..de58b1965492da 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -219,37 +219,13 @@ void recursive_store( auto new_obj = py::reinterpret_borrow(obj); auto val = new_obj.cast(); const double double_val = val.guard_float(__FILE__, __LINE__); - switch (elementSize) { - case 8: - *reinterpret_cast(data) = double_val; - break; - case 4: - *reinterpret_cast(data) = static_cast(double_val); - break; - } - return; + obj = Py_BuildValue("d", double_val); } if (is_symint) { auto new_obj = py::reinterpret_borrow(obj); auto val = new_obj.cast(); - const auto int_val = val.guard_int(__FILE__, __LINE__); - switch (elementSize) { - case 8: - *reinterpret_cast(data) = int_val; - break; - case 4: - *reinterpret_cast(data) = static_cast(int_val); - break; - case 2: - *reinterpret_cast(data) = static_cast(int_val); - break; - case 1: - *reinterpret_cast(data) = static_cast(int_val); - break; - default: - TORCH_CHECK(false, "Unexpected elementSize ", elementSize); - } - return; + const int64_t int_val = val.guard_int(__FILE__, __LINE__); + obj = Py_BuildValue("i", int_val); } torch::utils::store_scalar(data, scalarType, obj); return; From 9feb9dcf84bf8d0843bb332ff0c7fc257ce1cfbf Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Mon, 9 Sep 2024 19:35:57 +0000 Subject: [PATCH 0642/1018] [split build] move periodic split builds into own concurrency group (#135510) To avoid nightly workflows cancelling each other Pull Request resolved: https://github.com/pytorch/pytorch/pull/135510 Approved by: https://github.com/clee2000, https://github.com/huydhn, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .github/scripts/generate_ci_workflows.py | 5 +- ...nerated-linux-binary-manywheel-split-main} | 18 ++-- ...ated-linux-binary-manywheel-split-nightly} | 90 +++++++++---------- 3 files changed, 58 insertions(+), 55 deletions(-) rename .github/workflows/{generated-linux-binary-manywheel-main-split.yml => generated-linux-binary-manywheel-split-main} (93%) rename .github/workflows/{generated-linux-binary-manywheel-nightly-split.yml => generated-linux-binary-manywheel-split-nightly} (96%) diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 9f3536e2ede50f..302d8263846440 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -70,6 +70,9 @@ def __post_init__(self) -> None: ) else: self.build_environment = f"{self.os}-binary-{self.package_type}" + if self.use_split_build: + # added to distinguish concurrency groups + self.build_environment += "-split" def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: output_file_path = ( @@ -79,7 +82,7 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: if self.use_split_build: output_file_path = ( GITHUB_DIR - / f"workflows/generated-{self.build_environment}-{self.branches}-split.yml" + / f"workflows/generated-{self.build_environment}-{self.branches}" ) with open(output_file_path, "w") as output_file: GENERATED = "generated" # Note that please keep the variable GENERATED otherwise phabricator will hide the whole file diff --git a/.github/workflows/generated-linux-binary-manywheel-main-split.yml b/.github/workflows/generated-linux-binary-manywheel-split-main similarity index 93% rename from .github/workflows/generated-linux-binary-manywheel-main-split.yml rename to .github/workflows/generated-linux-binary-manywheel-split-main index 61d6df214c5a2c..9c2456e4632c75 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main-split.yml +++ b/.github/workflows/generated-linux-binary-manywheel-split-main @@ -2,7 +2,7 @@ # Template is at: .github/templates/linux_binary_build_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: linux-binary-manywheel +name: linux-binary-manywheel-split on: @@ -19,7 +19,7 @@ env: ANACONDA_USER: pytorch AWS_DEFAULT_REGION: us-east-1 BINARY_ENV_FILE: /tmp/env - BUILD_ENVIRONMENT: linux-binary-manywheel + BUILD_ENVIRONMENT: linux-binary-manywheel-split BUILDER_ROOT: /builder GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} @@ -28,7 +28,7 @@ env: SHA1: ${{ github.event.pull_request.head.sha || github.sha }} SKIP_ALL_TESTS: 0 concurrency: - group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + group: linux-binary-manywheel-split-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: @@ -58,7 +58,7 @@ jobs: DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -81,7 +81,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -105,7 +105,7 @@ jobs: DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -128,7 +128,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -152,7 +152,7 @@ jobs: DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -175,7 +175,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly-split.yml b/.github/workflows/generated-linux-binary-manywheel-split-nightly similarity index 96% rename from .github/workflows/generated-linux-binary-manywheel-nightly-split.yml rename to .github/workflows/generated-linux-binary-manywheel-split-nightly index 8e57081422f1a9..0d29acf56a0835 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly-split.yml +++ b/.github/workflows/generated-linux-binary-manywheel-split-nightly @@ -2,7 +2,7 @@ # Template is at: .github/templates/linux_binary_build_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: linux-binary-manywheel +name: linux-binary-manywheel-split on: @@ -24,7 +24,7 @@ env: ANACONDA_USER: pytorch AWS_DEFAULT_REGION: us-east-1 BINARY_ENV_FILE: /tmp/env - BUILD_ENVIRONMENT: linux-binary-manywheel + BUILD_ENVIRONMENT: linux-binary-manywheel-split BUILDER_ROOT: /builder GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} @@ -33,7 +33,7 @@ env: SHA1: ${{ github.event.pull_request.head.sha || github.sha }} SKIP_ALL_TESTS: 0 concurrency: - group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + group: linux-binary-manywheel-split-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: @@ -63,7 +63,7 @@ jobs: DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -86,7 +86,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -134,7 +134,7 @@ jobs: DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -157,7 +157,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -205,7 +205,7 @@ jobs: DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -228,7 +228,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -275,7 +275,7 @@ jobs: DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-test: # Testing @@ -296,7 +296,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: @@ -343,7 +343,7 @@ jobs: DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -366,7 +366,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -414,7 +414,7 @@ jobs: DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -437,7 +437,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -485,7 +485,7 @@ jobs: DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_1-full-test: # Testing @@ -507,7 +507,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -555,7 +555,7 @@ jobs: DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -578,7 +578,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -625,7 +625,7 @@ jobs: DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-test: # Testing @@ -646,7 +646,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: @@ -693,7 +693,7 @@ jobs: DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -716,7 +716,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -764,7 +764,7 @@ jobs: DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -787,7 +787,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -835,7 +835,7 @@ jobs: DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -858,7 +858,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -905,7 +905,7 @@ jobs: DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-test: # Testing @@ -926,7 +926,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: @@ -973,7 +973,7 @@ jobs: DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -996,7 +996,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -1044,7 +1044,7 @@ jobs: DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1067,7 +1067,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -1115,7 +1115,7 @@ jobs: DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1138,7 +1138,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -1185,7 +1185,7 @@ jobs: DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-test: # Testing @@ -1206,7 +1206,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: @@ -1253,7 +1253,7 @@ jobs: DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1276,7 +1276,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda11_8 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -1324,7 +1324,7 @@ jobs: DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1347,7 +1347,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_1 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -1395,7 +1395,7 @@ jobs: DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1418,7 +1418,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_4 - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: @@ -1465,7 +1465,7 @@ jobs: DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cpu-test: # Testing @@ -1486,7 +1486,7 @@ jobs: use_split_build: True DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu - build_environment: linux-binary-manywheel + build_environment: linux-binary-manywheel-split runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: From 239bad2ffa1f69b973d5b317aa955544455facf7 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Mon, 9 Sep 2024 12:05:27 -0400 Subject: [PATCH 0643/1018] NJT <-> padded dense conversions (#125947) This PR: * Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values) * Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics * Note: there is currently no public API for this; design booted to a future PR TODO: * ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~ * ~~Verify that Inductor does computation fusion via test logic~~ Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947 Approved by: https://github.com/soulitzer --- aten/src/ATen/native/native_functions.yaml | 5 + ...asDecompTest.test_has_decomposition.expect | 1 + test/test_nestedtensor.py | 176 ++++++++++++++++++ tools/autograd/derivatives.yaml | 7 +- torch/_dynamo/trace_rules.py | 4 +- torch/nested/_internal/nested_tensor.py | 25 +++ torch/nested/_internal/ops.py | 92 +++++++++ 7 files changed, 308 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9342da626beb03..6bcbdc0a6a7954 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14706,6 +14706,11 @@ CUDA: _fbgemm_dense_to_jagged_forward_symint CPU: _padded_dense_to_jagged_forward_cpu +- func: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + - func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor dispatch: NestedTensorCPU: NestedTensor_softmax_dropout diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 58d52f34345e6f..ef1752b5743dde 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -445,6 +445,7 @@ aten::_nested_from_padded aten::_nested_from_padded.out aten::_nested_from_padded_and_nested_example aten::_nested_from_padded_and_nested_example.out +aten::_nested_from_padded_tensor aten::_nested_get_jagged_dummy aten::_nested_get_lengths aten::_nested_get_max_seqlen diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 49cb422f8720cc..82db5d038bee71 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -1,5 +1,6 @@ # Owner(s): ["module: nestedtensor"] +import ast import io import itertools import math @@ -7164,6 +7165,181 @@ def check(nt): check(nt) + @dtypes(torch.float32, torch.double, torch.half) + @parametrize("nt_dim", [2, 3, 4]) + @parametrize("requires_grad", [False, True]) + def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad): + if nt_dim == 2: + post_seq_len_shape = () + elif nt_dim == 3: + post_seq_len_shape = (10,) + elif nt_dim == 4: + post_seq_len_shape = (9, 10) + + nt = torch.nested.nested_tensor( + [ + torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + for n in range(2, 9) + ], + layout=torch.jagged, + requires_grad=requires_grad, + ) + + PADDING_VAL = 4.2 + expected_padded = nt._values.new_full((7, 8, *post_seq_len_shape), PADDING_VAL) + for i, component in enumerate(nt.unbind()): + expected_padded[i, : component.shape[0]].copy_(component) + + padded = nt.to_padded_tensor(PADDING_VAL) + self.assertEqual(expected_padded, padded) + + # convert padded dense -> NJT + from torch.nested._internal.nested_tensor import nested_from_padded + + nt2 = nested_from_padded(padded, nt.offsets()) + self.assertEqual(nt, nt2) + + if requires_grad: + # ensure gradients flow through conversions + nt2.backward(torch.ones_like(nt2)) + self.assertEqual(nt.grad, torch.ones_like(nt)) + + # blows up due to test parametrization otherwise + @torch._dynamo.utils.disable_cache_limit() + @skipIfTorchDynamo("SDPA test compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + @dtypes(torch.float32, torch.double, torch.half) + @parametrize("nt_dim", [2, 3, 4]) + @parametrize("requires_grad", [False, True]) + def test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad): + if nt_dim == 2: + post_seq_len_shape = () + elif nt_dim == 3: + post_seq_len_shape = (10,) + elif nt_dim == 4: + post_seq_len_shape = (9, 10) + + nt = torch.nested.nested_tensor( + [ + torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + for n in range(2, 9) + ], + layout=torch.jagged, + requires_grad=requires_grad, + ) + + def f(x): + return x.sin() + 1 + + from torch.nested._internal.nested_tensor import nested_from_padded + + @torch.compile(fullgraph=True) + def g(nt): + def _g(nt): + PADDING_VAL = 4.2 + padded = nt.to_padded_tensor(PADDING_VAL) + padded = f(padded) + # NB: sum_S must be specified to use the lowering for dense -> jagged + # and get full fusion + return nested_from_padded( + padded, nt.offsets(), sum_S=nt.values().shape[0] + ) + + # NB: use checkpointing to force fusion + return torch.utils.checkpoint.checkpoint(_g, nt, use_reentrant=False) + + expected_output = f(nt) + if requires_grad: + expected_output.backward(torch.ones_like(expected_output)) + expected_grad = nt.grad.clone().detach() + nt.grad = None + + from torch._inductor.utils import run_and_get_code + + compiled_output, generated_code = run_and_get_code(g, nt) + if requires_grad: + compiled_output.backward(torch.ones_like(compiled_output)) + compiled_grad = nt.grad.clone().detach() + self.assertEqual(compiled_grad, expected_grad, rtol=1e-3, atol=1e-3) + + self.assertEqual(compiled_output, expected_output, rtol=1e-3, atol=1e-3) + + # === Verify that computation fusion happens. === + # Fallback op call -> fusion didn't happen. + fallback_op_calls_present = any( + "torch.ops.aten._padded_dense_to_jagged_forward.default(" + in generated_code[i] + or "torch.ops.aten._jagged_to_padded_dense_forward.default(" + in generated_code[i] + for i in range(len(generated_code)) + ) + + # NB: Fusion isn't supported on CPU. + self.assertEqual("cuda" in device, not fallback_op_calls_present) + + for i in range(len(generated_code)): + # Examine buffer construction lines in the generated code to determine + # whether fusion occurred. If fusion happens, a 3D buffer with shape + # (B, max_seqlen, D) should never be materialized. + buffer_constructions = [ + line.strip() + for line in generated_code[i].split("\n") + if "empty_strided_cuda(" in line + ] + + buffer_dims = [ + # buffer dim == number of elements in the tensor size tuple arg + len(ast.parse(t).body[0].value.args[0].elts) + for t in buffer_constructions + ] + + if "cuda" in device: + self.assertFalse(any(d == 3 for d in buffer_dims)) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + def test_compile_padded_dense_conversion_preserves_metadata_cache( + self, device, dtype + ): + # shape (B, *, D) + nt = random_nt_from_dims( + [4, None, 3, 16], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # expect min / max seqlen to be stored here + cache = dict(nt._metadata_cache) + + @torch.compile + def g(nt): + padded = nt.to_padded_tensor(0.3) + intermediate = padded.sin() + 1 + + from torch.nested._internal.nested_tensor import nested_from_padded + + return nested_from_padded( + intermediate, + nt.offsets(), + min_seqlen=nt._min_seqlen, + max_seqlen=nt._max_seqlen, + sum_S=nt.values().shape[0], + ) + + output = g(nt) + output.backward(torch.ones_like(output)) + self.assertEqual(output._metadata_cache, cache) + FORWARD_FAILURES = { # === BEGIN NotImplementedError SECTION === diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9f7ea3fbeb4ff4..f06c4f4bd16908 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2825,9 +2825,14 @@ cpu_nested_shape_example: non_differentiable - name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor - self: at::_nested_from_padded(grad, self._nested_tensor_size()) + self: "self.layout() == c10::kJagged ? at::_nested_from_padded_tensor_symint(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt, std::optional(at::_nested_get_values(self).sym_size(0))) : at::_nested_from_padded(grad, self._nested_tensor_size())" padding: non_differentiable +- name: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + padded: grad.to_padded_tensor_symint(0.0, at::OptionalArrayRef(padded.sym_sizes())) + offsets: non_differentiable + dummy: non_differentiable + - name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) self: grad.values() nested_size: non_differentiable diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 6609b84885ecd8..122ccea258826e 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -178,8 +178,9 @@ "torch.nn.Parameter": TorchInGraphFunctionVariable, "torch.nn.Buffer": TorchInGraphFunctionVariable, "torch._nested_tensor_from_mask": SkipFunctionVariable, - "torch._nested_from_padded": SkipFunctionVariable, + "torch.nested._internal.nested_tensor.nested_from_padded": TorchInGraphFunctionVariable, "torch.nested.nested_tensor_from_jagged": UserFunctionVariable, + "torch.nested.nested_tensor_from_padded": UserFunctionVariable, # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable, @@ -1525,6 +1526,7 @@ "torch._neg_view_copy", "torch._neg_view", "torch._nested_from_padded_and_nested_example", + "torch._nested_from_padded_tensor", "torch._nested_tensor_from_mask_left_aligned", "torch._nested_tensor_from_tensor_list", "torch._nested_tensor_softmax_with_shape", diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 6a0425a13d43af..76f8d5a29e2d8f 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -562,3 +562,28 @@ def nested_view_from_values_offsets_lengths( min_seqlen_tensor, max_seqlen_tensor, ) # type: ignore[return-value] + + +def nested_from_padded( + padded, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None, sum_S=None +): + if ragged_idx != 1: + raise RuntimeError("nested_from_padded(): only ragged_idx=1 supported for now") + + min_seqlen_tensor = None + if min_seqlen is not None: + min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + + max_seqlen_tensor = None + if max_seqlen is not None: + max_seqlen_tensor = _store_val_in_tensor(max_seqlen) + + return torch._nested_from_padded_tensor( + padded, + offsets, + _nt_view_dummy(), + ragged_idx, + min_seqlen_tensor, + max_seqlen_tensor, + sum_S, + ) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index ed9a54f9dca932..f8053d3bcf7baf 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1553,6 +1553,98 @@ def all_default(func, *args, **kwargs): return func(inp._values) +@register_jagged_func( + torch.ops.aten.to_padded_tensor.default, "self: jt, padding: any, output_size: any?" +) +def to_padded_tensor_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + # TODO: Handle the rest of output_size + output_size = new_kwargs["output_size"] + if output_size is not None: + max_seq_len = output_size[inp._ragged_idx] + else: + max_seq_len = inp._max_seqlen + + # only 2D values is supported by the underlying FBGEMM kernel so do shape + # gymnastics if needed + values = inp.values() + values_shape = values.shape + if values.dim() > 2: + values = values.flatten(start_dim=1) + elif values.dim() == 1: + values = values.unsqueeze(-1) + + padded_out = torch.ops.aten._jagged_to_padded_dense_forward( + values, + [inp._offsets], + [max_seq_len], + new_kwargs["padding"], + ) + + # shape gymnastics part 2 + if len(values_shape) > 2: + padded_out = padded_out.unflatten(-1, values_shape[1:]) + elif len(values_shape) == 1: + padded_out = padded_out.squeeze(-1) + + return padded_out + + +@register_jagged_func( + torch.ops.aten._nested_from_padded_tensor.default, + "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?", +) +def _nested_from_padded_tensor_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + if new_kwargs["ragged_idx"] != 1: + raise RuntimeError( + "_nested_from_padded_tensor(): only ragged_idx=1 supported for jagged layout" + ) + + padded, offsets = new_kwargs["padded"], new_kwargs["offsets"] + + # non-3D padded is not supported by the underlying FBGEMM kernel so do shape gymnastics + padded_shape = padded.shape + if padded.dim() > 3: + padded = padded.flatten(start_dim=2) + elif padded.dim() < 3: + padded = padded.unsqueeze(-1) + + values = torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets], new_kwargs["sum_S"] + ) + + # shape gymnastics part 2 + if len(padded_shape) > 3: + values = values.unflatten(-1, padded_shape[2:]) + elif len(padded_shape) < 3: + values = values.squeeze(-1) + + ragged_idx = new_kwargs["ragged_idx"] + min_seqlen = new_kwargs["min_seqlen"] + max_seqlen = new_kwargs["max_seqlen"] + metadata_cache = {} + if min_seqlen is not None: + metadata_cache["min_seqlen"] = min_seqlen + if max_seqlen is not None: + metadata_cache["max_seqlen"] = max_seqlen + + return NestedTensor( + values, + offsets, + _ragged_idx=ragged_idx, + _metadata_cache=metadata_cache, + ) + + @register_jagged_func( torch.ops.aten._nested_view_from_jagged.default, "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", From ae93540609ddb3b027c9790b2a215eb2e2623c59 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 9 Sep 2024 20:15:12 +0000 Subject: [PATCH 0644/1018] Revert "[ONNX] Update fake mode usage in onnx docs (#135512)" This reverts commit a13c118994b4f118388d97a35abcb91a396cd437. Reverted https://github.com/pytorch/pytorch/pull/135512 on behalf of https://github.com/davidberard98 due to failing test https://github.com/pytorch/pytorch/actions/runs/10778813316/job/29891679127 ([comment](https://github.com/pytorch/pytorch/pull/135512#issuecomment-2338999090)) --- torch/onnx/_internal/_exporter_legacy.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index f16cb09d2b2506..71b828b9bceb7a 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -419,15 +419,10 @@ def enable_fake_mode(): ... arg1, ... export_options=export_options ... ) - ... onnx_program.apply_weights(MyModel().state_dict()) >>> # Saving model WITHOUT initializers - >>> onnx_program.save( - ... "my_model_without_initializers.onnx", - ... include_initializers=False, - ... keep_initializers_as_inputs=True, - ... ) + >>> onnx_program.save("my_model_without_initializers.onnx") >>> # Saving model WITH initializers - >>> onnx_program.save("my_model_with_initializers.onnx") + >>> onnx_program.save("my_model_with_initializers.onnx", model_state=MyModel().state_dict()) .. warning:: This API is experimental and is *NOT* backward-compatible. From f5f4a58b16639659a0834fdbde897248ff75fc2d Mon Sep 17 00:00:00 2001 From: Maclyn Brandwein Date: Mon, 9 Sep 2024 20:18:08 +0000 Subject: [PATCH 0645/1018] Fix broken E2E tests on Linux machines (#135394) Summary: I'm not entirely sure why this is failing with an `ImportError` (according to lastnameye a super class of `ModuleNotFoundError`s), but on our E2E tests on Linux machines (but not Macs?), we're seeing the import failure not getting caught -- `ImportError: cannot import name 'parutil' from 'libfb.py' (/data/sandcastle/boxes/eden-trunk-hg-full-fbsource/buck-out/v2/gen/fbsource/d0c916ec8d40ce11/arvr/libraries/ctrl/studies/replay/__ctrl-r__/ctrl-r#link-tree/libfb/py/__init__.py)` from this test run https://www.internalfb.com/sandcastle/workflow/2522015791331601269, an instance of this job: https://www.internalfb.com/intern/test/844425085172858?ref_report_id=0 is the overall job Test Plan: `arc skycastle schedule tools/skycastle/workflows2/ctrl/js_tests.sky:test_js_e2e_replay_tests --sandcastle-spec-overrides '{"type": "fbcode", "unicastle_size": "I1_MEDIUM"}'` -> https://www.internalfb.com/sandcastle/workflow/256705178764255769 Differential Revision: D62321167 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135394 Approved by: https://github.com/laithsakka --- torch/_inductor/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 80b65fbc31af91..28664d54f4e58a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -589,7 +589,7 @@ def decide_compile_threads() -> int: ) else: global_cache_dir = parutil.get_dir_path("fb/cache") - except (ValueError, ModuleNotFoundError): + except (ValueError, ImportError): global_cache_dir = None else: From 3713952d8c11539ab5e4088d7253e4550a0e214c Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 9 Sep 2024 20:48:52 +0000 Subject: [PATCH 0646/1018] Use upload-artifact@v4.4.0 for create_release.yml (#135528) Fixes failure: https://github.com/pytorch/pytorch/actions/runs/10780281005/job/29895846007 Due broken sync ``` actions/upload-artifact@v2 and actions/download-artifact@v4.1.7 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135528 Approved by: https://github.com/kit1980, https://github.com/malfet --- .github/workflows/create_release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index ab8275ad263f69..43cb2ac5cb2ca3 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -63,7 +63,7 @@ jobs: files: ${{env.PT_RELEASE_FILE}} - name: Upload source distribution to GHA artifacts for release tags if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }} - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4.4.0 with: name: ${{ env.PT_RELEASE_FILE }} path: ${{ env.PT_RELEASE_FILE }} From 3592d0baf57941063a9c06b6e2ac5d7f20f874aa Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 9 Sep 2024 21:55:13 +0000 Subject: [PATCH 0647/1018] Revert "Add `__init__.py` to shape inference folder. (#135461)" This reverts commit dced0d6d9f05f0962f74a3c6227f774111c15715. Reverted https://github.com/pytorch/pytorch/pull/135461 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it exposes some public function without appropriate doc. I will reopen the issue with hi-prio so that it can be fixed properly ([comment](https://github.com/pytorch/pytorch/pull/135461#issuecomment-2339218382)) --- torch/fx/experimental/shape_inference/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 torch/fx/experimental/shape_inference/__init__.py diff --git a/torch/fx/experimental/shape_inference/__init__.py b/torch/fx/experimental/shape_inference/__init__.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000 From 33b8c35571673498652c7c22158c03a81cbb0c0c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 9 Sep 2024 22:01:09 +0000 Subject: [PATCH 0648/1018] Revert "NJT <-> padded dense conversions (#125947)" This reverts commit 09a5e88bef04d5485b70d8f65f46a675aaa52942. Reverted https://github.com/pytorch/pytorch/pull/125947 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing dynamo test https://hud.pytorch.org/pytorch/pytorch/commit/09a5e88bef04d5485b70d8f65f46a675aaa52942, maybe a landrace ([comment](https://github.com/pytorch/pytorch/pull/125947#issuecomment-2339228570)) --- aten/src/ATen/native/native_functions.yaml | 5 - ...asDecompTest.test_has_decomposition.expect | 1 - test/test_nestedtensor.py | 176 ------------------ tools/autograd/derivatives.yaml | 7 +- torch/_dynamo/trace_rules.py | 4 +- torch/nested/_internal/nested_tensor.py | 25 --- torch/nested/_internal/ops.py | 92 --------- 7 files changed, 2 insertions(+), 308 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6bcbdc0a6a7954..9342da626beb03 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14706,11 +14706,6 @@ CUDA: _fbgemm_dense_to_jagged_forward_symint CPU: _padded_dense_to_jagged_forward_cpu -- func: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor - variants: function - device_check: NoCheck - dispatch: {} - - func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor dispatch: NestedTensorCPU: NestedTensor_softmax_dropout diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index ef1752b5743dde..58d52f34345e6f 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -445,7 +445,6 @@ aten::_nested_from_padded aten::_nested_from_padded.out aten::_nested_from_padded_and_nested_example aten::_nested_from_padded_and_nested_example.out -aten::_nested_from_padded_tensor aten::_nested_get_jagged_dummy aten::_nested_get_lengths aten::_nested_get_max_seqlen diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 82db5d038bee71..49cb422f8720cc 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -1,6 +1,5 @@ # Owner(s): ["module: nestedtensor"] -import ast import io import itertools import math @@ -7165,181 +7164,6 @@ def check(nt): check(nt) - @dtypes(torch.float32, torch.double, torch.half) - @parametrize("nt_dim", [2, 3, 4]) - @parametrize("requires_grad", [False, True]) - def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad): - if nt_dim == 2: - post_seq_len_shape = () - elif nt_dim == 3: - post_seq_len_shape = (10,) - elif nt_dim == 4: - post_seq_len_shape = (9, 10) - - nt = torch.nested.nested_tensor( - [ - torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) - for n in range(2, 9) - ], - layout=torch.jagged, - requires_grad=requires_grad, - ) - - PADDING_VAL = 4.2 - expected_padded = nt._values.new_full((7, 8, *post_seq_len_shape), PADDING_VAL) - for i, component in enumerate(nt.unbind()): - expected_padded[i, : component.shape[0]].copy_(component) - - padded = nt.to_padded_tensor(PADDING_VAL) - self.assertEqual(expected_padded, padded) - - # convert padded dense -> NJT - from torch.nested._internal.nested_tensor import nested_from_padded - - nt2 = nested_from_padded(padded, nt.offsets()) - self.assertEqual(nt, nt2) - - if requires_grad: - # ensure gradients flow through conversions - nt2.backward(torch.ones_like(nt2)) - self.assertEqual(nt.grad, torch.ones_like(nt)) - - # blows up due to test parametrization otherwise - @torch._dynamo.utils.disable_cache_limit() - @skipIfTorchDynamo("SDPA test compiles internally") - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm - @dtypes(torch.float32, torch.double, torch.half) - @parametrize("nt_dim", [2, 3, 4]) - @parametrize("requires_grad", [False, True]) - def test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad): - if nt_dim == 2: - post_seq_len_shape = () - elif nt_dim == 3: - post_seq_len_shape = (10,) - elif nt_dim == 4: - post_seq_len_shape = (9, 10) - - nt = torch.nested.nested_tensor( - [ - torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) - for n in range(2, 9) - ], - layout=torch.jagged, - requires_grad=requires_grad, - ) - - def f(x): - return x.sin() + 1 - - from torch.nested._internal.nested_tensor import nested_from_padded - - @torch.compile(fullgraph=True) - def g(nt): - def _g(nt): - PADDING_VAL = 4.2 - padded = nt.to_padded_tensor(PADDING_VAL) - padded = f(padded) - # NB: sum_S must be specified to use the lowering for dense -> jagged - # and get full fusion - return nested_from_padded( - padded, nt.offsets(), sum_S=nt.values().shape[0] - ) - - # NB: use checkpointing to force fusion - return torch.utils.checkpoint.checkpoint(_g, nt, use_reentrant=False) - - expected_output = f(nt) - if requires_grad: - expected_output.backward(torch.ones_like(expected_output)) - expected_grad = nt.grad.clone().detach() - nt.grad = None - - from torch._inductor.utils import run_and_get_code - - compiled_output, generated_code = run_and_get_code(g, nt) - if requires_grad: - compiled_output.backward(torch.ones_like(compiled_output)) - compiled_grad = nt.grad.clone().detach() - self.assertEqual(compiled_grad, expected_grad, rtol=1e-3, atol=1e-3) - - self.assertEqual(compiled_output, expected_output, rtol=1e-3, atol=1e-3) - - # === Verify that computation fusion happens. === - # Fallback op call -> fusion didn't happen. - fallback_op_calls_present = any( - "torch.ops.aten._padded_dense_to_jagged_forward.default(" - in generated_code[i] - or "torch.ops.aten._jagged_to_padded_dense_forward.default(" - in generated_code[i] - for i in range(len(generated_code)) - ) - - # NB: Fusion isn't supported on CPU. - self.assertEqual("cuda" in device, not fallback_op_calls_present) - - for i in range(len(generated_code)): - # Examine buffer construction lines in the generated code to determine - # whether fusion occurred. If fusion happens, a 3D buffer with shape - # (B, max_seqlen, D) should never be materialized. - buffer_constructions = [ - line.strip() - for line in generated_code[i].split("\n") - if "empty_strided_cuda(" in line - ] - - buffer_dims = [ - # buffer dim == number of elements in the tensor size tuple arg - len(ast.parse(t).body[0].value.args[0].elts) - for t in buffer_constructions - ] - - if "cuda" in device: - self.assertFalse(any(d == 3 for d in buffer_dims)) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm - def test_compile_padded_dense_conversion_preserves_metadata_cache( - self, device, dtype - ): - # shape (B, *, D) - nt = random_nt_from_dims( - [4, None, 3, 16], - device=device, - dtype=dtype, - layout=torch.jagged, - requires_grad=True, - ) - - # expect min / max seqlen to be stored here - cache = dict(nt._metadata_cache) - - @torch.compile - def g(nt): - padded = nt.to_padded_tensor(0.3) - intermediate = padded.sin() + 1 - - from torch.nested._internal.nested_tensor import nested_from_padded - - return nested_from_padded( - intermediate, - nt.offsets(), - min_seqlen=nt._min_seqlen, - max_seqlen=nt._max_seqlen, - sum_S=nt.values().shape[0], - ) - - output = g(nt) - output.backward(torch.ones_like(output)) - self.assertEqual(output._metadata_cache, cache) - FORWARD_FAILURES = { # === BEGIN NotImplementedError SECTION === diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index f06c4f4bd16908..9f7ea3fbeb4ff4 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2825,14 +2825,9 @@ cpu_nested_shape_example: non_differentiable - name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor - self: "self.layout() == c10::kJagged ? at::_nested_from_padded_tensor_symint(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt, std::optional(at::_nested_get_values(self).sym_size(0))) : at::_nested_from_padded(grad, self._nested_tensor_size())" + self: at::_nested_from_padded(grad, self._nested_tensor_size()) padding: non_differentiable -- name: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor - padded: grad.to_padded_tensor_symint(0.0, at::OptionalArrayRef(padded.sym_sizes())) - offsets: non_differentiable - dummy: non_differentiable - - name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) self: grad.values() nested_size: non_differentiable diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 122ccea258826e..6609b84885ecd8 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -178,9 +178,8 @@ "torch.nn.Parameter": TorchInGraphFunctionVariable, "torch.nn.Buffer": TorchInGraphFunctionVariable, "torch._nested_tensor_from_mask": SkipFunctionVariable, - "torch.nested._internal.nested_tensor.nested_from_padded": TorchInGraphFunctionVariable, + "torch._nested_from_padded": SkipFunctionVariable, "torch.nested.nested_tensor_from_jagged": UserFunctionVariable, - "torch.nested.nested_tensor_from_padded": UserFunctionVariable, # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable, @@ -1526,7 +1525,6 @@ "torch._neg_view_copy", "torch._neg_view", "torch._nested_from_padded_and_nested_example", - "torch._nested_from_padded_tensor", "torch._nested_tensor_from_mask_left_aligned", "torch._nested_tensor_from_tensor_list", "torch._nested_tensor_softmax_with_shape", diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 76f8d5a29e2d8f..6a0425a13d43af 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -562,28 +562,3 @@ def nested_view_from_values_offsets_lengths( min_seqlen_tensor, max_seqlen_tensor, ) # type: ignore[return-value] - - -def nested_from_padded( - padded, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None, sum_S=None -): - if ragged_idx != 1: - raise RuntimeError("nested_from_padded(): only ragged_idx=1 supported for now") - - min_seqlen_tensor = None - if min_seqlen is not None: - min_seqlen_tensor = _store_val_in_tensor(min_seqlen) - - max_seqlen_tensor = None - if max_seqlen is not None: - max_seqlen_tensor = _store_val_in_tensor(max_seqlen) - - return torch._nested_from_padded_tensor( - padded, - offsets, - _nt_view_dummy(), - ragged_idx, - min_seqlen_tensor, - max_seqlen_tensor, - sum_S, - ) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index f8053d3bcf7baf..ed9a54f9dca932 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1553,98 +1553,6 @@ def all_default(func, *args, **kwargs): return func(inp._values) -@register_jagged_func( - torch.ops.aten.to_padded_tensor.default, "self: jt, padding: any, output_size: any?" -) -def to_padded_tensor_default(func, *args, **kwargs): - _, new_kwargs = normalize_function( # type: ignore[misc] - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - - # TODO: Handle the rest of output_size - output_size = new_kwargs["output_size"] - if output_size is not None: - max_seq_len = output_size[inp._ragged_idx] - else: - max_seq_len = inp._max_seqlen - - # only 2D values is supported by the underlying FBGEMM kernel so do shape - # gymnastics if needed - values = inp.values() - values_shape = values.shape - if values.dim() > 2: - values = values.flatten(start_dim=1) - elif values.dim() == 1: - values = values.unsqueeze(-1) - - padded_out = torch.ops.aten._jagged_to_padded_dense_forward( - values, - [inp._offsets], - [max_seq_len], - new_kwargs["padding"], - ) - - # shape gymnastics part 2 - if len(values_shape) > 2: - padded_out = padded_out.unflatten(-1, values_shape[1:]) - elif len(values_shape) == 1: - padded_out = padded_out.squeeze(-1) - - return padded_out - - -@register_jagged_func( - torch.ops.aten._nested_from_padded_tensor.default, - "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?", -) -def _nested_from_padded_tensor_default(func, *args, **kwargs): - _, new_kwargs = normalize_function( # type: ignore[misc] - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - if new_kwargs["ragged_idx"] != 1: - raise RuntimeError( - "_nested_from_padded_tensor(): only ragged_idx=1 supported for jagged layout" - ) - - padded, offsets = new_kwargs["padded"], new_kwargs["offsets"] - - # non-3D padded is not supported by the underlying FBGEMM kernel so do shape gymnastics - padded_shape = padded.shape - if padded.dim() > 3: - padded = padded.flatten(start_dim=2) - elif padded.dim() < 3: - padded = padded.unsqueeze(-1) - - values = torch.ops.aten._padded_dense_to_jagged_forward( - padded, [offsets], new_kwargs["sum_S"] - ) - - # shape gymnastics part 2 - if len(padded_shape) > 3: - values = values.unflatten(-1, padded_shape[2:]) - elif len(padded_shape) < 3: - values = values.squeeze(-1) - - ragged_idx = new_kwargs["ragged_idx"] - min_seqlen = new_kwargs["min_seqlen"] - max_seqlen = new_kwargs["max_seqlen"] - metadata_cache = {} - if min_seqlen is not None: - metadata_cache["min_seqlen"] = min_seqlen - if max_seqlen is not None: - metadata_cache["max_seqlen"] = max_seqlen - - return NestedTensor( - values, - offsets, - _ragged_idx=ragged_idx, - _metadata_cache=metadata_cache, - ) - - @register_jagged_func( torch.ops.aten._nested_view_from_jagged.default, "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", From 1b6c03e0eb0567c8f265118800008c7f7eeaeb2d Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 9 Sep 2024 22:28:38 +0000 Subject: [PATCH 0649/1018] [ONNX] Drop final None values as inputs for nodes in exporter graph (#135520) When value for an optional input is not provided, it is defaulted to `None`, which gets translates to "" in the onnx graph. To avoid this, if we have a list of inputs and the final few are all `None`, strip them in the graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135520 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu --- torch/onnx/_internal/exporter/_building.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 3e5b48394c6579..7729aa000b281e 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -465,6 +465,12 @@ def _construct_node( else: inputs.append(value) + # If final inputs are None, strip them from the node inputs + for input in reversed(inputs): + if input is not None: + break + inputs.pop() + # Construct and filter out None attributes attributes = [ attr From cb191af8853838f6d80c5dcf57eb5bb2a22e787e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 9 Sep 2024 11:59:23 -0700 Subject: [PATCH 0650/1018] [Dynamo] Adding CallFunctionNoArgsSource and (#135425) CallFunctionNoArgsGuardAccessor to support torch.cuda.current_device() Pull Request resolved: https://github.com/pytorch/pytorch/pull/135425 Approved by: https://github.com/anijain2305 --- .../aot_eager_torchbench_training.csv | 2 +- .../dynamic_aot_eager_torchbench_training.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../dynamo_eager_torchbench_training.csv | 2 +- .../inductor_torchbench_training.csv | 2 +- test/dynamo/test_functions.py | 21 +++++ test/dynamo/test_guard_manager.py | 9 ++ .../TestTorch.test_cuda_not_built | 0 torch/_dynamo/guards.py | 10 ++- torch/_dynamo/source.py | 5 ++ torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/torch.py | 13 ++- torch/csrc/dynamo/guards.cpp | 89 +++++++++++++++++++ 13 files changed, 150 insertions(+), 9 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestTorch.test_cuda_not_built diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 223aae9b97234b..914849fa010cc1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 47894a2b48c372..5d2fe7d197768e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index 3d3548d3ac5aee..ab99edec8b4ecf 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 223aae9b97234b..914849fa010cc1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 223aae9b97234b..914849fa010cc1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index e936e66f23f8ab..c99948a902dc3b 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -28,6 +28,7 @@ from torch._dynamo.variables import ConstantVariable from torch._dynamo.variables.lists import RangeVariable from torch.nn import functional as F +from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( disable_translation_validation_if_dynamic_shapes, instantiate_parametrized_tests, @@ -3888,6 +3889,26 @@ def fn(x, ys, zs): with self.assertRaisesRegex(ValueError, "zip()"): opt_fn(x, ys, zs[:1]) + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + def test_cuda_current_device(self): + def fn(x): + y = torch.empty( + (2, 3), dtype=torch.float32, device=torch.cuda.current_device() + ) + y.copy_(x) + return torch.sin(y + y.device.index) + + counter = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(backend=counter, fullgraph=True)(fn) + + with torch.cuda.device(0): + x = torch.randn(2, 3) + self.assertEqual(opt_fn(x), fn(x)) + self.assertEqual(counter.frame_count, 1) + with torch.cuda.device(1): + self.assertEqual(opt_fn(x), fn(x)) + self.assertEqual(counter.frame_count, 2) + def test_fn_with_attr(self): def fn(x): if fn.pred: diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index d7e85014ef3657..45f8cbe9690308 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import functools +import unittest import weakref import torch @@ -373,6 +374,14 @@ def test_weakref_alive_guard(self): del x self.assertFalse(guard(weakref_x())) + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_call_function_no_args_guard(self): + x = torch.cuda.current_device() + guard = guards.EQUALS_MATCH(x, [0]) + self.assertTrue(guard(0)) + self.assertFalse(guard(1)) + self.assertFalse(guard(2)) + def test_guard_manager_leaf_guard(self): guard_manager = RootGuardManager() guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"]) diff --git a/test/dynamo_expected_failures/TestTorch.test_cuda_not_built b/test/dynamo_expected_failures/TestTorch.test_cuda_not_built deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1923426bf060f8..b0bd273e96bbb5 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -79,6 +79,7 @@ from .source import ( AttrProxySource, AttrSource, + CallFunctionNoArgsSource, ChainedSource, ConstDictKeySource, DefaultsSource, @@ -1093,13 +1094,20 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) - elif isinstance(source, WeakRefCallSource): + elif istype(source, WeakRefCallSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.weakref_call_manager( source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, CallFunctionNoArgsSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.call_function_no_args_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) else: raise AssertionError( f"missing guard manager builder {source} - {source.name()}" diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 2d3a4424167da9..860e7d3429f487 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -185,6 +185,11 @@ def name(self): return f"{self.base.name()}()" +@dataclasses.dataclass(frozen=True) +class CallFunctionNoArgsSource(WeakRefCallSource): + pass + + @dataclasses.dataclass(frozen=True) class AttrSource(ChainedSource): member: str diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 6609b84885ecd8..e1ee77fe63a085 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -191,7 +191,7 @@ "torch.Tensor#_make_wrapper_subclass": SkipFunctionVariable, "torch.Tensor#__init__": SkipFunctionVariable, "torch.cuda.set_device": SkipFunctionVariable, - "torch.cuda.current_device": SkipFunctionVariable, + "torch.cuda.current_device": TorchInGraphFunctionVariable, "torch._C.autocast_decrement_nesting": SkipFunctionVariable, "torch._C.autocast_increment_nesting": SkipFunctionVariable, "torch.autograd.grad": SkipFunctionVariable, diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 3dfda1cfec83da..b32ac9af6ccfd9 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -27,7 +27,7 @@ from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard -from ..source import SyntheticLocalSource +from ..source import CallFunctionNoArgsSource, SyntheticLocalSource from ..utils import ( check_unspec_or_constant_args, guard_if_dyn, @@ -100,6 +100,10 @@ ] ) +constant_fold_functions_need_guards = [ + torch.cuda.current_device, +] + constant_fold_functions = [ torch._assert, torch._utils._get_device_index, @@ -120,7 +124,7 @@ torch.promote_types, torch._C._get_privateuse1_backend_name, torch.autograd._is_checkpoint_valid, -] +] + constant_fold_functions_need_guards if torch.distributed.is_available(): constant_fold_functions.extend( [ @@ -130,6 +134,7 @@ ] ) # Convert to dict for O(1) access times +constant_fold_functions_need_guards = dict.fromkeys(constant_fold_functions_need_guards) constant_fold_functions = dict.fromkeys(constant_fold_functions) @@ -836,6 +841,10 @@ def call_function( if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): + # constant fold functions need to be guarded. + if self.value in constant_fold_functions_need_guards: + source = CallFunctionNoArgsSource(self.source) + install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) # constant fold return ConstantVariable.create( self.as_python_constant()( diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 0c2cf51a2cbbbb..afc3182c632640 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -3438,6 +3438,69 @@ class WeakRefCallGuardAccessor : public GuardAccessor { } }; +/** + * Implements function call no args - e.g, torch.cuda.current_device() + */ +class CallFunctionNoArgsGuardAccessor : public GuardAccessor { + public: + CallFunctionNoArgsGuardAccessor( + RootGuardManager* root, + py::str name, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + std::move(name), + std::move(source), + example_value, + guard_manager_enum) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) + override { // borrowed ref + if (!PyCallable_Check(obj)) { + return false; + } + + PyObject* x = PyObject_CallNoArgs(obj); + if (x == nullptr) { + // Call failed, clear the exception and return false. + PyErr_Clear(); + return false; + } + + bool result = _guard_manager->check_nopybind(x); + Py_DECREF(x); + return result; + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + if (!PyCallable_Check(obj)) { + return GuardDebugInfo( + false, std::string("Not a callable obj ") + get_source(), 0); + } + + PyObject* x = PyObject_CallNoArgs(obj); + if (x == nullptr) { + // Call failed, clear the exception and return debug info. + std::string exc_message = get_exception_message(); + PyErr_Clear(); + return GuardDebugInfo(false, exc_message, 0); + } + + GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x); + Py_DECREF(x); + return result; + } + + std::string repr() const override { + return "CallFunctionNoArgsGuardAccessor()"; + } +}; + /** * Similar to PythonLambdaLeafGuard, this class is a way to allow developers to * supply accessor as a python function. This is useful for from_numpy source. @@ -3782,6 +3845,12 @@ PyObject* torch_c_dynamo_guards_init() { std::unique_ptr>( py_m, "WeakRefCallGuardAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + CallFunctionNoArgsGuardAccessor, + GuardAccessor, + std::unique_ptr>( + py_m, "CallFunctionNoArgsGuardAccessor"); + // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< TupleIteratorGetItemAccessor, GuardAccessor, @@ -4101,6 +4170,26 @@ PyObject* torch_c_dynamo_guards_init() { py::return_value_policy::reference) // return by reference because GuardManager has the ownership of accessors // and guard managers + .def( + "call_function_no_args_manager", + [](GuardManager& self, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) -> GuardManager* { + // A unique key is used to save as the accessor key. + py::str unique_key("__call_function_no_args_accessor__"); + return self.get_child_manager( + std::move(unique_key), + std::move(source), + example_value, + guard_manager_enum); + }, + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) + // return by reference because GuardManager has the ownership of accessors + // and guard managers .def( "tuple_iterator_getitem_manager", &GuardManager::get_child_manager, From 47e3d94ea0e2a528a3f4e4659b852156e971fe2e Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 9 Sep 2024 13:17:00 -0700 Subject: [PATCH 0651/1018] [xplat][XNNPACK] don't prefer static linkage in xplat for main target (#135529) Building XNNPACK as a static library has some issues because of multiple global params floating around. Let's try to get rid of it in xplat and see how it fares. Differential Revision: [D60776152](https://our.internmc.facebook.com/intern/diff/D60776152/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D60776152/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/135529 Approved by: https://github.com/kimishpatel, https://github.com/mcr229, https://github.com/kirklandsign --- third_party/xnnpack.buck.bzl | 455 +++++++++-------------------------- 1 file changed, 113 insertions(+), 342 deletions(-) diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index 6ce938b28ad685..1ca8e776de683b 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -19,8 +19,8 @@ load( "PROD_AVX512F_MICROKERNEL_SRCS", "PROD_AVX512SKX_MICROKERNEL_SRCS", "PROD_AVX512VBMI_MICROKERNEL_SRCS", - "PROD_AVX512VNNI_MICROKERNEL_SRCS", "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS", + "PROD_AVX512VNNI_MICROKERNEL_SRCS", "PROD_AVXVNNI_MICROKERNEL_SRCS", "PROD_AVX_MICROKERNEL_SRCS", "PROD_F16C_MICROKERNEL_SRCS", @@ -46,6 +46,12 @@ load( "PROD_XOP_MICROKERNEL_SRCS", ) +XNN_COMMON_PREPROCESSOR_FLAGS = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + "-DXNN_LOG_LEVEL=0" +] + # This defines XNNPACK targets for both fbsource BUCK and OSS BUCK # Note that the file path is relative to the BUCK file that called from, not to this bzl file. # So for fbsource build it points to xplat/third-party/XNNPACK/XNNPACK, @@ -78,9 +84,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }, apple_sdks = (IOS, MACOSX, APPLETVOS), labels = labels, - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], exported_deps = [ # Dependency only on pthreadpool interface @@ -99,14 +103,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS + [ "-DXNN_ENABLE_SPARSE=0", "-DXNN_ENABLE_MEMOPT", ], @@ -132,15 +132,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -167,15 +162,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-fno-math-errno", "-ffp-contract=off", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -198,10 +188,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -217,10 +203,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_SSE_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse"], @@ -240,10 +225,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -253,10 +234,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse"], @@ -278,10 +258,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -297,10 +273,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_SSE2_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse2"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse2"], @@ -321,10 +296,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -334,10 +305,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse2"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse2"], @@ -360,10 +330,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -379,10 +345,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_SSSE3_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mssse3"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mssse3"], @@ -403,10 +368,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -416,10 +377,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mssse3"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mssse3"], @@ -442,10 +402,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -461,10 +417,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_SSE41_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse4.1"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse4.1"], @@ -485,10 +440,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -498,10 +449,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse4.1"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse4.1"], @@ -532,10 +482,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -551,10 +497,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -575,10 +520,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-O2", "-mavx", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -588,10 +529,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -631,10 +571,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512dq", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -657,9 +593,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else []), preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -679,10 +613,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -699,9 +629,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -725,16 +653,22 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ] + select({ "DEFAULT": [], "ovr_config//cpu:x86_32": [ - "-mavx", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + "-mavx512vnni", ], "ovr_config//cpu:x86_64": [ - "-mavx", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + "-mavx512vnni", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -756,9 +690,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else []), preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, exported_preprocessor_flags = [ "-DXNN_ENABLE_AVX512VNNI" ], @@ -781,10 +713,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -800,9 +728,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, exported_preprocessor_flags = [ "-DXNN_ENABLE_AVX512VNNI" ], @@ -830,10 +756,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", "-mfma", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -853,9 +775,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else []), preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -875,10 +795,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -890,9 +806,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -922,10 +836,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -942,10 +852,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else []), platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mf16c"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mf16c"], @@ -966,10 +875,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-O2", "-mf16c", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -980,10 +885,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mf16c"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mf16c"], @@ -1015,10 +919,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1035,10 +935,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_FMA3_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mfma", @@ -1066,10 +965,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfma", "-mf16c", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1080,10 +975,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mfma", @@ -1123,10 +1017,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1144,10 +1034,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX2_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx2", @@ -1178,10 +1067,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfma", "-mf16c", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1193,10 +1078,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx2", @@ -1235,10 +1119,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512f", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1254,10 +1134,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX512F_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx512f"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx512f"], @@ -1296,10 +1175,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512vbmi", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1320,10 +1195,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX512VBMI_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx512f", @@ -1358,10 +1232,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-O2", "-mavx512f", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1371,10 +1241,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx512f"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx512f"], @@ -1412,10 +1281,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512vl", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1435,10 +1300,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX512SKX_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx512f", @@ -1453,7 +1317,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512bw", "-mavx512dq", "-mavx512vl", - + ], deps = [ ":interface", @@ -1476,10 +1340,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512dq", "-mavx512vl", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1493,10 +1353,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx512f", @@ -1533,10 +1392,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-fno-fast-math", "-fno-math-errno", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1549,10 +1404,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1584,10 +1438,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1605,10 +1455,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEON_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1633,10 +1482,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], platform_srcs = [ ( "(aarch64|arm64)", @@ -1644,10 +1489,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else [], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1679,10 +1523,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon-vfpv4", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1707,10 +1547,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONFMA_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1735,10 +1574,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_srcs = [ ( @@ -1747,10 +1582,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else [], platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1787,10 +1621,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-ffinite-math-only", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1807,10 +1637,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1838,10 +1667,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon-fp16", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1853,10 +1678,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1880,10 +1704,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "DEFAULT": [], "ovr_config//cpu:arm64": ["-march=armv8-a"], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1908,10 +1728,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1943,10 +1762,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfloat-abi=softfp", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1964,10 +1779,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONDOT_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1995,10 +1809,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "DEFAULT": [], "ovr_config//cpu:arm64": ["-march=armv8.2-a+dotprod"], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2014,10 +1824,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONDOT_MICROKERNEL_SRCS + PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2065,15 +1874,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2103,10 +1907,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-march=armv8.2-a+dotprod+fp16", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], platform_compiler_flags = [ ( "(aarch64|arm64)", @@ -2122,10 +1922,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else [], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2157,10 +1956,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon-fp-armv8", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2178,10 +1973,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONFP16ARITH_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2209,10 +2003,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "DEFAULT": [], "ovr_config//cpu:arm64": ["-march=armv8.2-a+fp16"], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2228,10 +2018,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONFP16ARITH_MICROKERNEL_SRCS + PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2263,10 +2052,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-march=armv8.2-a+i8mm+fp16", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2285,10 +2070,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2317,10 +2101,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon-fp-armv8", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2333,10 +2113,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2363,10 +2142,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-march=armv8.2-a+fp16+dotprod", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2376,10 +2151,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2393,6 +2167,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "arm64_lib", apple_sdks = (IOS, MACOSX, APPLETVOS), labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ @@ -2495,6 +2270,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "armv7_lib", apple_sdks = (IOS, MACOSX, APPLETVOS), labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ @@ -2510,6 +2286,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "prod_ukernels", apple_sdks = (IOS, MACOSX, APPLETVOS), labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ @@ -2543,19 +2320,13 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F exported_headers = { "xnnpack.h": "XNNPACK/include/xnnpack.h", }, - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], header_namespace = "", headers = subdir_glob([ ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h"), ]), platforms = (APPLE, ANDROID, CXX, WINDOWS), - preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS + [ "-DXNN_NO_Q8_OPERATORS", "-DXNN_NO_F16_OPERATORS", "-DXNN_NO_NCHW_OPERATORS", From b0bcefbb89c3804fd72b0d464ac38047ac231c15 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 9 Sep 2024 23:04:34 +0000 Subject: [PATCH 0652/1018] remove amax_ptr from scaled_gemm (#135421) amax was removed from _scaled_mm by #128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well. This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135421 Approved by: https://github.com/drisspg, https://github.com/eqy --- aten/src/ATen/cuda/CUDABlas.cpp | 9 ++------- aten/src/ATen/cuda/CUDABlas.h | 1 - aten/src/ATen/cuda/tunable/TunableGemm.h | 1 - aten/src/ATen/native/cuda/Blas.cpp | 18 ++---------------- 4 files changed, 4 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index eea4a9f421a0c5..9b3fd5dc6e4dd9 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1408,7 +1408,6 @@ void scaled_gemm( const void *result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - void* amax_ptr, bool use_fast_accum) { #if CUDA_VERSION >= 11080 || defined(USE_ROCM) const auto computeType = CUBLAS_COMPUTE_32F; @@ -1421,13 +1420,9 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); -#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200) - // Amax support in ROCm as of 6.2 - if (isFloat8Type(result_dtype)) { - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr); + if (result_scale_ptr != nullptr) { + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); } -#endif #ifndef USE_ROCM computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode); #endif diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 2c6cef95f79fe8..e6f0c5a9a373ba 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -140,7 +140,6 @@ void scaled_gemm( const void* result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - void* amax_ptr, bool use_fast_accum); #define CUDABLAS_BGEMM_ARGTYPES(Dtype) \ diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 50a7344b0260c2..00b02e91b4f359 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -104,7 +104,6 @@ class DefaultScaledGemmOp : public Callable> { params->c_scale_ptr, params->ldc, params->c_dtype, - params->amax_ptr, params->use_fast_accum); return OK; } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 991c7d2dba16db..741d05bdd7169e 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -964,9 +964,9 @@ ScalingType get_scaling_type( } // namespace -// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax +// Computes matrix multiply + bias while applying scaling to input and output matrices // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default. -// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed. +// If output matrix type is 16 or 32-bit type, scale_result is not applied. // Known limitations: // - Only works if mat1 is row-major and mat2 is column-major // - Only works if matrices sizes are divisible by 32 @@ -1068,9 +1068,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); - // Some scaled_gemms require an amax to populate lets create one here - Tensor amax = at::empty({0}, mat1.options().dtype(ScalarType::Float)); - #ifdef USE_ROCM auto tuning_ctx = at::cuda::tunable::getTuningContext(); if (tuning_ctx->IsTunableOpEnabled()) { @@ -1126,7 +1123,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.c_scale_ptr = scale_result ? scale_result->data_ptr() : nullptr; params.ldc = args.result_ld; params.c_dtype = out_dtype_; - params.amax_ptr = amax.data_ptr(); params.use_fast_accum = use_fast_accum; if (transa_ && transb_) { TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) @@ -1150,11 +1146,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, else #endif { -#if defined(USE_ROCM) && ROCM_VERSION >= 60200 - // hipBlasLT requires scaleD to be set to something in order to use AMAX - auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA); - auto dummy_scale = at::ones(1, dummy_options); -#endif at::cuda::blas::scaled_gemm( args.transa, args.transb, @@ -1172,14 +1163,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), -#if defined(USE_ROCM) && ROCM_VERSION >= 60200 - scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(), -#else scale_result ? scale_result->data_ptr() : nullptr, -#endif args.result_ld, out_dtype_, - amax.data_ptr(), use_fast_accum); } From 84efb231fa915a2b9060bec15d7e25223536ad66 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 9 Sep 2024 12:14:52 -0300 Subject: [PATCH 0653/1018] Enable forward AD in functional.affine_grid (#135494) Fixes #121411 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135494 Approved by: https://github.com/zou3519, https://github.com/soulitzer --- test/test_nn.py | 4 ++-- tools/autograd/derivatives.yaml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index df8120eb84e26e..8bb14f3188a2bd 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6224,7 +6224,7 @@ def test_affine_grid(self): warnings.simplefilter("always") # python2 requires this so other tests can trigger self.assertTrue(gradcheck( lambda inp: F.affine_grid(inp, sz, align_corners=align_corners), - (inp,))) + (inp,), check_forward_ad=True)) # test CPU against CUDA if TEST_CUDA: @@ -6276,7 +6276,7 @@ def test_affine_grid_3d(self): warnings.simplefilter("always") # python2 requires this so other tests can trigger self.assertTrue(gradcheck( lambda inp: F.affine_grid(inp, sz, align_corners=align_corners), - (inp,))) + (inp,), check_forward_ad=True)) # test CPU against CUDA if TEST_CUDA: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9f7ea3fbeb4ff4..aaa6c3a55a43ef 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -278,6 +278,7 @@ - name: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor theta: affine_grid_generator_backward_symint(grad, size, align_corners) + result: auto_linear - name: alias(Tensor(a) self) -> Tensor(a) self: grad From 1de5e5478158bc1309ea0014fab0c76fbbbc5b6b Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Tue, 10 Sep 2024 02:33:26 +0000 Subject: [PATCH 0654/1018] [inductor]Add profiler to operatorbench (#135515) Add profiling to operatorbench. The new argument `--profile` is added and the profiling trace is like the following figure. image Pull Request resolved: https://github.com/pytorch/pytorch/pull/135515 Approved by: https://github.com/shunting314 --- .../dynamo/microbenchmarks/operatorbench.py | 118 ++++++++++++------ 1 file changed, 82 insertions(+), 36 deletions(-) diff --git a/benchmarks/dynamo/microbenchmarks/operatorbench.py b/benchmarks/dynamo/microbenchmarks/operatorbench.py index 2ebca7e340336c..d61ec36870563f 100644 --- a/benchmarks/dynamo/microbenchmarks/operatorbench.py +++ b/benchmarks/dynamo/microbenchmarks/operatorbench.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from contextlib import nullcontext + import click import numpy as np from operator_inp_utils import OperatorInputsLoader @@ -16,11 +18,13 @@ aten = torch.ops.aten +profile_enabled = False def compute_speedups( operator, models, example_inputs, repeats, accuracy_checking=False, device="cuda" ): + global profile_enabled expected = models[0](*example_inputs) if accuracy_checking: for model in models[1:]: @@ -35,20 +39,32 @@ def compute_speedups( timings = np.zeros((repeats, len(models)), np.float64) for rep in range(repeats): - # interleave the runs to handle frequency scaling and load changes - for m, model in enumerate(models): - if device == "cuda": - model(*example_inputs) - - # benchmarker.benchmark_gpu() clears L2 cache to hide the latency of CPU launch time - # along with cuda synchronization - timings[rep, m] = benchmarker.benchmark_gpu( - lambda: model(*example_inputs) + record_rep_context = ( + torch.profiler.record_function(f"rep_{rep}") + if profile_enabled + else nullcontext() + ) + with record_rep_context: + # interleave the runs to handle frequency scaling and load changes + for m, model in enumerate(models): + record_model_context = ( + torch.profiler.record_function(f"model_{m}") + if profile_enabled + else nullcontext() ) - else: - from torch._inductor.utils import timed - - timings[rep, m] = timed(model, example_inputs) + with record_model_context: + if device == "cuda": + model(*example_inputs) + + # benchmarker.benchmark_gpu() clears L2 cache to hide the latency of CPU launch time + # along with cuda synchronization + timings[rep, m] = benchmarker.benchmark_gpu( + lambda: model(*example_inputs) + ) + else: + from torch._inductor.utils import timed + + timings[rep, m] = timed(model, example_inputs) return np.median(timings, axis=0) @@ -171,6 +187,7 @@ def skip_operator(operator): @click.option( "--channels-last", help="force inputs to channels last", is_flag=True, default=False ) +@click.option("--profile", help="profile the benchmark", is_flag=True, default=False) def benchmark( suite, op, @@ -183,7 +200,9 @@ def benchmark( inp_file, start_idx, channels_last, + profile, ): + global profile_enabled if inp_file is not None: loader = OperatorInputsLoader(inp_file) else: @@ -209,6 +228,8 @@ def benchmark( ops = [eval(op)] max_samples = max_samples + start_idx + profile_enabled = profile + for operator in ops: if skip_operator(operator): continue @@ -216,10 +237,31 @@ def benchmark( print(f"Running {operator}") inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype, device=device) timings = [] - - for i in range(min(max_samples, 1000000)): + inputs_list = [] + for _ in range(min(max_samples, 1000000)): try: inps = next(inp_gen) + inputs_list.append(inps) + except StopIteration: + break + + profiler_context = ( + torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + profile_memory=False, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f"./log/operator_{operator}", use_gzip=True + ), + ) + if profile_enabled + else nullcontext() + ) + with profiler_context as prof: + for i, inps in enumerate(inputs_list): if inps is None: break if i < start_idx: @@ -230,28 +272,32 @@ def benchmark( args, kwargs = tree_map_only( torch.Tensor, to_channels_last, (args, kwargs) ) - - except StopIteration: - break - try: - # aten, nvfuser, inductor - timings.append( - microbenchmark( - operator, - args, - kwargs, - dtype, - accuracy_checking, - repeats, - measure_nvfuser, - device, + try: + iter_context = ( + torch.profiler.record_function(f"iter_{i}") + if profile_enabled + else nullcontext() ) - ) - except Exception as e: - print(f"error {operator}") - print(e) - # comment out this line to avoid blocking other tests - # raise e + with iter_context: + # aten, nvfuser, inductor + timings.append( + microbenchmark( + operator, + args, + kwargs, + dtype, + accuracy_checking, + repeats, + measure_nvfuser, + device, + ) + ) + + except Exception as e: + print(f"error {operator}") + print(e) + # comment out this line to avoid blocking other tests + # raise e if not timings: continue From 30e6e46bf3a5e68f96eb78d561ff9079ccaca442 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Mon, 9 Sep 2024 10:55:47 +0000 Subject: [PATCH 0655/1018] [inductor] [cpp] fix the input contiguous check in max-autotune (#134982) ## Description Fixes the FP32 accuracy failure of `resmlp_12_224` and BF16 accuracy failure of `volo_d1_224` in timm. In this PR, we check whether input is contiguous using the following way: If it has `FixedLayout`, we know the accurate strides. For `FlexibleLayout`, if its data is a `ComputedBuffer`, we could get the fill order of the buffer to decide whether it's contiguous. For the other cases, we won't use GEMM template as we can't infer whether it's contiguous. ## Additional context The current GEMM template only supports this case: `input.get_stride()[-1] == 1`. In `resmlp_12_224`, when we run into this check, the layout of `input` is a `FlexibleLayout`. The reason is that when realizing the input which is a `View` IR, the `convert_to_reinterpret_view` call fails: https://github.com/pytorch/pytorch/blob/d14fe3ffeddff743af09ce7c8d91127940ddf7ed/torch/_inductor/ir.py#L4712-L4715 And it finally runs into this `copy_input` and returns a `FlexibleLayout`. https://github.com/pytorch/pytorch/blob/d14fe3ffeddff743af09ce7c8d91127940ddf7ed/torch/_inductor/ir.py#L4722 When checking its stride, this `FlexibleLayout` indeed satisfies `input.get_stride()[-1] == 1` but it is later decided as a `FixedLayout` with `size = (3072, 196), stride = (1, 3072)`, which is not supported by the GEMM template, thus causing accuracy issue in this model. The `FlexibleLayout` is converted to `FixedLayout` during [CppPackedGemmTemplate.add_choices](https://github.com/pytorch/pytorch/blob/d14fe3ffeddff743af09ce7c8d91127940ddf7ed/torch/_inductor/mkldnn_lowerings.py#L1051) which calls [slice_nd](https://github.com/pytorch/pytorch/blob/d14fe3ffeddff743af09ce7c8d91127940ddf7ed/torch/_inductor/codegen/cpp_template_kernel.py#L150) when rendering the kernel (`slice_nd(X)`). When creating the `SliceView` IR, [as_storage_and_layout](https://github.com/pytorch/pytorch/blob/d14fe3ffeddff743af09ce7c8d91127940ddf7ed/torch/_inductor/ir.py#L2288) invokes [decide_layout](https://github.com/pytorch/pytorch/blob/d14fe3ffeddff743af09ce7c8d91127940ddf7ed/torch/_inductor/ir.py#L2135) and converts it to a `FixedLayout` with `size = (3072, 196), stride = (1, 3072)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134982 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel --- test/inductor/test_cpu_select_algorithm.py | 71 ++++++++++++++++++++++ torch/_inductor/utils.py | 6 +- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 19f518387a398c..2034fa05653a87 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -469,6 +469,77 @@ def forward(self, mul_272, _convolution_pointwise_default_31): self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (8,)) + @parametrize("in_features", (3,)) + @parametrize("linear_in_features", (384,)) + @parametrize("out_features", (196,)) + @parametrize("bias", (True,)) + @dtypes(torch.float) + def test_linear_with_input_of_flexible_layout( + self, batch_size, in_features, linear_in_features, out_features, bias, dtype + ): + # Reproducer from the resmlp_12_224 model in timm + flatten_BS = int(batch_size * linear_in_features) + + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.conv = torch.nn.Conv2d( + in_features, + linear_in_features, + kernel_size=16, + padding=0, + stride=16, + dilation=1, + groups=1, + ) + self._frozen_param151 = torch.randn(1, 1, linear_in_features) + self._frozen_param3 = torch.randn(1, 1, linear_in_features) + self._frozen_param2 = torch.randn(linear_in_features) + + self.linear = torch.nn.Linear(out_features, out_features, bias) + + def forward(self, arg150_1): + _convolution_pointwise_default = self.conv(arg150_1) + view_73 = torch.ops.aten.reshape.default( + _convolution_pointwise_default, + [batch_size, linear_in_features, out_features], + ) + _convolution_pointwise_default = None + permute_62 = torch.ops.aten.permute.default(view_73, [0, 2, 1]) + view_73 = None + mul_111 = torch.ops.aten.mul.Tensor(self._frozen_param151, permute_62) + add_73 = torch.ops.aten.add.Tensor(self._frozen_param3, mul_111) + permute_63 = torch.ops.aten.permute.default(add_73, [0, 2, 1]) + add_73 = None + view_74 = torch.ops.aten.reshape.default( + permute_63, [flatten_BS, out_features] + ) + permute_63 = None + _mkl_linear_36 = self.linear(view_74) + view_75 = torch.ops.aten.reshape.default( + _mkl_linear_36, [batch_size, linear_in_features, out_features] + ) + _mkl_linear_36 = None + permute_65 = torch.ops.aten.permute.default(view_75, [0, 2, 1]) + view_75 = None + mul_112 = torch.ops.aten.mul.Tensor(self._frozen_param2, permute_65) + _frozen_param2 = permute_65 = None + add_74 = torch.ops.aten.add.Tensor(permute_62, mul_112) + permute_62 = mul_112 = None + return add_74 + + v = torch.randn(batch_size, in_features, 224, 224).to(dtype=dtype) + mod = M(bias=bias).to(dtype=dtype).eval() + with verify(dtype) as (atol, rtol): + self.common(mod, (v,), atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + @inductor_config.patch({"freezing": True}) @patches @torch.no_grad diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5c76dd19822a2d..b17f3a68559a60 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1251,10 +1251,14 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2, mat2_transposed=False): num_threads=parallel_num_threads(), ) + def is_last_dim_stride1(x): + x.freeze_layout() + return x.get_stride()[-1] == 1 + return ( layout.dtype in layout_dtypes and micro_gemm is not None - and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input + and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input and isinstance(mat2, ir.StorageBox) and mat2.is_module_buffer() ) From faf76de75d0367c015da30601df48c18b8e7cf23 Mon Sep 17 00:00:00 2001 From: zengxian Date: Tue, 10 Sep 2024 03:02:19 +0000 Subject: [PATCH 0656/1018] Fix dynamo benchmark skip logic for cpu device (#135193) Fixes #132380, adjust torchbench and huggingface skip models list, then we can remove `--no-skip` when running benchmarks on 3 suites. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135193 Approved by: https://github.com/chuanqi129, https://github.com/jansel --- ...ctor_amp_freezing_torchbench_inference.csv | 16 +++++++++++++++ ...nductor_freezing_huggingface_inference.csv | 4 ++++ ...inductor_freezing_torchbench_inference.csv | 16 +++++++++++++++ ...tor_amp_freezing_huggingface_inference.csv | 4 ++++ ...ctor_amp_freezing_torchbench_inference.csv | 20 +++++++++++++++++++ ...nductor_freezing_huggingface_inference.csv | 4 ++++ ...inductor_freezing_torchbench_inference.csv | 20 +++++++++++++++++++ .../cpu_inductor_huggingface_inference.csv | 4 ++++ .../cpu_inductor_torchbench_inference.csv | 12 +++++++++++ ...ctor_amp_freezing_torchbench_inference.csv | 16 +++++++++++++++ ...inductor_freezing_torchbench_inference.csv | 16 +++++++++++++++ ...mic_cpu_inductor_huggingface_inference.csv | 4 ++++ ...amic_cpu_inductor_torchbench_inference.csv | 12 +++++++++++ benchmarks/dynamo/common.py | 12 ++++++++--- benchmarks/dynamo/huggingface.py | 2 +- benchmarks/dynamo/huggingface.yaml | 5 ++--- benchmarks/dynamo/runner.py | 7 +------ benchmarks/dynamo/torchbench.py | 8 ++++++-- benchmarks/dynamo/torchbench.yaml | 10 +++++++--- 19 files changed, 174 insertions(+), 18 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv index 13762422e2a83e..71345480d423d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv @@ -130,6 +130,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,0 + + + hf_DistilBert,pass,0 @@ -142,6 +146,10 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -170,6 +178,10 @@ maml_omniglot,pass,0 +mnasnet1_0,pass,0 + + + mobilenet_v2,pass,0 @@ -242,6 +254,10 @@ resnext50_32x4d,pass,0 +shufflenet_v2_x1_0,pass,0 + + + soft_actor_critic,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv index 8d3646d6533a40..1cafcbe55675d3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv @@ -126,6 +126,10 @@ MobileBertForQuestionAnswering,pass,0 +OPTForCausalLM,pass,0 + + + PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv index 1349ce2537da16..282c18feea4f64 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv @@ -130,6 +130,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,0 + + + hf_DistilBert,pass,0 @@ -142,6 +146,10 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -170,6 +178,10 @@ maml_omniglot,pass,0 +mnasnet1_0,pass,0 + + + mobilenet_v2,pass,0 @@ -242,6 +254,10 @@ resnext50_32x4d,pass,0 +shufflenet_v2_x1_0,pass,0 + + + soft_actor_critic,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv index cfc3baf0d2ef9d..349239b058a7b7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv @@ -130,6 +130,10 @@ MobileBertForQuestionAnswering,pass,0 +OPTForCausalLM,pass,0 + + + PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index c24f8e1b3fbbc3..dafbd90e9aa793 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -138,6 +138,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,19 + + + hf_DistilBert,pass,0 @@ -150,10 +154,18 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_Longformer,pass,4 + + + hf_Reformer,pass,5 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -178,6 +190,10 @@ maml,pass_due_to_skip,0 +mnasnet1_0,pass,0 + + + maml_omniglot,pass,0 @@ -258,6 +274,10 @@ resnext50_32x4d,pass,0 +shufflenet_v2_x1_0,pass,0 + + + soft_actor_critic,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv index cfc3baf0d2ef9d..349239b058a7b7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv @@ -130,6 +130,10 @@ MobileBertForQuestionAnswering,pass,0 +OPTForCausalLM,pass,0 + + + PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv index 71f6ccce68c437..a897806e5188b7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv @@ -138,6 +138,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,19 + + + hf_DistilBert,pass,0 @@ -150,10 +154,18 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_Longformer,pass,4 + + + hf_Reformer,pass,5 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -182,6 +194,10 @@ maml_omniglot,pass,0 +mnasnet1_0,pass,0 + + + mobilenet_v2,pass,0 @@ -258,6 +274,10 @@ resnext50_32x4d,pass,0 +shufflenet_v2_x1_0,pass,0 + + + soft_actor_critic,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv index cfc3baf0d2ef9d..349239b058a7b7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv @@ -130,6 +130,10 @@ MobileBertForQuestionAnswering,pass,0 +OPTForCausalLM,pass,0 + + + PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index 8b3868107df51e..c8980699f9615d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -138,6 +138,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,13 + + + hf_DistilBert,pass,0 @@ -150,10 +154,18 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_Longformer,pass,4 + + + hf_Reformer,pass,5 +hf_T5,pass,0 + + + hf_T5_base,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv index 085ea372032622..9fe2b93f08e814 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv @@ -114,6 +114,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,0 + + + hf_DistilBert,pass,0 @@ -126,6 +130,10 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -158,6 +166,10 @@ mobilenet_v2,pass,0 +mnasnet1_0,pass,0 + + + mobilenet_v2_quantized_qat,fail_to_run,0 @@ -226,6 +238,10 @@ resnext50_32x4d,pass,0 +shufflenet_v2_x1_0,pass,0 + + + soft_actor_critic,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv index 2f4d4d0a4fa91b..98fe4427622601 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv @@ -114,6 +114,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,0 + + + hf_DistilBert,pass,0 @@ -126,6 +130,10 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -154,6 +162,10 @@ maml_omniglot,pass,0 +mnasnet1_0,pass,0 + + + mobilenet_v2,pass,0 @@ -226,6 +238,10 @@ resnext50_32x4d,pass,0 +shufflenet_v2_x1_0,pass,0 + + + soft_actor_critic,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv index cfc3baf0d2ef9d..349239b058a7b7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv @@ -130,6 +130,10 @@ MobileBertForQuestionAnswering,pass,0 +OPTForCausalLM,pass,0 + + + PLBartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index 449d34d42f68a2..9b6ec5b6cddecf 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -122,6 +122,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,13 + + + hf_DistilBert,pass,0 @@ -134,10 +138,18 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_Longformer,pass,4 + + + hf_Reformer,pass,5 +hf_T5,pass,0 + + + hf_T5_base,pass,0 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index bea9fc76e226a7..b3de936bad1b48 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2370,7 +2370,11 @@ def skip_models_for_cpu(self): return set() @property - def skip_models_for_freezing(self): + def skip_models_for_freezing_cpu(self): + return set() + + @property + def skip_models_for_freezing_cuda(self): return set() @property @@ -4275,7 +4279,6 @@ def run(runner, args, original_dir=None): runner.skip_models.update(runner.slow_models) if args.devices == ["cpu"]: - runner.skip_models.update(runner.very_slow_models) runner.skip_models.update(runner.skip_models_for_cpu) elif args.devices == ["cuda"]: runner.skip_models.update(runner.skip_models_for_cuda) @@ -4284,7 +4287,10 @@ def run(runner, args, original_dir=None): runner.skip_models.update(runner.skip_multiprocess_models) if args.freezing: - runner.skip_models.update(runner.skip_models_for_freezing) + if args.devices == ["cpu"]: + runner.skip_models.update(runner.skip_models_for_freezing_cpu) + elif args.devices == ["cuda"]: + runner.skip_models.update(runner.skip_models_for_freezing_cuda) if args.no_skip: runner.skip_models.clear() diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 40268c0773ae74..06bf4f0ee7610a 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -505,7 +505,7 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name): return 4e-3, cosine if ( current_device == "cpu" - and name in self._config["tolerance"]["higher_inference"] + and name in self._config["tolerance"]["higher_inference_cpu"] ): return 4e-3, cosine return 1e-3, cosine diff --git a/benchmarks/dynamo/huggingface.yaml b/benchmarks/dynamo/huggingface.yaml index 9650defcbfac66..2ddc242537d6e7 100644 --- a/benchmarks/dynamo/huggingface.yaml +++ b/benchmarks/dynamo/huggingface.yaml @@ -11,9 +11,7 @@ skip: - GPTJForQuestionAnswering device: - cpu: - # OOMs - - OPTForCausalLM + cpu: [] control_flow: - AllenaiLongformerBase @@ -71,6 +69,7 @@ batch_size: TrOCRForCausalLM: 2 XGLMForCausalLM: 4 XLNetLMHeadModel: 2 + YituTechConvBert: 2 tolerance: diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 747a8e06e635ad..2b13a4b5c5308a 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -387,11 +387,6 @@ def get_skip_tests(suite, device, is_training: bool): skip_tests.update(module.TorchBenchmarkRunner().skip_models_for_cpu) elif device == "cuda": skip_tests.update(module.TorchBenchmarkRunner().skip_models_for_cuda) - else: - if hasattr(module, "SKIP"): - skip_tests.update(module.SKIP) - if is_training and hasattr(module, "SKIP_TRAIN"): - skip_tests.update(module.SKIP_TRAIN) skip_tests = (f"-x {name}" for name in skip_tests) skip_str = " ".join(skip_tests) @@ -438,7 +433,7 @@ def generate_commands(args, dtypes, suites, devices, compilers, output_dir): if args.enable_cpu_launcher: launcher_cmd = f"python -m torch.backends.xeon.run_cpu {args.cpu_launcher_args}" cmd = f"{launcher_cmd} benchmarks/dynamo/{suite}.py --{testing} --{dtype} -d{device} --output={output_filename}" - cmd = f"{cmd} {base_cmd} {args.extra_args} --no-skip --dashboard" + cmd = f"{cmd} {base_cmd} {args.extra_args} --dashboard" skip_tests_str = get_skip_tests(suite, device, args.training) cmd = f"{cmd} {skip_tests_str}" diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index 6233ecb4e9ae9b..b27951b16494d6 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -138,8 +138,12 @@ def skip_models_for_cuda(self): return self._skip["device"]["cuda"] @property - def skip_models_for_freezing(self): - return self._skip["freezing"] + def skip_models_for_freezing_cuda(self): + return self._skip["freezing"]["cuda"] + + @property + def skip_models_for_freezing_cpu(self): + return self._skip["freezing"]["cpu"] @property def slow_models(self): diff --git a/benchmarks/dynamo/torchbench.yaml b/benchmarks/dynamo/torchbench.yaml index 84c7de12a36ee6..0b9a083515a69a 100644 --- a/benchmarks/dynamo/torchbench.yaml +++ b/benchmarks/dynamo/torchbench.yaml @@ -191,6 +191,7 @@ skip: - hf_Whisper - stable_diffusion_text_encoder - llava + - moco cuda: [] @@ -232,10 +233,13 @@ skip: # for these models, conv-batchnorm fusing causes big numerical churn. # Skip them + # mnasnet1_0 and shufflenet_v2_x1_0 can pass on cpu, moco cuda only. freezing: - - mnasnet1_0 - - moco - - shufflenet_v2_x1_0 + cuda: + - mnasnet1_0 + - moco + - shufflenet_v2_x1_0 + cpu: [] From c4641d44d1e5fa6b651d797be6752b1b80ab1e0c Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 9 Sep 2024 15:46:24 -0700 Subject: [PATCH 0657/1018] [DCP] Fixes the stateless optimizer issue of distributed state_dict (#135535) Some optimizers don't have states that can cause get_state_dict/set_state_dict behave incorrectly. This PR fixes the issues. fixes: https://github.com/pytorch/pytorch/issues/133415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135535 Approved by: https://github.com/wz337 --- .../distributed/checkpoint/test_state_dict.py | 71 +++++++++++-------- torch/distributed/checkpoint/state_dict.py | 16 ++--- .../distributed/common_state_dict.py | 7 +- 3 files changed, 54 insertions(+), 40 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index e299e3ee8269ed..e3dfb782ad5657 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -86,18 +86,19 @@ def _test_save_load( model, optim, copy_optim, dist_model, dist_optim = init_model_optim() # Train 10 steps. + _dist_optim = [dist_optim] if not isinstance(dist_optim, list) else dist_optim for i in range(10): + optim.zero_grad() + for d_optim in _dist_optim: + d_optim.zero_grad() + batch = torch.rand(8, 100, device="cuda") model(batch).sum().backward() - optim.step() dist_model(batch).sum().backward() - if not isinstance(dist_optim, list): - dist_optim.step() - dist_optim.zero_grad() - else: - for _dist_optim in dist_optim: - _dist_optim.zero_grad() - optim.zero_grad() + + optim.step() + for d_optim in _dist_optim: + d_optim.step() # Get the state_dict, and compare the result msd = model.state_dict() @@ -176,8 +177,8 @@ def init_model_optim(): device_mesh = init_device_mesh("cuda", (self.world_size,)) orig_model = CompositeParamModel(device=torch.device("cuda")) - orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3, foreach=True) - copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3, foreach=True) + orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) + copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) if wrapping: strategy = set(wrapping) else: @@ -204,7 +205,7 @@ def init_model_optim(): if compile_model: dist_model = torch.compile(dist_model) - dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3, foreach=True) + dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4, foreach=True) return orig_model, orig_optim, copy_optim, dist_model, dist_optim self._test_save_load(init_model_optim) @@ -218,7 +219,11 @@ def test_fsdp(self) -> None: "use_composable": [True, False], "use_dtensor": [True, False], "wrapping": [(), (nn.Linear, UnitModule)], - "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], + "optimizer_class": [ + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.SGD, + ], }, self._test_fsdp, ) @@ -248,10 +253,10 @@ def _test_fsdp2( def init_model_optim(): orig_model = CompositeParamModel(device=torch.device("cuda")) orig_optim = optimizer_class( - orig_model.parameters(), lr=1e-3, foreach=foreach + orig_model.parameters(), lr=1e-4, foreach=foreach ) copy_optim = optimizer_class( - orig_model.parameters(), lr=1e-3, foreach=foreach + orig_model.parameters(), lr=1e-4, foreach=foreach ) dist_model = FSDP2( @@ -262,7 +267,7 @@ def init_model_optim(): if compile_model: dist_model = torch.compile(dist_model) dist_optim = optimizer_class( - dist_model.parameters(), lr=1e-3, foreach=foreach + dist_model.parameters(), lr=1e-4, foreach=foreach ) return orig_model, orig_optim, copy_optim, dist_model, dist_optim @@ -284,13 +289,13 @@ def test_fsdp2(self) -> None: def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None: def init_model_optim(): orig_model = CompositeParamModel(device=torch.device("cuda")) - orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) - copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) + orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) + copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) if use_composable: dist_model = replicate(copy.deepcopy(orig_model)) else: dist_model = DDP(copy.deepcopy(orig_model)) - dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) + dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4) return orig_model, orig_optim, copy_optim, dist_model, dist_optim self._test_save_load(init_model_optim) @@ -301,7 +306,11 @@ def test_ddp(self) -> None: self.run_subtests( { "use_composable": [True, False], - "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], + "optimizer_class": [ + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.SGD, + ], }, self._test_ddp, ) @@ -320,8 +329,8 @@ def init_model_optim(): orig_model.u1.parameters(), orig_model.u2.parameters() ): param.requires_grad = False - orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) - copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) + orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) + copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) dist_model = copy.deepcopy(orig_model) if use_composable: replicate(dist_model.l) @@ -336,13 +345,13 @@ def init_model_optim(): ) if optim_in_backward: _apply_optimizer_in_backward( - optimizer_class, dist_model.parameters(), {"lr": 1e-3} + optimizer_class, dist_model.parameters(), {"lr": 1e-4} ) dist_optim = [ p._in_backward_optimizers[0] for p in dist_model.parameters() ] else: - dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) + dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4) return orig_model, orig_optim, copy_optim, dist_model, dist_optim self._test_save_load(init_model_optim, test_frozen) @@ -395,10 +404,10 @@ def test_apply_optimizer_in_backward(self) -> None: def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None: def init_model_optim(): orig_model = CompositeParamModel(device=torch.device("cuda")) - orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) - copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) + orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) + copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) model_copy = copy.deepcopy(orig_model) - optim_copy = optimizer_class(model_copy.parameters(), lr=1e-3) + optim_copy = optimizer_class(model_copy.parameters(), lr=1e-4) return orig_model, orig_optim, copy_optim, model_copy, optim_copy self._test_save_load(init_model_optim) @@ -445,7 +454,7 @@ def _test_cpu_offload_full_state_dict( device_mesh=device_mesh, ) - dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) + dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4) mst, ost = get_state_dict( dist_model, @@ -887,10 +896,10 @@ def forward(self, input): def init_model_optim(): device_mesh = init_device_mesh("cuda", (self.world_size,)) orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) - orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) - copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) + orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) + copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) - dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3) + dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-4) return orig_model, orig_optim, copy_optim, dist_model, dist_optim self._test_save_load(init_model_optim) @@ -958,7 +967,7 @@ def setUp(self) -> None: @skip_if_lt_x_gpu(1) def test_no_dist(self) -> None: model = CompositeParamModel(device=torch.device("cuda")) - optim = torch.optim.AdamW(model.parameters(), lr=1e-3) + optim = torch.optim.AdamW(model.parameters(), lr=1e-4) self.assertFalse(dist.is_initialized()) msd = get_model_state_dict( diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index bd0bff271a5974..ef90b32eb2042f 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -591,17 +591,17 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None: # The optimizer state is initialized. return + # There are some stateless optimizers like SGD. These optimizer will + # not return in the above condition. So if gradients exist, we should also + # return. If gradients do not exist, the following initialization should + # not disturb SGD because the gradients and lr are both zero. for param_group in optim.param_groups: for param in param_group[_PARAMS]: if param.grad is not None: - raise RuntimeError( - "state_dict can only be used if the optimizer " - "states are initialized (usually after one step() with " - "gradients) or gradients are None. For the later case, " - "state_dict will fake the gradients as zero " - "to initialize the optimizer states. However, the " - "gradients are not None." - ) + return + + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: if param.requires_grad: param.grad = torch.zeros_like(param) diff --git a/torch/testing/_internal/distributed/common_state_dict.py b/torch/testing/_internal/distributed/common_state_dict.py index 5aa8d7f14aa0ab..c2ab9af9f197e8 100644 --- a/torch/testing/_internal/distributed/common_state_dict.py +++ b/torch/testing/_internal/distributed/common_state_dict.py @@ -43,7 +43,12 @@ def _verify_msd( dist_param = dist_msd.get(fqn, None) if not options.ignore_frozen_params: self.assertIsNotNone(dist_param, f"{fqn=}") - self._compare_tensor(param, dist_param, offload_to_cpu) + try: + self._compare_tensor(param, dist_param, offload_to_cpu) + except AssertionError as e: + raise AssertionError( + f"{fqn} has mismatched value {param} {dist_param}" + ) from e elif dist_param is None: self.assertFalse(param.requires_grad, f"{fqn=}") From 6f2ea46e6ce8ed4fb6b005d52a19e3169f7d2ec7 Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Tue, 10 Sep 2024 03:47:36 +0000 Subject: [PATCH 0658/1018] do not raise when flatten_fn_with_keys not found when suggesting fixes (#135518) Test Plan: added test Differential Revision: D62395371 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135518 Approved by: https://github.com/zhxchen17 --- test/export/test_export.py | 30 ++++++++++++++++++++++++ torch/fx/experimental/symbolic_shapes.py | 11 ++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 0e26947f5c3026..7ee5e5c982b0e4 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2490,6 +2490,36 @@ def forward(self, xs, y): strict=strict, ) + class Box: + def __init__(self, content): + self.content = content + + from torch.utils._pytree import register_pytree_node + + register_pytree_node( + Box, + lambda box: ([box.content], None), # flatten_fn + lambda contents, _context: Box(*contents), # unflatten_fn + flatten_with_keys_fn=None, # unflatten_fn + serialized_type_name="test_no_suggested_fixes_for_data_dependent_errors.Box", + ) + + class cf_stacklist_udd(torch.nn.Module): + def forward(self, xs, y): + box = Box(y.item()) + # box.content is not a local, so we can't suggest a fix + return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() + + with self.assertRaisesRegex( + error_type, + "Could not guard on data-dependent expression u0 < 0", + ): + export( + cf_stacklist_udd(), + ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), + strict=strict, + ) + def test_tolist(self): class M(torch.nn.Module): def forward(self, x): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 369708d2c2dbc0..21e72542d34cda 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -5588,8 +5588,17 @@ def _suggest_fixes_for_data_dependent_error_non_strict(e): # map symbol names reachable via frame locals to their source-level names src_map = defaultdict(list) for var, val in frame.f_locals.items(): + try: + tree_leaves_with_path = pytree.tree_leaves_with_path(val) + except ValueError: + log.warning( + "pytree.tree_leaves_with_path failed for value of type {%s} in local variable {%s}", + type(val), + var, + ) + continue # figure out how to access any symbol inside `val` through `var` - for path, leaf in pytree.tree_leaves_with_path(val): + for path, leaf in tree_leaves_with_path: name = var + pytree.keystr(path) if isinstance(leaf, torch.SymInt): src_map[str(leaf.node.expr)].append(name) From e57006dc9a7165cd4459af2ab4bc57cc180725f3 Mon Sep 17 00:00:00 2001 From: Thomas Bohnstingl Date: Tue, 10 Sep 2024 04:51:16 +0000 Subject: [PATCH 0659/1018] Implementation of scan (#134102) This operation is supposed to be the pendant to the `associative_scan`, but can operate with non-associative functions. @ydwu4 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134102 Approved by: https://github.com/ydwu4 --- test/functorch/test_control_flow.py | 1794 +++++++++++++++++-- test/test_fake_tensor.py | 16 + torch/_dynamo/trace_rules.py | 2 + torch/_dynamo/variables/higher_order_ops.py | 219 ++- torch/_higher_order_ops/associative_scan.py | 95 +- torch/_higher_order_ops/scan.py | 438 +++++ torch/_higher_order_ops/utils.py | 1 + torch/_inductor/lowering.py | 10 +- 8 files changed, 2360 insertions(+), 215 deletions(-) create mode 100644 torch/_higher_order_ops/scan.py diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index ba57a91a9e34a4..e07510828f4584 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -8,6 +8,7 @@ from functorch.experimental import control_flow from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException from torch._higher_order_ops.associative_scan import associative_scan +from torch._higher_order_ops.scan import scan from torch._higher_order_ops.while_loop import while_loop from torch._subclasses.functional_tensor import ( CppFunctionalizeAPI, @@ -23,6 +24,7 @@ instantiate_parametrized_tests, IS_WINDOWS, parametrize, + requires_cuda, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, @@ -81,8 +83,8 @@ def _fake_while_loop(cond_fn, body_fn, operands): return operands -def _fake_associative_scan(combine_fn, input, dim, reverse=False): - inp_leaves, spec = pytree.tree_flatten(input) +def _fake_associative_scan(combine_fn, xs, dim, reverse=False): + inp_leaves, spec = pytree.tree_flatten(xs) result_flat = [] num_leaves = len(inp_leaves) op = reversed if reverse else lambda x: x @@ -109,6 +111,118 @@ def _fake_associative_scan(combine_fn, input, dim, reverse=False): return pytree.tree_unflatten(results, spec) +def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False): + carry_leaves, carry_spec = pytree.tree_flatten(init) + inp_leaves, inp_spec = pytree.tree_flatten(xs) + if xs is None or len(inp_leaves) == 0: + return init, [] + result_flat = [] + carry = carry_leaves + op = reversed if reverse else lambda x: x + + dummy_carry, dummy_out = combine_fn( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten( + [torch._ops.ops.aten.slice(elem, dim, 0, 1, 1) for elem in inp_leaves], + inp_spec, + ), + ) + dummy_out_leaves, dummy_out_spec = pytree.tree_flatten(dummy_out) + num_leaves = len(dummy_out_leaves) + + for ind in op(range(inp_leaves[0].size(dim))): + xs = [ + torch._ops.ops.aten.slice(elem, dim, ind, ind + 1, 1) for elem in inp_leaves + ] + + carry, y = combine_fn( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten(xs, inp_spec), + ) + carry, _ = pytree.tree_flatten(carry) + y, _ = pytree.tree_flatten(y) + result_flat.append(y) + + results = [ + torch.concatenate([e[leave_ind] for e in op(result_flat)], dim) + for leave_ind in range(num_leaves) + ] + return ( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten(results, dummy_out_spec), + ) + + +def compile_mode_helper(fct, compile_mode): + if compile_mode == "compile": + return torch.compile(fct, fullgraph=True, dynamic=False) + elif compile_mode == "compile_dynamic_shape": + return torch.compile(fct, fullgraph=True, dynamic=True) + elif compile_mode == "eager": + return torch.compile(fct, fullgraph=True, backend="eager") + else: + return fct + + +def get_scan_combine_fn(name, associative=True): + def add(x: torch.Tensor, y: torch.Tensor): + return x + y + + def adds(x: torch.Tensor, y: torch.Tensor): + return x + x, y + y + + def mul(x: torch.Tensor, y: torch.Tensor): + return x * y + + def div(x: torch.Tensor, y: torch.Tensor): + return x / y + + def s5_operator(x, y): + A_i, Bu_i = x + A_j, Bu_j = y + return A_j * A_i, A_j * Bu_i + Bu_j + + def tuple_fct(x, y): + return (x[0] + y[0], x[1] * y[1]) + + def complex_pointwise(x, y): + return { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + } + + def non_pointwise(x: torch.Tensor, y: torch.Tensor): + W = torch.diag(torch.ones(2, device=x.device)) + return x @ W + y @ W + + if name == "add": + fct = add + elif name == "adds": + fct = adds + elif name == "mul": + fct = mul + elif name == "div": + fct = div + elif name == "s5_operator": + fct = s5_operator + elif name == "tuple_fct": + fct = tuple_fct + elif name == "complex_pointwise": + fct = complex_pointwise + elif name == "non_pointwise": + fct = non_pointwise + else: + raise ValueError("Combine_fn name unknown!") + + if not associative: + return lambda x, y: (fct(x, y), fct(x, y)) + else: + return fct + + def _while_loop_tests(): def simple(x): def cond_fn(x): @@ -1205,9 +1319,11 @@ def fwbw(map_op, f, x, y): fake_outs = fwbw(_fake_map, f, x, y) self.assertEqual(true_outs, fake_outs) + # TODO: provide an implementation for all compile modes and re-enable all test @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @requires_cuda @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "compile", "compile_dynamic_shape"]) @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) # Skipping the combination of combine_mode=pointwise and device=cpu @@ -1219,28 +1335,36 @@ def fwbw(map_op, f, x, y): and params["device"] == torch.device("cpu") ), ) - def test_pointwise_associative_scan_simple(self, reverse, combine_mode, device): - def add(x: torch.Tensor, y: torch.Tensor): - return x + y - - def mul(x: torch.Tensor, y: torch.Tensor): - return x * y - + def test_associative_scan_compile( + self, combine_mode, reverse, compile_mode, device + ): x = torch.randn(3, 10, 2, device=device) - for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: - result = associative_scan( - op, x, 0, reverse=reverse, combine_mode=combine_mode - ) - result_exp = _fake_associative_scan(op, x, 0, reverse=reverse) + scan_fct = compile_mode_helper(associative_scan, compile_mode) + + for op, op_pt in [ + (get_scan_combine_fn("add", True), torch.cumsum), + (get_scan_combine_fn("mul", True), torch.cumprod), + ]: + result = scan_fct(op, x, 0, reverse=reverse, combine_mode=combine_mode) + result_exp = _fake_associative_scan(op, xs=x, dim=0, reverse=reverse) self.assertEqual(result, result_exp) + if not reverse: + result_exp_PT = op_pt(x, 0) + self.assertEqual(result, result_exp_PT) # Jax Examples x = torch.arange(0, 4, device=device) - cumsum1 = associative_scan( - add, x, 0, reverse=reverse, combine_mode=combine_mode + cumsum1 = scan_fct( + get_scan_combine_fn("add", True), + x, + 0, + reverse=reverse, + combine_mode=combine_mode, + ) + cumsum_exp = _fake_associative_scan( + get_scan_combine_fn("add", True), x, 0, reverse=reverse ) - cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse) if not reverse: self.assertEqual( cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) @@ -1251,8 +1375,177 @@ def mul(x: torch.Tensor, y: torch.Tensor): ) self.assertEqual(cumsum1, cumsum_exp) + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_compile(self, reverse, compile_mode, device): + def add2(x: torch.Tensor, y: torch.Tensor): + return x * y, x + y + + x = torch.randn(3, 10, 2, device=device) + + scan_fct = compile_mode_helper(scan, compile_mode) + + for op, op_pt, init in [ + ( + get_scan_combine_fn("add", False), + torch.cumsum, + torch.zeros(1, 10, 2, device=device), + ), + ( + get_scan_combine_fn("mul", False), + torch.cumprod, + torch.ones(1, 10, 2, device=device), + ), + ]: + result = scan_fct(op, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) + self.assertEqual(result, result_exp) + if not reverse: + result_exp_PT = op_pt(x, 0) + self.assertEqual(result[1], result_exp_PT) + + # Jax Examples + x = torch.arange(0, 4, device=device, dtype=torch.int64) + init = torch.zeros(1, device=device, dtype=torch.int64) + cumsum1 = scan_fct( + get_scan_combine_fn("add", False), + init, + x, + dim=0, + reverse=reverse, + ) + cumsum_exp = _fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=x, + dim=0, + reverse=reverse, + ) + if not reverse: + self.assertEqual( + cumsum1[1], torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) + ) + self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64)) + else: + self.assertEqual( + cumsum1[1], torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64) + ) + self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64)) + self.assertEqual(cumsum1, cumsum_exp) + + # Different carry computation as output computation + x = torch.arange(1, 5, device=device, dtype=torch.int64) + init = torch.ones(1, device=device, dtype=torch.int64) + result = scan_fct(add2, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(add2, init=init, xs=x, dim=0, reverse=reverse) + if not reverse: + self.assertEqual( + result[1], torch.tensor([2.0, 3.0, 5.0, 10.0], dtype=torch.int64) + ) + self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64)) + else: + self.assertEqual( + result[1], torch.tensor([25.0, 14.0, 7.0, 5.0], dtype=torch.int64) + ) + self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64)) + self.assertEqual(result, result_exp) + + # Non associative operation + x = torch.arange(0, 5, device=device, dtype=torch.float32) + init = torch.ones(1, device=device, dtype=torch.float32) + result = scan_fct( + get_scan_combine_fn("div", False), + init, + x, + dim=0, + reverse=reverse, + ) + result_exp = _fake_scan( + get_scan_combine_fn("div", False), + init=init, + xs=x, + dim=0, + reverse=reverse, + ) + self.assertEqual(result, result_exp) + + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.int32, + torch.int64, + torch.complex64, + ], + ) + def test_scan_dtype(self, reverse, compile_mode, device, dtype): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Check all outputs and carries on the correct device and with torch.float32 + x = torch.randn(3, 10, 2, device=device).to(dtype=dtype) + op, init = ( + get_scan_combine_fn("adds"), + torch.zeros(1, 10, 2, device=device, dtype=dtype), + ) + result = scan_fct(op, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) + self.assertEqual(result, result_exp) + self.assertEqual( + [[r.device.type for r in res] for res in result], + [[device.type for _ in res] for res in result], + ) + self.assertEqual( + [[r.dtype for r in res] for res in result], + [[dtype for _ in res] for res in result], + ) + + # Check all outputs and carries on the correct device and + # carry.dtype torch.float32 and output.dtype torch.float16 + x = torch.randn(3, 10, 2, device=device).to(dtype=dtype) + op, init = ( + get_scan_combine_fn("adds"), + torch.zeros(1, 10, 2, device=device, dtype=torch.float32), + ) + result = scan_fct(op, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) + self.assertEqual(result, result_exp) + self.assertEqual( + [[r.dtype for r in res] for res in result], + [ + [torch.float32 for _ in range(len(result[0]))], + [dtype for _ in range(len(result[1]))], + ], + ) + + # Check all outputs and carries on the correct device and + # carry.dtype torch.int64 and output.dtype torch.float32 + x = torch.randn(3, 10, 2, device=device) + op, init = ( + get_scan_combine_fn("adds"), + torch.zeros(1, 10, 2, device=device, dtype=dtype), + ) + result = scan_fct(op, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) + self.assertEqual(result, result_exp) + self.assertEqual( + [[r.dtype for r in res] for res in result], + [ + [dtype for _ in range(len(result[0]))], + [torch.float32 for _ in range(len(result[1]))], + ], + ) + @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @requires_cuda @parametrize("reverse", [False, True]) @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) @@ -1265,22 +1558,19 @@ def mul(x: torch.Tensor, y: torch.Tensor): and params["device"] == torch.device("cpu") ), ) - def test_pointwise_associative_scan_dim(self, reverse, combine_mode, device): + def test_associative_scan_dim(self, combine_mode, reverse, device): import random - def add(x: torch.Tensor, y: torch.Tensor): - return x + y - - def mul(x: torch.Tensor, y: torch.Tensor): - return x * y - num_dims = [random.randint(2, 5) for _ in range(10)] for num_dim in num_dims: shapes = [random.randint(1, 10) for _ in range(num_dim)] rnd_scan_dim = random.randint(0, num_dim - 1) x = torch.randn(*shapes, device=device) - for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: + for op, op_pt in [ + (get_scan_combine_fn("add", True), torch.cumsum), + (get_scan_combine_fn("mul", True), torch.cumprod), + ]: result = associative_scan( op, x, rnd_scan_dim, reverse=reverse, combine_mode=combine_mode ) @@ -1292,71 +1582,45 @@ def mul(x: torch.Tensor, y: torch.Tensor): result_exp_PT = op_pt(x, rnd_scan_dim) self.assertEqual(result, result_exp_PT) - @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @requires_cuda @parametrize("reverse", [False, True]) - @parametrize("combine_mode", ["pointwise", "generic"]) - @parametrize("compile_mode", ["compile", "compile_dynamic_shape"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - # Skipping the combination of combine_mode=pointwise and device=cpu - # as the current implementation of pointwise does only support CUDA device - @decorateIf( - unittest.skip, - lambda params: ( - params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") - ), - ) - def test_pointwise_associative_scan_compile( - self, reverse, combine_mode, compile_mode, device - ): - def add(x: torch.Tensor, y: torch.Tensor): - return x + y - - def mul(x: torch.Tensor, y: torch.Tensor): - return x * y - - x = torch.randn(3, 10, 2, device=device) - torch.compiler.reset() - if compile_mode == "compile": - associative_scan_fct = torch.compile( - associative_scan, fullgraph=True, dynamic=False - ) - else: - associative_scan_fct = torch.compile( - associative_scan, fullgraph=True, dynamic=True - ) - - for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: - result = associative_scan_fct( - op, x, 0, reverse=reverse, combine_mode=combine_mode - ) - result_exp = _fake_associative_scan(op, x, 0, reverse=reverse) - self.assertEqual(result, result_exp) - if not reverse: - result_exp_PT = op_pt(x, 0) - self.assertEqual(result, result_exp_PT) + def test_scan_dim(self, reverse, device): + import random - # Jax Examples - x = torch.arange(0, 4, device=device) - cumsum1 = associative_scan( - add, x, 0, reverse=reverse, combine_mode=combine_mode - ) - cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse) - if not reverse: - self.assertEqual( - cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) - ) - else: - self.assertEqual( - cumsum1, torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64) - ) - self.assertEqual(cumsum1, cumsum_exp) + num_dims = [random.randint(2, 5) for _ in range(10)] + for num_dim in num_dims: + shapes = [random.randint(1, 10) for _ in range(num_dim)] + rnd_scan_dim = random.randint(0, num_dim - 1) + x = torch.randn(*shapes, device=device) + init_shapes = shapes + init_shapes[rnd_scan_dim] = 1 + + for op, op_pt, init in [ + ( + get_scan_combine_fn("add", False), + torch.cumsum, + torch.zeros(*init_shapes, device=device), + ), + ( + get_scan_combine_fn("mul", False), + torch.cumprod, + torch.ones(*init_shapes, device=device), + ), + ]: + result = scan(op, init, x, dim=rnd_scan_dim, reverse=reverse) + result_exp = _fake_scan( + op, init=init, xs=x, dim=rnd_scan_dim, reverse=reverse + ) + self.assertEqual(result, result_exp) + if not reverse: + result_exp_PT = op_pt(x, rnd_scan_dim) + self.assertEqual(result[1], result_exp_PT) @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") - @parametrize("reverse", [False, True]) + @requires_cuda @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device @@ -1367,18 +1631,7 @@ def mul(x: torch.Tensor, y: torch.Tensor): and params["device"] == torch.device("cpu") ), ) - def test_pointwise_associative_scan_binary_operator( - self, reverse, combine_mode, device - ): - def fct(x, y): - A_i, Bu_i = x - A_j, Bu_j = y - return A_j * A_i, A_j * Bu_i + Bu_j - - torch.compiler.reset() - associative_scan1 = torch.compile(associative_scan, fullgraph=True) - associative_scan2 = associative_scan - + def test_associative_scan_binary_operator(self, combine_mode, reverse, device): state_dim = 20 timesteps = 10 projected_inputs = torch.randn( @@ -1387,28 +1640,62 @@ def fct(x, y): A = torch.randn(state_dim, requires_grad=True, device=device) elements = (A.repeat((timesteps, 1)), projected_inputs) - result1 = associative_scan1( - fct, elements, 0, combine_mode=combine_mode, reverse=reverse + result1 = associative_scan( + get_scan_combine_fn("s5_operator", True), + elements, + 0, + combine_mode=combine_mode, + reverse=reverse, ) - result2 = associative_scan2( - fct, elements, 0, combine_mode=combine_mode, reverse=reverse + expected_result = _fake_associative_scan( + get_scan_combine_fn("s5_operator", True), elements, 0, reverse=reverse ) - expected_result = _fake_associative_scan(fct, elements, 0, reverse=reverse) self.assertEqual( result1, expected_result, ) self.assertEqual([r.device.type for r in result1], [device.type] * len(result1)) - self.assertEqual( - result2, - expected_result, + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_binary_operator(self, reverse, device): + state_dim = 20 + timesteps = 10 + projected_inputs = torch.randn( + timesteps, state_dim, requires_grad=True, device=device + ) + A = torch.randn(state_dim, requires_grad=True, device=device) + elements = (A.repeat((timesteps, 1)), projected_inputs) + init = tuple( + [torch.ones_like(torch._ops.ops.aten.slice(elements[0], 0, 0, 1, 1))] + + [ + torch.zeros_like( + torch._ops.ops.aten.slice(projected_inputs, 0, 0, 1, 1) + ) + ] ) - self.assertEqual([r.device.type for r in result2], [device.type] * len(result2)) + + result = scan( + get_scan_combine_fn("s5_operator", False), + init, + elements, + dim=0, + reverse=reverse, + ) + expected_result = _fake_scan( + get_scan_combine_fn("s5_operator", False), + init=init, + xs=elements, + dim=0, + reverse=reverse, + ) + self.assertEqual(result, expected_result) @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") - @parametrize("reverse", [False, True]) + @requires_cuda @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device @@ -1419,24 +1706,92 @@ def fct(x, y): and params["device"] == torch.device("cpu") ), ) - def test_pointwise_associative_scan_tuple(self, reverse, combine_mode, device): - def fct(x, y): - return (x[0] + y[0], x[1] * y[1]) - - x = torch.randn(3, 2, 2, device=device, requires_grad=True) - y = torch.randn(3, 2, 2, device=device, requires_grad=True) + def test_associative_scan_tuple(self, combine_mode, reverse, device): + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) inp = (x, y) result1 = associative_scan( - fct, inp, 0, reverse=reverse, combine_mode=combine_mode + get_scan_combine_fn("tuple_fct", True), + inp, + 0, + reverse=reverse, + combine_mode=combine_mode, + ) + expected_result = _fake_associative_scan( + get_scan_combine_fn("tuple_fct", True), inp, 0, reverse=reverse ) - expected_result = _fake_associative_scan(fct, inp, 0, reverse=reverse) self.assertEqual(result1, expected_result) - @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @requires_cuda @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_tuple(self, reverse, device): + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + inp = (x, y) + init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp) + + result_same = scan( + get_scan_combine_fn("tuple_fct", False), + init, + inp, + dim=0, + reverse=reverse, + ) + expected_result = _fake_scan( + get_scan_combine_fn("tuple_fct", False), + init=init, + xs=inp, + dim=0, + reverse=reverse, + ) + self.assertEqual(result_same, expected_result) + + def fct_different_output_tuple(x, y): + return ((x[0] + y[0], x[1] * y[1]), (x[1] * y[1])) + + inp = (x, y) + init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp) + + result_diff = scan( + fct_different_output_tuple, init, inp, dim=0, reverse=reverse + ) + expected_result = _fake_scan( + fct_different_output_tuple, init=init, xs=inp, dim=0, reverse=reverse + ) + self.assertEqual(result_diff, expected_result) + self.assertEqual(result_diff[1], result_same[1][1]) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_associative_scan_wrong_pytree(self, device): + def fct_wrong_pytree(x, y): + return { + "i": x["i"] * y["j"][0][0], + "k": 0.0, + "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), + } + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + + with self.assertRaisesRegex( + # Should be: RuntimeError, + # r"The number of leaves of the pytree of the output of the operator + # needs to match the lenght of the pytree of the input", + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result = associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic") + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device @@ -1447,16 +1802,7 @@ def fct(x, y): and params["device"] == torch.device("cpu") ), ) - def test_pointwise_associative_scan_complex_pytree( - self, reverse, combine_mode, device - ): - def fct_wrong_pytree(x, y): - return { - "i": x["i"] * y["j"][0][0], - "k": 0.0, - "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), - } - + def test_associative_scan_complex_pytree(self, combine_mode, reverse, device): def fct_pointwise(x, y): return { "i": x["i"] * y["i"], @@ -1466,49 +1812,1130 @@ def fct_pointwise(x, y): ), } - x = torch.randn(3, 2, 2, device=device, requires_grad=True) - y = torch.randn(3, 2, 2, device=device, requires_grad=True) - z = torch.randn(3, 2, 2, device=device, requires_grad=True) + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) inp = {"i": x, "j": ([y], [{"o": z}])} - with self.assertRaisesRegex(Exception, r"."): - result = associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic") + result = associative_scan( + get_scan_combine_fn("complex_pointwise", True), + inp, + 0, + combine_mode=combine_mode, + reverse=reverse, + ) + expected_result = _fake_associative_scan( + get_scan_combine_fn("complex_pointwise", True), inp, 0, reverse=reverse + ) + self.assertEqual(result, expected_result) + + @requires_cuda + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_wrong_pytree(self, device): + # Init and input have same pytree + def fct_wrong_pytree(x, y): + return ( + { + "i": x["i"] * y["j"][0][0], + "k": 0.0, + "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), + }, + { + "i": x["i"] * y["j"][0][0], + "k": 0.0, + "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), + }, + ) - torch.compiler.reset() - associative_scan1 = torch.compile(associative_scan, fullgraph=True) - associative_scan2 = associative_scan + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + inp_flat, inp_spec = pytree.tree_flatten(inp) + init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] + init = pytree.tree_unflatten(init_flat, inp_spec) - result1 = associative_scan1( - fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse + with self.assertRaisesRegex( + # Should be: RuntimeError, + # r"The number of leaves of the pytree of the new carry produced by + # the operator needs to match the length of the pytree of the init", + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result = scan(fct_wrong_pytree, init, inp, dim=0) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_complex_pytree(self, reverse, device): + # Init and input have same pytree + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + inp_flat, inp_spec = pytree.tree_flatten(inp) + init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] + init = pytree.tree_unflatten(init_flat, inp_spec) + + result = scan( + get_scan_combine_fn("complex_pointwise", False), + init, + inp, + dim=0, + reverse=reverse, ) - result2 = associative_scan2( - fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse + expected_result = _fake_scan( + get_scan_combine_fn("complex_pointwise", False), + init=init, + xs=inp, + dim=0, + reverse=reverse, ) - expected_result = _fake_associative_scan(fct_pointwise, inp, 0, reverse=reverse) + self.assertEqual(result, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("compile_mode", ["none", "compile", "compile_dynamic_shape"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_associative_scan_downstream_scan_matmul( + self, combine_mode, compile_mode, reverse, device + ): + # Chain with matmul + def chain_fct(inp): + W = torch.ones(2, 5, device=device) + o = associative_scan( + get_scan_combine_fn("add", True), + inp, + 1, + reverse=reverse, + combine_mode=combine_mode, + ) + return o @ W + + fct_cmp = compile_mode_helper(chain_fct, compile_mode) + + inp = torch.randn(3, 10, 2, device=device) + expected_result = _fake_associative_scan( + get_scan_combine_fn("add", True), inp, 1, reverse=reverse + ) @ torch.ones(2, 5, device=device) + result1 = fct_cmp(inp) self.assertEqual(result1, expected_result) - self.assertEqual(result2, expected_result) + # TODO: provide an implementation for all compile modes and re-enable all test @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("compile_mode", ["none", "compile", "compile_dynamic_shape"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_generic_associative_scan_generic_simple(self, reverse, device): - def non_pointwise(x: torch.Tensor, y: torch.Tensor): - W = torch.diag(torch.ones(2, device=device)) - return x @ W + y @ W + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_associative_scan_downstream_scan_scan( + self, combine_mode, compile_mode, reverse, device + ): + # Chain with scan + def chain_fct_same_dim(inp): + o1 = associative_scan( + get_scan_combine_fn("add", True), + inp, + 1, + combine_mode=combine_mode, + reverse=reverse, + ) + o2 = associative_scan( + get_scan_combine_fn("add", True), + o1, + 1, + combine_mode=combine_mode, + reverse=reverse, + ) + return o2 + + fct_cmp = compile_mode_helper(chain_fct_same_dim, compile_mode) + + inp = torch.randn(3, 10, 2, device=device) + + expected_result = _fake_associative_scan( + get_scan_combine_fn("add", True), + _fake_associative_scan( + get_scan_combine_fn("add", True), inp, 1, reverse=reverse + ), + 1, + reverse=reverse, + ) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("compile_mode", ["none", "compile", "compile_dynamic_shape"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and params["device"] == torch.device("cpu") + ), + ) + def test_associative_scan_downstream_scan_scan_different_dim( + self, combine_mode, compile_mode, reverse, device + ): + # Chain with scan on different dim + def chain_fct_different_dim(inp): + o1 = associative_scan( + get_scan_combine_fn("add", True), + inp, + 1, + combine_mode=combine_mode, + reverse=reverse, + ) + o2 = associative_scan( + get_scan_combine_fn("add", True), + o1, + 0, + combine_mode=combine_mode, + reverse=reverse, + ) + return o2 + + fct_cmp = compile_mode_helper(chain_fct_different_dim, compile_mode) + + inp = torch.randn(3, 10, 2, device=device) + expected_result = _fake_associative_scan( + get_scan_combine_fn("add", True), + _fake_associative_scan( + get_scan_combine_fn("add", True), inp, 1, reverse=reverse + ), + 0, + reverse=reverse, + ) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_downstream_scan_matmul(self, compile_mode, reverse, device): + inp = torch.randn(3, 10, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + + for ind in range(2): + # Chain with matmul + def chain_fct(inp): + W = torch.ones(2, 5, device=device) + o = scan( + get_scan_combine_fn("add", False), + init, + inp, + dim=1, + reverse=reverse, + ) + return o[ind] @ W + + fct_cmp = compile_mode_helper(chain_fct, compile_mode) + + expected_result = _fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=inp, + dim=1, + reverse=reverse, + )[ind] @ torch.ones(2, 5, device=device) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_downstream_scan_scan(self, compile_mode, reverse, device): + inp = torch.randn(3, 10, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + + # Chain with scan + def chain_fct_same_dim(inp): + o1 = scan( + get_scan_combine_fn("add", False), + init, + inp, + dim=1, + reverse=reverse, + ) + o2 = scan( + get_scan_combine_fn("add", False), + init, + o1[1], + dim=1, + reverse=reverse, + ) + return o2 + + fct_cmp = compile_mode_helper(chain_fct_same_dim, compile_mode) + + expected_result = _fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=_fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=inp, + dim=1, + reverse=reverse, + )[1], + dim=1, + reverse=reverse, + ) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_downstream_scan_scan_dim(self, compile_mode, reverse, device): + inp = torch.randn(3, 10, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + + # Chain with scan on different dim + init2 = torch.randn(1, 10, 2, device=device) + + def chain_fct_different_dim(inp): + o1 = scan( + get_scan_combine_fn("add", False), + init, + inp, + dim=1, + reverse=reverse, + ) + o2 = scan( + get_scan_combine_fn("add", False), + init2, + o1[1], + dim=0, + reverse=reverse, + ) + return o2 + + fct_cmp = compile_mode_helper(chain_fct_different_dim, compile_mode) + + expected_result = _fake_scan( + get_scan_combine_fn("add", False), + init=init2, + xs=_fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=inp, + dim=1, + reverse=reverse, + )[1], + dim=0, + reverse=reverse, + ) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of associative_scan and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: (params["device"] == torch.device("cpu")), + ) + def test_associative_scan_non_pointwise(self, reverse, device): x = torch.randn(3, 10, 2, device=device) - with self.assertRaisesRegex(Exception, ".*"): + # Expected to fail, as the pointwise combine_mode does not allow non-pointwise operations + with self.assertRaisesRegex( + Exception, + "For combine_mode='pointwise', the combine_fn needs to be pointwise", + ): out = associative_scan( - non_pointwise, x, 0, reverse=reverse, combine_mode="pointwise" + get_scan_combine_fn("non_pointwise", True), + x, + 0, + reverse=reverse, + combine_mode="pointwise", ) + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of associative_scan and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: (params["device"] == torch.device("cpu")), + ) + def test_associative_scan_non_pointwise_generic(self, reverse, device): + x = torch.randn(3, 10, 2, device=device) + result_expected = _fake_associative_scan( + get_scan_combine_fn("non_pointwise", True), x, 0, reverse=reverse + ) result1 = associative_scan( - non_pointwise, x, 0, reverse=reverse, combine_mode="generic" + get_scan_combine_fn("non_pointwise", True), + x, + 0, + reverse=reverse, + combine_mode="generic", ) - result_expected = _fake_associative_scan(non_pointwise, x, 0, reverse=reverse) self.assertEqual(result1, result_expected) + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_non_pointwise(self, reverse, device): + x = torch.randn(3, 10, 2, device=device) + init = torch.randn(1, 10, 2, device=device) + result_expected = _fake_scan( + get_scan_combine_fn("non_pointwise", False), + init=init, + xs=x, + dim=0, + reverse=reverse, + ) + + out = scan( + get_scan_combine_fn("non_pointwise", False), + init, + x, + dim=0, + reverse=reverse, + ) + self.assertEqual(out, result_expected) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_compile_cnt(self, reverse, device): + dim = 1 + + from torch._dynamo.testing import CompileCounter + + # Tests rely on automatic_dynamic = True + with torch._dynamo.config.patch(automatic_dynamic_shapes=True): + cnt = CompileCounter() + x = torch.randn(3, 2, 5, device=device) + init = torch.randn(3, 1, 5, device=device) + # First compilation step + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=dim, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 1) + + x = torch.randn(3, 20, 5, device=device) + init = torch.randn(3, 1, 5, device=device) + # Recompilation due to first different size + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=dim, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 2) + + x = torch.randn(3, 40, 5, device=device) + init = torch.randn(3, 1, 5, device=device) + # No recompilation, because of dynamic shape + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=dim, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 2) + + x = torch.randn(3, 40, 5, device=device) + init = torch.randn(3, 40, 1, device=device) + # Recompilation because of dim change + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=2, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 3) + + x = torch.randn(3, 40, 20, device=device) + init = torch.randn(3, 40, 1, device=device) + # Recompilation due to first different size on new dim + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=2, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 4) + + x = torch.randn(3, 40, 40, device=device) + init = torch.randn(3, 40, 1, device=device) + # No recompilation, because of dynamic shape on new dim + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=2, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 4) + + x = torch.randn(3, 60, 40, device=device) + init = torch.randn(3, 1, 40, device=device) + # Recompilation because of dim change + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=1, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 5) + + x = torch.randn(3, 60, 40, device=device) + init = torch.randn(3, 1, 40, device=device) + # Recompilation because of reverse change + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=1, + reverse=not reverse, + ) + self.assertEqual(cnt.frame_count, 6) + + x = torch.randn(3, 60, 40, device=device) + init = torch.randn(3, 1, 40, device=device) + # No recompilation, as nothing changed + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=1, + reverse=not reverse, + ) + self.assertEqual(cnt.frame_count, 6) + + x = torch.randn(3, 120, 80, device=device) + init = torch.randn(3, 1, 80, device=device) + # No recompilation, final test + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=1, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 6) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_scanned_0(self, reverse, compile_mode, device): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + + # Scan dimension is 0 + init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + with self.assertRaisesRegex( + # Should be: RuntimeError, "Input leaves must have a scan dimension > 0" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), + init, + inp, + dim=dim, + reverse=reverse, + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_non_tensor(self, reverse, compile_mode, device): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + + # Init is a float and not a tensor + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + init = 1.0 + with self.assertRaisesRegex( + # Should be: RuntimeError, "Init leaves must be a Tensor" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), init, inp, dim=dim, reverse=reverse + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_wrong_shape(self, reverse, compile_mode, device): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + + # Init wrong shape (Other dim different) + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) + init = torch.tile(init, (1, 2, 1)) + with self.assertRaisesRegex( + # Should be: RuntimeError, "The size of tensor a.*" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), + init, + inp, + dim=dim, + reverse=reverse, + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_wrong_pytree(self, reverse, compile_mode, device): + def add_one_carry(x: torch.Tensor, y: torch.Tensor): + return x[0], x + + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + + # Init wrong pytree + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + init = ( + torch._ops.ops.aten.slice(x, dim, 0, 1, 1), + torch._ops.ops.aten.slice(x, dim, 0, 1, 1), + ) + + with self.assertRaisesRegex( + # Should be: RuntimeError: The number of leaves of the pytree of the new carry produced + # by the operator needs to match the length of the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct(add_one_carry, init, inp, dim=dim, reverse=reverse) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init(self, reverse, compile_mode, device): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) + + # Only init given + init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) + result = scan_fct(op, init, [], dim=dim, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=[], dim=dim, reverse=reverse) + result_init = scan_fct(op, init, [], dim=dim, reverse=reverse) + self.assertEqual(result, result_exp) + self.assertEqual(result_init, result_exp) + self.assertEqual(result_init[0], init) + + x = torch.randn(3, 5, 2, device=device) + init = torch.randn(3, 5, 2, device=device) + dim = 0 + + op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + + # Init tensor scalar + init = torch.ones(1, device=device) + + def add_scalar_carry(x: torch.Tensor, y: torch.Tensor): + return x + 1.0, x + y + + result_init = scan_fct(add_scalar_carry, init, inp, dim=dim, reverse=reverse) + result_exp = _fake_scan( + add_scalar_carry, init=init, xs=inp, dim=dim, reverse=reverse + ) + self.assertEqual(result_init, result_exp) + self.assertEqual(result_init[0], torch.tensor([3.0], device=device)) + + # Init tensor entirely different shape than inp + init = torch.randn(7, 8, device=device) + + def add_scalar_carry2(x: torch.Tensor, y: torch.Tensor): + return x + 1.0, x[: y.shape[1], : y.shape[2]] + y + + result_init = scan_fct(add_scalar_carry2, init, inp, dim=dim, reverse=reverse) + result_exp = _fake_scan( + add_scalar_carry2, init=init, xs=inp, dim=dim, reverse=reverse + ) + self.assertEqual(result_init, result_exp) + + # Init with two timestep on dim axis. Should work as y has always 1 on dim axis and + # hence automatic broadcasting should work + # I.e., the input shape is 2x5x2, but the carry at each iteration is 2x5x2, + # thus the output of each iteration is 2x5x2, which results in the total output + # to be 4x5x2 + init = torch._ops.ops.aten.slice(x, dim, 0, 2, 1) + result_init = scan_fct(op, init, inp, dim=dim, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=inp, dim=dim, reverse=reverse) + self.assertEqual(result_init, result_exp) + self.assertEqual(result_init[0].shape, torch.Size([2, 5, 2])) + + init = torch.tile(init, (1, 2, 1)) + + def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor): + return x + 1.0, x[:, :1, :] + y + + result_init = scan_fct( + add_scalar_carry_sliced_out, init, inp, dim=dim, reverse=reverse + ) + result_exp = _fake_scan( + add_scalar_carry_sliced_out, init=init, xs=inp, dim=dim, reverse=reverse + ) + self.assertEqual(result_init, result_exp) + self.assertEqual(result_init[0].shape, torch.Size([2, 10, 2])) + self.assertEqual(result_init[1].shape, torch.Size([4, 5, 2])) + + # Correct case + op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) + x = torch.randn(3, 2, 2, device=device) + dim = 1 + + if reverse: + init = torch.zeros_like(torch._ops.ops.aten.slice(x, dim, -1, None, 1)) + inp = torch._ops.ops.aten.slice(x, dim, 0, -1, 1) + else: + init = torch.zeros_like(torch._ops.ops.aten.slice(x, dim, 0, 1, 1)) + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + + result = scan_fct(op, init, x, dim=dim, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=dim, reverse=reverse) + + self.assertEqual(result, result_exp) + if not reverse: + result_exp_PT = op_pt(x, dim) + self.assertEqual(result[1], result_exp_PT) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_carry_wrong_pytree(self, reverse, device): + def fct_pointwise_carry_wrong_pytree(x, y): + return ( + ( + x["i"], + { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ), + { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ) + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + inp_flat, inp_spec = pytree.tree_flatten(inp) + init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] + init = pytree.tree_unflatten(init_flat, inp_spec) + + # Wrong pytree of the carry produced by the operation + with self.assertRaisesRegex( + # Should be: RuntimeError: The number of leaves of the pytree of the new carry + # produced by the operator needs to match the length of the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result = scan( + fct_pointwise_carry_wrong_pytree, + init, + inp, + dim=0, + reverse=reverse, + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_wrong_pytree_complex(self, reverse, device): + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + + # Wrong pytree fed to the function + init = { + "i": torch._ops.ops.aten.slice(x, 0, 0, 1, 1), + "j": ( + {"a": torch._ops.ops.aten.slice(x, 0, 0, 1, 1)}, + [torch._ops.ops.aten.slice(y, 0, 0, 1, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, 0, 1, 1)}], + ), + } + inp = { + "i": torch._ops.ops.aten.slice(x, 0, 0, None, 1), + "j": ( + [torch._ops.ops.aten.slice(y, 0, 0, None, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, 0, None, 1)}], + ), + } + with self.assertRaisesRegex( + Exception, + ".*", + ): + result = scan( + get_scan_combine_fn("complex_pointwise", False), + init, + inp, + dim=0, + reverse=reverse, + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_pytree_complex(self, reverse, device): + def fct_pointwise_different_output(x, y): + return ( + { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ( + y["i"], + { + "o": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ), + ) + + def fct_pointwise_different_carry(x, y): + return ( + { + "i": x["i"] * y["i"], + "j": ( + x["i"], + [x["j"][1][0] * y["j"][0][0]], + [{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ( + y["i"], + { + "o": x["i"] * y["i"] + x["j"][0][0], + "j": ( + [x["j"][1][0] * y["j"][0][0]], + [{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ), + ) + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + + if reverse: + init_start, init_end = -1, None + inp_start, inp_end = 0, -1 + else: + init_start, init_end = 0, 1 + inp_start, inp_end = 1, None + + # Regular case + init = { + "i": torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1), + "j": ( + [torch._ops.ops.aten.slice(y, 0, init_start, init_end, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, init_start, init_end, 1)}], + ), + } + inp = { + "i": torch._ops.ops.aten.slice(x, 0, inp_start, inp_end, 1), + "j": ( + [torch._ops.ops.aten.slice(y, 0, inp_start, inp_end, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}], + ), + } + result = scan( + get_scan_combine_fn("complex_pointwise", False), + init, + inp, + dim=0, + reverse=reverse, + ) + expected_result = _fake_scan( + get_scan_combine_fn("complex_pointwise", False), + init, + inp, + dim=0, + reverse=reverse, + ) + self.assertEqual(result, expected_result) + + # Pytree of output is different + result = scan(fct_pointwise_different_output, init, inp, dim=0, reverse=reverse) + expected_result = _fake_scan( + fct_pointwise_different_output, init=init, xs=inp, dim=0, reverse=reverse + ) + self.assertEqual(result, expected_result) + + # Pytree of carry is different + init = { + "i": torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1), + "j": ( + torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1), + [torch._ops.ops.aten.slice(y, 0, init_start, init_end, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, init_start, init_end, 1)}], + ), + } + inp = { + "i": torch._ops.ops.aten.slice(x, 0, inp_start, inp_end, 1), + "j": ( + [torch._ops.ops.aten.slice(y, 0, inp_start, inp_end, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}], + ), + } + result = scan(fct_pointwise_different_carry, init, inp, dim=0, reverse=reverse) + expected_result = _fake_scan( + fct_pointwise_different_carry, init=init, xs=inp, dim=0, reverse=reverse + ) + self.assertEqual(result, expected_result) + + def test_scan_RNN(self): + dim = 1 + device = torch.device("cpu") + + rnn = torch.nn.RNN( + input_size=5, + hidden_size=7, + ) + rnn = rnn.to(device=device) + x = torch.randn(1, 2, 5, device=device) + h = torch.randn(1, 2, 7, device=device) + + new_state_dict = { + "weight_ih_l0": torch.ones_like(rnn.weight_ih_l0), + "bias_ih_l0": torch.ones_like(rnn.bias_ih_l0), + "weight_hh_l0": torch.ones_like(rnn.weight_hh_l0), + "bias_hh_l0": torch.ones_like(rnn.bias_hh_l0), + } + rnn.load_state_dict(new_state_dict) + + def RNN(x: torch.Tensor, y: torch.Tensor): + W_ih = torch.ones((5, 7), device=device) + b_ih = torch.ones((7), device=device) + W_hh = torch.ones((7, 7), device=device) + b_hh = torch.ones((7), device=device) + c_new = y @ W_ih + b_ih + h_new = torch.tanh(c_new + x @ W_hh + b_hh) + return h_new, h_new + + expected_result = rnn( + torch.permute(x, (1, 0, 2)), torch.unsqueeze(h[:, 0, :], 0) + ) + expected_result_out = torch.permute(expected_result[0], (1, 0, 2)) + expected_result_state = torch.permute(expected_result[1], (1, 0, 2)) + result = scan(RNN, h[:, 0:1, :], x, dim=dim) + self.assertEqual(result[0], expected_result_state) + self.assertEqual(result[1], expected_result_out) + + @skipIfNoDynamoSupport + def test_scan_simple_graph_no_carry(self): + x = torch.randn(3, 10, 2, device=torch.device("cpu")) + init = torch.randn(1, 10, 2, device=torch.device("cpu")) + + def f(fct, init, xs): + return scan(fct, init, xs, dim=0, reverse=True) + + # Wrong number of returns from function + with self.assertRaisesRegex( + # Should be: RuntimeError: The pytree of the new carry produced + # by the operator needs to match the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + gm = make_fx(f, tracing_mode="symbolic")( + get_scan_combine_fn("add", True), init, x + ) + + @skipIfNoDynamoSupport + def test_scan_simple_graph_wrong_carry(self): + def add_wrong_carry(x: torch.Tensor, y: torch.Tensor): + return (x + y)[0, :], x + y + + x = torch.randn(3, 10, 2, device=torch.device("cpu")) + init = torch.randn(1, 10, 2, device=torch.device("cpu")) + + def f(fct, init, xs): + return scan(fct, init, xs, dim=0, reverse=True) + + # Wrong carry shape + with self.assertRaisesRegex( + # Should be: RuntimeError: The pytree of the new carry produced by + # the operator needs to match the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + gm = make_fx(f, tracing_mode="symbolic")(add_wrong_carry, init, x) + + @skipIfNoDynamoSupport + def test_scan_simple_graph_wrong_dtype(self): + def add_wrong_dtype(x: torch.Tensor, y: torch.Tensor): + return torch.ones_like(x + y, dtype=torch.int64), x + y + + x = torch.randn(3, 10, 2, device=torch.device("cpu")) + init = torch.randn(1, 10, 2, device=torch.device("cpu")) + + def f(fct, init, xs): + return scan(fct, init, xs, dim=0, reverse=True) + + # Wrong dtype + with self.assertRaisesRegex( + # Should be: RuntimeError: Expected the init and + # the new carry produced by the operator to be a tensor of + # torch.int64 but got torch.float32 and torch.int64 + torch._dynamo.exc.UncapturedHigherOrderOpError, + ".*", + ): + gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x) + + @skipIfNoDynamoSupport + def test_scan_simple_graph(self): + from torch._dynamo.testing import EagerAndRecordGraphs + + x = torch.randn(3, 10, 2, device=torch.device("cpu")) + init = torch.randn(1, 10, 2, device=torch.device("cpu")) + + def f(fct, init, xs): + return scan(fct, init, xs, dim=0, reverse=True) + + # Correct case + gm = make_fx(f, tracing_mode="symbolic")( + get_scan_combine_fn("add", False), init, x + ) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, fct_1, init_1, xs_1): + slice_1 = torch.ops.aten.slice.Tensor(xs_1, 0, 0, 1) + add = torch.ops.aten.add.Tensor(init_1, slice_1); add = None + add_1 = torch.ops.aten.add.Tensor(init_1, slice_1); slice_1 = add_1 = None + sym_size_int = torch.ops.aten.sym_size.int(init_1, 1) + sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 2) + new_empty = torch.ops.aten.new_empty.default(init_1, [1, sym_size_int, sym_size_int_1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False); new_empty = None + new_empty_1 = torch.ops.aten.new_empty.default(xs_1, [1, sym_size_int, sym_size_int_1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False); sym_size_int = sym_size_int_1 = new_empty_1 = None + scan_combine_graph_0 = self.scan_combine_graph_0 + scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [xs_1], 0, True); scan_combine_graph_0 = init_1 = xs_1 = None + getitem = scan[0] + getitem_1 = getitem[0]; getitem = None + getitem_2 = scan[1]; scan = None + getitem_3 = getitem_2[0]; getitem_2 = None + return (getitem_1, getitem_3)""", # noqa: B950 + ) + + # Check graph + backend = EagerAndRecordGraphs() + torch.compile(f, backend=backend)(get_scan_combine_fn("add", False), init, x) + gm = backend.graphs[0] + + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor): + l_init_ = L_init_ + l_xs_ = L_xs_ + slice_1 = torch.ops.aten.slice(l_xs_, 0, 0, 1, 1) + out_l = l_init_ + slice_1; out_l = None + add_1 = l_init_ + slice_1; slice_1 = add_1 = None + child = l_init_.new_empty((1, 10, 2), dtype = torch.float32, device = device(type='cpu'), requires_grad = False); child = None + child_1 = l_xs_.new_empty((1, 10, 2), dtype = torch.float32, device = device(type='cpu'), requires_grad = False); child_1 = None + scan_combine_fn_0 = self.scan_combine_fn_0 + scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, True); scan_combine_fn_0 = l_init_ = l_xs_ = None + getitem = scan[0] + getitem_1 = getitem[0]; getitem = None + getitem_2 = scan[1]; scan = None + getitem_3 = getitem_2[0]; getitem_2 = None + return (getitem_1, getitem_3)""", # noqa: B950 + ) + @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @skipIfNoDynamoSupport @@ -3912,6 +5339,69 @@ class WrongHop(torch._ops.HigherOrderOperator): with self.assertRaisesRegex(TypeError, "WrongHop"): wrong_hop = WrongHop("wrong_hop") + def test_scan_functionalized(self): + def f(init, xs): + return scan(get_scan_combine_fn("add", False), init, xs, dim=1) + + example_inputs = torch.ones(5, 7, 4) + example_init = torch.ones(5, 1, 4) + functional_f = torch.func.functionalize(f) + self.assertEqual( + functional_f(example_init, example_inputs), f(example_init, example_inputs) + ) + + # https://github.com/pytorch/pytorch/issues/126988 + @xfailIfTorchDynamo + def test_scan_functionalized_elem_mutation(self): + def add1(x, y): + x.add_(4) + return x + y, x + y + + def f(init, xs): + return scan(add1, init, xs, dim=1) + + example_inputs = torch.ones(5, 7, 4) + example_init = torch.ones(5, 1, 4) + functional_f = torch.func.functionalize(f) + with self.assertRaisesRegex( + UnsupportedAliasMutationException, + "Combine_fn might be modifying the input!", + ): + functional_f(example_init, example_inputs) + + def add2(x, y): + y.add_(4) + return x + y, x + y + + def f(init, xs): + return scan(add2, init, xs, dim=1) + + example_inputs = torch.ones(5, 7, 4) + example_init = torch.ones(5, 1, 4) + functional_f = torch.func.functionalize(f) + with self.assertRaisesRegex( + UnsupportedAliasMutationException, + "Combine_fn might be modifying the input!", + ): + functional_f(example_init, example_inputs) + + # https://github.com/pytorch/pytorch/issues/126988 + @xfailIfTorchDynamo + def test_scan_functionalized_elem_alias(self): + def add(x, y): + return x, x + + def f(init, xs): + return scan(add, init, xs, dim=1) + + example_inputs = torch.ones(5, 7, 4) + example_init = torch.ones(5, 1, 4) + functional_f = torch.func.functionalize(f) + with self.assertRaisesRegex( + UnsupportedAliasMutationException, "Combine_fn might be aliasing the input!" + ): + functional_f(example_init, example_inputs) + _hop_schema_test_schema_types = [ "bool", diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index f16dc1c0862744..bf756f7b30fcdc 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -23,6 +23,7 @@ from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor from torch._dynamo.testing import make_test_cls_with_patches, rand_strided from torch._guards import tracing, TracingContext +from torch._higher_order_ops.scan import scan from torch._subclasses.fake_tensor import ( DynamicOutputShapeException, extract_tensor_metadata, @@ -923,6 +924,21 @@ def f(x): self.assertIsInstance(r, FakeTensor) self.assertEqual(r.size(), [3]) + @parametrize("reverse", [False, True]) + def test_scan(self, reverse): + def add(x, y): + return x + y, x + y + + with torch._subclasses.fake_tensor.FakeTensorMode(): + x = torch.randn((3, 5, 7), device="cpu") + init = torch.randn((3, 1, 7), device="cpu") + r = scan(add, init, x, dim=1, reverse=reverse) + + self.assertIsInstance(r[0], FakeTensor) + self.assertIsInstance(r[1], FakeTensor) + self.assertEqual(r[0].size(), init.size()) + self.assertEqual(r[1].size(), x.size()) + instantiate_parametrized_tests(FakeTensorTest) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index e1ee77fe63a085..cf0731911612ba 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3188,6 +3188,7 @@ def _module_dir(m: types.ModuleType): "torch._higher_order_ops.cond", "torch._higher_order_ops.while_loop", "torch._higher_order_ops.associative_scan", + "torch._higher_order_ops.scan", "torch.nn.attention.flex_attention", "torch.ao.quantization.pt2e.export_utils", "torch.ao.quantization.pt2e.qat_utils", @@ -3228,6 +3229,7 @@ def _module_dir(m: types.ModuleType): "torch._functorch.functional_call", "torch._functorch.vmap", "torch._higher_order_ops.associative_scan", + "torch._higher_order_ops.scan", "torch._higher_order_ops.strict_mode", "torch._higher_order_ops.while_loop", "torch._inductor.test_operators", diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index fa0c606296e8d3..0a384246736736 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -602,6 +602,8 @@ def make(value, source=None, **kwargs): return RunWithRNGStateHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "associative_scan": return AssociativeScanHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "scan": + return ScanHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "call_torchbind": return CallTorchbindHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "wrap_with_set_grad_enabled": @@ -1022,16 +1024,16 @@ def call_function( args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) - def arg_extractor(combine_fn, input, dim): - return combine_fn, input, dim + def arg_extractor(combine_fn, xs, dim): + return combine_fn, xs, dim - combine_fn, input, dim = arg_extractor(*args, **kwargs) + combine_fn, xs, dim = arg_extractor(*args, **kwargs) - if input.python_type() != list: + if xs.python_type() != list: unimplemented( - f"Expected input to be a list of tensors but got {input.python_type()}", + f"Expected xs to be a list of tensors but got {xs.python_type()}", ) - assert isinstance(input, torch._dynamo.variables.lists.BaseListVariable) + assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable) # Trace the subgraph # TODO: Fix these pointless new_empty calls appearing in the dynamo output graph. @@ -1054,7 +1056,7 @@ def arg_extractor(combine_fn, input, dim): "requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad), }, ) - for leaf in itertools.chain(input.items, input.items) + for leaf in itertools.chain(xs.items, xs.items) ] ( (combine_result, combine_treespec), @@ -1065,7 +1067,7 @@ def arg_extractor(combine_fn, input, dim): combine_fn, sub_args, sub_kwargs={}, - description="scan_combine", + description="associative_scan_combine_fn", source_target=self.value, set_subgraph_inputs="flatten_manual", ) @@ -1080,9 +1082,9 @@ def arg_extractor(combine_fn, input, dim): f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}", ) - input_proxy = input.as_proxy() + xs_proxy = xs.as_proxy() combine_result_proxy = combine_result.as_proxy() - for result, inp_proxy in zip(combine_result_proxy, input_proxy): + for result, inp_proxy in zip(combine_result_proxy, xs_proxy): inp_meta = inp_proxy.node.meta["example_value"] combine_result_meta = result.node.meta["example_value"] if combine_result_meta.device != inp_meta.device: @@ -1097,18 +1099,17 @@ def arg_extractor(combine_fn, input, dim): ) combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) - combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm) + combine_fn_name = add_subgraph(tx, "associative_scan_combine_fn", combine_gm) p_args = ( make_attr(tx, combine_fn_name), - input_proxy, + xs_proxy, dim.as_proxy(), ) with tx.fake_mode: out_meta = tuple( - inp_proxy.node.meta["example_value"].clone() - for inp_proxy in input_proxy + inp_proxy.node.meta["example_value"].clone() for inp_proxy in xs_proxy ) return wrap_fx_proxy( tx=tx, @@ -1119,6 +1120,196 @@ def arg_extractor(combine_fn, input, dim): ) +class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="scan must be captured completely with torch.compile." + ) + def call_function( + self, + tx: "InstructionTranslator", + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.scan import make_expanded_output_shape + + from .builder import SourcelessBuilder, wrap_fx_proxy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + def arg_extractor(combine_fn, init, xs, dim, reverse): + return combine_fn, init, xs, dim, reverse + + combine_fn, init, xs, dim, reverse = arg_extractor(*args, **kwargs) + + if xs.python_type() != list: + unimplemented( + f"Expected xs to be a list of tensors but got {xs.python_type()}", + ) + assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable) + if init.python_type() != list: + unimplemented( + f"Expected init to be a list of tensors but got {init.python_type()}", + ) + assert isinstance(init, torch._dynamo.variables.lists.BaseListVariable) + + dim_fake = ( + dim.as_proxy() + if type(dim.as_proxy()) == int + else get_fake_value(dim.as_proxy().node, tx) + ) + scan_length = get_fake_value(xs.items[0].as_proxy().node, tx).size()[dim_fake] + if scan_length == 0: + unimplemented( + "scan() operator doesn't support zero-sized tensors during tracing." + ) + + init_len = len(init.items) + if init_len == 0: + unimplemented("scan() operator requires init leaves.") + + # Trace the subgraph + # TODO: Fix these pointless new_empty calls appearing in the dynamo output graph. + # TODO: Unify handling of sub_args across control flow ops, such as cond, while_loop, etc. + sub_args_init = [ + ini.call_method( + tx, + "new_empty", + args=( + SourcelessBuilder.create( + tx, + ini.size + if ini.size is not None + else tuple( + BuiltinVariable(getattr) + .call_function( + tx, [ini, ConstantVariable.create("shape")], {} + ) + .items + ), + ), + ), + kwargs={ + "dtype": SourcelessBuilder.create(tx, ini.dtype), + "device": SourcelessBuilder.create(tx, ini.device), + "requires_grad": SourcelessBuilder.create(tx, ini.requires_grad), + }, + ) + for ini in init.items + ] + sub_args_inp_shapes = make_expanded_output_shape( + dim_fake, + 1, + [ + tuple( + BuiltinVariable(getattr) + .call_function(tx, [inp, ConstantVariable.create("shape")], {}) + .items + ) + for inp in xs.items + ], + True, + ) + sub_args_inp = [ + inp.call_method( + tx, + "new_empty", + args=(SourcelessBuilder.create(tx, inp_sh),), + kwargs={ + "dtype": SourcelessBuilder.create(tx, inp.dtype), + "device": SourcelessBuilder.create(tx, inp.device), + "requires_grad": SourcelessBuilder.create(tx, inp.requires_grad), + }, + ) + for inp, inp_sh in zip(xs.items, sub_args_inp_shapes) + ] + sub_args = sub_args_init + sub_args_inp + ( + (combine_result, combine_treespec), + combine_graph, + combine_lifted_freevars, + ) = speculate_subgraph( + tx, + combine_fn, + sub_args, + sub_kwargs={}, + description="scan_combine_fn", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + ) + + if combine_lifted_freevars: + unimplemented( + f"Combine fn had unexpected freevars: {combine_lifted_freevars}" + ) + + if any(cr.python_type() != list for cr in combine_result.items): + unimplemented( + f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}", + ) + + xs_proxy = xs.as_proxy() + init_proxy = init.as_proxy() + combine_carry_proxy = combine_result.items[0].as_proxy() + + # Checks for carry and init + for ini_proxy, carry in zip(init_proxy, combine_carry_proxy): + ini_meta = ini_proxy.node.meta["example_value"] + carry_meta = carry.node.meta["example_value"] + if ( + carry_meta.device != ini_meta.device + or carry_meta.dtype != ini_meta.dtype + or carry_meta.shape != ini_meta.shape + ): + unimplemented( + f"Expected metadata of the combine_fn result {carry_meta} to be the same as " + + f"the metadata of init with {ini_meta}" + ) + + combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) + combine_fn_name = add_subgraph(tx, "scan_combine_fn", combine_gm) + + p_args = ( + make_attr(tx, combine_fn_name), + init_proxy, + xs_proxy, + dim.as_proxy(), + reverse.as_proxy(), + ) + + with tx.fake_mode: + # For the fake mode, we need to duplicate the init tensor along the dim + # to have the same size as the xs arguments + # We also do a clone with contiguous_format. This is to be consistent with + # eager semantic of map, which stacks the outputs. The result is contiguous + # as a result of the stack operation. + fake_out_shapes = make_expanded_output_shape( + dim_fake, + scan_length, + [ + get_fake_value(o.as_proxy().node, tx).size() + for o in combine_result.items[1].items + ], + ) + out_meta = ( + [init_p.node.meta["example_value"].clone() for init_p in init_proxy], + list( # noqa: C400 + t.as_proxy() + .node.meta["example_value"] + .expand(*sh) + .clone(memory_format=torch.contiguous_format) + for t, sh in zip(combine_result.items[1].items, fake_out_shapes) + ), + ) + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", torch.ops.higher_order.scan, p_args, {} + ), + example_value=out_meta, + ) + + def non_single_tensor_return_unsupported(api, ret): from . import TensorVariable diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index 11dcca5eb4959f..d58d6b26bd33f7 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -73,8 +73,8 @@ class AssociativeScanOp(HigherOrderOperator): def __init__(self): super().__init__("associative_scan") - def __call__(self, combine_fn, input, dim): - return super().__call__(combine_fn, input, dim) + def __call__(self, combine_fn, xs, dim): + return super().__call__(combine_fn, xs, dim) associative_scan_op = AssociativeScanOp() @@ -82,13 +82,13 @@ def __call__(self, combine_fn, input, dim): def associative_scan( combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree], - input: pytree.PyTree, + xs: pytree.PyTree, dim: int, reverse: bool = False, combine_mode: str = "pointwise", ) -> torch.Tensor: r""" - Performs an inclusive scan with an associative pointwise combine function. + Performs an inclusive scan with an associative combine function. .. warning:: `torch.associative_scan` is a prototype feature in PyTorch. It currently @@ -102,14 +102,15 @@ def associative_scan( Args: combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, or if input is a pytree ``(pytree, pytree) -> pytree``. - This function must be pure, pointwise, and satisfy the associative property. - input (torch.Tensor): The input tensor, or nested pytree of tensors. + This function must be pure, i.e., no lifted arguments are supported at the moment, + satisfy the associative property and have no side-effects. + xs (torch.Tensor): The input tensor, or nested pytree of tensors. All inputs are expected to have the same shape. dim (int): the dimension to scan over - reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension. - combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``. + reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``. + combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``, default ``pointwise``. If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations - and ``input`` must be CUDA tensors. + and ``xs`` must be CUDA tensors. In all other cases ``combine_mode=generic`` should be used. Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``. @@ -122,27 +123,32 @@ def add(x: torch.Tensor, y: torch.Tensor): cumsum = associative_scan(add, x, dim) """ - assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" - assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" - assert combine_mode in ["pointwise", "generic"] + if not callable(combine_fn): + raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}") + if not isinstance(dim, int): + raise RuntimeError("Dim must be an int, but got " + str(type(dim))) + if combine_mode not in ["pointwise", "generic"]: + raise RuntimeError( + "Combine_mode must either 'pointwise' or 'generic', but got {combine_mode}" + ) if not torch._dynamo.is_compiling(): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): return torch.compile(associative_scan, fullgraph=True)( - combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode + combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode ) - leaves, spec = pytree.tree_flatten(input) + leaves, spec = pytree.tree_flatten(xs) if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves): raise ValueError( "For combine_mode='pointwise', all input tensors need to be on CUDA" ) - assert len(leaves) >= 1, "expected at least 1 input leaf" - assert all( - isinstance(x, torch.Tensor) for x in leaves - ), "input leaves must be a Tensor" + if len(leaves) == 0: + raise RuntimeError("Expected at least 1 xs leaf") + if any(not isinstance(x, torch.Tensor) for x in leaves): + raise RuntimeError("xs leaves must be a Tensor") if reverse: leaves = [torch.flip(elem, [dim]) for elem in leaves] @@ -152,20 +158,21 @@ def add(x: torch.Tensor, y: torch.Tensor): dim = utils.canonicalize_dim(ndim, dim) for x in leaves[1:]: - assert x.shape == shape, "All input tensors must have the same shape" + assert x.shape == shape, "All xs tensors must have the same shape" out = combine_fn( pytree.tree_unflatten(leaves, spec), pytree.tree_unflatten(leaves, spec), ) out_leaves, tree_out = pytree.tree_flatten(out) - assert len(leaves) == len( - out_leaves - ), "The pytree of the output of the operator needs to match the input pytree" - for x in out_leaves: - assert ( - x.shape == shape - ), "The pytree of the output of the operator needs to match the input pytree" + if len(leaves) != len(out_leaves): + raise RuntimeError( + "The number of leaves of the pytree of the output of the operator needs to match the length of the pytree of the input" + ) + if any(x.shape != shape for x in out_leaves): + raise RuntimeError( + "The pytree of the output of the operator needs to match the xs pytree" + ) combine_fn = functools.partial( wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves) @@ -194,7 +201,7 @@ def generic_associative_scan(operator, elems_flat, dim=0): or if input is a pytree ``(pytree, pytree) -> pytree``. This function must be pure, pointwise, and satisfy the associative property. elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of - ``input`` provided to ``associative_scan``. + ``xs`` provided to ``associative_scan``. All inputs are expected to have the same shape. dim (int): the dimension to scan over @@ -279,19 +286,19 @@ def _scan(elems): def trace_associative_scan( - proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int + proxy_mode, func_overload, combine_fn: Callable, xs: List[torch.Tensor], dim: int ): with disable_proxy_modes_tracing(): - sample_inputs = [ + sample_xs = [ torch.empty_like( x, dtype=x.dtype, device=x.device, requires_grad=x.requires_grad, ) - for x in itertools.chain(input, input) + for x in itertools.chain(xs, xs) ] - combine_graph = reenter_make_fx(combine_fn)(*sample_inputs) + combine_graph = reenter_make_fx(combine_fn)(*sample_xs) outputs = None for node in combine_graph.graph.nodes: @@ -307,10 +314,10 @@ def trace_associative_scan( assert outputs is not None assert len(outputs) == len( - input - ), f"expected combine_fn to return {len(input)} results but got {len(outputs)}" + xs + ), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}" - for i, o in zip(input, outputs): + for i, o in zip(xs, outputs): o_meta = o.meta["tensor_meta"] assert o_meta.dtype == i.dtype, ( f"combine_fn output type mismatch, expected {i.dtype} " @@ -321,20 +328,20 @@ def trace_associative_scan( proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) - args = (combine_graph, input, dim) + args = (combine_graph, xs, dim) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) out_proxy = proxy_mode.tracer.create_proxy( "call_function", func_overload, proxy_args, {}, name="associative_scan" ) with disable_proxy_modes_tracing(): - out = [aten.clone(x) for x in input] + out = [aten.clone(x) for x in xs] return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) @associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) -def associative_scan_op_dense(combine_fn, input, dim): +def associative_scan_op_dense(combine_fn, xs, dim): raise NotImplementedError("associative_scan is not implemented for eager") @@ -344,22 +351,22 @@ def associative_scan_op_dense(combine_fn, input, dim): @associative_scan_op.py_impl(ProxyTorchDispatchMode) -def associative_scan_proxy_mode(mode, combine_fn, input, dim): - return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim) +def associative_scan_proxy_mode(mode, combine_fn, xs, dim): + return trace_associative_scan(mode, associative_scan_op, combine_fn, xs, dim) @associative_scan_op.py_impl(FakeTensorMode) -def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim): +def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, dim): with mode: - return [x.clone() for x in input] + return [x.clone() for x in xs] @associative_scan_op.py_functionalize_impl -def associative_scan_functionalize(ctx, combine_fn, input, dim): - unwrapped_input = ctx.unwrap_tensors(input) +def associative_scan_functionalize(ctx, combine_fn, xs, dim): + unwrapped_xs = ctx.unwrap_tensors(xs) with ctx.redispatch_to_next() as m: functional_combine_fn = ctx.functionalize( _maybe_run_with_interpreter(combine_fn) ) - ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim) + ret = associative_scan_op(functional_combine_fn, unwrapped_xs, dim) return ctx.wrap_tensors(ret) diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py new file mode 100644 index 00000000000000..d6eecb0a9f29f9 --- /dev/null +++ b/torch/_higher_order_ops/scan.py @@ -0,0 +1,438 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from typing import Callable, List, Tuple + +import torch +import torch._prims_common as utils +import torch._subclasses.functional_tensor +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _set_compilation_env, + autograd_not_implemented, + reenter_make_fx, + unique_graph_id, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +aten = torch._ops.ops.aten + + +def wrap_combine_fn_flat( + *args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves +): + assert len(args) == (num_init_leaves + num_inp_leaves) + carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init) + xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs) + carry, combined = combine_fn(carry, xs) + carry_flat = pytree.tree_leaves(carry) + combined_flat = pytree.tree_leaves(combined) + assert num_init_leaves == len(carry_flat) + return (carry_flat, combined_flat) + + +def scan( + combine_fn: Callable[ + [pytree.PyTree, pytree.PyTree], Tuple[pytree.PyTree, pytree.PyTree] + ], + init: pytree.PyTree, + xs: pytree.PyTree, + /, + *, + dim: int = 0, + reverse: bool = False, +) -> Tuple[pytree.PyTree, pytree.PyTree]: + r""" + Performs an inclusive scan with a combine function. + + .. warning:: + `torch.scan` is a prototype feature in PyTorch. It currently + does not support autograd and you may run into miscompiles. + Read more about feature classification at: + https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + Args: + combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> (Tensor, Tensor)``, + or if xs is a pytree ``(pytree, pytree) -> (pytree, pytree)``. + The first input to ``combine_fn`` is the previous or initial scan carry + and the second input element to ``combine_fn`` is a slice of the input along dim. + The first output element of ``combine_fn`` is the next scan carry + and the second output of ``combine_fn`` represents a slice of the output. + This function must be pure, i.e., no lifted arguments are supported at the moment + and may not have any side effects. + init (torch.Tensor or pytree with tensor leaves): The inital scan carry, a tensor, or nested pytree of tensors. + The ``init`` is expected to have the same pytree structure as the first output element (i.e. carry) + of ``combine_fn``. + xs (torch.Tensor or pytree with tensor leaves): The input tensor, or nested pytree of tensors. + + Kwargs: + dim (int): the dimension to scan over, default 0. + reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``. + + Returns: + final_carry (torch.Tensor or pytree with tensor leaves), + the final carry of the scan operation with same pytree structure as init. + out (torch.Tensor or pytree with tensor leaves), + each tensor leaf is a stacked output along dim, where each slice is the output of a scan iteration. + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + next_carry = y = x + y + return next_carry, y + + i0 = torch.zeros(1) + xs = torch.arange(1, 5) + # returns torch.tensor([10]), torch.tensor([1., 3., 6., 10.]) + last_carry, cumsum = scan(add, init=i0, xs=xs) + + + """ + if not callable(combine_fn): + raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}") + if not isinstance(dim, int): + raise RuntimeError("Dim must be an int, but got " + str(type(dim))) + if not isinstance(reverse, bool): + raise RuntimeError("Reverse must be a bool, but got " + str(type(reverse))) + + # TODO: Support closures/nn_modules in order to be able represent RNNs with scan + # TODO: Support _inductor lowering + # TODO: Support Autograd + # TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc. + + # Dynamo is expecting a callable with "__code__" attribute. + # We cannot directly pass cond_op to it. So we wrap it in a dummy function. + def _scan_op_wrapper(*args, **kwargs): + return scan(*args, **kwargs) + + if not torch._dynamo.is_compiling(): + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): + return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)( + combine_fn, init, xs, dim=dim, reverse=reverse + ) + + leaves_init, spec_init = pytree.tree_flatten(init) + leaves_xs, spec_xs = pytree.tree_flatten(xs) + + if len(leaves_init) == 0: + raise RuntimeError("Init tensors must be provided") + if any(not isinstance(x, torch.Tensor) for x in leaves_init): + raise RuntimeError("All init leaves must be a Tensor") + if any(not isinstance(x, torch.Tensor) for x in leaves_xs): + raise RuntimeError("All xs leaves must be a Tensor") + if any(x.shape[dim] == 0 for x in leaves_xs): + raise RuntimeError("All xs leaves must have a scan dimension > 0") + + if len(leaves_xs) > 0: + shape = leaves_xs[0].shape + ndim = len(shape) + dim = utils.canonicalize_dim(ndim, dim) + + out = combine_fn( + pytree.tree_unflatten(leaves_init, spec_init), + pytree.tree_unflatten( + [aten.slice(elem, dim, 0, 1, 1) for elem in leaves_xs], spec_xs + ), + ) + + # The first output needs to have the same pytree as init + carry_leaves = pytree.tree_leaves(out[0]) + if len(carry_leaves) != len(leaves_init): + raise RuntimeError( + "The number of leaves of the pytree of the new carry produced by the operator\ + needs to match the length of the pytree of the init" + ) + if any( + in_l.shape != out_l.shape for in_l, out_l in zip(leaves_init, carry_leaves) + ): + raise RuntimeError( + "The pytree of the new carry produced by the operator needs to match the pytree of the init" + ) + + # There are no pytree restrictions on the second output of the operator + out_leaves, tree_out = pytree.tree_flatten(out[1]) + + combine_fn = functools.partial( + wrap_combine_fn_flat, + combine_fn=combine_fn, + spec_init=spec_init, + spec_xs=spec_xs, + num_init_leaves=len(leaves_init), + num_inp_leaves=len(leaves_xs), + ) + + result_carry, result_flat = scan_op( + combine_fn, leaves_init, leaves_xs, dim, reverse + ) + + return pytree.tree_unflatten(result_carry, spec_init), pytree.tree_unflatten( + result_flat, tree_out + ) + + else: + return pytree.tree_unflatten(leaves_init, spec_init), xs + + +class ScanOp(HigherOrderOperator): + def __init__(self): + super().__init__("scan") + + def __call__(self, combine_fn, init, xs, dim, reverse): + return super().__call__(combine_fn, init, xs, dim, reverse) + + +scan_op = ScanOp() + + +def generic_scan(operator, init, xs, dim=0, reverse=False): + def _scan(init, xs): + """Perform scan on `elems` using `elems_init.""" + carry = init + if len(xs) == 0: + return carry, [] + + num_elems = xs[0].shape[dim] + if reverse: + ind = num_elems - 1 + else: + ind = 0 + + # Compute dummy shapes for the pre-allocation + dummy_carry, dummy_out = operator( + *carry, *[aten.slice(elem, dim, 0, 1, 1) for elem in xs] + ) + output_scanned_dim = dummy_out[0].shape[dim] + + # Pre-alocate + # outs -> Output matrix + # idxs -> Index matrix for scatter_ + outs, outs_idxs = zip( + *[ + [ + torch.zeros( + list(e.size())[:dim] + + [list(e.size())[dim] * num_elems] + + list(e.size())[dim + 1 :], + dtype=e.dtype, + device=e.device, + ), + torch.cat( + [ + id * t + for id, t in zip( + range(output_scanned_dim), + torch.tensor_split( + torch.ones_like(e, dtype=torch.int64), + output_scanned_dim, + dim=dim, + ), + ) + ], + dim, + ), + ] + for i, e in enumerate(dummy_out) + ] + ) + + def store_in_mat(mat, out, d, index, index_modifier): + # Store the intermediate out in the outs matrix + for o, x, idx in zip(mat, out, index): + o.scatter_(d, idx + index_modifier, x) + + def cond(i, n, r): + if (r and i < 0) or (not r and i > (n - 1)): + return False + else: + return True + + def op(i): + if reverse: + return i - 1 + else: + return i + 1 + + while cond(ind, num_elems, reverse): + carry, out = operator( + *carry, + *[aten.slice(elem, dim, ind, ind + 1, 1) for elem in xs], + ) + + # Store the inits in the outs matrix. + store_in_mat(outs, out, dim, outs_idxs, ind * output_scanned_dim) + + ind = op(ind) + + return (carry, list(outs)) + + scans = _scan(init, xs) + return scans + + +def make_expanded_output_shape(dim, scan_length, shapes, use_sh=False): + expanded_shapes = [ + tuple( + (s if use_sh else -1) if i != dim else scan_length for i, s in enumerate(sh) + ) + for sh in shapes + ] + return expanded_shapes + + +def trace_scan( + proxy_mode, + func_overload, + combine_fn: Callable, + init: List[torch.Tensor], + xs: List[torch.Tensor], + dim: int, + reverse: bool, +): + with disable_proxy_modes_tracing(): + sample_inits = [ + torch.empty_like( + x_init, + dtype=x_init.dtype, + device=x_init.device, + requires_grad=x_init.requires_grad, + ) + for x_init in init + ] + sample_xs = [ + torch.empty_like( + aten.slice(x, dim, 0, 1, 1), + dtype=x.dtype, + device=x.device, + requires_grad=x.requires_grad, + ) + for x in xs + ] + combine_graph = reenter_make_fx(combine_fn)(*sample_inits, *sample_xs) + + outputs = None + for node in combine_graph.graph.nodes: + if node.op == "output": + assert outputs is None + assert len(node.args) == 1 + outputs = node.args[0] + + assert outputs is not None + if len(outputs) != 2: + raise RuntimeError( + f"Expected to return 2 outputs: carry, out_matrix, but got:" + f"\n {len(outputs)} elements" + ) + + for ini, carry in zip(init, outputs[0]): + ini_meta = ini + carry_meta = carry.meta["tensor_meta"] + carry_val = carry.meta["val"] + if ( + carry_val.device != ini_meta.device + or carry_meta.dtype != ini_meta.dtype + or carry_meta.shape != ini_meta.shape + ): + raise RuntimeError( + f"Expected metadata of the combine_fn result {carry_meta} to be the same as " + + f"the metadata of init with {ini_meta}" + ) + + _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") + + proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) + + args = (combine_graph, init, xs, dim, reverse) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="scan" + ) + + with disable_proxy_modes_tracing(): + scan_length = xs[0].shape[dim] + fake_out_shapes = make_expanded_output_shape( + dim, scan_length, [o.meta["val"].size() for o in outputs[1]] + ) + + def expand_tensor(t, sh): + if isinstance(t, torch.Tensor): + return t.expand(*sh) + return t + + expanded_outs = [ + pytree.tree_map(expand_tensor, t.meta["val"], sh) + for t, sh in zip(outputs[1], fake_out_shapes) + ] + out = (init, expanded_outs) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def scan_op_dense(combine_fn, init, xs, dim, reverse): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return generic_scan(combine_fn, init, xs, dim, reverse) + + +scan_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(scan_op, deferred_error=True) +) + + +@scan_op.py_impl(ProxyTorchDispatchMode) +def scan_proxy_mode(mode, combine_fn, init, xs, dim, reverse): + return trace_scan(mode, scan_op, combine_fn, init, xs, dim, reverse) + + +@scan_op.py_impl(FakeTensorMode) +def scan_fake_tensor_mode(mode, combine_fn, init, xs, dim, reverse): + with mode: + dim_len = xs[0].shape[dim] + carry, outputs = combine_fn( + *init, *[aten.slice(inp, dim, 0, 1, 1) for inp in xs] + ) + fake_out_shapes = [ + tuple(-1 if i != dim else dim_len for i, sh in enumerate(o.size())) + for o in outputs + ] + out = ( + carry, + tuple(t.expand(*sh).clone() for t, sh in zip(outputs, fake_out_shapes)), + ) + return out + + +@scan_op.py_functionalize_impl +def scan_functionalize(ctx, combine_fn, init, xs, dim, reverse): + unwrapped_xs = ctx.unwrap_tensors(xs) + unwrapped_init = ctx.unwrap_tensors(init) + with ctx.redispatch_to_next() as m: + functional_combine_fn = ctx.functionalize(combine_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + sample_xs = list(itertools.chain(unwrapped_init, unwrapped_init)) + if _has_potential_branch_input_mutation( + functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "Combine_fn might be modifying the input!" + ) + if _has_potential_branch_input_alias( + functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "Combine_fn might be aliasing the input!" + ) + ret = scan_op(functional_combine_fn, unwrapped_init, unwrapped_xs, dim, reverse) + return ctx.wrap_tensors(ret) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 3b88d5f7dbac8a..139e9a160cbe20 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -224,6 +224,7 @@ def _from_fun(t): t.stride(), dtype=t.dtype, requires_grad=t.requires_grad, + device=t.device, ) else: # clone of a functional tensor produces a functional tensor diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a1eade0aab773c..81be111f1e63f4 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -6218,12 +6218,12 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): @register_lowering(associative_scan_op, type_promotion_kind=None) -def associative_scan(combine_fn: ir.Subgraph, input, dim: int): +def associative_scan(combine_fn: ir.Subgraph, xs, dim: int): from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph subgraph_inputs = [ InputDescriptor(dtype=x.get_dtype(), device=x.get_device()) - for x in itertools.chain(input, input) + for x in itertools.chain(xs, xs) ] lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated] @@ -6233,9 +6233,9 @@ def wrapped_combine_fn(lhs, rhs): *pytree.tree_leaves(rhs), ) - kwargs = _make_scan_inner(input[0], axis=dim, dtype=None) - kwargs["dtypes"] = tuple(x.get_dtype() for x in input) - kwargs["inner_fns"] = tuple(x.make_loader() for x in input) + kwargs = _make_scan_inner(xs[0], axis=dim, dtype=None) + kwargs["dtypes"] = tuple(x.get_dtype() for x in xs) + kwargs["inner_fns"] = tuple(x.make_loader() for x in xs) result = ir.Scan.create( combine_fn=wrapped_combine_fn, can_fallback_to_aten=False, From 0267eb9d478990887a0ac8de5c3cbf0e6ef94866 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Tue, 10 Sep 2024 05:54:35 +0000 Subject: [PATCH 0660/1018] Increase `TRITON_MAX_BLOCK['X']` (#135181) Fixes #135028 As title, increase `TRITON_MAX_BLOCK['X']` to 4096 and fix an error, thanks to @Chillee: https://github.com/pytorch/pytorch/pull/133300/files#r1744706189 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135181 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor_strided_blocks.py | 2 +- torch/_inductor/runtime/hints.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 33782d76b7f934..16424fe1b4f008 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -241,7 +241,7 @@ def get_input(view_size: Tuple[int]) -> torch.Tensor: ((3 * max_block, 2), 3, 2), # Multiple of max block. Uses loops. ( (2, 3 * max_block), - 3, + 2, 2, ), # Multiple of max block. Uses loops. ((128, 128), 3, 2), # Test a large size, with loops. diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 90f9e5a1cc5966..d3c9f560fd2a1e 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -8,7 +8,7 @@ # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 2048, + "X": 4096, "Y": 1024, "Z": 1024, "R": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a3b8e611f3852b..0c499956ff7406 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1140,10 +1140,10 @@ def _check_max_grid_x(size_hints, x, num_warps): while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints[0]: x *= 2 # Scale up XBLOCK if grid exceeds limits num_blocks = num_blocks // 2 - if x >= max_grid_x: - raise AssertionError( - "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" - ) + if (num_blocks * num_warps * warp_size) > max_grid_x: + raise AssertionError( + "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" + ) return x, num_blocks From 9009160fb52ee52fa7afd549eea6f03d63c31090 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Mon, 9 Sep 2024 17:26:20 -0700 Subject: [PATCH 0661/1018] Add some minor doc improvement and ban using training IR for unflattener (#135549) Title Differential Revision: [D62412490](https://our.internmc.facebook.com/intern/diff/D62412490/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135549 Approved by: https://github.com/yushangdi --- torch/export/__init__.py | 6 +++++- torch/export/unflatten.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 367b3f127a16bb..3b0a36ef1ea18a 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -88,7 +88,11 @@ def export_for_training( flow and data structures (with certain exceptions), and (3) records the set of shape constraints needed to show that this normalization and control-flow elimination is sound for future inputs. This API is intended for PT2 quantization training use cases - and will soon be the default IR of torch.export.export in the near future. + and will soon be the default IR of torch.export.export in the near future. To read further about + the motivation behind this change, please refer to + https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206 + With this API, and :func:`run_decompositions()`, you should be able to get inference IR with + your custom decomposition behaviour. **Soundness Guarantee** diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 992818f1758048..a655c48fd9f6ca 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -548,6 +548,8 @@ def unflatten( An instance of :class:`UnflattenedModule`, which has the same module hierarchy as the original eager module pre-export. """ + if module.verifier.dialect == "TRAINING": + raise RuntimeError("Unflattener doesn't support non-functional training IR yet") module = _remove_effect_tokens(module) return UnflattenedModule(module, flat_args_adapter) From 6fe181c9fc1704d4078ec1908340b6d11053f2f9 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Mon, 9 Sep 2024 19:25:51 -0700 Subject: [PATCH 0662/1018] [inductor] print triton float64 constants correctly (#135260) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135260 Approved by: https://github.com/jansel --- test/inductor/test_triton_kernels.py | 19 ++++++++++++++++++- torch/_inductor/codegen/triton.py | 6 ++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 14eec775cdec31..64aaccdbf085a3 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -16,7 +16,12 @@ from torch._inductor.utils import run_and_get_code from torch._library import capture_triton from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM +from torch.testing._internal.common_utils import ( + parametrize, + skipIfRocm, + skipIfXpu, + TEST_WITH_ROCM, +) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU from torch.testing._internal.logging_utils import logs_to_string @@ -1545,6 +1550,18 @@ def f(x, y): x = torch.randn(4, device=GPU_TYPE) f(x, x) + @requires_gpu + @parametrize("dtype", (torch.float16, torch.float32, torch.float64)) + def test_triton_kernel_float64_constant(self, dtype): + def f(x): + return x * (0.12 * x.shape[0]) + + x = torch.ones(200, device=GPU_TYPE, dtype=torch.float64) + + eager_out = f(x) + compiled_out = torch.compile(f, dynamic=True)(x) + self.assertEqual(compiled_out, eager_out) + def make_mutation_test(fn): @requires_gpu diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 2596f3d59793e9..2d4b2b75ae263e 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -385,6 +385,12 @@ def _print_TruncToInt(self, expr): f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) + def _print_Float(self, expr): + # Use a tensor here to get float64. Otherwise the constant is + # truncated to float32. + ret = f"tl.full([1], {expr}, tl.float64)" + return ret + def _print_ToFloat(self, expr): assert len(expr.args) == 1 return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" From 5f5fa2e51b24a3b7543fa5b6d206e1ff3d01ed70 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Mon, 9 Sep 2024 17:26:20 -0700 Subject: [PATCH 0663/1018] Fix doc (#135551) Differential Revision: [D62412667](https://our.internmc.facebook.com/intern/diff/D62412667/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135551 Approved by: https://github.com/yushangdi ghstack dependencies: #135549 --- torch/export/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 3b0a36ef1ea18a..b280fa29ff9789 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -175,8 +175,7 @@ def export( preserve_module_call_signature: Tuple[str, ...] = (), ) -> ExportedProgram: """ - :func:`export` takes an arbitrary Python callable (an nn.Module, a function or - a method) along with example inputs, and produces a traced graph representing + :func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different inputs or serialized. The traced graph (1) produces normalized operators in the functional ATen operator set From 33428d52c5255e76e2f9ad326eca84a1c482002e Mon Sep 17 00:00:00 2001 From: Victor Tao Date: Tue, 10 Sep 2024 07:27:54 +0000 Subject: [PATCH 0664/1018] [Inductor] Make static_input_idxs a set for faster lookup (#135314) `static_input_idxs` is only used for lookups. With large models, this is a large list. This takes over a millisecond in some cases. Profile before change: image Profile after change: gaps are smaller, 1ms speedup before launching the cuda graph image Pull Request resolved: https://github.com/pytorch/pytorch/pull/135314 Approved by: https://github.com/oulgen --- torch/_inductor/compile_fx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index db4c1e0eddfece..08ff9c281fb90d 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -62,6 +62,7 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.monitor import _WaitCounter +from torch.utils._ordered_set import OrderedSet from .._dynamo.backends.common import aot_autograd from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] @@ -1037,7 +1038,9 @@ def cudagraphify_impl( Assumes inputs[static_input_idxs[i]] are always the same memory address """ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] - static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + static_input_idxs: OrderedSet[int] = OrderedSet( + remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + ) copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type] assert isinstance(inputs, list) From 51409ed7a47b2d82958f26391d761bdda0c8d31b Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Tue, 10 Sep 2024 08:37:57 +0000 Subject: [PATCH 0665/1018] [MPS] Allow nan mean reduction in `nll_loss` (#135434) This PR allows results from `nn_loss` to be `nan`, which is the same behavior as with CUDA and CPU https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162. Fixes #134431 Ref #64572 #119108 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135434 Approved by: https://github.com/malfet --- .../src/ATen/native/mps/operations/LossOps.mm | 18 ++++++- test/test_nn.py | 52 ++++++++++--------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 271b4b2c492b34..c45054f8bc1638 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -437,6 +437,20 @@ static void nllnd_loss_forward_impl(Tensor& output, if (output.numel() == 0) return; + // https://github.com/pytorch/pytorch/blob/042f2f7746a064f1527d95d1f1d712b4f0b34186/aten/src/ATen/native/cuda/Loss.cu#L335-L346 + if (target_arg.numel() == 0) { + // Here target (and input) have zero elements + // Mean reduction on empty tensors produces NaN. See the discussion in + // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if (reduction == Reduction::Mean) { + output.fill_(std::numeric_limits::quiet_NaN()); + } else { + output.zero_(); + } + total_weight.zero_(); + return; + } + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; @@ -537,7 +551,9 @@ static void nllnd_loss_forward_impl(Tensor& output, mpsGraphBatchSizeTensor = [mpsGraph reductionSumWithTensor:mpsSelectOneTensor axes:nil name:@"batchSizeReductionTensor"]; - mpsGraphReducedTensor = divisionNoNaN(mpsGraph, mpsGraphReducedTensor, mpsGraphBatchSizeTensor); + mpsGraphReducedTensor = [mpsGraph divisionWithPrimaryTensor:mpsGraphReducedTensor + secondaryTensor:mpsGraphBatchSizeTensor + name:@"divisionTensor"]; } } diff --git a/test/test_nn.py b/test/test_nn.py index 8bb14f3188a2bd..533f889975b99b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -11549,8 +11549,8 @@ def test_cross_entropy_64bit(self, device, reduction): print(logits.numel(), labels.numel(), loss.numel()) self.assertTrue(torch.allclose(loss_cpu, loss.cpu(), rtol=1e-4, atol=1e-4)) - def _nll_loss_helper(self, input_size, reduction, expected, device): - input = torch.rand(input_size, requires_grad=True, device=device) + def _nll_loss_helper(self, input_size, reduction, expected, device, dtype): + input = torch.rand(input_size, requires_grad=True, device=device, dtype=dtype) num_channels = input_size[1] target_size = (input_size[0], ) + tuple(input_size[2:]) target = torch.randint(num_channels, target_size, device=device) @@ -11561,32 +11561,35 @@ def _nll_loss_helper(self, input_size, reduction, expected, device): output.sum().backward() self.assertEqual(input.grad.size(), input.size()) - def test_nll_loss_empty_tensor_reduction_none(self, device): - self._nll_loss_helper([0, 3], "none", torch.empty([0], device=device), device) - self._nll_loss_helper([0, 3, 5, 7], "none", torch.empty([0, 5, 7], device=device), device) - self._nll_loss_helper([2, 3, 0, 7], "none", torch.empty([2, 0, 7], device=device), device) - self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device) - self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device) + @dtypesIfMPS(torch.half, torch.float) + @dtypes(torch.float) + def test_nll_loss_empty_tensor_reduction_none(self, device, dtype): + self._nll_loss_helper([0, 3], "none", torch.empty([0], device=device), device, dtype) + self._nll_loss_helper([0, 3, 5, 7], "none", torch.empty([0, 5, 7], device=device), device, dtype) + self._nll_loss_helper([2, 3, 0, 7], "none", torch.empty([2, 0, 7], device=device), device, dtype) + self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device, dtype) + self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device, dtype) - @expectedFailureMPS # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431 - def test_nll_loss_empty_tensor_reduction_mean(self, device): + @dtypesIfMPS(torch.half, torch.float) + @dtypes(torch.float) + def test_nll_loss_empty_tensor_reduction_mean(self, device, dtype): nan = torch.tensor(float('nan'), device=device) - self._nll_loss_helper([0, 3], "mean", nan, device) - self._nll_loss_helper([0, 3, 5, 7], "mean", nan, device) - self._nll_loss_helper([2, 3, 0, 7], "mean", nan, device) - self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device) - self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device) - - @expectedFailureMPS # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431 - def test_nll_loss_empty_tensor_reduction_sum(self, device): + self._nll_loss_helper([0, 3], "mean", nan, device, dtype) + self._nll_loss_helper([0, 3, 5, 7], "mean", nan, device, dtype) + self._nll_loss_helper([2, 3, 0, 7], "mean", nan, device, dtype) + self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device, dtype) + self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device, dtype) + + @dtypesIfMPS(torch.half, torch.float) + @dtypes(torch.float) + def test_nll_loss_empty_tensor_reduction_sum(self, device, dtype): zero = torch.tensor(0, device=device) - self._nll_loss_helper([0, 3], "sum", zero, device) - self._nll_loss_helper([0, 3, 5, 7], "sum", zero, device) - self._nll_loss_helper([2, 3, 0, 7], "sum", zero, device) - self._nll_loss_helper([2, 3, 5, 0], "sum", zero, device) - self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device) + self._nll_loss_helper([0, 3], "sum", zero, device, dtype) + self._nll_loss_helper([0, 3, 5, 7], "sum", zero, device, dtype) + self._nll_loss_helper([2, 3, 0, 7], "sum", zero, device, dtype) + self._nll_loss_helper([2, 3, 5, 0], "sum", zero, device, dtype) + self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device, dtype) - @expectedFailureMPS # AssertionError: Expected nan but got 0.0. def test_nll_loss_total_weight_is_zero(self, device): def helper(input_size): @@ -11603,7 +11606,6 @@ def helper(input_size): helper([2, 3, 5, 7]) helper([2, 3, 5, 7, 9]) - @expectedFailureMPS # AssertionError: Expected nan but got 0.0. def test_nll_loss_all_ignored(self, device): def helper(input_size): From 5132379ba5fd288ebb47fa2391a044973b6890fc Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 10 Sep 2024 09:30:00 +0000 Subject: [PATCH 0666/1018] [FlexAttention] Add broadcast support for kv batch dimension (#135505) This PR adds broadcast support for KV batch dimension. ## Details Consider Q of shape `[Bq, Hq, Q_LEN, D]`, and K, V of shape `[Bkv, Hkv, KV_LEN, D]`. Prior to this diff, we require `Bq == Bkv`. However, for some use cases, we may have Bkv < Bq. For example, in paged attention, we provide K, V of shape `[1, Hkv, MAX_LEN, D]`, while still providing Q of shape `[Bq, Hq, Q_LEN, D]`. Here, MAX_LEN is the maximal number of tokens supported by paged attention. This PR relax this requirement to be `Bq == Bkv or (Bq > 1 and Bkv == 0)`. This support covers both flex decoding, flex attention forward and backward. ## Benchmark GPU: H100 We see negligible (1%~2%) performance change from this PR when `Bq == Bkv`. ``` python benchmarks/transformer/score_mod.py --calculate-bwd ``` ### Perf before this PR **FWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|---------------|------------|----------------|------------------------------| | Average | 0.743 | | | | | | Max | 0.955 | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | | Min | 0.548 | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | **BWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|-------------|------------|----------------|-----------------------------| | Average | 0.834 | | | | | | Max | 1.261 | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | | Min | 0.456 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |
Full performance sweep | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 107.040 | 140.800 | 0.888 | 0.760 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.840 | 19.744 | 112.576 | 140.064 | 0.802 | 0.804 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.232 | 17.344 | 87.744 | 142.496 | 0.878 | 0.616 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.264 | 17.184 | 108.192 | 143.328 | 0.888 | 0.755 | | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.904 | 22.400 | 106.432 | 136.512 | 0.889 | 0.780 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.424 | 26.752 | 91.712 | 106.688 | 0.726 | 0.860 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.808 | 22.432 | 89.024 | 101.920 | 0.883 | 0.873 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.840 | 22.272 | 88.896 | 102.592 | 0.891 | 0.867 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.240 | 32.416 | 116.768 | 112.256 | 0.933 | 1.040 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 29.536 | 37.024 | 113.664 | 102.688 | 0.798 | 1.107 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.656 | 32.800 | 116.992 | 127.008 | 0.935 | 0.921 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.592 | 32.480 | 116.928 | 112.160 | 0.942 | 1.043 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.920 | 198.656 | 204.512 | 0.653 | 0.971 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 37.760 | 62.528 | 189.536 | 170.624 | 0.604 | 1.111 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.896 | 62.368 | 198.304 | 205.824 | 0.656 | 0.963 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 40.448 | 61.952 | 198.432 | 203.648 | 0.653 | 0.974 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 318.528 | 355.904 | 947.232 | 1162.496 | 0.895 | 0.815 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 199.776 | 252.128 | 677.792 | 813.184 | 0.792 | 0.834 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 316.512 | 363.328 | 947.712 | 1361.984 | 0.871 | 0.696 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 317.984 | 356.864 | 947.264 | 1165.024 | 0.891 | 0.813 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 446.656 | 734.656 | 1664.288 | 2172.960 | 0.608 | 0.766 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 278.688 | 467.648 | 1182.624 | 1339.296 | 0.596 | 0.883 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 447.872 | 744.096 | 1662.944 | 2196.544 | 0.602 | 0.757 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 448.128 | 732.928 | 1663.072 | 2156.800 | 0.611 | 0.771 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.648 | 16.640 | 107.520 | 143.008 | 0.940 | 0.752 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.776 | 18.240 | 129.056 | 141.920 | 0.865 | 0.909 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.168 | 16.640 | 103.616 | 139.648 | 0.912 | 0.742 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.616 | 16.640 | 128.608 | 164.448 | 0.938 | 0.782 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 21.952 | 125.344 | 170.304 | 0.901 | 0.736 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.776 | 23.712 | 104.288 | 196.896 | 0.834 | 0.530 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.072 | 21.952 | 102.080 | 177.056 | 0.869 | 0.577 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.648 | 21.920 | 109.920 | 170.848 | 0.896 | 0.643 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.936 | 127.808 | 228.832 | 0.954 | 0.559 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 29.472 | 33.856 | 113.152 | 215.072 | 0.871 | 0.526 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.496 | 32.160 | 116.576 | 231.744 | 0.948 | 0.503 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.464 | 31.904 | 116.320 | 229.824 | 0.955 | 0.506 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.480 | 61.440 | 176.448 | 345.312 | 0.659 | 0.511 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 38.304 | 59.424 | 169.312 | 371.360 | 0.645 | 0.456 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.960 | 61.760 | 176.512 | 358.912 | 0.663 | 0.492 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 40.352 | 61.696 | 176.512 | 344.928 | 0.654 | 0.512 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.224 | 357.728 | 905.728 | 1668.448 | 0.884 | 0.543 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 199.904 | 248.416 | 636.544 | 1109.088 | 0.805 | 0.574 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 314.880 | 363.616 | 906.304 | 1658.176 | 0.866 | 0.547 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 316.160 | 354.368 | 906.080 | 1649.024 | 0.892 | 0.549 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.912 | 739.840 | 1555.808 | 2521.952 | 0.604 | 0.617 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 279.776 | 463.904 | 1068.928 | 1849.888 | 0.603 | 0.578 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.080 | 748.960 | 1553.504 | 2629.888 | 0.596 | 0.591 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 446.208 | 740.608 | 1558.880 | 2524.960 | 0.602 | 0.617 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 33.568 | 41.280 | 170.016 | 147.584 | 0.813 | 1.152 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 30.688 | 43.040 | 159.552 | 146.720 | 0.713 | 1.087 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.112 | 41.504 | 170.112 | 152.672 | 0.822 | 1.114 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 34.240 | 41.152 | 170.272 | 134.976 | 0.832 | 1.261 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.672 | 76.416 | 295.296 | 263.648 | 0.637 | 1.120 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.088 | 72.576 | 281.920 | 237.664 | 0.621 | 1.186 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.032 | 76.672 | 295.520 | 265.248 | 0.626 | 1.114 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.096 | 76.096 | 295.456 | 262.112 | 0.632 | 1.127 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.920 | 111.232 | 401.568 | 382.944 | 0.844 | 1.049 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 68.192 | 95.232 | 338.752 | 326.816 | 0.716 | 1.037 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 93.984 | 111.840 | 401.856 | 444.224 | 0.840 | 0.905 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 94.176 | 110.496 | 401.600 | 383.136 | 0.852 | 1.048 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.488 | 227.040 | 727.424 | 739.712 | 0.579 | 0.983 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 95.616 | 169.760 | 616.864 | 574.112 | 0.563 | 1.074 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.680 | 228.672 | 727.616 | 746.048 | 0.576 | 0.975 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 131.104 | 225.696 | 727.904 | 735.392 | 0.581 | 0.990 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1227.296 | 1386.656 | 3720.192 | 4539.904 | 0.885 | 0.819 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 691.360 | 831.712 | 2515.872 | 3067.808 | 0.831 | 0.820 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1228.192 | 1403.136 | 3715.520 | 5309.280 | 0.875 | 0.700 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1229.024 | 1384.992 | 3715.904 | 4550.368 | 0.887 | 0.817 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1784.832 | 2865.888 | 6539.840 | 8460.224 | 0.623 | 0.773 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1017.408 | 1660.480 | 4369.824 | 5056.992 | 0.613 | 0.864 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1792.448 | 2904.864 | 6546.080 | 8537.024 | 0.617 | 0.767 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1795.552 | 2856.864 | 6544.672 | 8400.160 | 0.629 | 0.779 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.880 | 148.832 | 179.936 | 0.881 | 0.827 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.168 | 38.080 | 138.528 | 167.552 | 0.818 | 0.827 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 39.168 | 148.512 | 181.248 | 0.874 | 0.819 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 34.240 | 38.784 | 148.864 | 180.224 | 0.883 | 0.826 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.832 | 76.352 | 253.632 | 295.968 | 0.640 | 0.857 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 45.760 | 65.792 | 239.040 | 290.752 | 0.696 | 0.822 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.576 | 253.312 | 304.032 | 0.637 | 0.833 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 48.768 | 76.192 | 253.600 | 296.096 | 0.640 | 0.856 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.728 | 109.728 | 357.696 | 498.912 | 0.854 | 0.717 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 68.704 | 92.288 | 295.616 | 386.240 | 0.744 | 0.765 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.632 | 111.392 | 357.408 | 512.448 | 0.841 | 0.697 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 93.280 | 109.952 | 357.696 | 501.440 | 0.848 | 0.713 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.392 | 230.496 | 612.224 | 807.552 | 0.570 | 0.758 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 96.512 | 165.184 | 502.624 | 672.384 | 0.584 | 0.748 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.360 | 232.608 | 612.064 | 832.320 | 0.565 | 0.735 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 131.008 | 230.528 | 612.640 | 804.320 | 0.568 | 0.762 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1227.968 | 1377.408 | 3477.920 | 5324.384 | 0.892 | 0.653 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 695.264 | 824.544 | 2268.224 | 3210.208 | 0.843 | 0.707 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.640 | 1404.576 | 3476.832 | 5463.456 | 0.875 | 0.636 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1228.416 | 1378.752 | 3478.048 | 5367.712 | 0.891 | 0.648 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1788.736 | 2867.712 | 6039.520 | 8616.256 | 0.624 | 0.701 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1021.952 | 1653.824 | 3866.208 | 5306.848 | 0.618 | 0.729 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.752 | 2896.352 | 6044.128 | 8871.360 | 0.617 | 0.681 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1786.080 | 2868.672 | 6040.160 | 8550.144 | 0.623 | 0.706 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.504 | 71.552 | 312.768 | 255.040 | 0.804 | 1.226 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 49.472 | 71.104 | 285.696 | 243.520 | 0.696 | 1.173 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 58.112 | 72.896 | 312.768 | 288.256 | 0.797 | 1.085 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 57.952 | 71.680 | 312.768 | 255.552 | 0.808 | 1.224 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.336 | 144.256 | 580.128 | 500.160 | 0.571 | 1.160 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.160 | 123.712 | 552.544 | 447.648 | 0.616 | 1.234 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.400 | 145.184 | 580.032 | 504.032 | 0.568 | 1.151 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 82.368 | 143.904 | 580.192 | 499.936 | 0.572 | 1.161 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.216 | 209.568 | 787.872 | 747.712 | 0.846 | 1.054 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 121.984 | 168.256 | 651.968 | 628.256 | 0.725 | 1.038 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.088 | 211.488 | 788.320 | 864.352 | 0.837 | 0.912 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 177.440 | 208.576 | 787.424 | 749.120 | 0.851 | 1.051 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.472 | 441.376 | 1405.440 | 1431.648 | 0.565 | 0.982 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 172.960 | 312.064 | 1172.064 | 1096.448 | 0.554 | 1.069 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 249.632 | 446.336 | 1405.408 | 1448.480 | 0.559 | 0.970 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 250.944 | 440.128 | 1406.624 | 1421.952 | 0.570 | 0.989 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2418.720 | 2747.936 | 7330.432 | 9023.712 | 0.880 | 0.812 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1353.696 | 1608.480 | 4941.696 | 6078.752 | 0.842 | 0.813 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2427.456 | 2746.816 | 7329.792 | 10539.968 | 0.884 | 0.695 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2426.688 | 2763.168 | 7336.256 | 9057.536 | 0.878 | 0.810 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3554.240 | 5634.400 | 12919.872 | 16843.489 | 0.631 | 0.767 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2003.648 | 3250.784 | 8610.144 | 10015.424 | 0.616 | 0.860 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3582.080 | 5710.944 | 12923.328 | 17011.871 | 0.627 | 0.760 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3581.920 | 5618.144 | 12934.528 | 16745.888 | 0.638 | 0.772 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.120 | 71.232 | 269.760 | 295.680 | 0.802 | 0.912 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 49.408 | 65.312 | 242.304 | 253.952 | 0.756 | 0.954 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.504 | 72.544 | 269.632 | 298.976 | 0.793 | 0.902 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 57.760 | 71.040 | 269.600 | 296.640 | 0.813 | 0.909 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 82.336 | 147.168 | 466.080 | 487.456 | 0.559 | 0.956 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.040 | 435.392 | 453.248 | 0.667 | 0.961 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.856 | 147.424 | 465.920 | 499.552 | 0.555 | 0.933 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 81.760 | 146.656 | 466.176 | 485.984 | 0.557 | 0.959 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 206.976 | 678.080 | 866.976 | 0.853 | 0.782 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 121.664 | 164.768 | 538.240 | 636.160 | 0.738 | 0.846 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 176.608 | 209.664 | 677.696 | 883.424 | 0.842 | 0.767 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 177.440 | 207.840 | 677.248 | 868.288 | 0.854 | 0.780 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.272 | 449.536 | 1163.424 | 1420.832 | 0.557 | 0.819 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 173.472 | 305.376 | 929.408 | 1104.544 | 0.568 | 0.841 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 249.376 | 454.976 | 1163.648 | 1455.296 | 0.548 | 0.800 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 250.368 | 450.144 | 1163.520 | 1409.984 | 0.556 | 0.825 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2416.576 | 2726.208 | 6835.520 | 10442.784 | 0.886 | 0.655 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1357.440 | 1590.752 | 4433.664 | 5975.296 | 0.853 | 0.742 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2427.360 | 2747.040 | 6853.056 | 10670.784 | 0.884 | 0.642 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2441.120 | 2718.944 | 6836.640 | 10433.792 | 0.898 | 0.655 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3555.392 | 5620.960 | 11944.000 | 16504.801 | 0.633 | 0.724 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2010.848 | 3241.152 | 7636.064 | 9870.464 | 0.620 | 0.774 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3557.440 | 5688.352 | 11935.744 | 17090.496 | 0.625 | 0.698 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3562.720 | 5630.432 | 11939.168 | 16392.033 | 0.633 | 0.728 |
### Perf after this PR **FWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|---------------|------------|----------------|----------------------------| | Average | 0.776 | | | | | | Max | 1.006 | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | | Min | 0.566 | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | **BWD** | Type | Speedup | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | |---------|-----------|-------------|------------|----------------|-----------------------------| | Average | 0.817 | | | | | | Max | 1.150 | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | | Min | 0.454 | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) |
Full performance sweep | score_mod | mask_mod | dtype | shape(B,Hq,M,Hkv,N,D) | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------|------------|----------------|-------------------------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.680 | 17.056 | 64.544 | 73.376 | 0.919 | 0.880 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 15.712 | 19.872 | 65.408 | 72.864 | 0.791 | 0.898 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.160 | 17.280 | 64.896 | 73.888 | 0.935 | 0.878 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 64) | 16.192 | 17.120 | 64.896 | 75.424 | 0.946 | 0.860 | | None | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.648 | 22.496 | 89.184 | 82.592 | 0.873 | 1.080 | | None | causal | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.320 | 26.816 | 91.264 | 82.880 | 0.758 | 1.101 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 20.096 | 22.528 | 89.184 | 83.776 | 0.892 | 1.065 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 16, 512, 128) | 19.680 | 22.432 | 89.184 | 120.096 | 0.877 | 0.743 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.384 | 32.512 | 119.232 | 128.960 | 0.996 | 0.925 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 30.176 | 37.248 | 113.664 | 119.520 | 0.810 | 0.951 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.512 | 32.928 | 119.264 | 131.456 | 0.987 | 0.907 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 64) | 32.448 | 32.704 | 119.200 | 128.352 | 0.992 | 0.929 | | None | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.952 | 62.176 | 199.040 | 214.304 | 0.675 | 0.929 | | None | causal | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 39.744 | 62.880 | 189.504 | 179.968 | 0.632 | 1.053 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 41.472 | 62.784 | 199.136 | 217.664 | 0.661 | 0.915 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 16, 1024, 128) | 42.048 | 61.952 | 199.168 | 214.496 | 0.679 | 0.929 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 341.184 | 357.632 | 980.256 | 1328.896 | 0.954 | 0.738 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 212.576 | 252.960 | 673.888 | 824.864 | 0.840 | 0.817 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.000 | 363.296 | 980.768 | 1375.808 | 0.936 | 0.713 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 64) | 340.768 | 356.832 | 980.960 | 1326.272 | 0.955 | 0.740 | | None | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 459.392 | 737.120 | 1678.240 | 2205.248 | 0.623 | 0.761 | | None | causal | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 292.672 | 468.096 | 1178.016 | 1371.584 | 0.625 | 0.859 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.144 | 745.312 | 1680.000 | 2252.512 | 0.620 | 0.746 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 16, 4096, 128) | 462.112 | 736.576 | 1679.008 | 2216.480 | 0.627 | 0.758 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.064 | 16.704 | 105.120 | 120.768 | 0.962 | 0.870 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 15.552 | 18.144 | 107.136 | 121.696 | 0.857 | 0.880 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.096 | 16.768 | 102.688 | 120.864 | 0.960 | 0.850 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 64) | 16.032 | 16.576 | 104.736 | 124.672 | 0.967 | 0.840 | | None | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.392 | 21.952 | 104.736 | 174.656 | 0.883 | 0.600 | | None | causal | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 20.128 | 23.712 | 105.216 | 199.008 | 0.849 | 0.529 | | relative_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.904 | 21.888 | 103.744 | 179.520 | 0.909 | 0.578 | | head_bias | None | torch.bfloat16 | (2, 16, 512, 2, 512, 128) | 19.968 | 21.952 | 104.640 | 177.312 | 0.910 | 0.590 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.096 | 31.904 | 118.720 | 231.968 | 1.006 | 0.512 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 30.528 | 33.952 | 112.480 | 218.304 | 0.899 | 0.515 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.160 | 32.224 | 118.752 | 237.312 | 0.998 | 0.500 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 64) | 32.128 | 32.032 | 118.240 | 233.120 | 1.003 | 0.507 | | None | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.280 | 177.408 | 350.688 | 0.674 | 0.506 | | None | causal | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 39.552 | 59.360 | 168.832 | 371.488 | 0.666 | 0.454 | | relative_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.984 | 61.696 | 177.376 | 360.416 | 0.680 | 0.492 | | head_bias | None | torch.bfloat16 | (2, 16, 1024, 2, 1024, 128) | 41.312 | 61.760 | 177.184 | 355.744 | 0.669 | 0.498 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.744 | 357.888 | 939.712 | 1665.376 | 0.949 | 0.564 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 212.608 | 248.832 | 633.280 | 1122.848 | 0.854 | 0.564 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 339.712 | 363.232 | 940.448 | 1689.440 | 0.935 | 0.557 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 64) | 341.056 | 355.264 | 940.128 | 1641.152 | 0.960 | 0.573 | | None | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.736 | 741.024 | 1569.824 | 2559.552 | 0.622 | 0.613 | | None | causal | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 293.856 | 464.192 | 1066.240 | 1840.416 | 0.633 | 0.579 | | relative_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.704 | 753.152 | 1570.112 | 2641.088 | 0.612 | 0.594 | | head_bias | None | torch.bfloat16 | (2, 16, 4096, 2, 4096, 128) | 460.832 | 745.536 | 1570.144 | 2602.560 | 0.618 | 0.603 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.680 | 41.280 | 171.840 | 158.176 | 0.864 | 1.086 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 31.360 | 42.976 | 158.912 | 139.264 | 0.730 | 1.141 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.168 | 41.600 | 171.648 | 161.344 | 0.845 | 1.064 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 64) | 35.136 | 41.152 | 171.808 | 158.336 | 0.854 | 1.085 | | None | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.832 | 76.384 | 295.680 | 277.696 | 0.639 | 1.065 | | None | causal | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 45.632 | 72.512 | 281.760 | 250.752 | 0.629 | 1.124 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 49.504 | 76.608 | 295.584 | 279.712 | 0.646 | 1.057 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 16, 512, 128) | 48.864 | 75.904 | 295.456 | 277.568 | 0.644 | 1.064 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.392 | 111.232 | 408.640 | 442.656 | 0.894 | 0.923 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 71.392 | 95.168 | 338.784 | 341.760 | 0.750 | 0.991 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 99.808 | 112.256 | 408.608 | 456.160 | 0.889 | 0.896 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 64) | 100.032 | 110.816 | 408.512 | 444.192 | 0.903 | 0.920 | | None | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.040 | 226.112 | 726.880 | 774.176 | 0.597 | 0.939 | | None | causal | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 99.904 | 169.696 | 616.448 | 607.104 | 0.589 | 1.015 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.488 | 228.384 | 727.776 | 782.368 | 0.593 | 0.930 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 16, 1024, 128) | 135.744 | 225.664 | 728.000 | 773.600 | 0.602 | 0.941 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1324.192 | 1387.808 | 3866.944 | 5217.184 | 0.954 | 0.741 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 738.464 | 832.608 | 2507.392 | 3146.688 | 0.887 | 0.797 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.016 | 1404.256 | 3867.872 | 5382.624 | 0.944 | 0.719 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 64) | 1326.144 | 1386.688 | 3867.552 | 5203.264 | 0.956 | 0.743 | | None | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1847.488 | 2866.336 | 6612.704 | 8597.696 | 0.645 | 0.769 | | None | causal | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1066.592 | 1660.640 | 4357.696 | 5174.016 | 0.642 | 0.842 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1850.464 | 2905.408 | 6616.928 | 8793.280 | 0.637 | 0.752 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 16, 4096, 128) | 1848.896 | 2834.720 | 6623.872 | 8637.920 | 0.652 | 0.767 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.384 | 38.656 | 150.336 | 182.624 | 0.941 | 0.823 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 31.360 | 38.112 | 137.664 | 171.840 | 0.823 | 0.801 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.608 | 39.040 | 150.528 | 183.872 | 0.938 | 0.819 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 64) | 36.064 | 38.656 | 150.560 | 183.520 | 0.933 | 0.820 | | None | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.344 | 76.352 | 253.920 | 301.440 | 0.646 | 0.842 | | None | causal | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 46.720 | 65.824 | 239.424 | 296.384 | 0.710 | 0.808 | | relative_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.248 | 76.416 | 253.728 | 307.808 | 0.644 | 0.824 | | head_bias | None | torch.bfloat16 | (8, 16, 512, 2, 512, 128) | 49.376 | 76.288 | 253.728 | 304.736 | 0.647 | 0.833 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.144 | 364.960 | 503.072 | 0.901 | 0.725 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 71.136 | 92.384 | 294.432 | 393.056 | 0.770 | 0.749 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.200 | 111.360 | 365.152 | 512.640 | 0.891 | 0.712 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 64) | 99.264 | 110.240 | 365.088 | 504.224 | 0.900 | 0.724 | | None | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.680 | 230.336 | 613.472 | 816.896 | 0.589 | 0.751 | | None | causal | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 100.256 | 165.088 | 502.144 | 676.480 | 0.607 | 0.742 | | relative_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.008 | 232.480 | 613.184 | 836.672 | 0.581 | 0.733 | | head_bias | None | torch.bfloat16 | (8, 16, 1024, 2, 1024, 128) | 135.232 | 230.624 | 613.536 | 827.136 | 0.586 | 0.742 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1324.064 | 1378.688 | 3631.808 | 5308.384 | 0.960 | 0.684 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 731.776 | 826.688 | 2263.168 | 3241.344 | 0.885 | 0.698 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1316.128 | 1403.200 | 3625.088 | 5550.688 | 0.938 | 0.653 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 64) | 1311.904 | 1378.880 | 3616.320 | 5353.696 | 0.951 | 0.675 | | None | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1837.856 | 2887.392 | 6121.632 | 8586.656 | 0.637 | 0.713 | | None | causal | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1066.976 | 1654.368 | 3843.136 | 5291.040 | 0.645 | 0.726 | | relative_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1854.208 | 2896.832 | 6130.112 | 8745.984 | 0.640 | 0.701 | | head_bias | None | torch.bfloat16 | (8, 16, 4096, 2, 4096, 128) | 1860.512 | 2889.344 | 6135.648 | 8750.592 | 0.644 | 0.701 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.640 | 71.552 | 315.968 | 296.512 | 0.847 | 1.066 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 50.784 | 71.040 | 284.288 | 258.880 | 0.715 | 1.098 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 61.312 | 72.704 | 315.680 | 302.016 | 0.843 | 1.045 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 64) | 60.800 | 71.776 | 316.320 | 297.152 | 0.847 | 1.065 | | None | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.576 | 144.416 | 580.576 | 535.936 | 0.586 | 1.083 | | None | causal | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 76.064 | 123.648 | 553.344 | 481.376 | 0.615 | 1.150 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.160 | 145.248 | 581.024 | 540.000 | 0.579 | 1.076 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 16, 512, 128) | 84.512 | 143.552 | 581.088 | 535.776 | 0.589 | 1.085 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.152 | 209.408 | 798.400 | 868.704 | 0.903 | 0.919 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 127.552 | 168.800 | 650.816 | 663.328 | 0.756 | 0.981 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.376 | 211.360 | 798.080 | 895.552 | 0.896 | 0.891 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 64) | 189.440 | 208.576 | 797.888 | 873.152 | 0.908 | 0.914 | | None | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 257.536 | 441.760 | 1408.960 | 1514.720 | 0.583 | 0.930 | | None | causal | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 179.328 | 312.096 | 1170.368 | 1177.472 | 0.575 | 0.994 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 259.264 | 446.944 | 1408.768 | 1530.400 | 0.580 | 0.921 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 16, 1024, 128) | 258.080 | 440.480 | 1408.864 | 1514.144 | 0.586 | 0.930 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.808 | 2771.456 | 7616.704 | 10405.248 | 0.937 | 0.732 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 1435.744 | 1610.336 | 4927.520 | 6220.000 | 0.892 | 0.792 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2595.264 | 2745.056 | 7611.232 | 10631.392 | 0.945 | 0.716 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 64) | 2576.256 | 2735.456 | 7626.400 | 10346.976 | 0.942 | 0.737 | | None | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.744 | 5634.816 | 13077.056 | 17182.528 | 0.653 | 0.761 | | None | causal | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 2099.360 | 3250.176 | 8589.664 | 10236.672 | 0.646 | 0.839 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3676.800 | 5716.288 | 13073.088 | 17311.071 | 0.643 | 0.755 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 16, 4096, 128) | 3679.136 | 5570.496 | 13070.720 | 17192.863 | 0.660 | 0.760 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.600 | 71.008 | 272.320 | 300.000 | 0.868 | 0.908 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 50.176 | 65.344 | 241.568 | 258.912 | 0.768 | 0.933 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.120 | 72.512 | 272.672 | 305.408 | 0.843 | 0.893 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 64) | 61.248 | 71.136 | 272.640 | 301.120 | 0.861 | 0.905 | | None | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.872 | 146.784 | 466.912 | 496.832 | 0.571 | 0.940 | | None | causal | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 76.704 | 115.072 | 435.584 | 462.112 | 0.667 | 0.943 | | relative_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.392 | 147.392 | 466.656 | 504.448 | 0.566 | 0.925 | | head_bias | None | torch.bfloat16 | (16, 16, 512, 2, 512, 128) | 83.360 | 146.688 | 466.656 | 499.040 | 0.568 | 0.935 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.024 | 207.584 | 684.768 | 873.568 | 0.911 | 0.784 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 126.944 | 164.288 | 536.192 | 645.984 | 0.773 | 0.830 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 188.768 | 209.760 | 684.096 | 897.504 | 0.900 | 0.762 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 64) | 189.408 | 207.776 | 685.024 | 876.384 | 0.912 | 0.782 | | None | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 259.168 | 449.536 | 1167.936 | 1433.280 | 0.577 | 0.815 | | None | causal | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 180.000 | 305.312 | 928.000 | 1113.920 | 0.590 | 0.833 | | relative_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 258.464 | 455.136 | 1167.808 | 1462.848 | 0.568 | 0.798 | | head_bias | None | torch.bfloat16 | (16, 16, 1024, 2, 1024, 128) | 257.824 | 450.208 | 1167.744 | 1448.000 | 0.573 | 0.806 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2598.368 | 2729.120 | 7134.400 | 10381.632 | 0.952 | 0.687 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 1435.456 | 1591.040 | 4424.768 | 6035.808 | 0.902 | 0.733 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2594.752 | 2725.952 | 7128.384 | 10822.496 | 0.952 | 0.659 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 64) | 2597.888 | 2716.960 | 7101.568 | 10385.440 | 0.956 | 0.684 | | None | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3647.648 | 5581.632 | 12089.952 | 16667.233 | 0.654 | 0.725 | | None | causal | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 2093.952 | 3241.440 | 7579.392 | 9847.936 | 0.646 | 0.770 | | relative_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3650.528 | 5650.688 | 12105.568 | 16963.680 | 0.646 | 0.714 | | head_bias | None | torch.bfloat16 | (16, 16, 4096, 2, 4096, 128) | 3680.064 | 5585.312 | 12117.504 | 16935.040 | 0.659 | 0.716 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135505 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 83 +++++++++++++++- test/inductor/test_flex_decoding.py | 40 +++++++- torch/_higher_order_ops/flex_attention.py | 23 +++++ torch/_inductor/kernel/flex_attention.py | 116 ++++++++++++++-------- torch/_inductor/kernel/flex_decoding.py | 11 +- 5 files changed, 225 insertions(+), 48 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 06b012348d9f88..cc821d1467c329 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -5,7 +5,7 @@ import string from collections import namedtuple from contextlib import contextmanager, nullcontext -from typing import Callable, Optional +from typing import Callable, Optional, Tuple from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch @@ -212,6 +212,17 @@ def _trig2(score, b, h, m, n): S = 2048 D = 64 +test_Hq_Hkv = [ + (4, 2), + (4, 1), +] + +test_Bq_Bkv = [ + (3, 1), + (4, 1), + (5, 1), +] + def query_key_value_clones( query: torch.Tensor, @@ -618,6 +629,76 @@ def test_builtin_score_mods_different_seqlen( D, ) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("batch_dims", test_Bq_Bkv) + @common_utils.parametrize("head_dims", test_Hq_Hkv) + @common_utils.parametrize("score_mod", test_score_mods) + def test_kv_batch_broadcast( + self, + dtype: torch.dtype, + batch_dims: Tuple[int, int], + head_dims: Tuple[int, int], + score_mod: Callable, + ): + Hq, Hkv = head_dims + assert Hq % Hkv == 0 + + Bq, Bkv = batch_dims + assert Bq > 1 and Bkv == 1 + + self.run_test( + score_mod, + dtype, + Bq, + Hq, + S, + D, + Bkv, + Hkv, + S, + D, + ) + + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("batch_dims", test_Bq_Bkv) + @common_utils.parametrize("head_dims", test_Hq_Hkv) + @common_utils.parametrize("score_mod", test_score_mods) + def test_kv_batch_broadcast_causal_mask( + self, + dtype: torch.dtype, + batch_dims: Tuple[int, int], + head_dims: Tuple[int, int], + score_mod: Callable, + ): + Hq, Hkv = head_dims + assert Hq % Hkv == 0 + + Bq, Bkv = batch_dims + assert Bq > 1 and Bkv == 1 + + def mask_mod(b, h, q, kv): + return q >= kv + + block_mask = create_block_mask(mask_mod, 1, 1, S, S) + attention = functools.partial( + flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv) + ) + + self.run_test_with_call( + attention, + torch.float16, + Bq, + Hq, + S, + D, + Bkv, + Hkv, + S, + D, + ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 4d22fd877ecabb..27d0f3083be4d7 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -4,7 +4,7 @@ import functools from collections import namedtuple from contextlib import nullcontext -from typing import Callable, Optional +from typing import Callable, Optional, Tuple from unittest import expectedFailure, skipUnless from unittest.mock import patch @@ -186,6 +186,13 @@ def _trig2(score, b, h, m, n): (16, 16), ] +test_Bq_Bkv = [ + (3, 1), + (5, 1), + (8, 1), + (16, 1), +] + (Hq, Hkv) = (16, 8) @@ -449,6 +456,37 @@ def test_strided_inputs(self, dtype: torch.dtype, k_s, v_s, head_dims): ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol ) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("head_dims", test_Hq_Hkv) + @common_utils.parametrize("batch_dims", test_Bq_Bkv) + @common_utils.parametrize("score_mod", test_score_mods) + def test_kv_batch_broadcast( + self, + dtype: torch.dtype, + head_dims: Tuple[int, int], + batch_dims: Tuple[int, int], + score_mod: Callable, + ): + Hq, Hkv = head_dims + assert Hq % Hkv == 0 + + Bq, Bkv = batch_dims + assert Bq > 1 and Bkv == 1 + + self.run_test( + score_mod, + dtype, + Bq, + Hq, + 1, + D, + Bkv, + Hkv, + S, + D, + ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes) def test_skip_odd_keys(self, dtype: torch.dtype): diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index cd1c9d4fd92c5e..3783b29f117e1e 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -192,6 +192,13 @@ def math_attention( value = torch.repeat_interleave(value, G, dim=1) key = torch.repeat_interleave(key, G, dim=1) + Bq, Bkv = query.size(0), key.size(0) + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + + key = key.expand((Bq, *key.size()[1:])) + value = value.expand((Bq, *value.size()[1:])) + _, post_mod_scores = _math_attention_inner( query, key, @@ -691,6 +698,13 @@ def sdpa_dense_backward( actual_grad_key = torch.empty_like(key) actual_grad_value = torch.empty_like(value) + Bq, Bkv = query.size(0), key.size(0) + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + + key = key.expand((Bq, *key.size()[1:])) + value = value.expand((Bq, *value.size()[1:])) + G = query.size(1) // key.size(1) key = torch.repeat_interleave(key, G, dim=1) value = torch.repeat_interleave(value, G, dim=1) @@ -772,6 +786,15 @@ def sdpa_dense_backward( grad_key = torch.sum(grad_key, 2, keepdim=False) grad_value = torch.sum(grad_value, 2, keepdim=False) + if Bq != Bkv: + assert ( + Bq > 1 and Bkv == 1 + ), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" + + # Reduce DK, DV along broadcasted batches. + grad_key = torch.sum(grad_key, 0, keepdim=True) + grad_value = torch.sum(grad_value, 0, keepdim=True) + actual_grad_query.copy_(grad_query) actual_grad_key.copy_(grad_key) actual_grad_value.copy_(grad_value) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 4c1c4de3338cf0..ec79b8097cb846 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -179,22 +179,27 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} - Z = {{size("Q", 0)}} + ZQ = {{size("Q", 0)}} HQ = {{size("Q", 1)}} Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} KV_LEN = {{size("K", 2)}} MATMUL_PRECISION = Q.dtype.element_ty q_start = tl.program_id(0) - off_z = tl.program_id(1) // HQ + off_zq = tl.program_id(1) // HQ off_hq = tl.program_id(1) % HQ + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV off_hkv = off_hq // GQA_SHARED_HEADS off_g = off_hq % GQA_SHARED_HEADS - q_offset = off_z * stride_qz + off_hq * stride_qh - k_offset = off_z * stride_kz + off_hkv * stride_kh - v_offset = off_z * stride_vz + off_hkv * stride_vh + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh Q = Q + q_offset K = K + k_offset @@ -203,7 +208,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - sparse_idx_z = off_z % SPARSE_Z + sparse_idx_z = off_zq % SPARSE_Z sparse_idx_hq = off_hq % SPARSE_HQ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) @@ -270,7 +275,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK {{gen_argdefs()}}, q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, acc, l_i, m_i, - off_z, off_hq, offs_m[:, None], offs_n[None, :], + off_zq, off_hq, offs_m[:, None], offs_n[None, :], kv_indices, kv_num_blocks, 0, block_n_end, MATMUL_PRECISION, @@ -309,7 +314,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK {{gen_argdefs()}}, q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, acc, l_i, m_i, - off_z, off_hq, offs_m[:, None], offs_n[None, :], + off_zq, off_hq, offs_m[:, None], offs_n[None, :], kv_indices, kv_num_blocks, 0, block_n_end, MATMUL_PRECISION, @@ -323,14 +328,14 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK l_i = tl.where(l_i == 0.0, 1, l_i) acc = acc / l_i[:, None] - idx_z = tl.program_id(1) // HQ + idx_zq = tl.program_id(1) // HQ idx_hq = tl.program_id(1) % HQ idx_m = offs_m[:, None] idx_d = tl.arange(0, V_HEAD_DIM)[None, :] mask = idx_m < Q_LEN # TODO generalize and add proper mask support - {{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} # TODO dont want to write this if we dont require grad if OUTPUT_LOGSUMEXP: @@ -741,7 +746,10 @@ def flex_attention( Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert Bq == Bkv, "Batch dimension must match" + + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + B = Bq if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: @@ -940,10 +948,11 @@ def flex_attention_backward_grid( stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} - Z = {{size("Q", 0)}} + ZQ = {{size("Q", 0)}} HQ = {{size("Q", 1)}} HKV = {{size("K", 1)}} Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} KV_LEN = {{size("K", 2)}} MATMUL_PRECISION = Q.dtype.element_ty @@ -953,17 +962,20 @@ def flex_attention_backward_grid( NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) off_hz = tl.program_id(2) - off_z = off_hz // HKV # batch idx + off_zq = off_hz // HKV # q batch idx off_hkv = off_hz % HKV # kv head idx + off_zkv = off_zq % ZKV # kv batch idx SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - sparse_idx_z = off_z % SPARSE_Z + sparse_idx_z = off_zq % SPARSE_Z - k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64) - v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64) - dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64) + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) # offset K, V, DV pointers for batch/kv-head K += k_adj @@ -993,10 +1005,10 @@ def flex_attention_backward_grid( sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. - q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64) - do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64) - dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64) - off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64) + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) Q2 = Q + q_adj2 DO2 = DO + do_adj2 @@ -1042,7 +1054,7 @@ def flex_attention_backward_grid( {{gen_argdefs()}}, K, V, dq, q, do, Di, lse, - off_z, off_hq2, offs_m2, offs_n2, + off_zq, off_hq2, offs_m2, offs_n2, stride_kn, stride_kd, stride_vn, stride_vd, kv_indices, sparse_kv_num_blocks, MATMUL_PRECISION, @@ -1061,7 +1073,7 @@ def flex_attention_backward_grid( {{gen_argdefs()}}, K, V, dq, q, do, Di, lse, - off_z, off_hq2, offs_m2, offs_n2, + off_zq, off_hq2, offs_m2, offs_n2, stride_kn, stride_kd, stride_vn, stride_vd, kv_indices, sparse_kv_num_blocks, MATMUL_PRECISION, @@ -1106,10 +1118,10 @@ def flex_attention_backward_grid( off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. - q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64) - do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64) - dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64) - off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64) + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) Q1 = Q + q_adj1 DO1 = DO + do_adj1 @@ -1135,7 +1147,7 @@ def flex_attention_backward_grid( {{gen_argdefs()}}, Q1, DO1, DELTA1, LSE1, dk, dv, k, v, - off_z, off_hq1, offs_n1, offs_m1, + off_zq, off_hq1, offs_n1, offs_m1, stride_qm, stride_qd, stride_dom, stride_dod, q_indices, sparse_q_num_blocks, MATMUL_PRECISION, @@ -1155,7 +1167,7 @@ def flex_attention_backward_grid( {{gen_argdefs()}}, Q1, DO1, DELTA1, LSE1, dk, dv, k, v, - off_z, off_hq1, offs_n1, offs_m1, + off_zq, off_hq1, offs_n1, offs_m1, stride_qm, stride_qd, stride_dom, stride_dod, q_indices, sparse_q_num_blocks, MATMUL_PRECISION, @@ -1175,7 +1187,10 @@ def flex_attention_backward_grid( dk *= SM_SCALE mask = index_n < KV_LEN - {{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} @triton.jit def bwd_dq_inner( @@ -1618,8 +1633,9 @@ def flex_attention_backward(*args, **kwargs): dtype = query.get_dtype() Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert Bq == Bkv, "Batch dimension must match" - B = Bq + + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") kernel_options = dict(kernel_options) kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) @@ -1662,10 +1678,10 @@ def flex_attention_backward(*args, **kwargs): mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) - layout_k = FixedLayout( + layout_broadcasted_k = FixedLayout( key.get_device(), key.get_dtype(), - key.get_size(), + [Bq, Hkv, seq_len_kv, qk_head_dim], key.get_stride(), ) @@ -1682,8 +1698,11 @@ def flex_attention_backward(*args, **kwargs): grad_query = empty_strided( query.get_size(), query.get_stride(), dtype=dtype, device=device ) - grad_value = empty_strided( - value.get_size(), value.get_stride(), dtype=dtype, device=device + broadcasted_grad_value = empty_strided( + (Bq, *value.get_size()[1:]), + value.get_stride(), + dtype=dtype, + device=device, ) kernel_options.setdefault("SM_SCALE", scale) @@ -1746,7 +1765,7 @@ def flex_attention_backward(*args, **kwargs): delta, grad_out, grad_query, - grad_value, + broadcasted_grad_value, kv_num_blocks, kv_indices, q_num_blocks, @@ -1756,9 +1775,9 @@ def flex_attention_backward(*args, **kwargs): full_q_num_blocks, full_q_indices, ], - layout=layout_k, # We use store_output only for grad_key + layout=layout_broadcasted_k, # We use store_output only for grad_key subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer], - mutated_inputs=[grad_query, grad_value], + mutated_inputs=[grad_query, broadcasted_grad_value], call_sizes=query.get_size() + key.get_size()[1:3], num_stages=num_stages, num_warps=num_warps, @@ -1773,7 +1792,7 @@ def flex_attention_backward(*args, **kwargs): delta, grad_out, grad_query, - grad_value, + broadcasted_grad_value, kv_num_blocks, kv_indices, q_num_blocks, @@ -1797,13 +1816,24 @@ def flex_attention_backward(*args, **kwargs): 15: create_indices_fake, } - grad_key = autotune_select_algorithm( + broadcasted_grad_key = autotune_select_algorithm( "flex_attention_backward", choices, inputs_for_autotuning, - layout_k, + layout_broadcasted_k, input_gen_fns=input_gen_fns, - ) + ) # [Bq, Hkv, seq_len_kv, k_head_dim] + + if Bq == Bkv: + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value + else: + assert ( + Bq > 1 and Bkv == 1 + ), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" + grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) + grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) + return ( grad_query, grad_key, diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 8a5b7c60eafdda..ef57f208a93665 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -83,6 +83,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} HKV = {{size("Q", 1)}} G: tl.constexpr = GQA_SHARED_HEADS HQ = HKV * G @@ -97,12 +98,13 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) off_z = tl.program_id(0) // HKV + off_zkv = off_z % ZKV off_hkv = tl.program_id(0) % HKV off_t = tl.program_id(1) q_offset = off_z * stride_qz + off_hkv * stride_qh - k_offset = off_z * stride_kz + off_hkv * stride_kh - v_offset = off_z * stride_vz + off_hkv * stride_vh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} @@ -338,7 +340,10 @@ def create_flex_decoding_kernel(*args, **kwargs): Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert Bq == Bkv, "Batch dimension must match" + + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + B = Bq kernel_options = dict(kernel_options) From f2ca0c017bf2e570a91703841734508e4296b12b Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Mon, 9 Sep 2024 23:40:32 -0700 Subject: [PATCH 0667/1018] [Inductor] Generalize `is_cuda` to specific device_type to make cpp_wrapper mode be extensible (#134693) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134693 Approved by: https://github.com/ezyang, https://github.com/EikanWang, https://github.com/jansel --- test/inductor/test_aot_inductor.py | 2 +- torch/_inductor/autotune_process.py | 4 +- torch/_inductor/codecache.py | 60 +++++++++-------- torch/_inductor/codegen/cpp_wrapper_cpu.py | 10 ++- torch/_inductor/codegen/cpp_wrapper_cuda.py | 1 - torch/_inductor/cpp_builder.py | 75 ++++++++++++--------- torch/_inductor/graph.py | 29 ++++---- torch/utils/cpp_extension.py | 68 ++++++++++++++++--- 8 files changed, 156 insertions(+), 93 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index cfac539c63b879..4b270a3bb21bda 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1745,7 +1745,7 @@ def forward(self, x, y): torch._export.aot_compile(Model(), example_inputs) supported_dtype_of_cpp_wrapper_mock.assert_called_once_with( - torch.float32, self.device == "cuda" + torch.float32, self.device ) def test_consecutive_compiles(self): diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index eea4f8d6573d82..94efbaf4e32fdd 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -825,14 +825,14 @@ def precompile(self): # Prepopulate CppCodeCache # may happen in separate Threadpool log.debug("Precompiling %s", self) - CppCodeCache.load(self.source_code, cuda=False) + CppCodeCache.load(self.source_code, device_type="cpu") log.debug("Done precompiling %s", self) def make_run_fn( self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor ) -> Callable[[], None]: # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf - self.DLL = CppCodeCache.load(self.source_code, cuda=False) + self.DLL = CppCodeCache.load(self.source_code, device_type="cpu") args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] log.debug( "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 59cc47ac06c94e..6d77ea959e5f36 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -83,7 +83,7 @@ _transform_cuda_paths, CppBuilder, CppOptions, - CppTorchCudaOptions, + CppTorchDeviceOptions, get_compiler_version_info, get_cpp_compiler, get_name_and_dir_from_output_file_path, @@ -1512,7 +1512,6 @@ def set(cls, key: str, params: Dict[str, str], cubin: str, bin_type: str) -> Non config.aot_inductor.output_path )[0], ) - params[get_cpp_wrapper_cubin_path_name()] = path cls.cache[key] = params @@ -1533,7 +1532,7 @@ def compile( graph: GraphLowering, source_code: str, serialized_extern_kernel_nodes: Optional[str], - cuda: bool, + device_type: str, ) -> str: if sys.platform == "win32": raise RuntimeError("AotCodeCompiler not yet supported for inductor") @@ -1544,9 +1543,9 @@ def compile( vec_isa_cmd_gen = CppBuilder( name="o", sources="i", - BuildOption=CppTorchCudaOptions( + BuildOption=CppTorchDeviceOptions( vec_isa=picked_vec_isa, - cuda=cuda, + device_type=device_type, aot_mode=graph.aot_mode, ), ) @@ -1561,7 +1560,7 @@ def compile( use_absolute_path = False if config.is_fbcode(): ld_command = build_paths.ld() - if not cuda and graph.aot_mode: # Meta internal AOTInductor CPU + if device_type == "cpu" and graph.aot_mode: # Meta internal AOTInductor CPU objcopy_command = build_paths.objcopy_fallback() fbcode_aot_cpu_re = True use_absolute_path = True @@ -1761,9 +1760,9 @@ def get_nbytes_of_tensor(tensor: torch.Tensor, all_cuda: bool) -> int: object_output_name, object_output_dir, ) = get_name_and_dir_from_output_file_path(input_path) - object_build_options = CppTorchCudaOptions( + object_build_options = CppTorchDeviceOptions( vec_isa=picked_vec_isa, - cuda=cuda, + device_type=device_type, aot_mode=graph.aot_mode, compile_only=True, use_absolute_path=use_absolute_path, @@ -1786,9 +1785,9 @@ def get_nbytes_of_tensor(tensor: torch.Tensor, all_cuda: bool) -> int: object_output_name, object_output_dir, ) = get_name_and_dir_from_output_file_path(input_path) - object_build_options = CppTorchCudaOptions( + object_build_options = CppTorchDeviceOptions( vec_isa=picked_vec_isa, - cuda=cuda, + device_type=device_type, aot_mode=graph.aot_mode, compile_only=True, use_absolute_path=use_absolute_path, @@ -1864,9 +1863,9 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: output_name, output_dir = get_name_and_dir_from_output_file_path( output_so ) - so_build_options = CppTorchCudaOptions( + so_build_options = CppTorchDeviceOptions( vec_isa=picked_vec_isa, - cuda=cuda, + device_type=device_type, aot_mode=graph.aot_mode, use_absolute_path=use_absolute_path, ) @@ -1898,9 +1897,9 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: output_name, output_dir = get_name_and_dir_from_output_file_path( output_so ) - so_build_options = CppTorchCudaOptions( + so_build_options = CppTorchDeviceOptions( vec_isa=picked_vec_isa, - cuda=cuda, + device_type=device_type, aot_mode=graph.aot_mode, use_absolute_path=use_absolute_path, ) @@ -2111,13 +2110,13 @@ def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: def load_async( cls, source_code: str, - cuda: bool = False, + device_type: str = "cpu", submit_fn: Any = None, extra_flags: Sequence[str] = (), ) -> Any: compile_command = { **cls.cpp_compile_command_flags, - "cuda": cuda, + "device_type": device_type, "vec_isa": pick_vec_isa(), "extra_flags": extra_flags, } @@ -2125,7 +2124,7 @@ def load_async( _set_gpu_runtime_env() # cpp_extension consults the env command_gen = CppBuilder( - name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command) + name="o", sources="i", BuildOption=CppTorchDeviceOptions(**compile_command) ) # write function will calc source_code hash, the same source code with different # ISA level should be generate different hash. @@ -2148,7 +2147,7 @@ def load_async( future: Optional[Future[Any]] = None lib = None - cpp_build_option = CppTorchCudaOptions(**compile_command) + cpp_build_option = CppTorchDeviceOptions(**compile_command) cpp_builder = CppBuilder( name=output_name, sources=input_path, @@ -2191,8 +2190,8 @@ def load_fn() -> Any: return cls.cache[key] @classmethod - def load(cls, source_code: str, cuda: bool = False) -> Any: - return cls.load_async(source_code, cuda)() + def load(cls, source_code: str, device_type: str = "cpu") -> Any: + return cls.load_async(source_code, device_type)() def _worker_compile_cpp( @@ -2335,7 +2334,7 @@ def load_pybinding_async( cls, argtypes: List[str], source_code: str, - cuda: bool = False, + device_type: str = "cpu", num_outputs: int = -1, submit_fn: Any = None, extra_flags: Sequence[str] = (), @@ -2367,7 +2366,10 @@ def load_pybinding_async( cls.entry_function, ) get_result = cls.load_async( - source_code + suffix, cuda, submit_fn=submit_fn, extra_flags=extra_flags + source_code + suffix, + device_type, + submit_fn=submit_fn, + extra_flags=extra_flags, ) result = None @@ -2718,7 +2720,7 @@ def generate_halide_async( cls._codegen_glue(meta, headerfile), extra_flags=(libfile, cls.build_standalone_runtime()), submit_fn=jobs.append if need_compile else None, - cuda=meta.is_cuda(), + device_type="cuda" if meta.is_cuda() else "cpu", ) if need_compile: @@ -2746,9 +2748,9 @@ def build_standalone_runtime(cls) -> str: cls._standalone_runtime_path ): return cls._standalone_runtime_path - is_cuda = torch.cuda.is_available() + device_type = "cuda" if torch.cuda.is_available() else "cpu" libname = "libStandaloneHalideRuntime.so" - target = "host-cuda" if is_cuda else "host" + target = "host-cuda" if device_type == "cuda" else "host" if cls._standalone_runtime_path: assert not os.path.exists(cls._standalone_runtime_path) # We hit this case in unittests when we run with fresh_inductor_cache() @@ -2772,7 +2774,7 @@ def build_standalone_runtime(cls) -> str: with filelock.FileLock(lockfile, LOCK_TIMEOUT): if not os.path.exists(donefile): with open(hookfile, "w") as f: - if is_cuda: + if device_type == "cuda": f.write( cls.standalone_runtime_cuda_init.format( cls.find_header("HalideRuntimeCuda.h") @@ -2785,8 +2787,8 @@ def build_standalone_runtime(cls) -> str: name=name, sources=[hookfile, afile], output_dir=output_dir, - BuildOption=CppTorchCudaOptions( - cuda=is_cuda, + BuildOption=CppTorchDeviceOptions( + device_type=device_type, ), ) @@ -2958,7 +2960,7 @@ def _cuda_lib_options() -> List[str]: _set_gpu_runtime_env() # cpp_extension consults the env from torch.utils import cpp_extension - lpaths = cpp_extension.library_paths(cuda=True) + [ + lpaths = cpp_extension.library_paths(device_type="cuda") + [ sysconfig.get_config_var("LIBDIR") ] extra_ldflags: List[str] = [] diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 07d84ac249d5eb..fe178978d53be4 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -50,7 +50,6 @@ def __init__(self): self.extern_call_ops = set() self.size = "sizes()" self.stride = "strides()" - self.cuda = False self.supports_intermediate_hooks = False self.outputs_need_copy = set() self.kernel_callsite_id = count() @@ -1137,7 +1136,7 @@ def generate_end(self, result): result.splice( f""" inductor_entry = CppWrapperCodeCache.load_pybinding( - ["std::vector"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)}) + ["std::vector"], cpp_wrapper_src, "{self.device}", {len(V.graph.graph_outputs)}) """ ) @@ -1647,6 +1646,11 @@ def make_allocation( f"at::Tensor {name} = at::detail::empty_strided_cuda(" f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);" ) + if device.type == "xpu": + return ( + f"at::Tensor {name} = at::detail::empty_strided_xpu(" + f"{size}, {stride}, {dtype_code}, c10::DeviceType::XPU);" + ) return ( f"{self.declare}{name} = {self.namespace}empty_strided(" f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}" @@ -1735,7 +1739,7 @@ def create_dtypeview_call(reinterpret_call): ) call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] dtype_name = str(dtype).split(".")[-1] - device_name = "cuda" if data.layout.device.type == "cuda" else "cpu" + device_name = data.layout.device.type get_dtype_function = f"aoti_torch_dtype_{dtype_name}" dtypeview_function = f"aoti_torch_{device_name}_view_dtype" call_strs.append( diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index b82a707961df9e..89c28bab380949 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -168,7 +168,6 @@ def __init__(self) -> None: self.device = "cuda" super().__init__() self.grid_id = count() - self.cuda = True def write_header(self): if V.graph.is_const_graph: diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 95a0bff86fd8a4..e6bc67b5289b92 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1166,8 +1166,8 @@ def _transform_cuda_paths(lpaths: List[str]) -> None: break -def get_cpp_torch_cuda_options( - cuda: bool, +def get_cpp_torch_device_options( + device_type: str, aot_mode: bool = False, compile_only: bool = False, ) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: @@ -1190,10 +1190,10 @@ def get_cpp_torch_cuda_options( _set_gpu_runtime_env() from torch.utils import cpp_extension - include_dirs = cpp_extension.include_paths(cuda) - libraries_dirs = cpp_extension.library_paths(cuda) + include_dirs = cpp_extension.include_paths(device_type) + libraries_dirs = cpp_extension.library_paths(device_type) - if cuda: + if device_type == "cuda": definations.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") if torch.version.hip is not None: @@ -1208,6 +1208,11 @@ def get_cpp_torch_cuda_options( else: libraries += ["c10_cuda", "cuda", "torch_cuda"] + if device_type == "xpu": + definations.append(" USE_XPU") + cflags += ["fsycl"] + libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"] + if aot_mode: if config.is_fbcode(): from torch._inductor.codecache import cpp_prefix_path @@ -1215,7 +1220,7 @@ def get_cpp_torch_cuda_options( cpp_prefix_include_dir = [f"{os.path.dirname(cpp_prefix_path())}"] include_dirs += cpp_prefix_include_dir - if cuda and torch.version.hip is None: + if device_type == "cuda" and torch.version.hip is None: _transform_cuda_paths(libraries_dirs) if config.is_fbcode(): @@ -1224,7 +1229,7 @@ def get_cpp_torch_cuda_options( else: include_dirs.append(os.path.join(build_paths.cuda(), "include")) - if aot_mode and cuda: + if aot_mode and device_type == "cuda": if torch.version.hip is None: if not compile_only: # Only add link args, when compile_only is false. @@ -1241,18 +1246,18 @@ def get_cpp_torch_cuda_options( ) -class CppTorchCudaOptions(CppTorchOptions): +class CppTorchDeviceOptions(CppTorchOptions): """ This class is inherited from CppTorchOptions, which automatic contains base cxx build options and torch common build options. And then it will - maintains cuda device related build args. + maintains cuda/xpu device related build args. """ def __init__( self, vec_isa: VecISA = invalid_vec_isa, include_pytorch: bool = False, - cuda: bool = True, + device_type: str = "cuda", aot_mode: bool = False, compile_only: bool = False, use_absolute_path: bool = False, @@ -1269,33 +1274,37 @@ def __init__( use_mmap_weights=use_mmap_weights, extra_flags=extra_flags, ) + if device_type == "xpu": + from torch.utils.cpp_extension import _join_sycl_home + + self._compiler = _join_sycl_home("bin", "icpx") - cuda_definations: List[str] = [] - cuda_include_dirs: List[str] = [] - cuda_cflags: List[str] = [] - cuda_ldflags: List[str] = [] - cuda_libraries_dirs: List[str] = [] - cuda_libraries: List[str] = [] - cuda_passthough_args: List[str] = [] + device_definations: List[str] = [] + device_include_dirs: List[str] = [] + device_cflags: List[str] = [] + device_ldflags: List[str] = [] + device_libraries_dirs: List[str] = [] + device_libraries: List[str] = [] + device_passthough_args: List[str] = [] ( - cuda_definations, - cuda_include_dirs, - cuda_cflags, - cuda_ldflags, - cuda_libraries_dirs, - cuda_libraries, - cuda_passthough_args, - ) = get_cpp_torch_cuda_options( - cuda=cuda, aot_mode=aot_mode, compile_only=compile_only + device_definations, + device_include_dirs, + device_cflags, + device_ldflags, + device_libraries_dirs, + device_libraries, + device_passthough_args, + ) = get_cpp_torch_device_options( + device_type=device_type, aot_mode=aot_mode, compile_only=compile_only ) - _append_list(self._definations, cuda_definations) - _append_list(self._include_dirs, cuda_include_dirs) - _append_list(self._cflags, cuda_cflags) - _append_list(self._ldflags, cuda_ldflags) - _append_list(self._libraries_dirs, cuda_libraries_dirs) - _append_list(self._libraries, cuda_libraries) - _append_list(self._passthough_args, cuda_passthough_args) + _append_list(self._definations, device_definations) + _append_list(self._include_dirs, device_include_dirs) + _append_list(self._cflags, device_cflags) + _append_list(self._ldflags, device_ldflags) + _append_list(self._libraries_dirs, device_libraries_dirs) + _append_list(self._libraries, device_libraries) + _append_list(self._passthough_args, device_passthough_args) self._finalize_options() diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 6a4f8ce741d6f4..26a897bf806ed4 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -124,7 +124,7 @@ def log_module_code(*args: Any, **kwargs: Any) -> None: pass -def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool: +def supported_dtype_of_cpp_wrapper(dtype: torch.device, device_type: str) -> bool: supported_dtype = { torch.float32, torch.float64, @@ -140,7 +140,7 @@ def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool: torch.complex128, torch.float16, } - if cuda: + if device_type == "cuda": supported_dtype.add(torch.float8_e4m3fn) supported_dtype.add(torch.float8_e5m2) supported_dtype.add(torch.float8_e4m3fnuz) @@ -365,7 +365,7 @@ def __init__( self.device_idxs: OrderedSet[int] = ( const_module.device_idxs if const_module else OrderedSet() ) - self.cuda = False + self.device_type = "cpu" self.buffers: List[ir.Buffer] = [] self.operations: List[ir.Operation] = [] self.const_output_index: Dict[str, int] = ( @@ -1636,14 +1636,10 @@ def validate_can_generate_cpp_wrapper(self) -> None: ): dtype = may_get_constant_buffer_dtype(value) - if not supported_dtype_of_cpp_wrapper(dtype, self.cuda): + if not supported_dtype_of_cpp_wrapper(dtype, self.device_type): raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}") def init_wrapper_code(self) -> None: - self.cuda = "cuda" in self.device_types - if self.cpp_wrapper: - self.validate_can_generate_cpp_wrapper() - device_types = self.device_types.copy() device_types.discard("cpu") device_types.discard("meta") @@ -1652,13 +1648,18 @@ def init_wrapper_code(self) -> None: "+".join(device_types) ) only_cpu = len(device_types) == 0 - device_type = "cpu" if only_cpu else device_types.pop() + self.device_type = "cpu" if only_cpu else device_types.pop() + + if self.cpp_wrapper: + self.validate_can_generate_cpp_wrapper() - self.device_ops = get_device_op_overrides(device_type) + self.device_ops = get_device_op_overrides(self.device_type) wrapper_code_gen_cls = get_wrapper_codegen_for_device( - device_type, self.cpp_wrapper + self.device_type, self.cpp_wrapper ) - assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported" + assert ( + wrapper_code_gen_cls is not None + ), f"Device {self.device_type} not supported" self.wrapper_code = wrapper_code_gen_cls() if self.const_module: @@ -1676,7 +1677,7 @@ def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]: wrapper code and run it to generate autotuned kernel binaries in the first pass; and then generate cpp wrapper code and compile it to a dynamic library in the second pass. """ - if "cuda" in self.device_types: + if any(device in self.device_types for device in ["cuda", "xpu"]): # first pass self.cpp_wrapper = False # Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick @@ -1906,7 +1907,7 @@ def compile_to_fn(self) -> Any: # Directly return the file path with the compiled code return AotCodeCompiler.compile( - self, code, serialized_extern_kernel_nodes, cuda=self.cuda + self, code, serialized_extern_kernel_nodes, device_type=self.device_type ) else: return self.compile_to_module().call diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index aaa45ea4c909a3..4c7366193cc5c8 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -140,6 +140,21 @@ def _find_rocm_home() -> Optional[str]: file=sys.stderr) return rocm_home +def _find_sycl_home() -> Optional[str]: + """Find the OneAPI install path.""" + # Guess #1 + sycl_home = os.environ.get('ONEAPI_ROOT') + if sycl_home is None: + # Guess #2 + icpx_path = shutil.which('icpx') + if icpx_path is not None: + sycl_home = os.path.dirname(os.path.dirname( + os.path.realpath(icpx_path))) + + if sycl_home and not torch.xpu.is_available(): + print(f"No XPU runtime is found, using ONEAPI_ROOT='{sycl_home}'", + file=sys.stderr) + return sycl_home def _join_rocm_home(*paths) -> str: """ @@ -156,6 +171,20 @@ def _join_rocm_home(*paths) -> str: 'ROCm and Windows is not supported.') return os.path.join(ROCM_HOME, *paths) +def _join_sycl_home(*paths) -> str: + """ + Join paths with SYCL_HOME, or raises an error if it SYCL_HOME is not set. + + This is basically a lazy way of raising an error for missing $SYCL_HOME + only once we need to get any SYCL-specific path. + """ + if SYCL_HOME is None: + raise OSError('SYCL_HOME environment variable is not set. ' + 'Please set it to your OneAPI install root.') + + return os.path.join(SYCL_HOME, *paths) + + ABI_INCOMPATIBILITY_WARNING = ''' @@ -207,6 +236,8 @@ def _join_rocm_home(*paths) -> str: CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH') +SYCL_HOME = _find_sycl_home() if torch.xpu._is_compiled() else None + # PyTorch releases have the version pattern major.minor.patch, whereas when # PyTorch is built from source, we append the git commit hash, which gives # it the below pattern. @@ -1075,7 +1106,7 @@ def CUDAExtension(name, sources, *args, **kwargs): ... 'nvcc': ['-O2', '-rdc=true']}) """ library_dirs = kwargs.get('library_dirs', []) - library_dirs += library_paths(cuda=True) + library_dirs += library_paths(device_type="cuda") kwargs['library_dirs'] = library_dirs libraries = kwargs.get('libraries', []) @@ -1119,7 +1150,7 @@ def CUDAExtension(name, sources, *args, **kwargs): sources = list(hipified_sources) - include_dirs += include_paths(cuda=True) + include_dirs += include_paths(device_type="cuda") kwargs['include_dirs'] = include_dirs kwargs['language'] = 'c++' @@ -1144,9 +1175,9 @@ def CUDAExtension(name, sources, *args, **kwargs): return setuptools.Extension(name, sources, *args, **kwargs) -def include_paths(cuda: bool = False) -> List[str]: +def include_paths(device_type: str = "cpu") -> List[str]: """ - Get the include paths required to build a C++ or CUDA extension. + Get the include paths required to build a C++ or CUDA or SYCL extension. Args: cuda: If `True`, includes CUDA-specific include paths. @@ -1164,10 +1195,10 @@ def include_paths(cuda: bool = False) -> List[str]: os.path.join(lib_include, 'TH'), os.path.join(lib_include, 'THC') ] - if cuda and IS_HIP_EXTENSION: + if device_type == "cuda" and IS_HIP_EXTENSION: paths.append(os.path.join(lib_include, 'THH')) paths.append(_join_rocm_home('include')) - elif cuda: + elif device_type == "cuda": cuda_home_include = _join_cuda_home('include') # if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home. # but gcc doesn't like having /usr/include passed explicitly @@ -1180,10 +1211,12 @@ def include_paths(cuda: bool = False) -> List[str]: paths.append(cuda_inc_path) if CUDNN_HOME is not None: paths.append(os.path.join(CUDNN_HOME, 'include')) + elif device_type == "xpu": + paths.append(_join_sycl_home('include')) return paths -def library_paths(cuda: bool = False) -> List[str]: +def library_paths(device_type: str = "cpu") -> List[str]: """ Get the library paths required to build a C++ or CUDA extension. @@ -1196,12 +1229,12 @@ def library_paths(cuda: bool = False) -> List[str]: # We need to link against libtorch.so paths = [TORCH_LIB_PATH] - if cuda and IS_HIP_EXTENSION: + if device_type == "cuda" and IS_HIP_EXTENSION: lib_dir = 'lib' paths.append(_join_rocm_home(lib_dir)) if HIP_HOME is not None: paths.append(os.path.join(HIP_HOME, 'lib')) - elif cuda: + elif device_type == "cuda": if IS_WINDOWS: lib_dir = os.path.join('lib', 'x64') else: @@ -1216,6 +1249,17 @@ def library_paths(cuda: bool = False) -> List[str]: paths.append(_join_cuda_home(lib_dir)) if CUDNN_HOME is not None: paths.append(os.path.join(CUDNN_HOME, lib_dir)) + elif device_type == "xpu": + if IS_WINDOWS: + lib_dir = os.path.join('lib', 'x64') + else: + lib_dir = 'lib64' + if (not os.path.exists(_join_sycl_home(lib_dir)) and + os.path.exists(_join_sycl_home('lib'))): + lib_dir = 'lib' + + paths.append(_join_sycl_home(lib_dir)) + return paths @@ -2165,7 +2209,11 @@ def _write_ninja_file_to_build_library(path, user_includes = [os.path.abspath(file) for file in extra_include_paths] # include_paths() gives us the location of torch/extension.h - system_includes = include_paths(with_cuda) + # TODO generalize with_cuda as specific device type. + if with_cuda: + system_includes = include_paths("cuda") + else: + system_includes = include_paths("cpu") # sysconfig.get_path('include') gives us the location of Python.h # Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS # installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder From f9e74f33c3416c9aa80673e9115f2d80f294b579 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Mon, 9 Sep 2024 23:40:33 -0700 Subject: [PATCH 0668/1018] [Inductor] Generalize device guard codegen for cpp_wrapper mode. (#134761) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134761 Approved by: https://github.com/jansel, https://github.com/EikanWang ghstack dependencies: #134693 --- torch/_inductor/codegen/common.py | 21 ++++++++++++++++++- .../codegen/cuda/device_op_overrides.py | 15 +++++++++++++ torch/_inductor/codegen/wrapper.py | 10 ++++----- .../codegen/xpu/device_op_overrides.py | 15 +++++++++++++ 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f80d1f71e380d0..82fb318f2ca814 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -109,6 +109,21 @@ def synchronize(self): def device_guard(self, device_idx): raise NotImplementedError + def cpp_device_guard(self): + raise NotImplementedError + + def cpp_aoti_device_guard(self): + raise NotImplementedError + + def cpp_stream_guard(self): + raise NotImplementedError + + def cpp_aoti_stream_guard(self): + raise NotImplementedError + + def cpp_getStreamFromExternal(self): + raise NotImplementedError + device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} @@ -222,7 +237,11 @@ def init_backend_registration(): ) if get_scheduling_for_device("xpu") is None: - register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen) + register_backend_for_device( + "xpu", + TritonScheduling, + WrapperCodeGen, + ) private_backend = torch._C._get_privateuse1_backend_name() if ( diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 7ff99b871c82d6..03e760ddc3da86 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -15,5 +15,20 @@ def synchronize(self): def device_guard(self, device_idx): return f"torch.cuda._DeviceGuard({device_idx})" + def cpp_device_guard(self): + return "at::cuda::CUDAGuard" + + def cpp_aoti_device_guard(self): + return "AOTICudaGuard" + + def cpp_stream_guard(self): + return "at::cuda::CUDAStreamGuard" + + def cpp_aoti_stream_guard(self): + return "AOTICudaStreamGuard" + + def cpp_getStreamFromExternal(self): + return "at::cuda::getStreamFromExternal" + register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 500532d432b39f..f3a6d07647627f 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -310,13 +310,13 @@ def codegen(self, code: IndentedBuffer) -> None: if self.last_seen_device_guard_index is None: if config.abi_compatible: code.writeline( - "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);" + f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" ) else: code.writeline( maybe_hipify_code_wrapper( - "at::cuda::CUDAStreamGuard stream_guard(" - + "at::cuda::getStreamFromExternal(stream, this->device_idx_));" + f"{V.graph.device_ops.cpp_stream_guard()} stream_guard(" + + f"{V.graph.device_ops.cpp_getStreamFromExternal()}(stream, this->device_idx_));" ) ) else: @@ -326,10 +326,10 @@ def codegen(self, code: IndentedBuffer) -> None: else: if self.last_seen_device_guard_index is None: code.writeline( - f"AOTICudaGuard device_guard({self.device_idx});" + f"{V.graph.device_ops.cpp_aoti_device_guard()} device_guard({self.device_idx});" if config.abi_compatible else maybe_hipify_code_wrapper( - f"at::cuda::CUDAGuard device_guard({self.device_idx});" + f"{V.graph.device_ops.cpp_device_guard()} device_guard({self.device_idx});" ) ) else: diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 6eec71344ae8cc..a0c9a952991ecb 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -15,5 +15,20 @@ def synchronize(self): def device_guard(self, device_idx): return f"torch.xpu._DeviceGuard({device_idx})" + def cpp_device_guard(self): + return "at::xpu::XPUGuard" + + def cpp_aoti_device_guard(self): + return "AOTIXpuGuard" + + def cpp_stream_guard(self): + return "at::xpu::XPUStreamGuard" + + def cpp_aoti_stream_guard(self): + return "AOTIXpuStreamGuard" + + def cpp_getStreamFromExternal(self): + return "at::xpu::getStreamFromExternal" + register_device_op_overrides("xpu", XPUDeviceOpOverrides()) From 6760029cda8c7e18e73e771e0e93decc1069b48b Mon Sep 17 00:00:00 2001 From: torotoki Date: Tue, 10 Sep 2024 11:37:57 +0000 Subject: [PATCH 0669/1018] Add dynamo itertools.pairwise support (#135416) Fixes #133766 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135416 Approved by: https://github.com/XuehaiPan, https://github.com/jansel Co-authored-by: Xuehai Pan --- test/dynamo/test_functions.py | 11 +++++++++++ torch/_dynamo/polyfills/itertools.py | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index c99948a902dc3b..d215f21fbb654c 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -315,6 +315,17 @@ def test_itertools_combinations(a, b): combs.append(torch.ones(size)) return combs + @unittest.skipIf( + sys.version_info < (3, 10), + "itertools.pairwise was added at Python 3.10", + ) + @make_test + def test_itertools_pairwise(a): + pairs = [] + for size in itertools.pairwise((1, 2, 3, 4)): + pairs.append(torch.ones(size)) + return pairs + @make_test def test_np_iinfo(a): max_dim = np.iinfo(np.int16).max diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index a829d2e9b88081..63266e85a2b8a1 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -5,6 +5,7 @@ from __future__ import annotations import itertools +import sys from typing import Iterable, Iterator, TypeVar from ..decorators import substitute_in_graph @@ -65,6 +66,23 @@ def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: next_i += step +# Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise +if sys.version_info >= (3, 10): + + @substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type] + def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]: + a = None + first = True + for b in iterable: + if first: + first = False + else: + yield a, b # type: ignore[misc] + a = b + + __all__ += ["pairwise"] + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: From 43d25d6eca7234005134b5ccb68041a9e0cb9d21 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 6 Sep 2024 12:02:56 -0700 Subject: [PATCH 0670/1018] Handle KeyError for compiler collective in scalars too (#135385) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135385 Approved by: https://github.com/jansel --- test/distributed/test_dynamo_distributed.py | 56 +++++++++++++++++++++ torch/_dynamo/variables/builder.py | 3 +- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 52943afe984971..c689b8b321f6a1 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -959,6 +959,62 @@ def f(rank, xs): for r in res[1:]: self.assertEqual(res[0], r) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @config.patch(enable_compiler_collectives=True) + def test_compiler_collectives_scalar_missing_source(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(rank, xs): + return torch.tensor(xs[rank], device=self.rank) + + xs = [] + for i in range(self.world_size): + xs.append(10 + i) + + f(self.rank, xs) + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @config.patch(enable_compiler_collectives=True) + def test_compiler_collectives_type_mismatch(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(x): + if isinstance(x, int): + return torch.tensor(x, device=self.rank) + else: + return x.sum() + + if self.rank == 0: + x = torch.randn(10, device=self.rank) + else: + x = 12 + f(x) + + # This deadlocks, I guess we don't support this + """ + if self.rank == 0: + x = torch.randn(12, device=self.rank) + else: + x = 10 + f(x) + """ + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", False) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5644f45e098ed1..f7ac4d073b534b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1778,7 +1778,8 @@ def update_frame_state(value): else: # Apply the updates for sub_state in st.all_states: - update_frame_state(sub_state.input_sizes[name]) + if name in sub_state.input_sizes: + update_frame_state(sub_state.input_sizes[name]) frame_state_entry = self.tx.output.frame_state[name] # TODO: This should be dynamic, as we in general do not From c8bb1af615b6fab96045005d9473fa256ecc1610 Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 9 Sep 2024 07:30:54 -0700 Subject: [PATCH 0671/1018] [inductor] Flip custom_op_default_layout_constraint (#135239) By default, Inductor should respect the stride order of input Tensors to custom operators. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/135239 Approved by: https://github.com/albanD ghstack dependencies: #135391 --- test/inductor/test_torchinductor.py | 1 + test/test_custom_ops.py | 2 +- torch/_inductor/config.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 5bc9a7050686d2..881c2cf6db7d65 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10627,6 +10627,7 @@ def test_mutable_custom_op_fixed_layout2(self): lib.define( "bar(Tensor x, bool is_compiling) -> Tensor", + tags=torch.Tag.flexible_layout, ) bar_strides = [] diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 05afe8fec38a38..524095f741a70c 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3436,7 +3436,7 @@ def test_layout_constraint_tags(self): ({needs_fixed_stride_order}, needs_fixed_stride_order), ({flexible_layout}, flexible_layout), # If no tags are provided, then the following is the default - (set(), flexible_layout), + (set(), needs_fixed_stride_order), # If multiple tags are provided, then we use the most constrained tag. ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), ] diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 28664d54f4e58a..fb089e03dad2f2 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -70,7 +70,7 @@ def autotune_remote_cache_default() -> Optional[bool]: # (that is, one of {"needs_fixed_stride_order", "flexible_layout"}), # If the custom op does not have a layout constraint tag already # then we assume the following applies. -custom_op_default_layout_constraint = "flexible_layout" +custom_op_default_layout_constraint = "needs_fixed_stride_order" # use cpp wrapper instead of python wrapper cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" From f8a1e92ef33f6271963ab31b27b672780c001715 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 10 Sep 2024 16:05:19 +0000 Subject: [PATCH 0672/1018] Check function declarations of Core ML code (#135467) Relax the restrictions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135467 Approved by: https://github.com/ezyang --- caffe2/CMakeLists.txt | 2 +- torch/csrc/jit/backends/coreml/cpp/context.cpp | 13 ++++--------- torch/csrc/jit/backends/coreml/cpp/context.h | 11 ++--------- 3 files changed, 7 insertions(+), 19 deletions(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 2160399a3ea296..e498f2380fbf2d 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -767,7 +767,7 @@ if(NOT MSVC) set_source_files_properties(${PROJECT_SOURCE_DIR}/torch/csrc/distributed/c10d/socket.cpp PROPERTIES COMPILE_OPTIONS "-Wno-error=deprecated") endif() -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") target_compile_options_if_supported(torch_cpu "-Wmissing-prototypes") target_compile_options_if_supported(torch_cpu "-Werror=missing-prototypes") get_target_property(TORCH_CPU_SOURCES torch_cpu SOURCES) diff --git a/torch/csrc/jit/backends/coreml/cpp/context.cpp b/torch/csrc/jit/backends/coreml/cpp/context.cpp index 3c63acce711343..a8c385eef525a8 100644 --- a/torch/csrc/jit/backends/coreml/cpp/context.cpp +++ b/torch/csrc/jit/backends/coreml/cpp/context.cpp @@ -1,10 +1,8 @@ #include #include +#include -namespace torch { -namespace jit { -namespace mobile { -namespace coreml { +namespace torch::jit::mobile::coreml { std::atomic g_coreml_ctx_registry; @@ -15,11 +13,8 @@ BackendRegistrar::BackendRegistrar(ContextInterface* ctx) { void setModelCacheDirectory(std::string path) { auto p = g_coreml_ctx_registry.load(); if (p) { - p->setModelCacheDirectory(path); + p->setModelCacheDirectory(std::move(path)); } } -} // namespace coreml -} // namespace mobile -} // namespace jit -} // namespace torch +} // namespace torch::jit::mobile::coreml diff --git a/torch/csrc/jit/backends/coreml/cpp/context.h b/torch/csrc/jit/backends/coreml/cpp/context.h index 644e3428af3e1d..a07a8b81fc7d28 100644 --- a/torch/csrc/jit/backends/coreml/cpp/context.h +++ b/torch/csrc/jit/backends/coreml/cpp/context.h @@ -1,13 +1,9 @@ #ifndef PTM_COREML_Context_h #define PTM_COREML_Context_h -#include #include -namespace torch { -namespace jit { -namespace mobile { -namespace coreml { +namespace torch::jit::mobile::coreml { struct ContextInterface { virtual ~ContextInterface() = default; @@ -21,9 +17,6 @@ class BackendRegistrar { void setModelCacheDirectory(std::string path); -} // namespace coreml -} // namespace mobile -} // namespace jit -} // namespace torch +} // namespace torch::jit::mobile::coreml #endif From f9788a642fd16bc4d6d6f98679b8ebeaadd8af17 Mon Sep 17 00:00:00 2001 From: Zixi Qi Date: Tue, 10 Sep 2024 17:20:59 +0000 Subject: [PATCH 0673/1018] Add more logging to TunableOp validators (#135396) Summary: Add more logging to TunableOp validators Test Plan: Verified additional logging when loading kernel selections: ``` ROCBLAS_VERSION validation: expect 4.0.0-72e57364-dirty to match 4.0.0-72e57364-dirty GCN_ARCH_NAME validation: expect gfx942:sramecc+:xnack- to match gfx942:sramecc+:xnack- HIPBLASLT_VERSION validation: expect 800-a15e4178 to match 800-a15e4178 ROCM_VERSION validation: expect 6.0.0.0-12969-1544e39 to match 6.0.0.0-12969-1544e39 PT_VERSION validation: expect 2.5.0 to match 2.5.0 ``` ``` [qizixi@devgpu039.atn3 /data/users/qizixi/fbsource/fbcode (f9305317d|remote/master)]$ PYTORCH_TUNABLEOP_VERBOSE=1 buck2 run mode/{opt,amd-gpu} -c fbcode.e nable_gpu_sections=true //scripts/xdwang/example:fc_llama -- --enable-tuning File changed: fbcode//hipblas_tuning_pt_llama0.csv Buck UI: https://www.internalfb.com/buck2/1ed2fac4-743e-49ef-805f-7fb6b9300022 Network: Up: 0B Down: 0B Jobs completed: 4189. Time elapsed: 0.2s. BUILD SUCCEEDED Enabled tuning - Run Linear (matmul) 2 x 1280 x 8192, dtype = torch.bfloat16 INFO:2024-09-06 14:38:07 2834864:2835138 CuptiActivityProfiler.cpp:260] HIP versions. Roctracer: 4.1; Runtime: 60032830; Driver: 60032830 INFO:2024-09-06 14:38:07 2834864:2836083 DynoConfigLoader.cpp:61] Setting communication fabric enabled = 0 reading tuning results from hipblas_tuning_pt_llama0.csv Validator PT_VERSION=2.5.0 Validator ROCM_VERSION=6.0.0.0-12969-1544e39 Validator HIPBLASLT_VERSION=800-a15e4178 Validator GCN_ARCH_NAME=gfx942:sramecc+:xnack- Validator ROCBLAS_VERSION=4.0.0-72e57364-dirty ROCBLAS_VERSION validation: expect 4.0.0-72e57364-dirty to match 4.0.0-72e57364-dirty GCN_ARCH_NAME validation: expect gfx942:sramecc+:xnack- to match gfx942:sramecc+:xnack- HIPBLASLT_VERSION validation: expect 800-a15e4178 to match 800-a15e4178 ROCM_VERSION validation: expect 6.0.0.0-12969-1544e39 to match 6.0.0.0-12969-1544e39 PT_VERSION validation: expect 2.5.0 to match 2.5.0 Loading results Avg time: 13.165860176086426 us, Achieved 3.19 TFLOPS, 1598.24 GB/s - Run Linear (matmul) 2 x 8192 x 1024, dtype = torch.bfloat16 Avg time: 13.230760097503662 us, Achieved 2.54 TFLOPS, 1271.14 GB/s - Run Linear (matmul) 2 x 7168 x 8192, dtype = torch.bfloat16 Avg time: 26.804399490356445 us, Achieved 8.76 TFLOPS, 4384.90 GB/s - Run Linear (matmul) 2 x 8192 x 3584, dtype = torch.bfloat16 Avg time: 13.407809734344482 us, Achieved 8.76 TFLOPS, 4384.14 GB/s 2x1280x8192-torch.bfloat16,13.165860176086426,3.18574247630113,1598.237845349412 2x8192x1024-torch.bfloat16,13.230760097503662,2.536092541374924,1271.1420867780075 2x7168x8192-torch.bfloat16,26.804399490356445,8.762778814892096,4384.9040543618985 2x8192x3584-torch.bfloat16,13.407809734344482,8.759112362638383,4384.138585247748 ``` Reviewed By: leitian Differential Revision: D62322830 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135396 Approved by: https://github.com/eqy --- aten/src/ATen/cuda/tunable/Tunable.cpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index 5053f6693b96da..1b7c898758558f 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -188,7 +188,10 @@ TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "ROCM_VERSION", [rocm_version]() { return rocm_version; }, - [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); + [rocm_version](auto&& k) { + TUNABLE_LOG1("ROCM_VERSION validation: expect ", k, " to match ", rocm_version); + return rocm_version == k ? OK : FAIL; + }); } // gfx arch { @@ -196,7 +199,10 @@ TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "GCN_ARCH_NAME", [gcn_arch_name]() { return gcn_arch_name; }, - [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); + [gcn_arch_name](auto&& k) { + TUNABLE_LOG1("GCN_ARCH_NAME validation: expect ", k, " to match ", gcn_arch_name); + return gcn_arch_name == k ? OK : FAIL; + }); } // rocblas { @@ -212,7 +218,10 @@ TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "ROCBLAS_VERSION", [rocblas_version]() { return rocblas_version; }, - [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + [rocblas_version](auto&& k) { + TUNABLE_LOG1("ROCBLAS_VERSION validation: expect ", k, " to match ", rocblas_version); + return rocblas_version == k ? OK : FAIL; + }); } // hipblaslt { @@ -226,7 +235,10 @@ TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "HIPBLASLT_VERSION", [hipblaslt_version]() { return hipblaslt_version; }, - [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); + [hipblaslt_version](auto&& k) { + TUNABLE_LOG1("HIPBLASLT_VERSION validation: expect ", k, " to match ", hipblaslt_version); + return hipblaslt_version == k ? OK : FAIL; + }); } #endif } From f285bd99e1b535ed6def3e9d2f43f685aaee004c Mon Sep 17 00:00:00 2001 From: "Zhou, Lingzhi" Date: Tue, 10 Sep 2024 17:45:27 +0000 Subject: [PATCH 0674/1018] [Partitioner] Reuse partition to check whether nodes exist (#135317) The time complexity of find node whether in NodeList is O(n). Reuse partition to speed up due to partition.nodes is hash table and has same elements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135317 Approved by: https://github.com/ezyang --- test/test_fx_passes.py | 4 ++-- torch/fx/passes/infra/partitioner.py | 2 +- torch/fx/passes/utils/fuser_utils.py | 13 +++++++------ 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py index e5ed6e078c9eab..ac78aef325b4ce 100644 --- a/test/test_fx_passes.py +++ b/test/test_fx_passes.py @@ -360,7 +360,7 @@ def test_fuser_util(self, partition): partitions = [] for node_names in partition: - partitions.append([nodes_by_name[name] for name in node_names]) + partitions.append(dict.fromkeys([nodes_by_name[name] for name in node_names])) fused_graph = fuse_by_partitions(gm, partitions) @@ -385,7 +385,7 @@ def test_fuser_util_xfail(self, partition): partitions = [] for node_names in partition: - partitions.append([nodes_by_name[name] for name in node_names]) + partitions.append(dict.fromkeys([nodes_by_name[name] for name in node_names])) with self.assertRaises(Exception): fuse_by_partitions(gm, partitions) diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 271f90a7b75e8b..0df74ceba71e8b 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -268,7 +268,7 @@ def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") - # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] return fuse_by_partitions( self.graph_module, - [list(partition.nodes) for partition in partitions], + [partition.nodes for partition in partitions], prefix=prefix, ) diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 11a9cfa34898ab..3375472528b6a6 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -92,6 +92,7 @@ def bfs_find_cycle(root_nodes: NodeList) -> bool: @compatibility(is_backward_compatible=False) def fuse_as_graphmodule(gm: GraphModule, nodes: NodeList, + partition: Dict[Node, None], module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: """ @@ -135,7 +136,7 @@ def remap_inputs(x): # do something here pass - if x in nodes: + if x in partition: # x is inside subgraph, return the copied node # the node should have been copied aleady, as we are copying graph in the topological order return node_map[x] @@ -159,7 +160,7 @@ def remap_inputs(x): for node in nodes: for user_node in node.users: - if user_node not in nodes: + if user_node not in partition: # external user node, need to expose as an output output_mapping[node] = node_map[node] @@ -219,12 +220,12 @@ def erase_nodes(gm: GraphModule, nodes: NodeList): @compatibility(is_backward_compatible=False) -def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList], prefix: str = "fused_") -> GraphModule: - for partition_id, nodes in enumerate(partitions): - sorted_nodes = topo_sort(nodes) +def fuse_by_partitions(gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_") -> GraphModule: + for partition_id, partition in enumerate(partitions): + sorted_nodes = topo_sort(list(partition)) submodule_name = prefix + str(partition_id) - sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, partition, submodule_name) insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) From 349d30a6378206ab5a451d181172461b268b5f62 Mon Sep 17 00:00:00 2001 From: Ivan Zaitsev Date: Tue, 10 Sep 2024 18:07:11 +0000 Subject: [PATCH 0675/1018] Revert "[fx] Bypass custom __setattr__ in Node.__init__ (#135079)" (#135562) This reverts commit 66da3b3b2acacb116a9b23e91b24934830eaf6b8. #135079 breaks internal tests and needs to be reverted. Revert with mergebot doesn't work as this PR is technically part of the stack, but, according to @jansel, it should be possible to revert it individually. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135562 Approved by: https://github.com/jansel, https://github.com/seemethere --- torch/fx/graph.py | 4 +-- torch/fx/node.py | 77 ++++++++++++++++++----------------------------- 2 files changed, 32 insertions(+), 49 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 1601eee07f2717..e4fdd79fcbb280 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -816,10 +816,10 @@ def remove(self, node: Node) -> None: def find_nodes(self, *, op: str, target: Optional['Target'] = None): if op == "call_function": assert target is not None - return [*self.table[(op, target)].keys()] + return dict(self.table[(op, target)]).keys() if target is None: - return [*self.table[(op, None)].keys()] + return dict(self.table[(op, None)]).keys() # op is call_method, get_attr, call_module return [node for node in self.table[(op, None)].keys() if node.target == target] diff --git a/torch/fx/node.py b/torch/fx/node.py index 323f954d1a3e2e..f84b23e29ddf85 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -168,16 +168,6 @@ class Node(_NodeBase): """ _args: Tuple['Argument', ...] _kwargs: Dict[str, 'Argument'] - graph: 'Graph' - name: str - op: str - target: 'Target' - _input_nodes: Dict['Node', None] - users: Dict['Node', None] - type: Optional[Any] - _sort_key: Any - _repr_fn: Optional[Callable[['Node'], str]] - meta: Dict[str, Any] @compatibility(is_backward_compatible=True) def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', @@ -209,7 +199,11 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', annotation of values in the generated code or for other types of analyses. """ + super().__init__() + self.graph = graph + self.name = name # unique name of value being created assert op in _legal_ops + self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr if op == 'call_function': if not callable(target): raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' @@ -218,22 +212,13 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', if not isinstance(target, str): raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' 'but a str is expected') - super().__init__() - - # bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects - assign = object.__setattr__ - - assign(self, "graph", graph) - assign(self, "name", name) # unique name of value being created - assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr - - assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr + self.target = target # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add # All `Node`-valued inputs. Key is the Node, value is don't-care. # The public API for this is `all_input_nodes`, this private attribute # should not be accessed directly. - assign(self, "_input_nodes", {}) + self._input_nodes : Dict[Node, None] = {} self.__update_args_kwargs(args, kwargs) # All of the nodes that use the value produced by this Node @@ -241,8 +226,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # would appear once here, but represents two uses. # # Is a dict to act as an "ordered set". Keys are significant, value dont-care - assign(self, "users", {}) - + self.users : Dict[Node, None] = {} # Type expression representing the output value of this node. # This should contain the same class of Type objects that would appear # as type annotations for function inputs/outputs. @@ -253,15 +237,15 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # generated function return type. (Note this is a special case. ``return`` # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. - assign(self, "type", return_type) - assign(self, "_sort_key", ()) + self.type : Optional[Any] = return_type + self._sort_key: Any = () # If set, use this fn to print this node - assign(self, "_repr_fn", None) + self._repr_fn : Optional[Callable[[Node], str]] = None # Dictionary to store metadata passes need to do their # transformations. This metadata is preserved across node copies - assign(self, "meta", {}) + self.meta : Dict[str, Any] = {} def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -506,14 +490,14 @@ def update_users_and_input_nodes(n: Any) -> Any: # Clear prior users and input_nodes for old_use in self._input_nodes.keys(): old_use.users.pop(self) - object.__setattr__(self, "_input_nodes", {}) # bypass Node.__setattr__ + self._input_nodes = {} # We do three things in a single pass of the args # - Normalize list->immutable_list, dict->immutable_dict, etc # - Populate self._input_nodes # - Populate arg.users[self] for each arg - object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes)) - object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)) + self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment] + self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment] def __repr__(self) -> str: if self._repr_fn: @@ -756,24 +740,23 @@ def _rename(self, candidate: str) -> None: self.graph._graph_namespace._rename_object(self, name) def __setattr__(self, name: str, value: Any) -> None: - if name in _setattr_custom_handling: - if name == "name": - m = self.graph.owning_module - if getattr(m, "_replace_hook", None): - assert isinstance(value, str) - for user in self.users: - m._replace_hook(old=self, new=value, user=user) - elif name in ("op", "target"): - table = getattr(self.graph, "_find_nodes_lookup_table", None) - if table and self in table: - self.graph._find_nodes_lookup_table.remove(self) - object.__setattr__(self, name, value) - self.graph._find_nodes_lookup_table.insert(self) - return - + if name == 'name' and hasattr(self, "name"): + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + assert isinstance(value, str) + for user in self.users: + m._replace_hook(old=self, new=value, user=user) + update = False + if ( + hasattr(self, name) and + hasattr(self.graph, "_find_nodes_lookup_table") and + self in self.graph._find_nodes_lookup_table + ): + update = True + self.graph._find_nodes_lookup_table.remove(self) object.__setattr__(self, name, value) - -_setattr_custom_handling = dict.fromkeys(["name", "op", "target"]) + if update: + self.graph._find_nodes_lookup_table.insert(self) @compatibility(is_backward_compatible=True) def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: From f8b869dfc33d003866911fd61d6dc72828747c10 Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Tue, 10 Sep 2024 18:13:59 +0000 Subject: [PATCH 0676/1018] Migrate remaining jobs to use runner determinator (#134867) At this point all self-hosted runner jobs should be using the runner determinator to switch between LF and Meta runners. This change updates the remaining jobs that have not yet been migrated over. Issue: https://lf-pytorch.atlassian.net/browse/PC-25 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134867 Approved by: https://github.com/ZainRizvi --- .github/workflows/build-libtorch-images.yml | 18 +++++++-- .github/workflows/build-manywheel-images.yml | 39 ++++++++++++++----- .github/workflows/build-triton-wheel.yml | 15 ++++++- .github/workflows/create_release.yml | 15 ++++++- .github/workflows/docker-builds.yml | 12 +++++- .github/workflows/docker-release.yml | 18 +++++++-- .github/workflows/inductor-cu124.yml | 11 ++++++ .../workflows/inductor-micro-benchmark.yml | 11 ++++++ .github/workflows/inductor-perf-compare.yml | 11 ++++++ .../inductor-perf-test-nightly-a10g.yml | 11 ++++++ .../inductor-perf-test-nightly-aarch64.yml | 11 ++++++ .../inductor-perf-test-nightly-x86.yml | 11 ++++++ .../workflows/inductor-perf-test-nightly.yml | 11 ++++++ .github/workflows/inductor-periodic.yml | 13 +++++++ .github/workflows/inductor-rocm.yml | 11 ++++++ .../target-determination-indexer.yml | 12 +++++- .github/workflows/torchbench.yml | 11 ++++++ 17 files changed, 219 insertions(+), 22 deletions(-) diff --git a/.github/workflows/build-libtorch-images.yml b/.github/workflows/build-libtorch-images.yml index 170b35107ac240..5146e7593a5fdd 100644 --- a/.github/workflows/build-libtorch-images.yml +++ b/.github/workflows/build-libtorch-images.yml @@ -29,9 +29,19 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + build-docker-cuda: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" strategy: matrix: cuda_version: ["12.4", "12.1", "11.8"] @@ -66,7 +76,8 @@ jobs: .ci/docker/libtorch/build.sh libtorch-cxx11-builder:cuda${{matrix.cuda_version}} build-docker-rocm: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" strategy: matrix: rocm_version: ["6.1", "6.2"] @@ -101,7 +112,8 @@ jobs: .ci/docker/libtorch/build.sh libtorch-cxx11-builder:rocm${{matrix.rocm_version}} build-docker-cpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index 8e64d612d7bb83..750ee99d52e38d 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -33,9 +33,19 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + build-docker-cuda: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: am2.linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral" strategy: matrix: cuda_version: ["12.4", "12.1", "11.8"] @@ -73,7 +83,8 @@ jobs: # NOTE: manylinux_2_28 are still experimental, see https://github.com/pytorch/pytorch/issues/123649 build-docker-cuda-manylinux_2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" strategy: matrix: cuda_version: ["12.4", "12.1", "11.8"] @@ -110,7 +121,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux2_28-builder:cuda${{matrix.cuda_version}} build-docker-cuda-aarch64: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral" strategy: matrix: cuda_version: ["12.4"] @@ -143,7 +155,8 @@ jobs: .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cuda${{matrix.cuda_version}} build-docker-rocm: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: am2.linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral" strategy: matrix: rocm_version: ["6.1", "6.2"] @@ -178,7 +191,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux-builder:rocm${{matrix.rocm_version}} build-docker-cpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: am2.linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral" steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main @@ -207,7 +221,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux-builder:cpu build-docker-cpu-manylinux_2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" env: GPU_ARCH_TYPE: cpu-manylinux_2_28 steps: @@ -238,7 +253,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux2_28-builder:cpu build-docker-cpu-aarch64: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral" env: GPU_ARCH_TYPE: cpu-aarch64 steps: @@ -269,7 +285,8 @@ jobs: .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cpu-aarch64 build-docker-cpu-aarch64-2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral" env: GPU_ARCH_TYPE: cpu-aarch64-2_28 steps: @@ -303,7 +320,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux2_28_aarch64-builder:cpu-aarch64 build-docker-cpu-cxx11-abi: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" env: GPU_ARCH_TYPE: cpu-cxx11-abi steps: @@ -334,7 +352,8 @@ jobs: .ci/docker/manywheel/build.sh manylinuxcxx11-abi-builder:cpu-cxx11-abi build-docker-xpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" env: GPU_ARCH_TYPE: xpu steps: diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 73abadd4f12e26..8fe307c3b4c86e 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -27,9 +27,19 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + build-wheel: name: "Build Triton Wheel" - runs-on: [self-hosted, linux.4xlarge] + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" strategy: fail-fast: false matrix: @@ -199,7 +209,8 @@ jobs: build-conda: name: "Build Triton Conda" - runs-on: [self-hosted, linux.2xlarge] + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" strategy: fail-fast: false matrix: diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index 43cb2ac5cb2ca3..2c83b8cb571961 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -16,6 +16,15 @@ on: paths: [.github/workflows/create_release.yml] jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + release: if: ${{ github.repository == 'pytorch/pytorch' }} name: Create Release @@ -73,12 +82,14 @@ jobs: upload_source_code_to_s3: if: ${{ github.repository == 'pytorch/pytorch' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }} - runs-on: linux.2xlarge + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" environment: sourcecode-upload name: Upload source code to S3 for release tags permissions: id-token: write - needs: release + needs: + - get-label-type + - release steps: - uses: actions/download-artifact@v4.1.7 with: diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 8b46e314c20da1..2e7f041a23e3d8 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -30,8 +30,18 @@ env: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + docker-build: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} + needs: get-label-type timeout-minutes: 240 strategy: fail-fast: false @@ -68,7 +78,7 @@ jobs: - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks runner: linux.arm64.m7g.4xlarge timeout-minutes: 600 - runs-on: [self-hosted, "${{ matrix.runner }}"] + runs-on: "${{ needs.get-label-type.outputs.label-type }}${{ matrix.runner }}" env: DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/${{ matrix.docker-image-name }} steps: diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index e9c8db722432ea..41c5b40860303a 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -34,9 +34,19 @@ env: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + generate-matrix: if: github.repository_owner == 'pytorch' - runs-on: [self-hosted, linux.large] + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.large" outputs: matrix: ${{ steps.generate-matrix.outputs.matrix }} steps: @@ -54,10 +64,12 @@ jobs: build: if: ${{ github.repository == 'pytorch/pytorch' }} - runs-on: [self-hosted, linux.2xlarge] + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" environment: ${{ (github.ref == 'refs/heads/nightly' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} timeout-minutes: 240 - needs: generate-matrix + needs: + - generate-matrix + - get-label-type strategy: matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }} fail-fast: false diff --git a/.github/workflows/inductor-cu124.yml b/.github/workflows/inductor-cu124.yml index 39d1204a4e7033..950afbf0b591e8 100644 --- a/.github/workflows/inductor-cu124.yml +++ b/.github/workflows/inductor-cu124.yml @@ -18,11 +18,22 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_4-py3_10-gcc9-inductor-build: # Should be synced with the one in inductor.yml, but this doesn't run inductor_timm name: cuda12.4-py3.10-gcc9-sm86 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 431545ea6d0dc7..fad0538d107559 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -16,10 +16,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-inductor-micro-benchmark-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index a5e4ad1781aa06..a38bcadf7e5f7a 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -13,10 +13,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-inductor-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-perf-test-nightly-a10g.yml b/.github/workflows/inductor-perf-test-nightly-a10g.yml index 4acfd43aac3637..e42d7d4148c228 100644 --- a/.github/workflows/inductor-perf-test-nightly-a10g.yml +++ b/.github/workflows/inductor-perf-test-nightly-a10g.yml @@ -68,10 +68,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-inductor-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index 73e76486381935..39bc85752ae67a 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -50,10 +50,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-jammy-aarch64-py3_10-inductor-build: name: linux-jammy-aarch64-py3.10-inductor uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: linux.arm64.m7g.4xlarge build-environment: linux-jammy-aarch64-py3.10 docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index 27a1ce5db4554a..83e8b26dd628e2 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -48,10 +48,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-jammy-cpu-py3_9-gcc11-inductor-build: name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.9-gcc11-build docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks test-matrix: | diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 2f129c52fe135b..5c7651d516d8b1 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -66,10 +66,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-inductor-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 7f83ce1eb54862..6bcb1be5ef0944 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -18,10 +18,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-periodic-dynamo-benchmarks-build: name: cuda12.1-py3.10-gcc9-sm86-periodic-dynamo-benchmarks uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' @@ -60,7 +71,9 @@ jobs: linux-focal-cuda12_1-py3_10-gcc9-inductor-build-gcp: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index 2b1ca7bbd1b3e1..01789ea8942fe5 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -22,10 +22,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-rocm6_1-py3_8-inductor-build: name: rocm6.1-py3.8-inductor uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-rocm6.1-py3.8 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index e8bf91c8d9ee91..373f464eae1390 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -10,8 +10,18 @@ permissions: contents: read jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + index: - runs-on: linux.g5.4xlarge.nvidia.gpu # 1 GPU A10G 24GB each + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" # 1 GPU A10G 24GB each environment: target-determinator-env steps: - name: Clone PyTorch diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index ac581496689977..4e2098e589238d 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -11,10 +11,21 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-torchbench-build-gcp: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' From c15f5c81e4725ae23c71aa0491ee35ecd70458d0 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 6 Sep 2024 10:22:38 -0700 Subject: [PATCH 0677/1018] Fix flaky TestCudaWrapper.test_randint_cuda_cuda_wrapper (#135370) Summary: This test is flaky when run after `test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda_cuda_wrapper` because the TestCase sets config options globally in its setUp() that stick around for subsequent tests. For test isolation, we use a contextlib.ExitStack pattern in other tests to patch the config options and restore them in tearDown(). Update all TestCases in `test/inductor/test_combo_kernels.py` to use that pattern. Test Plan: ``` python test/inductor/test_combo_kernels.py python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda_cuda_wrapper TestCudaWrapper.test_randint_cuda_cuda_wrapper ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135370 Approved by: https://github.com/jansel --- test/inductor/test_combo_kernels.py | 53 +++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index ddf2052703818f..0f997d01535721 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] +import contextlib import sys import unittest @@ -36,12 +37,20 @@ class ComboKernelTests(TestCase): def setUp(self): super().setUp() torch._inductor.metrics.reset() - torch._inductor.config.combo_kernels = True - torch._inductor.config.benchmark_combo_kernel = False + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context( + torch._inductor.config.patch( + { + "combo_kernels": True, + "benchmark_combo_kernel": False, + } + ) + ) def tearDown(self): - super().tearDown() + self._test_stack.close() torch._inductor.metrics.reset() + super().tearDown() @requires_cuda def test_activation_functions(self): @@ -157,12 +166,20 @@ class ComboKernelBenchmarkTests(TestCase): def setUp(self): super().setUp() torch._inductor.metrics.reset() - torch._inductor.config.combo_kernels = True - torch._inductor.config.benchmark_combo_kernel = True + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context( + torch._inductor.config.patch( + { + "combo_kernels": True, + "benchmark_combo_kernel": True, + } + ) + ) def tearDown(self): - super().tearDown() + self._test_stack.close() torch._inductor.metrics.reset() + super().tearDown() @requires_cuda def test_activation_benchmark(self): @@ -303,14 +320,28 @@ class ComboKernelDynamicShapesTests(TestCase): def setUp(self): super().setUp() torch._inductor.metrics.reset() - torch._inductor.config.combo_kernels = True - torch._inductor.config.benchmark_combo_kernel = True - torch._dynamo.config.automatic_dynamic_shapes = False - torch._dynamo.config.assume_static_by_default = False + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context( + torch._inductor.config.patch( + { + "combo_kernels": True, + "benchmark_combo_kernel": True, + } + ) + ) + self._test_stack.enter_context( + torch._dynamo.config.patch( + { + "automatic_dynamic_shapes": False, + "assume_static_by_default": False, + } + ) + ) def tearDown(self): - super().tearDown() + self._test_stack.close() torch._inductor.metrics.reset() + super().tearDown() @requires_cuda def test_dynamic_shapes_activations(self): From d8fcfc8d62cf13a92c8f38a619898708ee2a0464 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Tue, 10 Sep 2024 18:44:03 +0000 Subject: [PATCH 0678/1018] [Profiler] Harden Record Function Kwargs (#135365) Summary: In S445839, we had HTA break because of the "stream" parameter that was added to gpu traces. This brought up discussions regarding hardening our post processing of said inputs as to not break JSON schema as well as downstream tools. For this reason, this diff does the following. 1. Only allow int, double, bool and string values to be processed as kwinputs for JSON output. We can handle lists if needed in the future. 2. Make sure that any boolean is lowercase when a string so that the JSON does not break when parsing it 3. Force stream parameter to be an int Test Plan: Added unit tests to ensure that the list of requirements above is true for kwargs only. Differential Revision: D62304843 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135365 Approved by: https://github.com/aaronenyeshi --- test/profiler/test_profiler.py | 33 ++++++++++++++++++++++++- torch/csrc/autograd/profiler_kineto.cpp | 14 +++++++++++ torch/csrc/profiler/util.cpp | 8 ++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 27818947ddc0bc..295f08fbd10a5c 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -1644,7 +1644,12 @@ def test_profiler_op_event_kwargs(self): cm = torch._C._profiler._RecordFunctionFast( "add_test_kwinputs", [x, y], - {"stream": 0, "grid": "lambda x : x + 1", "debug": 'debug"'}, + { + "stream": 0, + "grid": "lambda x : x + 1", + "debug": 'debug"', + "boolean": True, + }, ) for _ in range(4): with cm: @@ -1658,12 +1663,38 @@ def test_profiler_op_event_kwargs(self): ] for e in op_events: if e["name"] == "add_test_kwinputs": + print(e["args"]) args = e["args"] self.assertTrue("stream" in args) self.assertTrue("grid" in args) + self.assertTrue("boolean" in args) self.assertTrue(args["stream"] == 0) self.assertTrue(args["grid"] == "lambda x : x + 1") self.assertTrue(args["debug"] == "None") + self.assertTrue(args["boolean"]) + + with profile(record_shapes=True) as p1: + cm = torch._C._profiler._RecordFunctionFast( + "add_test_kwinputs", + [x, y], + {"stream": "test", "grid": [1, 2]}, + ) + for _ in range(4): + with cm: + x.add(y) + with TemporaryFileName(mode="w+") as fname1: + p1.export_chrome_trace(fname1) + with open(fname1) as f1: + j = json.load(f1) + op_events = [ + e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op" + ] + for e in op_events: + if e["name"] == "add_test_kwinputs": + print(e["args"]) + args = e["args"] + self.assertTrue("stream" not in args) + self.assertTrue("grid" not in args) def test_is_profiler_enabled(self): self.assertFalse(torch.autograd.profiler._is_profiler_enabled) diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index bb5a7958f286bd..10d1c2e7ef786c 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -261,6 +261,20 @@ struct AddGenericMetadata : public MetadataBase { // Add metadata for kwinputs if exist for (const auto& [key, val] : op_event.kwinputs_) { + if (key == "stream" && !val.isInt()) { + LOG(WARNING) << "Inputted stream is not an int for op: " + << op_event.name_ << " skipping"; + continue; + } + + // Until needed, lets limit the kwargs to only ints, doubles, strings and + // bools + if (!val.isInt() && !val.isDouble() && !val.isString() && !val.isBool()) { + LOG(WARNING) << "Inputted kwarg: " << key + << " is not an int, double, string, or bool for op: " + << op_event.name_ << " skipping"; + continue; + } bool isString = val.isString(); addMetadata(key, ivalueToStr(val, isString)); } diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index d03ffa5c27ac20..0542309f81cf95 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -308,6 +308,14 @@ std::string ivalueToStr(const c10::IValue& val, bool isString) { } std::string mystr = ss.str(); + // For boolean the values that ivalue gives is "True" and "False" but + // json only takes "true" and "false" so we convert the string to lower case + if (val.isBool()) { + for (char& c : mystr) { + c = std::tolower(c); + } + } + // A double quote can cause issues with the chrome tracing so force // all inputs to not contain more than the 2 we add in this function int count = std::count(mystr.begin(), mystr.end(), '\"'); From 38a3a315b8209ef09db0eb4997ac2c48107bed2c Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Tue, 3 Sep 2024 15:11:52 -0700 Subject: [PATCH 0679/1018] [inductor] Enable subprocess parallel compile internally with killswitch (#132467) Differential Revision: [D60629630](https://our.internmc.facebook.com/intern/diff/D60629630) Pull Request resolved: https://github.com/pytorch/pytorch/pull/132467 Approved by: https://github.com/eellison --- torch/_inductor/async_compile.py | 17 +++++++++++++++-- torch/_inductor/config.py | 21 +++++++++++++++++---- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 0794ceb3eed571..26759df1c59c17 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -118,6 +118,16 @@ def after_fork(): pass # register_at_fork does not exists on windows +def get_worker_start_method() -> str: + """ + Temporary for internal subprocess pool rollout. Assign config.worker_start_method + lazily and return it. TODO: remove after rollout. + """ + if config.worker_start_method is None: + config.worker_start_method = config.decide_worker_start_method() + return config.worker_start_method + + class AsyncCompile: def __init__(self) -> None: pass @@ -138,12 +148,12 @@ def _get_ready(): def process_pool() -> AnyPool: assert config.compile_threads > 1 pool: AnyPool - if config.worker_start_method == "subprocess": + if get_worker_start_method() == "subprocess": # Wrapper around ProcessPoolExecutor forks in a new process we control pool = SubprocPool(config.compile_threads) else: pre_fork_setup() - ctx = multiprocessing.get_context(config.worker_start_method) + ctx = multiprocessing.get_context(get_worker_start_method()) pool = ProcessPoolExecutor( config.compile_threads, mp_context=ctx, @@ -291,6 +301,9 @@ def wait(self, scope: Dict[str, Any]) -> None: or os.environ.get("TORCH_WARM_POOL", "1") != "1" # The subprocess pool is only used for the Triton backend or not has_triton_package() + # Skip for fbcode so we can query the worker_start_method lazily. + # TODO: remove once "subprocess" has rolled out internally. + or config.is_fbcode() ): pass else: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index fb089e03dad2f2..966a48c29d06e3 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -514,9 +514,18 @@ def use_autoheuristic(name: str) -> bool: # The multiprocessing start method to use for inductor workers in the codecache. # Can be "subprocess" or "fork". def decide_worker_start_method() -> str: - start_method = os.environ.get( - "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess" - ) + # TODO: For internal rollout, we use a killswitch to disable the "subprocess" + # start method. The justknob check should not be performed at import, however, + # so for fbcode, we assign worker_start_method to None below and call this method + # lazily in async_compile.py. Remove this after "subprocess" rollout completes. + if "TORCHINDUCTOR_WORKER_START" in os.environ: + start_method = os.environ["TORCHINDUCTOR_WORKER_START"] + elif is_fbcode() and not torch._utils_internal.justknobs_check( + "pytorch/inductor:subprocess_parallel_compile" + ): + start_method = "fork" + else: + start_method = "subprocess" assert start_method in ( "subprocess", "fork", @@ -524,7 +533,10 @@ def decide_worker_start_method() -> str: return start_method -worker_start_method = decide_worker_start_method() +# TODO: Set start method directly after internal rollout of "subprocess". +worker_start_method: Optional[str] = ( + None if is_fbcode() else decide_worker_start_method() +) # Flags to turn on all_reduce fusion. These 2 flags should be automaticaly turned # on by DDP and should not be set by the users. @@ -1228,6 +1240,7 @@ class trace: # uses absolute path "cuda.cutlass_dir", # not relevant + "worker_start_method", "compile_threads", ] From 06f6ef93de33c930a3df51655c81423905fc1b5d Mon Sep 17 00:00:00 2001 From: Dan Zimmerman Date: Tue, 10 Sep 2024 19:15:39 +0000 Subject: [PATCH 0680/1018] [amdsmi][torch] Update amdsmi API usages (#135504) Summary: In ROCm 6.2.0 there were API name changes-- we check if the new APIs exist and use them in this diff; see https://github.com/ROCm/amdsmi/commit/7b2463abe01739cc6e9d88438fefc84bdbe4a1d5 for the changes Test Plan: CI Differential Revision: D62325661 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135504 Approved by: https://github.com/eqy, https://github.com/houseroad --- torch/cuda/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 5fb0038c23e145..cd25b2ab3de6fd 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1122,7 +1122,8 @@ def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: handle = _get_amdsmi_handler(device) - return amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)["cur_clk"] + clk_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX) + return clk_info["clk"] if "clk" in clk_info else clk_info["cur_clk"] def memory_usage(device: Optional[Union[Device, int]] = None) -> int: From 1276e2fe8c5c7c333e0de4e77547bbcefc85690e Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Tue, 10 Sep 2024 19:24:58 +0000 Subject: [PATCH 0681/1018] [AOTI][Tooling][7/n] Add debug printing support for JIT inductor codegen path as well (#135285) Summary: 1. Add the debug printer call to a level lower for triton kernel python wrapper codegen path 2. Add `torch.save()` for jit inductor as well 3. This also fixes the issue introduced in D61949020 (at python wrapper code level for triton kernel not printing) Test Plan: ``` AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_abi_compatible_cuda ``` Differential Revision: D62272588 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135285 Approved by: https://github.com/chenyang78 --- torch/_inductor/codegen/debug_utils.py | 24 ++++++++++++++++--- torch/_inductor/codegen/wrapper.py | 10 ++++++-- .../csrc/inductor/aoti_torch/shim_common.cpp | 2 +- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index 7b3dcd8eea2cf5..f129e3f31590be 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -3,6 +3,7 @@ import functools import logging +import os from enum import Enum from typing import List, Optional @@ -136,8 +137,23 @@ def codegen_intermediate_tensor_value_save( # TODO: add non-abi compatible mode debug printing info pass else: - # currently, not cpp wrapper codegen mode not supported. - pass + cwd = os.getcwd() + saved_dir = cwd + "/tmp/jit_inductor/" + if not os.path.exists(saved_dir): + log.info( + "Creating directory to save inductor intermediate tensor values." + ) + os.makedirs(saved_dir) + # Save the model to the directory + saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt" + log.info( + "Saved intermediate tensor %s for %s to %s", + arg, + kernel_name, + saved_path, + ) + line = f"torch.save({arg}, '{saved_path}')" + V.graph.wrapper_code.writeline(line) def codegen_intermediate_tensor_value_print( self, @@ -171,5 +187,7 @@ def codegen_intermediate_tensor_value_print( # TODO: add non-abi compatible mode debug printing info pass else: - line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})" + line = ( + f"print('inductor: {launch_prefix} - {kernel_name} - {arg}', {arg})" + ) V.graph.wrapper_code.writeline(line) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index f3a6d07647627f..0586076c948fb6 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1679,9 +1679,15 @@ def generate_kernel_call( if grid_extra_kwargs: grid_str = f"{grid_str}, {grid_extra_kwargs}" grid_str = f"{grid_fn}({grid_str})" - self.writeline( - f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_types, None ) + with debug_printer_manager: + self.writeline( + f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + ) if ( config.triton.autotune_at_compile_time and kernel_name not in self.kernel_autotune_names diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 434da9ad7d9a32..964a976775e6e8 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1025,7 +1025,7 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( fout.write(bytes.data(), bytes.size()); fout.close(); - std::cout << "aoti_torch_save_tensor_handle: Saved tensor to: " + std::cout << "aoti_torch_save_tensor_handle: Saved tensor to " << tensor_filepath_to_save << std::endl; #endif // !defined(C10_MOBILE) } From 118f80ee3e7392251e6fdee45e04099e5b78845a Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 9 Sep 2024 13:41:09 -0700 Subject: [PATCH 0682/1018] [FSDP2] Added `_set_unshard_async_op` (#135523) This PR adds a private API `_set_unshard_async_op` that allows for running pre-forward and pre-backward all-gathers using the `async_op=True` path so that all-gather allocations happen in the default stream to avoid inter-stream fragmentation. If using this option, forward requires explicit prefetching e.g. via the `unshard(async_op=True)` API for overlap. fp32 -> bf16 casts and the all-gather copy-in will not overlap with compute. Differential Revision: [D62401551](https://our.internmc.facebook.com/intern/diff/D62401551) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135523 Approved by: https://github.com/weifengpy --- .../fsdp/test_fully_shard_training.py | 26 +++++++++++++++++++ .../_composable/fsdp/_fsdp_param_group.py | 25 +++++++++++++----- .../_composable/fsdp/fully_shard.py | 19 ++++++++++++++ 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index f8bad4c5d5fa3a..ab52cb925709af 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -285,6 +285,7 @@ def test_train_parity_multi_group(self): "delay_before_all_gather": [False, True], "delay_before_reduce_scatter": [False, True], "delay_before_optim": [False, True], + "unshard_async_op": [False], }, self._test_train_parity_multi_group, ) @@ -307,6 +308,28 @@ def test_train_parity_multi_group_cpu_offload_eager(self): "delay_before_all_gather": [False, True], "delay_before_reduce_scatter": [False, True], "delay_before_optim": [False, True], + "unshard_async_op": [False], + }, + self._test_train_parity_multi_group, + ) + + @skip_if_lt_x_gpu(2) + @test_compiled_fsdp(compile_compute_on_module=Transformer) + def test_train_parity_multi_group_unshard_async_op(self): + """ + Tests train parity against DDP when using multiple parameter groups for + communication and setting ``unshard_async_op=True``. + """ + self.run_subtests( + { + "reshard_after_forward": [True], + "device_type": ["cuda"], + "offload_policy": [OffloadPolicy()], + "delay_after_forward": [False, True], + "delay_before_all_gather": [False, True], + "delay_before_reduce_scatter": [False, True], + "delay_before_optim": [False, True], + "unshard_async_op": [True], }, self._test_train_parity_multi_group, ) @@ -320,6 +343,7 @@ def _test_train_parity_multi_group( delay_before_all_gather: bool, delay_before_reduce_scatter: bool, delay_before_optim: bool, + unshard_async_op: bool, ): # Only test individual delays or all four delays to save test time if ( @@ -359,6 +383,8 @@ def _test_train_parity_multi_group( if isinstance(module, TransformerBlock): fully_shard_fn(module) fully_shard_fn(model) + if unshard_async_op: + model._set_unshard_async_op(unshard_async_op) optim = torch.optim.Adam(model.parameters(), lr=1e-2) delay_in_ms = 100 diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 3cb4a31c28d2f3..ae8a513f05af05 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -73,9 +73,12 @@ def lazy_init(self): self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles def get_all_gather_streams( - self, training_state: TrainingState + self, async_op: bool, training_state: TrainingState ) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]: - if training_state in (TrainingState.FORWARD, TrainingState.PRE_BACKWARD): + if not async_op and training_state in ( + TrainingState.FORWARD, + TrainingState.PRE_BACKWARD, + ): # Use separate streams for implicit prefetching return self.all_gather_copy_in_stream, self.all_gather_stream current_stream = torch.cuda.current_stream() @@ -155,6 +158,10 @@ def __init__( # Optional custom reduce-scatter reduce op (e.g. to divide by a # factor other than the shard world size) self.reduce_scatter_reduce_op: Optional[dist.ReduceOp] = None + # `async_op` arg used for pre-forward/pre-backward unshard; can be + # overridden to only do explicit prefetching and avoid inter-stream + # fragmentation from using separate unshard streams + self.unshard_async_op: bool = False # - CUDA events for stream synchronization # Holds the all-gather output buffer, sync objects, and metadata @@ -234,7 +241,7 @@ def unshard(self, async_op: bool = False): self.fsdp_params, self._all_gather_process_group, async_op, - *self.comm_ctx.get_all_gather_streams(self._training_state), + *self.comm_ctx.get_all_gather_streams(async_op, self._training_state), self.device, ) @@ -249,6 +256,7 @@ def wait_for_unshard(self): """ if not self._all_gather_result: return # no preceding unshard + async_op = self._all_gather_result.all_gather_work is not None if self._training_state == TrainingState.FORWARD: # implicit prefetch if prev_all_gather_state := self.comm_ctx.all_gather_state: self._wait_all_gather_streams_on_event(prev_all_gather_state.event) @@ -264,7 +272,9 @@ def wait_for_unshard(self): self._to_unsharded() all_gather_copy_out_event = torch.cuda.Event() all_gather_copy_out_event.record() - if self._training_state == TrainingState.FORWARD: + if not async_op and self._training_state == TrainingState.FORWARD: + # Defer free to allow for overlap of this copy-out with next + # all-gather collective self.comm_ctx.all_gather_state = AllGatherState( self._all_gather_result, all_gather_copy_out_event ) @@ -297,7 +307,7 @@ def pre_forward( logger.debug("%s", self._with_fqn("FSDP::pre_forward")) with record_function(self._with_fqn("FSDP::pre_forward")): self._training_state = TrainingState.FORWARD - self.unshard() + self.unshard(self.unshard_async_op) self.wait_for_unshard() args, kwargs = self._register_post_backward_hook(args, kwargs) return args, kwargs @@ -325,7 +335,7 @@ def pre_backward(self, default_prefetch: bool, *unused: Any): logger.debug("%s", self._with_fqn("FSDP::pre_backward")) with record_function(self._with_fqn("FSDP::pre_backward")): self._training_state = TrainingState.PRE_BACKWARD - self.unshard() # no-op if prefetched + self.unshard(self.unshard_async_op) # no-op if prefetched self.wait_for_unshard() if default_prefetch and not ca.compiled_autograd_enabled: self._backward_prefetch() @@ -430,7 +440,8 @@ def _prefetch_unshard( with record_function( f"FSDP::{pass_type}_prefetch for {target_fqn}" ), target_fsdp_param_group.use_training_state(training_state): - target_fsdp_param_group.unshard() + async_op = target_fsdp_param_group.unshard_async_op + target_fsdp_param_group.unshard(async_op) # Utilities # def _to_sharded(self): diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 5da4f83f9e5594..49c3da8fbfd02a 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -358,6 +358,25 @@ def set_reduce_scatter_divide_factor(self, factor: float) -> None: reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor) fsdp_param_group.reduce_scatter_reduce_op = reduce_op + def _set_unshard_async_op(self, async_op: bool): + """ + Sets whether to use ``async_op=True`` or ``False`` for the pre-forward + and pre-backward unshard op. This defaults to ``False`` but can be set + to ``True`` with this method. + + Setting this to ``True`` allows the all-gather allocations to happen in + the default stream, avoiding inter-stream memory fragmentation. + However, you must use explicit prefetching (e.g. via :meth:`unshard`) + in forward to still get overlap, and the pre-all-gather ops like dtype + casting and copy-in will not overlap with compute. + """ + self_module = cast(nn.Module, self) + for module in self_module.modules(): + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.unshard_async_op = async_op + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}") From edf329a80a9df0a55db802e9fb077348a616cb4a Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 10 Sep 2024 19:45:11 +0000 Subject: [PATCH 0683/1018] [CI] Increase sharding for jobs that are timing out (#135582) Increase sharding for * slow grad check * slow cuda tests slow / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test * avx Pull Request resolved: https://github.com/pytorch/pytorch/pull/135582 Approved by: https://github.com/huydhn, https://github.com/malfet --- .github/workflows/periodic.yml | 18 ++++++++++++------ .github/workflows/slow.yml | 19 +++++++++++-------- .github/workflows/trunk.yml | 6 ++++-- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index c5f6d8c853e112..57a16a5d3b9fec 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -57,8 +57,10 @@ jobs: docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-test: @@ -87,8 +89,10 @@ jobs: { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} @@ -333,8 +337,10 @@ jobs: docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 3b253bfbb89f6f..ffad7e8ca3637a 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -56,12 +56,14 @@ jobs: cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 6, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 7, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 8, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3-gcc9-slow-gradcheck-test: @@ -87,8 +89,9 @@ jobs: cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-sm86-test: diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 8696ba99957258..2b31568aae5733 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -266,8 +266,10 @@ jobs: docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, From a135766db26a7c0d94a0a3e82ed466e041d588b3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 10 Sep 2024 19:51:16 +0000 Subject: [PATCH 0684/1018] Revert "[AOTI] Fix assert_function call in cpu autotune template (#135086)" This reverts commit 16c3b8f87cfa9cb5acee8104820baa389e7ee2bd. Reverted https://github.com/pytorch/pytorch/pull/135086 on behalf of https://github.com/izaitsevfb due to breaks internal tests, see D62405818 ([comment](https://github.com/pytorch/pytorch/pull/135086#issuecomment-2341889428)) --- test/inductor/test_aot_inductor.py | 4 +--- test/inductor/test_cpu_cpp_wrapper.py | 13 +++++++------ torch/_inductor/codegen/cpp.py | 4 +++- torch/_inductor/codegen/cpp_micro_gemm.py | 10 +++++----- torch/_inductor/codegen/cpp_template.py | 9 +++++---- torch/_inductor/graph.py | 3 ++- 6 files changed, 23 insertions(+), 20 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 4b270a3bb21bda..42d771f7bb6fff 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3260,9 +3260,7 @@ def forward(self, values, offsets): model, example_inputs_list, dynamic_shapes=dynamic_shapes ) - # max_autotune is disabled due to https://github.com/pytorch/pytorch/issues/135106 - # @common_utils.parametrize("max_autotune", [False, True]) - @common_utils.parametrize("max_autotune", [False]) + @common_utils.parametrize("max_autotune", [False, True]) def test_misc_1(self, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 8ea3e55468019f..d02ca75cfa5c44 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -87,7 +87,13 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): } ) if config.abi_compatible: - xfail_list = [] + xfail_list = [ + *[ + func + for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) + if func.startswith("test_linear_with_pointwise") + ], + ] for test_name in xfail_list: test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( ("cpp_wrapper",), is_skip=False @@ -97,11 +103,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False) skip_list = [ "test_multihead_attention_cpu", - *[ - func - for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) - if func.startswith("test_linear_with_pointwise") - ], ] for test_name in skip_list: test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 940bc5f73cd838..e1d54762d9e8fd 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2146,7 +2146,9 @@ def codegen_loops(self, code, worksharing): @property def assert_function(self) -> str: - if config.abi_compatible: + if V.graph.aot_mode: + # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models + # compared with JIT Inductor which uses TORCH_CHECK return "AOTI_TORCH_CHECK" else: return "TORCH_CHECK" diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 2f0c3e78f56791..cb26c48fce53c6 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -332,8 +332,8 @@ class CppMicroGemmFP32Vec(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { - {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); @@ -364,7 +364,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm): break; {%- endfor %} default: - {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); + {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); } } } @@ -509,8 +509,8 @@ class CppMicroGemmAMX(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { - {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); + TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index a237924b9182d4..2bce16b2ad1538 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -111,10 +111,11 @@ def make_kernel_render( def header(self) -> IndentedBuffer: res = IndentedBuffer() res.writeline(codecache.cpp_prefix()) - # TODO: add c10::ForcedUnroll test to test_aoti_abi_check - res.splice("""#include """) - if config.abi_compatible: - res.splice("""#include """) + res.splice( + """ + #include "c10/util/Unroll.h" + """ + ) enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ "linux", "win32", diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 26a897bf806ed4..45ca55c7ab0111 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -854,6 +854,7 @@ def get_original_value_of_constant(self, name: str) -> torch.Tensor: def allocate_non_dup_const_name( self, name: Optional[str], data: Union[Tensor] ) -> str: + orig_name = name if not config.aot_inductor.use_runtime_constant_folding: for constant_name, value in self.constants.items(): if ( @@ -870,7 +871,7 @@ def allocate_non_dup_const_name( if name is None: name = f"constant{len(self.constants)}" - orig_name = name + assert name is not None if name[0].isdigit(): name = f"constant_{name}" name = self.qualify_name(name) From 8b12e696619ac709376198a493d23613dbe468c0 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Tue, 10 Sep 2024 20:15:02 +0000 Subject: [PATCH 0685/1018] [export] fix re-export custom metadata (#135282) Fixes #134778 When a model is exported and debug handles are added to the "custom" field of non-placeholder and non-output nodes in the graph, re-exporting it will change the metadata of placeholder nodes (the "custom" field will be added or copied to these nodes, depending whether `ExportedProgram` or `ExportedProgram.module()` is passed to `generate_numeric_debug_handle()`). This occurs because when we re-export the model, `placeholder` nodes are unlifted to `get_attr` nodes. These nodes remain as `get_attr` after being exported to `gm_torch_level`. Their metadata are modified [here](https://github.com/pytorch/pytorch/blob/main/torch/export/_trace.py#L1347) based on `params_buffers_to_node_meta` which is collected [here](https://github.com/pytorch/pytorch/blob/main/torch/export/_trace.py#L1312). Pull Request resolved: https://github.com/pytorch/pytorch/pull/135282 Approved by: https://github.com/jerryzh168, https://github.com/zhxchen17, https://github.com/tugsbayasgalan --- test/quantization/pt2e/test_numeric_debugger.py | 9 +++------ torch/_export/utils.py | 12 ------------ 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index e6cc4dc3252f44..85b5e9137bdb5c 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -20,6 +20,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import TestHelperModules from torch.testing._internal.common_utils import IS_WINDOWS, TestCase @@ -116,22 +117,18 @@ def test_deepcopy_preserve_handle(self): self.assertEqual(debug_handle_map, debug_handle_map_ref) - @unittest.skip("All nodes' meta are preserved but get_attr nodes' meta are wrong.") def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - m = capture_pre_autograd_graph(m, example_inputs) + m = export_for_training(m, example_inputs).module() generate_numeric_debug_handle(m) debug_handle_map_ref = _extract_debug_handles(m) - m_export = capture_pre_autograd_graph(m, example_inputs) + m_export = export_for_training(m, example_inputs).module() debug_handle_map = _extract_debug_handles(m_export) self.assertEqual(debug_handle_map, debug_handle_map_ref) - @unittest.skip( - "All nodes' meta are preserved but the first arg for the first node seems to be dropped" - ) def test_run_decompositions_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() diff --git a/torch/_export/utils.py b/torch/_export/utils.py index e085e18b68a207..7755099dfdb22a 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -130,18 +130,6 @@ def _getattr(model: torch.fx.GraphModule, attr_name: str): if not isinstance(submodule, torch.fx.GraphModule): params_buffers_to_node_meta[target] = meta - # If the call_function uses param as input, we also need to update params' meta - # with this call_function node's meta. - # This is basically the same flow as torch.fx.traceback.preserve_meta() - if node.op == "call_function" and not isinstance( - node.target, torch._ops.HigherOrderOperator - ): - for arg in node._input_nodes: - if arg.op == "get_attr": - for entry in torch.fx.proxy._COPY_META_FIELDS: - if entry in meta: - params_buffers_to_node_meta[arg.target][entry] = meta[entry] - return params_buffers_to_node_meta From 5a2ce15d955594ffe4e2fb603f9319edb07aa052 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Mon, 9 Sep 2024 13:57:01 -0700 Subject: [PATCH 0686/1018] [inductor] fix a device sync issue for benchmarking fusion (#135531) Fix https://github.com/pytorch/pytorch/issues/134768 . When we benchmark the latency for a fused node set, we do benchmarking twice: 1. benchmark the latency of the kernel including cloning mutated args 2. benchmark the latency of cloning mutated args without running the kernel We subtract result 2 from result 1 to get the latency of the kernel itself. But when the tensors are not on the cuda device 0, we get equal number for result 1 and result 2 no matter how much work the kernel does. The root cause is, in `triton.testing.do_bench` the `torch.cuda.synchronize` call sync the current cuda device (which is device 0 if it's not overriden). But since the tensors and kernels are located on another device, the sync actually does nothing (unless there happens to be other kernels on the device 0). The fix is to set the correct current device in our benchmarking code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135531 Approved by: https://github.com/jansel --- test/inductor/test_benchmark_fusion.py | 32 ++++++++++++++++++++++++++ torch/_inductor/codegen/triton.py | 4 +++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 9cd98735f7c10e..706c2d9ae76546 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -4,6 +4,7 @@ import sys import torch +from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.test_operators import realize from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code @@ -184,6 +185,37 @@ class BenchmarkFusionCudaTest(TestCase): copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda") + class BenchmarkingTest(TestCase): + @unittest.skipIf( + torch.cuda.device_count() < 2, "The test need at least 2 devices" + ) + def test_benchmark_on_non_zero_device(self): + hit_count = 0 + with torch.cuda.device("cuda:0"): + + @torch.compile + def relu(x): + return realize(x.relu()) + x + + x = torch.randn(int(16e6), device="cuda:1") + + orig_benchmark_fused_nodes = TritonScheduling.benchmark_fused_nodes + + def mock_benchmark_fused_nodes(*args, **kwargs): + nonlocal hit_count + hit_count += 1 + ms, path = orig_benchmark_fused_nodes(*args, **kwargs) + self.assertTrue(ms > 0) + return ms, path + + with unittest.mock.patch.object( + TritonScheduling, + "benchmark_fused_nodes", + mock_benchmark_fused_nodes, + ): + relu(x) + self.assertTrue(hit_count > 0) + class BenchmarkMultiTemplateFusionCudaTest(InductorTestCase): @classmethod def setUpClass(cls): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 2d4b2b75ae263e..ea68e68b6ef062 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3066,7 +3066,9 @@ def define_kernel(self, src_code, node_schedule, kernel): return kernel_name def benchmark_fused_nodes(self, nodes): - with preserve_rng_state(): + with preserve_rng_state(), torch.cuda.device( + self.scheduler.get_current_device_or_throw() + ): src_code = self.generate_kernel_code_from_nodes( nodes, benchmark_kernel=True ) From 55c0b60a12c92a16319577e3b25c8693ec849e39 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Mon, 9 Sep 2024 13:58:51 -0700 Subject: [PATCH 0687/1018] [ez][inductor] don't benchmark cloning if there are no mutated args (#135533) When a kernel does not have mutated args (this is quite common?), benchmarking the cost of cloning actually benchmarks a no-op. This still takes >100ms since triton.testing.do_bench will allocate 100 ms budget to run the kernel. Skipping this benchmarking can save quite some compilation time if the code path is hit multiple times. Let's say, if the code path is hit 100 times when the graph is large, we would save >10s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135533 Approved by: https://github.com/jansel ghstack dependencies: #135531 --- torch/_inductor/codegen/triton.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index ea68e68b6ef062..8fe368048c9c77 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3132,9 +3132,10 @@ def store_cache(): # in the case of mutating/in-placeable second fusion # TODO - would be better as a hook in triton do_bench that reset # the input values between benchmarking - ms = ms - benchmarker.benchmark_gpu( - lambda: wrapped_jit_function.clone_args(*args) - ) + if len(wrapped_jit_function.mutated_arg_names) > 0: + ms = ms - benchmarker.benchmark_gpu( + lambda: wrapped_jit_function.clone_args(*args) + ) log.debug( "The fused kernel for %s took %.3f ms to run", From 6709e6dc4ebb21be62f3a75fbc1341e91f898d33 Mon Sep 17 00:00:00 2001 From: hongxyan Date: Tue, 10 Sep 2024 21:03:01 +0000 Subject: [PATCH 0688/1018] [ROCm] slow torch.sum optimization by increasing max_values_per_thread in reduce config (#135397) Fixes #132964 This change is to optimize torch.sum() performance by increasing max_values_per_thread in setReduceConfig() for ROCm platform. By increasing this parameter, it uses fewer threadblocks and improved the performance. Test: Tested on MI300x and H100, and now the MI300x perf improved to 3205GByte/s from ~1690GByte/s for the test case and is slightly better than H100 (3136GByte/s). Also tested with other different sizes of tensors and also see perf improvement. ```python import torch from triton.testing import do_bench x = torch.randn(2**30, device='cuda') ms = do_bench(lambda: x.sum(dim=-1)) bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9) time_s = ms / 1000 bw_per_second = bandwidth_gbyte / time_s print(bw_per_second) ``` Co-author: @carlobertolli Pull Request resolved: https://github.com/pytorch/pytorch/pull/135397 Approved by: https://github.com/eqy, https://github.com/malfet --- aten/src/ATen/native/cuda/Reduce.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 79b23ea8afd4fc..7a25975b624b04 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1092,7 +1092,11 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ } constexpr int min_values_per_thread = 16; +#ifndef USE_ROCM constexpr int max_values_per_thread = 256; +#else + constexpr int max_values_per_thread = 1024; +#endif if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) { // Divide the input across warps in a thread-block, if that leaves at least From a210aa2c461505915d737ea0a87d2307c8e9ae29 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Tue, 10 Sep 2024 23:03:59 +0000 Subject: [PATCH 0689/1018] [ONNX] Fix scaled_dot_product_attention with float scale (#135594) Fixes #125158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135594 Approved by: https://github.com/justinchuby --- test/onnx/test_pytorch_onnx_onnxruntime.py | 21 +++++++++++++++++++++ torch/onnx/symbolic_opset14.py | 1 - 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 091209a8a4f697..c812b8e18b30c3 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -12534,6 +12534,27 @@ def forward(self, x, y): self.run_test(M(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(14) + def test_scaled_dot_product_attention(self): + class M(torch.nn.Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=1.0 + ) + + # Parameters + batch_size = 2 # Number of samples in the batch + num_heads = 4 # Number of attention heads + seq_length = 5 # Sequence length + head_dim = 8 # Dimensionality of each head + + # Create random query, key, and value tensors + q = torch.randn(batch_size, num_heads, seq_length, head_dim) + k = torch.randn(batch_size, num_heads, seq_length, head_dim) + v = torch.randn(batch_size, num_heads, seq_length, head_dim) + + self.run_test(M(), (q, k, v)) + @skipScriptTest() @skipIfUnsupportedMinOpsetVersion(11) def test_dist_normal(self): diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 1b10cc28531a0d..ae33ddf58c6e09 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -150,7 +150,6 @@ def scaled_dot_product_attention( ), "is_causal and attn_mask cannot be set at the same time" assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" - scale = symbolic_helper._maybe_get_const(scale, "f") if symbolic_helper._is_none(scale): scale = _attention_scale(g, query) From 9c2a43f1632c8f95a67a34941579b4fc4232b7ce Mon Sep 17 00:00:00 2001 From: atalman Date: Tue, 10 Sep 2024 23:46:36 +0000 Subject: [PATCH 0690/1018] Bump triton pin and release version (#135627) Update the pin and release version to sync with https://github.com/triton-lang/triton/tree/release/3.1.x Pull Request resolved: https://github.com/pytorch/pytorch/pull/135627 Approved by: https://github.com/Chillee, https://github.com/drisspg, https://github.com/malfet --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 378d625784fbc5..7148e9c99cec22 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -757b6a61e7df814ba806f498f8bb3160f84b120c +5fe38ffd73c2ac6ed6323b554205186696631c6f diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 4a36342fcab700..fd2a01863fdd30 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.0.0 +3.1.0 From eb500d05c56fe2761afadd1466d3e5bff70456b6 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 10 Sep 2024 13:29:11 -0700 Subject: [PATCH 0691/1018] [dynamo][dicts][nv-embed] Support update with kwargs (#135588) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135588 Approved by: https://github.com/yanboliang --- test/dynamo/test_functions.py | 6 ++++++ torch/_dynamo/variables/dicts.py | 37 +++++++++++++++++--------------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index d215f21fbb654c..d6879488c4fd24 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -699,6 +699,12 @@ def test_dict_setdefault3(x): else: return x - 1 + @make_test + def test_dict_update_kwargs(x): + d = {"a": 2} + d.update(b=4) + return x * d["a"] * d["b"] + @make_test def test_defaultdict_setdefault1(x): d = collections.defaultdict.fromkeys("a", "b") diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index e3ca0b43c0a6a7..e8323ec6f70d76 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -304,10 +304,8 @@ def call_method( tx.output.side_effects.mutation(self) self.items.clear() return ConstantVariable.create(None) - elif ( - name == "update" - and len(args) == 1 - and isinstance( + elif name == "update" and self.mutable_local: + is_args_supported = len(args) == 1 and isinstance( args[0], ( ConstDictVariable, @@ -318,20 +316,25 @@ def call_method( UserDefinedObjectVariable, ), ) - and self.mutable_local - ): - tx.output.side_effects.mutation(self) - if isinstance(args[0], ConstDictVariable): - dict_vt = args[0] + + is_kwargs_supported = len(kwargs) > 0 and len(args) == 0 + + if is_args_supported or is_kwargs_supported: + tx.output.side_effects.mutation(self) + if len(args) == 1: + if isinstance(args[0], ConstDictVariable): + dict_vt = args[0] + else: + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) + self.items.update(dict_vt.items) + # Wrap strings + kwargs = { + Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() + } + self.items.update(kwargs) + return ConstantVariable.create(None) else: - dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) - self.items.update(dict_vt.items) - # Wrap strings - kwargs = { - Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() - } - self.items.update(kwargs) - return ConstantVariable.create(None) + return super().call_method(tx, name, args, kwargs) elif name in ("get", "__getattr__") and args[0] in self: return self.getitem_const(tx, args[0]) elif name == "__contains__" and len(args) == 1: From be08b0da7277154b0ce321338392d0e8b5870b26 Mon Sep 17 00:00:00 2001 From: Haoming Lu Date: Wed, 11 Sep 2024 00:03:15 +0000 Subject: [PATCH 0692/1018] Add boolean support in pack segments ops for both cpu and cuda impls (#132897) (#135620) Summary: Same as int types, forward only. bypass-github-export-checks diff has been synced to github Test Plan: buck test mode/dev-nosan //caffe2/torch/fb/sparsenn:test -- test_pack_segments https://www.internalfb.com/intern/testinfra/testconsole/testrun/16888498646804437/ Reviewed By: garroud Differential Revision: D60785563 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135620 Approved by: https://github.com/kit1980 Co-authored-by: Haoming Lu --- aten/src/ATen/Dispatch.h | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index b98d648c684546..db2eccf7954be0 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -299,6 +299,15 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + #define AT_DISPATCH_FLOATING_TYPES_AND4( \ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ @@ -307,6 +316,26 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) +#define AT_DISPATCH_FLOATING_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + __VA_ARGS__)) + #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__) From 3bdfc4fd94532d3e5012888d178f165a0400a6aa Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 10 Sep 2024 07:43:48 -0700 Subject: [PATCH 0693/1018] Add option to tweak inductor stride settings for user-defined triton kernels (#135530) Previously, Inductor was allowed to modify the stride/storage_offset (layout) for inputs to user-defined triton kernels. This can cause silent incorrectness because most triton kernels are written for a specific striding pattern (usually contiguous). This PR adds a config to allow the user to choose Inductor's behavior on this. The options are: - "flexible_layout" (default): Inductor can modify the layout for inputs to user-defined triton kernels as much as it wants. - "needs_fixed_stride_order": Inductor must preserve the stride order (when compared to tracing) for inputs to user-defined triton kernels. This matches our handling for custom operators. In the future, we'll want a "needs_exact_strides" option (this is the safest option). Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/135530 Approved by: https://github.com/FindHao, https://github.com/oulgen --- test/inductor/test_triton_kernels.py | 60 +++++++++++++++++++ torch/_higher_order_ops/triton_kernel_wrap.py | 12 ++-- torch/_inductor/config.py | 4 ++ torch/_inductor/graph.py | 41 ++++++++++++- torch/_inductor/lowering.py | 2 + 5 files changed, 114 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 64aaccdbf085a3..365f43cdd83bee 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -1015,6 +1015,66 @@ def f(inp): compiled_out = torch.compile(f)(inp) self.assertEqual(compiled_out, eager_out) + @torch._inductor.config.patch( + triton_kernel_default_layout_constraint="needs_fixed_stride_order" + ) + @requires_gpu + def test_layout_constraint_needs_fixed_stride_order(self): + # Construct a custom op whose output strides are (1, 2) + @torch.library.custom_op("mylib::weird_op_with_lowering", mutates_args={}) + def weird_op_with_lowering(x: torch.Tensor) -> torch.Tensor: + return torch.empty_strided((2, 2), (1, 2), dtype=x.dtype, device=x.device) + + @weird_op_with_lowering.register_fake + def _(x): + return torch.empty_strided((2, 2), (1, 2), dtype=x.dtype, device=x.device) + + # The lowering for the custom op produces output strides (2, 1). + from torch._inductor.lowering import empty_strided, register_lowering + + @register_lowering(torch.ops.mylib.weird_op_with_lowering) + def _(x): + return empty_strided( + x.shape, (2, 1), dtype=x.dtype, device=torch.device("cuda:0") + ) + + # Triton kernel that has different behavior depending on the input strides. + @triton.jit + def kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + output = offsets + tl.store(out_ptr + offsets, output, mask=mask) + + def arange_out(x, out): + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + kernel[grid](x, out, n_elements, BLOCK_SIZE=4) + + def f(x): + y = weird_op_with_lowering(x) + # Inductor lowering will decide that y is better having strides (2, 1). + # This is different from the strides at tracing time (1, 2). + # Under the "needs_fixed_stride_order" config, inductor will coerce + # y to have strides (1, 2) before passing it to arange_out. + # If it doesn't, then the result will be different from eager mode. + arange_out(x, y) + return x + y + + x = torch.randn(2, 2, device="cuda") + eager_out = f(x) + + compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True) + compiled_inductor_out = compiled_inductor_f(x) + self.assertEqual(compiled_inductor_out, eager_out) + @requires_gpu def test_triton_kernel_strided_input_nonzero_offset(self): def f(inp): diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 5a1ad4405c5ec3..a0548a1c8ee125 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -624,19 +624,23 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( return None +def get_mutated_tensors(kernel_idx, constant_args_idx, kwargs): + kernel = kernel_side_table.get_kernel(kernel_idx) + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + return identify_mutated_tensors(kernel, {**kwargs, **constant_args}) + + @triton_kernel_wrapper_mutation.py_functionalize_impl def triton_kernel_wrapper_mutation_functionalize( ctx, kernel_idx, constant_args_idx, grid, kwargs ): unwrapped_kwargs = ctx.unwrap_tensors(kwargs) - kernel = kernel_side_table.get_kernel(kernel_idx) - constant_args = kernel_side_table.get_constant_args(constant_args_idx) # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each # other, and one gets mutated in kernel, and later another gets mutated, # they are no longer equal. Fix this by graph breaking on this condition # earlier in dynamo. - tensors_to_clone = identify_mutated_tensors( - kernel, {**unwrapped_kwargs, **constant_args} + tensors_to_clone = get_mutated_tensors( + kernel_idx, constant_args_idx, unwrapped_kwargs ) with ctx.redispatch_to_next(): unwrapped_outputs = triton_kernel_wrapper_functional( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 966a48c29d06e3..0ced92c3271bd1 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -72,6 +72,10 @@ def autotune_remote_cache_default() -> Optional[bool]: # then we assume the following applies. custom_op_default_layout_constraint = "needs_fixed_stride_order" +# The default layout constraint for user-defined triton kernels. +# See "The default layout constraint for custom operators" for options. +triton_kernel_default_layout_constraint = "flexible_layout" + # use cpp wrapper instead of python wrapper cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 45ca55c7ab0111..9fdcf4a5e98da7 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1238,10 +1238,30 @@ def propagate_mutation( If fx_node mutates any of new_args/new_kwargs, and they are different from old_args/old_kwargs, then we need to update the original tensor. """ - assert isinstance(fx_node.target, torch._ops.OpOverload) assert len(old_args) == len(new_args) assert len(old_kwargs) == len(new_kwargs) + if fx_node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation: + kwargs = fx_node.kwargs["kwargs"] + assert isinstance(kwargs, dict) + mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors( + old_kwargs["kernel_idx"], + old_kwargs["constant_args_idx"], + { + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in kwargs.items() + }, + ) + for name in mutated: + old_arg = old_kwargs["kwargs"][name] + new_arg = new_kwargs["kwargs"][name] + if old_arg is new_args: + continue + self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {}) + return + + assert isinstance(fx_node.target, torch._ops.OpOverload) + def maybe_propagate( schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode ) -> None: @@ -1303,6 +1323,25 @@ def debug(msg: str) -> None: # if they do, and if the target is mutable, then we need to # write the new values back into the original inputs. self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + elif ( + n.op == "call_function" + and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation + and config.triton_kernel_default_layout_constraint != "flexible_layout" + ): + debug("user_defined_triton_kernel_layout_constraints") + if ( + config.triton_kernel_default_layout_constraint + == "needs_fixed_stride_order" + ): + old_args = args # type: ignore[possibly-undefined] + old_kwargs = kwargs # type: ignore[possibly-undefined] + args, kwargs = torch._inductor.lowering.constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index] + result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type] + self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + else: + raise RuntimeError( + f"Unknown triton_kernel_default_layout_constraint: {config.triton_kernel_default_layout_constraint}" + ) elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 81be111f1e63f4..988453a86735de 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2114,6 +2114,8 @@ def apply_constraint(arg, fx_arg): if isinstance(arg, ir.IRNode): stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) return ir.ExternKernel.require_stride_order(arg, stride_order) + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()} return arg args = tuple( From 2143664d0818c498c98e74ffc2a05c63cd5c09c1 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Wed, 11 Sep 2024 00:18:17 +0000 Subject: [PATCH 0694/1018] [ONNX] Add assertion nodes to ignoring list (#135591) Fixes #135419 PS: there are 104 empty output nodes, I suggest we add them one by one when we run into them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135591 Approved by: https://github.com/justinchuby --- test/onnx/exporter/test_api.py | 9 +++++++++ torch/onnx/_internal/exporter/_fx_passes.py | 3 +++ 2 files changed, 12 insertions(+) diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 18330755510454..4a0ead2ed829b4 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -196,6 +196,15 @@ def forward(self, x, y): TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes ) + def test_zero_output_aten_node(self): + class Model(torch.nn.Module): + def forward(self, x): + torch.ops.aten._assert_async.msg(torch.tensor(True), "assertion failed") + return x + x + + input = torch.randn(2) + self.assert_export(Model(), (input)) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/_internal/exporter/_fx_passes.py b/torch/onnx/_internal/exporter/_fx_passes.py index bcab0d595b433e..4c19b552d0e4f0 100644 --- a/torch/onnx/_internal/exporter/_fx_passes.py +++ b/torch/onnx/_internal/exporter/_fx_passes.py @@ -39,7 +39,10 @@ def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.Graph """Remove all assertion and check nodes from the FX graph""" aten_assertion_targets = { torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_async.default, torch.ops.aten._assert_async.msg, + torch.ops.aten._assert_scalar.default, + torch.ops.aten._assert_tensor_metadata.default, } for node in graph_module.graph.nodes: if node.op == "call_function" and node.target in aten_assertion_targets: From a6fa5d1a36f2572a3beead22272bbec5e7b3a4a7 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 11 Sep 2024 00:20:53 +0000 Subject: [PATCH 0695/1018] [MPS] Add missing dispatch to rshift.Tensor (#135607) Missed it while working on https://github.com/pytorch/pytorch/pull/131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135607 Approved by: https://github.com/manuelcandales --- aten/src/ATen/native/native_functions.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9342da626beb03..c7533e4ef854c9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8554,7 +8554,7 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: __rshift__ + CPU, CUDA, MPS: __rshift__ tags: pointwise - func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) From 8bc0427ad3d14f7f0cf40da9cac90195606fb0a8 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 11 Sep 2024 00:43:24 +0000 Subject: [PATCH 0696/1018] =?UTF-8?q?[dynamo]=20Missing=20guard=20source?= =?UTF-8?q?=20keys=20for=20corner=20case=20of=20NNModuleVariabl=E2=80=A6?= =?UTF-8?q?=20(#135041)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Potentially fixes - https://fb.workplace.com/groups/1286739428954016/permalink/1319662695661689/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/135041 Approved by: https://github.com/ezyang --- test/distributed/test_dynamo_distributed.py | 47 +++++++++++++++++++++ torch/_dynamo/source.py | 6 +++ 2 files changed, 53 insertions(+) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index c689b8b321f6a1..d6357678c94f97 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -127,6 +127,24 @@ def get_mutating_model( return m, inputs, outputs +class ForcedGetAttrMod(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.linear = torch.nn.Linear(1, 1) + self.__dict__["forced_linear"] = torch.nn.Linear(1, 1).to(device=device) + self.counter = 0 + + def forward(self, x): + self.counter += 1 + return x * self.linear(x) * self.forced_linear.weight + + +def get_forced_getattr_module(device): + mod = ForcedGetAttrMod(device).to(device=device) + x = torch.randn(1, 1, device=device) + return mod, x, mod(x) + + class ToyInnerModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -627,6 +645,35 @@ def test_fsdp_setattr(self): first_graph_break = list(counters["graph_break"].keys())[0] # noqa: RUF015 self.assertTrue("setattr" not in first_graph_break) + @config.patch(inline_inbuilt_nn_modules=False) + @config.patch(enable_compiler_collectives=True) + @skip_if_lt_x_gpu(1) + def test_fsdp_unspecialized_forced_getattr_no_inline(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + from torch._dynamo.utils import counters + + counters.clear() + m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @config.patch(enable_compiler_collectives=True) + @skip_if_lt_x_gpu(1) + def test_fsdp_unspecialized_forced_getattr_inline(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + from torch._dynamo.utils import counters + + counters.clear() + m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 860e7d3429f487..6ed9d97f6a1210 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -25,6 +25,8 @@ GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, } # represents nn.Modules tracked with UnspecializedNNModuleVariable @@ -39,6 +41,8 @@ # Just to ensure that guard_source() works GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, } # represents nn.Modules tracked with UnspecializedBuiltinNNModuleVariable @@ -52,6 +56,8 @@ # Just to ensure that guard_source() works GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, } _GUARD_SOURCE_FSDP_MODULE = { From 4f07269a91f99c9d255f24c132b5f3509529e214 Mon Sep 17 00:00:00 2001 From: chuanqiw Date: Wed, 11 Sep 2024 00:56:15 +0000 Subject: [PATCH 0697/1018] Bump triton xpu pin and release version (#135638) Similar with https://github.com/pytorch/pytorch/pull/135627 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135638 Approved by: https://github.com/atalman --- .ci/docker/ci_commit_pins/triton-xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index 4b77bbfc0f8fa6..e25f89a42f2e6f 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -cc981feba10a3f4c2e46f3fe368e8fcf5f5643df +91b14bf5593cf58a8541f3e6b9125600a867d4ef From fab6d9131b95137cc46b2798c716effed375a292 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 10 Sep 2024 16:08:07 -0700 Subject: [PATCH 0698/1018] [aoti] Add cpp loader (#135374) * Added a cpp loader, AOTIModelPackageLoader, which can load the .pt2, build the .so, and create a runner. The python-facing API is that users can directly call the `run` function, whereas in cpp users can directly access the `runner_` if they are more familiar with that. I couldn't figure out how to bind the `get_runner()` function to python... * Added a new config, `aot_inductor.package_cpp_only` which will **not** package the so. This means that whenever the package is loaded, we will need to build the so. This is turned off by default so that new environments do not need to rebuild their so. The `package_cpp_only` is a feature which torchchat intends to use to provide flexibility to users. * Added a new config, `aot_inductor.metadata` which stores user-provided metadata, serialized to the pt2 as a json file. It also stores the device used when exporting, "cuda" or "cpu", so that during load time, we can use that data to determine which AOTIModelContainerRunner to use. The metadata can be accessed through `loader.get_metadata()`. TODO is to move this metadata to the toplevel `package_aoti` function so that we can remove the metadata as a config. * Separated out `package_aoti` as a standalone function, instead of it automatically being called in inductor. This is to prepare for the case where users will compile multiple models, and want to bundle it in one package. The specific use case is in torchchat, where we want to package the separately-exported encoder and decoder layers. An example of how to use this is in `test_multiple_methods`. * `load_package` will load a singular model, given the model name. * The loader doesn't support windows for now, I think I need to add some more casing to make the build commands work on windows? Differential Revision: [D62329906](https://our.internmc.facebook.com/intern/diff/D62329906) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135374 Approved by: https://github.com/desertfire, https://github.com/malfet --- .ci/docker/requirements-ci.txt | 5 + .ci/pytorch/win-test.sh | 3 + .../requirements/pip-requirements-macOS.txt | 1 + build_variables.bzl | 2 + setup.py | 1 + test/inductor/test_aot_inductor_package.py | 217 +++++---- torch/_C/_aoti.pyi | 3 + torch/_inductor/__init__.py | 42 ++ torch/_inductor/codecache.py | 166 +++---- torch/_inductor/compile_fx.py | 1 + torch/_inductor/config.py | 7 + torch/_inductor/package/__init__.py | 2 +- torch/_inductor/package/package.py | 140 +++--- torch/csrc/Module.cpp | 2 + .../aoti_package/model_package_loader.cpp | 411 ++++++++++++++++++ .../aoti_package/model_package_loader.h | 40 ++ torch/csrc/inductor/aoti_package/pybind.cpp | 24 + torch/csrc/inductor/aoti_package/pybind.h | 7 + .../aoti_runner/model_container_runner.h | 2 +- .../model_container_runner_cpu.cpp | 16 + .../model_container_runner_cuda.cpp | 13 + .../csrc/inductor/aoti_torch/shim_common.cpp | 39 +- 22 files changed, 894 insertions(+), 250 deletions(-) create mode 100644 torch/csrc/inductor/aoti_package/model_package_loader.cpp create mode 100644 torch/csrc/inductor/aoti_package/model_package_loader.h create mode 100644 torch/csrc/inductor/aoti_package/pybind.cpp create mode 100644 torch/csrc/inductor/aoti_package/pybind.h diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 33166eb1518723..88ecc0d1a7edf3 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -337,3 +337,8 @@ onnxscript==0.1.0.dev20240817 #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal #Pinned versions: #test that import: + +parameterized==0.8.1 +#Description: Parameterizes unittests, both the tests themselves and the entire testing class +#Pinned versions: +#test that import: diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index 69f0c88aa30744..09b624183c7ae7 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -43,6 +43,9 @@ python -m pip install z3-solver==4.12.2.0 # Install tlparse for test\dynamo\test_structured_trace.py UTs. python -m pip install tlparse==0.3.25 +# Install parameterized +python -m pip install parameterized==0.8.1 + run_tests() { # Run nvidia-smi if available for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index caa831b882312c..c72d1b568ca112 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -31,3 +31,4 @@ optree==0.12.1 # NB: test_hparams_* from test_tensorboard is failing with protobuf 5.26.0 in # which the stringify metadata is wrong when escaping double quote protobuf==3.20.2 +parameterized==0.8.1 diff --git a/build_variables.bzl b/build_variables.bzl index 11f5a5c9108e35..d11bba1ae1f37a 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -466,6 +466,7 @@ lazy_tensor_core_python_sources = [ ] inductor_core_resources = [ + "torch/csrc/inductor/aoti_package/model_package_loader.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp", "torch/csrc/inductor/aoti_torch/shim_common.cpp", @@ -845,6 +846,7 @@ libtorch_python_core_sources = [ "torch/csrc/fx/node.cpp", "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", + "torch/csrc/inductor/aoti_package/pybind.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/inductor/aoti_eager/kernel_holder.cpp", "torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp", diff --git a/setup.py b/setup.py index 89097e9fdba3e4..b2cb4cd5926172 100644 --- a/setup.py +++ b/setup.py @@ -1328,6 +1328,7 @@ def main(): "include/torch/csrc/distributed/autograd/rpc_messages/*.h", "include/torch/csrc/dynamo/*.h", "include/torch/csrc/inductor/*.h", + "include/torch/csrc/inductor/aoti_package/*.h", "include/torch/csrc/inductor/aoti_runner/*.h", "include/torch/csrc/inductor/aoti_runtime/*.h", "include/torch/csrc/inductor/aoti_torch/*.h", diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 0e20045cdbfc92..490b0e0324733c 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -1,78 +1,95 @@ # Owner(s): ["module: inductor"] import copy import sys +import tempfile import unittest +from parameterized import parameterized_class + import torch -from torch._inductor import config -from torch._inductor.package import load_package +from torch._inductor.package import AOTICompiledModel, load_package, package_aoti from torch._inductor.test_case import TestCase -from torch.testing._internal import common_utils +from torch.export import Dim from torch.testing._internal.common_utils import IS_FBCODE from torch.testing._internal.triton_utils import HAS_CUDA -try: - try: - from .test_torchinductor import copy_tests - except ImportError: - from test_torchinductor import copy_tests -except (unittest.SkipTest, ImportError) as e: - if __name__ == "__main__": - sys.exit(0) - raise - - -def compile(model, example_inputs, dynamic_shapes, options, device): +def compile( + model, + args, + kwargs=None, + *, + dynamic_shapes=None, + package_path=None, + inductor_configs=None, +) -> AOTICompiledModel: ep = torch.export.export( model, - example_inputs, + args, + kwargs, dynamic_shapes=dynamic_shapes, strict=False, ) - gm = ep.module() - package_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type] - compiled_model = load_package(package_path, device) - return compiled_model - - -def check_model( - self: TestCase, - model, - example_inputs, - options=None, - dynamic_shapes=None, - disable_constraint_solver=False, - atol=None, - rtol=None, -): - with torch.no_grad(), config.patch( - { - "aot_inductor.package": True, - # TODO: "aot_inductor.force_mmap_weights": True, - } - ): - torch.manual_seed(0) - model = model.to(self.device) - ref_model = copy.deepcopy(model) - ref_inputs = copy.deepcopy(example_inputs) - expected = ref_model(*ref_inputs) - - torch.manual_seed(0) - compiled_model = compile( - model, - example_inputs, - dynamic_shapes, - options, - self.device, - ) - - actual = compiled_model(*example_inputs) + package_path = torch._inductor.aoti_compile_and_package( + ep, args, kwargs, package_path=package_path, inductor_configs=inductor_configs + ) # type: ignore[arg-type] + loaded = load_package(package_path) + return loaded - self.assertEqual(actual, expected, atol=atol, rtol=rtol) +@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") +@unittest.skipIf(IS_FBCODE, "This is for OSS only") +@parameterized_class( + [ + {"device": "cpu", "package_cpp_only": False}, + {"device": "cpu", "package_cpp_only": True}, + ] + + ( + [ + {"device": "cuda", "package_cpp_only": False}, + {"device": "cuda", "package_cpp_only": True}, + ] + if sys.platform != "darwin" + else [] + ), + class_name_func=lambda cls, _, params: f"{cls.__name__}{'Cpp' if params['package_cpp_only'] else ''}_{params['device']}", +) +class TestAOTInductorPackage(TestCase): + def check_model( + self: TestCase, + model, + example_inputs, + inductor_configs=None, + dynamic_shapes=None, + disable_constraint_solver=False, + atol=None, + rtol=None, + ) -> AOTICompiledModel: + with torch.no_grad(): + torch.manual_seed(0) + model = model.to(self.device) + ref_model = copy.deepcopy(model) + ref_inputs = copy.deepcopy(example_inputs) + expected = ref_model(*ref_inputs) + + inductor_configs = inductor_configs or {} + inductor_configs["aot_inductor.package_cpp_only"] = self.package_cpp_only + + torch.manual_seed(0) + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + compiled_model = compile( + model, + example_inputs, + dynamic_shapes=dynamic_shapes, + inductor_configs=inductor_configs, + package_path=f.name, + ) + + actual = compiled_model(*example_inputs) + + self.assertEqual(actual, expected, atol=atol, rtol=rtol) + return compiled_model -class AOTInductorTestsTemplate: def test_add(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -99,34 +116,84 @@ def forward(self, x, y): ) self.check_model(Model(), example_inputs) + def test_metadata(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x, y): + return x + self.linear(y) -common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) + example_inputs = ( + torch.randn(10, 10, device=self.device), + torch.randn(10, 10, device=self.device), + ) + metadata = {"dummy": "moo"} + compiled_model = self.check_model( + Model(), + example_inputs, + inductor_configs={"aot_inductor.metadata": metadata}, + ) + loaded_metadata = compiled_model.get_metadata() # type: ignore[attr-defined] -@unittest.skipIf(sys.platform == "darwin" or IS_FBCODE, "No CUDA on MacOS") -class AOTInductorTestPackagedABICompatibleCuda(TestCase): - device = "cuda" - check_model = check_model + self.assertEqual(loaded_metadata.get("dummy"), "moo") + def test_multiple_methods(self): + options = { + "aot_inductor.package": True, + "aot_inductor.package_cpp_only": self.package_cpp_only, + } -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestPackagedABICompatibleCuda, - "packaged_abi_compatible_cuda", -) + class Model1(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + def forward(self, a, b): + return torch.cat([a, b], dim=0) -@unittest.skipIf(IS_FBCODE, "This is for OSS only") -class AOTInductorTestPackagedABICompatibleCpu(TestCase): - device = "cpu" - check_model = check_model + b = torch.randn(3, 4, device=self.device) + dim0_a = Dim("dim0_a", min=1, max=10) + dim0_b = Dim("dim0_b", min=1, max=20) + dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}} + example_inputs1 = ( + torch.randn(2, 4, device=self.device), + torch.randn(3, 4, device=self.device), + ) + ep1 = torch.export.export( + Model1(), example_inputs1, dynamic_shapes=dynamic_shapes + ) + aoti_files1 = torch._inductor.aot_compile( + ep1.module(), example_inputs1, options=options + ) + + class Model2(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + def forward(self, x): + t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float) + t = torch.sqrt(t * 3) + return x * t + + example_inputs2 = (torch.randn(5, 5, device=self.device),) + ep2 = torch.export.export(Model2(self.device), example_inputs2) + aoti_files2 = torch._inductor.aot_compile( + ep2.module(), example_inputs2, options=options + ) + + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + package_path = package_aoti( + f.name, {"model1": aoti_files1, "model2": aoti_files2} + ) + loaded1 = load_package(package_path, "model1") + loaded2 = load_package(package_path, "model2") + + self.assertEqual(loaded1(*example_inputs1), ep1.module()(*example_inputs1)) + self.assertEqual(loaded2(*example_inputs2), ep2.module()(*example_inputs2)) -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestPackagedABICompatibleCpu, - "packaged_abi_compatible_cpu", -) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_C/_aoti.pyi b/torch/_C/_aoti.pyi index a5e782fe62123e..4e9f5e7c8671b1 100644 --- a/torch/_C/_aoti.pyi +++ b/torch/_C/_aoti.pyi @@ -18,3 +18,6 @@ def alloc_tensor_by_stealing_from_void_ptr( class AOTIModelContainerRunnerCpu: ... class AOTIModelContainerRunnerCuda: ... + +# Defined in torch/csrc/inductor/aoti_package/pybind.cpp +class AOTIModelPackageLoader: ... diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index f95e7caaf71e95..404869debf1766 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -30,6 +30,48 @@ def compile( return compile_fx(gm, example_inputs, config_patches=options) +def aoti_compile_and_package( + exported_program, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + package_path: Optional[str] = None, + inductor_configs: Optional[Dict[str, Any]] = None, +) -> str: + """ + Compiles the exported program with AOTInductor, and packages it into a .pt2 + file specified by the input package_path. + """ + from torch._inductor.package import package_aoti + from torch.export import ExportedProgram + + if not isinstance(exported_program, ExportedProgram): + raise ValueError("Only ExportedProgram is supported") + + assert package_path is None or package_path.endswith(".pt2") + + inductor_configs = inductor_configs or {} + + if inductor_configs.get("aot_inductor.output_path"): + raise RuntimeError( + "Please pass in a package path to aot_inductor_compile() instead " + "of setting the aot_inductor.output_path config." + ) + inductor_configs["aot_inductor.package"] = True + + m = exported_program.module() + assert isinstance(m, torch.fx.GraphModule) + + aoti_files = aot_compile(m, args, kwargs, options=inductor_configs) # type: ignore[arg-type] + + if package_path is None: + package_path = aoti_files + ".pt2" + + res = package_aoti(package_path, aoti_files) + assert res == package_path + return package_path + + def aot_compile( gm: torch.fx.GraphModule, args: Tuple[Any], diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6d77ea959e5f36..b7f3bbf1e9dc99 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1719,10 +1719,23 @@ def _compile_consts_darwin(consts: bytes) -> str: # Currently, this only support serializing extern nodes in fbcode # Eventually, we should also have a serializer for OSS. if serialized_extern_kernel_nodes: - output_json = os.path.splitext(input_path)[0] + ".json" - with open(output_json, "w") as f: + extern_kernel_nodes_json = os.path.splitext(input_path)[0] + ".json" + with open(extern_kernel_nodes_json, "w") as f: f.write(serialized_extern_kernel_nodes) + metadata = config.aot_inductor.metadata + metadata["AOTI_DEVICE_KEY"] = device_type + + # Save user provided metadata + meta_json = os.path.splitext(input_path)[0] + "_metadata.json" + for k, v in config.aot_inductor.metadata.items(): + assert isinstance(k, str) and isinstance( + v, (str) + ), "Metadata must only contain strings" + + with open(meta_json, "w") as f: + f.write(json.dumps(config.aot_inductor.metadata)) + output_so = ( config.aot_inductor.output_path if specified_so_name @@ -1755,54 +1768,29 @@ def get_nbytes_of_tensor(tensor: torch.Tensor, all_cuda: bool) -> int: if config.aot_inductor.force_mmap_weights: use_mmap_weights = True - if config.aot_inductor.package: - ( - object_output_name, - object_output_dir, - ) = get_name_and_dir_from_output_file_path(input_path) - object_build_options = CppTorchDeviceOptions( - vec_isa=picked_vec_isa, - device_type=device_type, - aot_mode=graph.aot_mode, - compile_only=True, - use_absolute_path=use_absolute_path, - use_mmap_weights=use_mmap_weights, - ) - object_builder = CppBuilder( - name=object_output_name, - sources=input_path, - output_dir=object_output_dir, - BuildOption=object_build_options, - ) - compile_cmd = object_builder.get_command_line() - output_o = object_builder.get_target_file_path() - - compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" - object_build_options.save_flags_to_file(compile_flags) - - else: - ( - object_output_name, - object_output_dir, - ) = get_name_and_dir_from_output_file_path(input_path) - object_build_options = CppTorchDeviceOptions( - vec_isa=picked_vec_isa, - device_type=device_type, - aot_mode=graph.aot_mode, - compile_only=True, - use_absolute_path=use_absolute_path, - use_mmap_weights=use_mmap_weights, - ) - object_builder = CppBuilder( - name=object_output_name, - sources=input_path, - output_dir=object_output_dir, - BuildOption=object_build_options, - ) - compile_cmd = object_builder.get_command_line() - output_o = object_builder.get_target_file_path() + ( + object_output_name, + object_output_dir, + ) = get_name_and_dir_from_output_file_path(input_path) + object_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + compile_only=True, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + object_builder = CppBuilder( + name=object_output_name, + sources=input_path, + output_dir=object_output_dir, + BuildOption=object_build_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() - log.debug("aot compilation command: %s", compile_cmd) + log.debug("aot compilation command: %s", compile_cmd) + if not config.aot_inductor.package_cpp_only: if fbcode_aot_cpu_re: output_o = os.path.splitext(input_path)[0] + ".o" compile_file(input_path, output_o, compile_cmd.split()) @@ -1810,6 +1798,10 @@ def get_nbytes_of_tensor(tensor: torch.Tensor, all_cuda: bool) -> int: else: run_command_and_check(compile_cmd) + if config.aot_inductor.package: + compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" + object_build_options.save_flags_to_file(compile_flags) + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: def _pad_to_alignment(raw_bytes: bytes) -> bytes: padded_bytes = raw_bytes.ljust( @@ -1859,29 +1851,37 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: "darwin": _compile_consts_darwin, }[sys.platform](aot_constants) - if config.aot_inductor.package: - output_name, output_dir = get_name_and_dir_from_output_file_path( - output_so - ) - so_build_options = CppTorchDeviceOptions( - vec_isa=picked_vec_isa, - device_type=device_type, - aot_mode=graph.aot_mode, - use_absolute_path=use_absolute_path, - ) - so_builder = CppBuilder( - name=output_name, - sources=[output_o, consts_o], - output_dir=output_dir, - BuildOption=so_build_options, - ) - link_cmd = so_builder.get_command_line() - output_so = so_builder.get_target_file_path() + output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) + so_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) + so_builder = CppBuilder( + name=output_name, + sources=[output_o, consts_o], + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + log.debug("aot linkage command: %s", link_cmd) + + # Append cmds to the end of codegen-ed wrapper file + with open(input_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + if config.aot_inductor.package: linker_flags = os.path.splitext(input_path)[0] + "_linker_flags.json" so_build_options.save_flags_to_file(linker_flags) - from torch._inductor.package import package_aoti + if config.aot_inductor.package_cpp_only: + # If we only want to package the cpp, then we need to save the + # weights separately into a bin, and we also need to prevent compiling the so if use_mmap_weights: weight_file = ( @@ -1891,28 +1891,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: f_weights.write(serialized_weights) f_weights.write(struct.pack("q", magic_number)) - archive_path = package_aoti(os.path.split(input_path)[0]) - return archive_path else: - output_name, output_dir = get_name_and_dir_from_output_file_path( - output_so - ) - so_build_options = CppTorchDeviceOptions( - vec_isa=picked_vec_isa, - device_type=device_type, - aot_mode=graph.aot_mode, - use_absolute_path=use_absolute_path, - ) - so_builder = CppBuilder( - name=output_name, - sources=[output_o, consts_o], - output_dir=output_dir, - BuildOption=so_build_options, - ) - link_cmd = so_builder.get_command_line() - output_so = so_builder.get_target_file_path() - - log.debug("aot linkage command: %s", link_cmd) if fbcode_aot_cpu_re: output_so = ( config.aot_inductor.output_path @@ -1937,11 +1916,10 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: f_so.write(serialized_weights) f_so.write(struct.pack("q", magic_number)) - # Append cmds to the end of codegen-ed wrapper file - with open(input_path, "a") as f: - f.write("\n") - f.write(f"// Compile cmd\n// {compile_cmd}\n") - f.write(f"// Link cmd\n// {link_cmd}\n") + if config.aot_inductor.package: + # We want to return the directory that contains all the AOTI + # generated files, not just the so + return os.path.split(output_so)[0] return output_so diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 08ff9c281fb90d..4c583a4b2a90a9 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1132,6 +1132,7 @@ def compile_fx_aot( if config_patches is None else {**config_patches, "cpp_wrapper": True} ) + if ( "aot_inductor.output_path" not in config_patches and not config.aot_inductor.output_path diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 0ced92c3271bd1..32144d4ef97560 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1006,9 +1006,11 @@ class aot_inductor: ) # Serialized tree spec for flattening inputs + # TODO: Move this into metadata serialized_in_spec = "" # Serialized tree spec for flattening outputs + # TODO: Move this into metadata serialized_out_spec = "" # flag to decide whether to create a submodule for constant graph. @@ -1019,6 +1021,11 @@ class aot_inductor: force_mmap_weights: bool = False package: bool = False + package_cpp_only: bool = False + + # Dictionary of metadata users might want to save to pass to the runtime. + # TODO: Move this somewhere else, since it's no longer really a config + metadata: Dict[str, str] = {} class cuda: diff --git a/torch/_inductor/package/__init__.py b/torch/_inductor/package/__init__.py index c088562100cdae..15587401b72358 100644 --- a/torch/_inductor/package/__init__.py +++ b/torch/_inductor/package/__init__.py @@ -1 +1 @@ -from .package import load_package, package_aoti +from .package import AOTICompiledModel, load_package, package_aoti diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index d1304293e3cd88..ca62b7172e664d 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -1,24 +1,25 @@ -import glob import json +import logging import os import shlex import subprocess -import tempfile import zipfile from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Dict, List, Optional, Union import torch import torch._inductor import torch.utils._pytree as pytree -from torch._inductor import config, exc +from torch._inductor import exc from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder from torch.export._tree_utils import reorder_kwargs -from .build_package import build_package_contents from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION +log = logging.getLogger(__name__) + + class PT2ArchiveWriter: def __init__(self, archive_path: str) -> None: self.archive_path: str = archive_path @@ -154,84 +155,71 @@ def get_aoti_file_with_suffix(suffix: str) -> str: return output_so -def package_aoti(aoti_output_dir: str) -> str: +def package_aoti(archive_file: str, aoti_files: Union[str, Dict[str, str]]) -> str: """ Saves the AOTInductor generated files to the PT2Archive format. - """ - - # Add a makefile and python script - build_package_filename = "build_package.py" - with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f: - f.write(build_package_contents) - - with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f: - f.write(f"all:\n\tpython3 {build_package_filename}\n") - - if config.aot_inductor.output_path.endswith(".so"): - raise RuntimeError( - "Unable to save package as a .so. It should be a .pt2 format or a directory." - ) - elif config.aot_inductor.output_path.endswith(".pt2"): - # Save using the PT2 packaging format - # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) - archive_path = config.aot_inductor.output_path - - with PT2ArchiveWriter(archive_path) as archive_writer: - package_files = glob.glob(f"{aoti_output_dir}/*") - - for path in package_files: - filename = os.path.basename(path) - archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path) - return archive_path - - else: - # Directly put the files into the directory, without any archiving - return aoti_output_dir + Args: + archive_file: The file name to save the package to. + aoti_files: This can either be a singular path to a directory containing + the AOTInductor files, or a dictionary mapping the model name to the + path to its AOTInductor generated files. + """ + if isinstance(aoti_files, str): + aoti_files = {"model": aoti_files} + + assert isinstance(aoti_files, dict) + assert archive_file.endswith(".pt2") + + # Save using the PT2 packaging format + # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) + + with PT2ArchiveWriter(archive_file) as archive_writer: + for model_name, aoti_output_dir in aoti_files.items(): + log.debug( + "Packaging AOTInductor files from %s with model name, %s", + aoti_output_dir, + model_name, + ) + for root, dirs, files in os.walk(aoti_output_dir): + for file in files: + log.debug( + "Saving AOTI generated file %s to archive in %s%s/%s", + os.path.join(root, file), + AOTINDUCTOR_DIR, + model_name, + file, + ) + archive_writer.write_file( + f"{AOTINDUCTOR_DIR}{model_name}/{file}", + os.path.join(root, file), + ) + return archive_file + + +class AOTICompiledModel: + """ + Callable AOT Inductor loaded model from a .pt2 + """ + def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: + self.loader = loader -def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg] - if path.endswith(".so"): - raise RuntimeError( - "Unable to load .so. It should be a .pt2 format or a directory." - ) - - elif path.endswith(".pt2"): - so_path = os.path.splitext(path)[0] - with PT2ArchiveReader(path) as archive_reader: - file_names = archive_reader.get_file_names() - - with tempfile.TemporaryDirectory() as tmp_dir: - archive_reader.extractall(tmp_dir) - file_names = archive_reader.get_file_names() - aoti_files = [ - file for file in file_names if file.startswith(AOTINDUCTOR_DIR) - ] - - so_path = compile_so(tmp_dir, aoti_files, so_path) - - else: - assert os.path.isdir(path), "Must specify a directory or a .pt2 file" - aoti_files = [ - os.path.join(root, file) - for root, dirs, files in os.walk(path) - for file in files - ] - so_path = compile_so(path, aoti_files, path) - - if device == "cpu": - runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] - elif device == "cuda" or device.startswith("cuda:"): - runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] - else: - raise RuntimeError("Unsupported device " + device) - - def optimized(*args, **kwargs): # type: ignore[no-untyped-def] - call_spec = runner.get_call_spec() # type: ignore[attr-defined] + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + call_spec = self.loader.get_call_spec() # type: ignore[attr-defined] in_spec = pytree.treespec_loads(call_spec[0]) out_spec = pytree.treespec_loads(call_spec[1]) flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] - flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + flat_outputs = self.loader.run(flat_inputs) # type: ignore[attr-defined] return pytree.tree_unflatten(flat_outputs, out_spec) - return optimized + def get_metadata(self) -> Dict[str, str]: + return self.loader.get_metadata() # type: ignore[attr-defined] + + +def load_package(path: str, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg] + if not path.endswith(".pt2"): + raise RuntimeError("Unable to load package. Path must be a .pt2 file.") + + loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg] + return AOTICompiledModel(loader) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 3dff5fb65ddef0..19433d62985fd6 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -69,6 +69,7 @@ #include #include #include +#include #include #include #include @@ -1687,6 +1688,7 @@ PyObject* initModule() { torch::python::init_bindings(module); torch::lazy::initLazyBindings(module); torch::inductor::initAOTIRunnerBindings(module); + torch::inductor::initAOTIPackageBindings(module); #ifdef USE_ITT torch::profiler::initIttBindings(module); #endif diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp new file mode 100644 index 00000000000000..1c09d9186d0a15 --- /dev/null +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -0,0 +1,411 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) + +#include +#include +#include + +#include +#include +#include +#include +#include + +// TODO: Investigate why this is necessary, but fixes build problems in FRL +#if __has_include("filesystem") +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + +#ifndef _WIN32 +#include +#endif + +// TODO: C++17 has the filesystem header, which may replace these +#ifdef _WIN32 +// On Windows, the POSIX implementations are considered deprecated. We simply +// map to the newer variant. +#include +#include +#include +#define access _access +#define F_OK 0 +#else +#include +#include +#endif + +namespace { +bool file_exists(std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} + +std::string create_temp_dir() { +#ifdef _WIN32 + throw std::runtime_error("Not implemented"); +#else + std::string temp_dir = "/tmp/XXXXXX"; + if (mkdtemp(temp_dir.data()) == nullptr) { + throw std::runtime_error( + std::string("Failed to create temporary directory: ") + + strerror(errno)); + } + return temp_dir; +#endif +} +} // namespace + +namespace torch::inductor { + +const nlohmann::json& AOTIModelPackageLoader::load_json_file( + std::string json_path) { + if (!file_exists(json_path)) { + throw std::runtime_error("File found: " + json_path); + } + + std::ifstream json_file(json_path); + TORCH_CHECK(json_file.is_open()); + static nlohmann::json json_obj; + json_file >> json_obj; + + return json_obj; +} + +void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { + // Parse metadata json file (if it exists) into the metadata_ map + size_t lastindex = cpp_filename.find_last_of('.'); + std::string metadata_json_path = + cpp_filename.substr(0, lastindex) + "_metadata.json"; + + const nlohmann::json metadata_json_obj = load_json_file(metadata_json_path); + + for (auto& item : metadata_json_obj.items()) { + metadata_[item.key()] = item.value().get(); + } +} + +std::tuple AOTIModelPackageLoader:: + get_cpp_compile_command( + const std::string& filename, + const std::vector& sources, + const nlohmann::json& compile_options, + const std::string& output_dir = "") { + // Construct the cpp command + + std::string compiler = compile_options["compiler"].get(); + bool compile_only = compile_options["compile_only"].get(); + + std::string source_args = ""; + for (const std::string& source : sources) { + source_args += source + " "; + } + + std::string file_ext = compile_only ? ".o" : ".so"; + std::string target_file = output_dir + filename + file_ext; + + std::string cflags_args = ""; + for (auto& arg : compile_options["cflags"]) { + cflags_args += "-" + arg.get() + " "; + } + + std::string definitions_args = ""; + for (auto& arg : compile_options["definitions"]) { + definitions_args += "-D " + arg.get() + " "; + } + + std::string include_dirs_args = ""; + for (auto& arg : compile_options["include_dirs"]) { + include_dirs_args += "-I" + arg.get() + " "; + } + + std::string ldflags_args = ""; + for (auto& arg : compile_options["ldflags"]) { + ldflags_args += "-" + arg.get() + " "; + } + + std::string libraries_dirs_args = ""; + for (auto& arg : compile_options["libraries_dirs"]) { + libraries_dirs_args += "-L" + arg.get() + " "; + } + + std::string libraries_args = ""; + for (auto& arg : compile_options["libraries"]) { + libraries_args += "-l" + arg.get() + " "; + } + + std::string passthrough_parameters_args = ""; + for (auto& arg : compile_options["passthrough_args"]) { + passthrough_parameters_args += arg.get() + " "; + } + + std::string compile_only_arg = compile_only ? "-c" : ""; + + std::string cmd = fmt::format( + "{} {} {} {} {} {} {} {} {} {} -o {}", + compiler, + source_args, + definitions_args, + cflags_args, + include_dirs_args, + passthrough_parameters_args, + ldflags_args, + libraries_args, + libraries_dirs_args, + compile_only_arg, + target_file); + + return std::make_tuple(cmd, target_file); +} + +bool AOTIModelPackageLoader::recursive_mkdir(const std::string& dir) { + // Creates directories recursively, copied from jit_utils.cpp + // Check if current dir exists + const char* p_dir = dir.c_str(); + const bool dir_exists = (access(p_dir, F_OK) == 0); + if (dir_exists) { + return true; + } + + // Try to create current directory +#ifdef _WIN32 + int ret = _mkdir(dir.c_str()); +#else + int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + // Success + if (ret == 0) { + return true; + } + + // Find folder separator and check if we are at the top + auto pos = dir.find_last_of("/\\"); + if (pos == std::string::npos) { + return false; + } + + // Try to create parent directory + if (!(recursive_mkdir(dir.substr(0, pos)))) { + return false; + } + + // Try to create complete path again +#ifdef _WIN32 + ret = _mkdir(dir.c_str()); +#else + ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + return ret == 0; +} + +std::string AOTIModelPackageLoader::compile_so( + const std::string& cpp_filename, + const std::string& consts_filename) { + // Compile the cpp file into a .so + + size_t lastindex = cpp_filename.find_last_of('.'); + std::string filename = cpp_filename.substr(0, lastindex); + + std::string compile_flags_path = filename + "_compile_flags.json"; + const nlohmann::json compile_flags = load_json_file(compile_flags_path); + + auto compile_result = + get_cpp_compile_command(filename, {cpp_filename}, compile_flags); + std::string compile_cmd = std::get<0>(compile_result); + std::string output_o = std::get<1>(compile_result); + + std::string linker_flags_path = + cpp_filename.substr(0, lastindex) + "_linker_flags.json"; + const nlohmann::json linker_flags = load_json_file(linker_flags_path); + + auto link_result = get_cpp_compile_command( + filename, {output_o, consts_filename}, linker_flags); + std::string link_cmd = std::get<0>(link_result); + std::string output_so = std::get<1>(link_result); + + // Run the commands to generate a .so file + int status = system(compile_cmd.c_str()); + if (status != 0) { + throw std::runtime_error("Failed to compile cpp file."); + } + status = system(link_cmd.c_str()); + if (status != 0) { + throw std::runtime_error("Failed to link files."); + } + + // Move the mmapped weights onto the .so + std::string serialized_weights_path = filename + "_serialized_weights.bin"; + if (file_exists(serialized_weights_path)) { + std::ifstream serialized_weights_file( + serialized_weights_path, std::ios::binary); + if (!serialized_weights_file.is_open()) { + throw std::runtime_error("Failed to open serialized weights file"); + } + std::vector serialized_weights( + (std::istreambuf_iterator(serialized_weights_file)), + std::istreambuf_iterator()); + serialized_weights_file.close(); + + std::ofstream output_so_file(output_so, std::ios::binary | std::ios::app); + if (!output_so_file.is_open()) { + throw std::runtime_error("Failed to open output .so file"); + } + // Page align the weights + std::streampos so_size = output_so_file.tellp(); + std::vector padding(16384 - so_size % 16384, ' '); + output_so_file.write( + padding.data(), static_cast(padding.size())); + output_so_file.write( + serialized_weights.data(), + static_cast(serialized_weights.size())); + output_so_file.close(); + } + + return output_so; +} + +AOTIModelPackageLoader::AOTIModelPackageLoader( + const std::string& model_package_path) + : AOTIModelPackageLoader(model_package_path, "model") {} + +AOTIModelPackageLoader::AOTIModelPackageLoader( + const std::string& model_package_path, + const std::string& model_name = "model") { + // Extract all files within the zipfile to a temporary directory + mz_zip_archive zip_archive; + memset(&zip_archive, 0, sizeof(zip_archive)); + + if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) { + throw std::runtime_error( + std::string("Failed to initialize zip archive: ") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + + std::string temp_dir = create_temp_dir(); + std::string so_filename = ""; + std::string cpp_filename = ""; + std::string consts_filename = ""; + std::string found_filenames = ""; // Saving for bookkeeping + + for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { + uint32_t filename_len = + mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); + if (filename_len == 0) { + throw std::runtime_error("Failed to read filename"); + } + char* filename = new char[filename_len + 1]; + if (!mz_zip_reader_get_filename(&zip_archive, i, filename, filename_len)) { + throw std::runtime_error("Failed to read filename"); + } + + std::string filename_str(filename); + found_filenames += filename_str; + found_filenames += " "; + + // Only compile files in the specified model directory + std::string model_directory = "data/aotinductor/" + model_name; + if (filename_str.length() >= model_directory.length() && + filename_str.substr(0, model_directory.length()) == model_directory) { + std::string output_path_str = temp_dir; + output_path_str += "/"; + output_path_str += filename_str; + + // Create the parent directory if it doesn't exist + size_t parent_path_idx = output_path_str.find_last_of("/\\"); + if (parent_path_idx == std::string::npos) { + throw std::runtime_error( + "Failed to find parent path in " + output_path_str); + } + std::string parent_path = output_path_str.substr(0, parent_path_idx); + if (!recursive_mkdir(parent_path.c_str())) { + throw std::runtime_error(fmt::format( + "Failed to create directory {}: {}", parent_path, strerror(errno))); + } + + // Extracts file to the temp directory + mz_zip_reader_extract_file_to_file( + &zip_archive, filename, output_path_str.c_str(), 0); + + // Save the file for bookkeeping + size_t extension_idx = output_path_str.find_last_of('.'); + if (extension_idx != std::string::npos) { + std::string filename_extension = output_path_str.substr(extension_idx); + if (filename_extension == ".cpp") { + cpp_filename = output_path_str; + } + if (filename_extension == ".o") { + consts_filename = output_path_str; + } + if (filename_extension == ".so") { + so_filename = output_path_str; + } + } + } + } + + // Close the zip archive as we have extracted all files to the temp directory + if (!mz_zip_reader_end(&zip_archive)) { + throw std::runtime_error( + std::string("Failed to close zip archive: {}") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + + if (cpp_filename.empty() && so_filename.empty()) { + throw std::runtime_error( + "No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" + + found_filenames); + } + + // Compile the .so + std::string so_path = !so_filename.empty() + ? so_filename + : compile_so(cpp_filename, consts_filename); + + // Load metadata which can be queried by user + load_metadata(cpp_filename); + + // Construct the runner depending on the device information + std::string device = metadata_["AOTI_DEVICE_KEY"]; + + if (device.empty()) { + throw std::runtime_error("No device information found."); + } + + std::unordered_map + registered_aoti_runner = getAOTIModelRunnerRegistry(); + + if (registered_aoti_runner.find(device) == registered_aoti_runner.end()) { + throw std::runtime_error("Unsupported device found: " + device); + } + + runner_ = registered_aoti_runner[device](so_path, 1, device, ""); + + std::remove(temp_dir.c_str()); +} + +AOTIModelContainerRunner* AOTIModelPackageLoader::get_runner() { + return runner_.get(); +} + +std::vector AOTIModelPackageLoader::run( + std::vector& inputs) { + return runner_->run(inputs); +} + +std::unordered_map AOTIModelPackageLoader:: + get_metadata() { + return metadata_; +} + +std::vector AOTIModelPackageLoader::get_call_spec() { + return runner_->get_call_spec(); +} + +} // namespace torch::inductor +#endif diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.h b/torch/csrc/inductor/aoti_package/model_package_loader.h new file mode 100644 index 00000000000000..03dc1c64018ddd --- /dev/null +++ b/torch/csrc/inductor/aoti_package/model_package_loader.h @@ -0,0 +1,40 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +#include + +namespace torch::inductor { +class TORCH_API AOTIModelPackageLoader { + public: + AOTIModelPackageLoader(const std::string& model_package_path); + AOTIModelPackageLoader( + const std::string& model_package_path, + const std::string& model_name); + + AOTIModelContainerRunner* get_runner(); + std::unordered_map get_metadata(); + std::vector run(std::vector& inputs); + std::vector get_call_spec(); + + private: + std::unique_ptr runner_; + std::unordered_map metadata_; + + void load_metadata(const std::string& cpp_filename); + std::string compile_so( + const std::string& cpp_filename, + const std::string& consts_filename); + const nlohmann::json& load_json_file(std::string json_path); + std::tuple get_cpp_compile_command( + const std::string& filename, + const std::vector& sources, + const nlohmann::json& compile_options, + const std::string& output_dir); + bool recursive_mkdir(const std::string& dir); +}; + +} // namespace torch::inductor +#endif diff --git a/torch/csrc/inductor/aoti_package/pybind.cpp b/torch/csrc/inductor/aoti_package/pybind.cpp new file mode 100644 index 00000000000000..3d2154a7549327 --- /dev/null +++ b/torch/csrc/inductor/aoti_package/pybind.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#ifdef USE_CUDA +#include +#endif + +#include +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject* module) { + auto rootModule = py::handle(module).cast(); + auto m = rootModule.def_submodule("_aoti"); + + py::class_(m, "AOTIModelPackageLoader") + .def(py::init()) + .def(py::init()) + .def("get_metadata", &AOTIModelPackageLoader::get_metadata) + .def("run", &AOTIModelPackageLoader::run) + .def("get_call_spec", &AOTIModelPackageLoader::get_call_spec); +} +} // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_package/pybind.h b/torch/csrc/inductor/aoti_package/pybind.h new file mode 100644 index 00000000000000..1eb7818c00e906 --- /dev/null +++ b/torch/csrc/inductor/aoti_package/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject* module); + +} // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.h b/torch/csrc/inductor/aoti_runner/model_container_runner.h index 99669aec14ac09..6e6339d3dd273c 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.h +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.h @@ -82,7 +82,7 @@ class TORCH_API AOTIModelContainerRunner { std::unique_ptr proxy_executor_; }; -using CreateAOTIModelRunnerFunc = std::shared_ptr (*)( +using CreateAOTIModelRunnerFunc = std::unique_ptr (*)( const std::string& model_so_path, size_t num_models, const std::string& device_str, diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp index 25decff00c458d..f40545d04c4935 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp @@ -17,5 +17,21 @@ std::vector AOTIModelContainerRunnerCpu::run( return AOTIModelContainerRunner::run(inputs); } +namespace { +std::unique_ptr create_aoti_runner_cpu( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir) { + if (device_str != "cpu") { + throw std::runtime_error("Incorrect device passed to aoti_runner_cpu"); + } + return std::make_unique( + model_so_path, num_models); +} +} // namespace + +RegisterAOTIModelRunner register_cpu_runner("cpu", &create_aoti_runner_cpu); + } // namespace torch::inductor #endif diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp index 596b436bb703d1..3ddad0885aa53d 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp @@ -30,5 +30,18 @@ std::vector AOTIModelContainerRunnerCuda::run_with_cuda_stream( inputs, reinterpret_cast(cuda_stream.stream())); } +namespace { +std::unique_ptr create_aoti_runner_cuda( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir) { + return std::make_unique( + model_so_path, num_models, device_str, cubin_dir); +} +} // namespace + +RegisterAOTIModelRunner register_cuda_runner("cuda", &create_aoti_runner_cuda); + } // namespace torch::inductor #endif diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 964a976775e6e8..f49bf23b9ce422 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -55,6 +55,39 @@ namespace fs = std::filesystem; namespace fs = std::experimental::filesystem; #endif +#ifndef _WIN32 +#include +#endif + +// HACK for failed builds in ARVR, where it cannot find these symbols within +// std::experimental::filesystem +namespace { +fs::path get_current_path() { +#if __has_include("filesystem") + return fs::current_path(); +#else + throw std::runtime_error("Not implemented"); +#endif +} + +bool file_exists(std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} + +bool create_directories(const std::string& path) { +#if __has_include("filesystem") + return fs::create_directories(path); +#else + throw std::runtime_error("Not implemented"); +#endif +} +} // namespace + using namespace torch::aot_inductor; namespace { @@ -1004,14 +1037,14 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( at::Tensor* t = tensor_handle_to_tensor_pointer(self); #ifndef C10_MOBILE // Save tensor to tmp .pt file for tensors and can be torch.load'ed later - std::string cwd = fs::current_path().string(); + std::string cwd = get_current_path().string(); std::string tmp_folder = cwd + "/tmp/aoti_torch/"; - if (!fs::exists(tmp_folder)) { + if (!file_exists(tmp_folder)) { std::cout << "aoti_torch_save_tensor_handle: Path does not exist, creating it..." << tmp_folder << std::endl; - if (!fs::create_directories(tmp_folder)) { + if (!create_directories(tmp_folder)) { std::cout << "aoti_torch_save_tensor_handle: Error creating directory: " << tmp_folder << std::endl; return; From 20949c73f7e4d790c6b9c451146f1e9c8729db62 Mon Sep 17 00:00:00 2001 From: Sathyanarayanan Saravanamuthu Date: Wed, 11 Sep 2024 03:35:02 +0000 Subject: [PATCH 0699/1018] Adding entry-point based support for out-of-tree rendezvous plugins (#132633) Fixes #127519 Currently in torchrun rendezvous, there are only two rendezvous backends supported out of the box: `C10d` and `Etcd`. The changes in this PR enables the distributed elastic users to bring their out-of-tree rendezvous backend implementations as Python packages. #### AUTHORING NEW PLUGIN Any new plugin will be a python package exposing entry-points. For example, the structure of redis plugin is as follows: ``` plugin_root |_ pyproject.toml |_ src |_ redis |_ __init__.py |_ redis_store.py |_ redis_backend.py ``` The contents of the `pyproject.toml` should indicate that this is exposes a torchrun entry-point by mentioning the group name `torchrun.plugins`. The `pyproject.toml` for redis plugin would be as follows: ``` [project] name = "redis" version = "0.0.1" [project.entry-points.'torchrun.plugins'] redis = 'redis' ``` The `src/redis/__init__.py` file would contain functions that return the plugin name and plugin handler. The contents of `__init__.py` for redis would be as follows: ``` def getPluginHandler(): def _create_redis_handler(params: RendezvousParameters): from redis_rendezvous_backend import create_backend backend, store = create_backend(params) return create_handler(store, backend, params) return _create_redis_handler ``` The files `redis_store` and `redis_backend` contain the implementation of [Store](https://github.com/pytorch/pytorch/blob/41189b0da4f793d506c3964d70ea2fe16f06daf6/torch/_C/_distributed_c10d.pyi#L171) and [RendezvousBackend](https://github.com/pytorch/pytorch/blob/e782918b8eeffc465acff40f0c55c8d80bc20ce2/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py#L61) respectively. #### USER EXPERIENCE Before using the plugin for the first time, the user has to install the plugin packages. For example, the published packages can be installed using `pip3 install ` and the plugin is in local file systemcan be installed using `pip3 install -e `. Once installed, the new backend can be used in torchrun as follows: ``` torchrun --rdzv-backend=redis --rdzv-endpoint=redis-container:6379 --nnodes=3 --nproc-per-node=1 --max-restarts=3 --rdzv-id=1 test.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/132633 Approved by: https://github.com/fduwjj --- .../rendezvous/out_of_tree_rendezvous_test.py | 38 +++++++++++++++++++ .../out_of_tree_test_package/pyproject.toml | 6 +++ .../src/testbackend/__init__.py | 2 + .../elastic/rendezvous/__init__.py | 3 +- .../elastic/rendezvous/registry.py | 25 ++++++++++++ 5 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py create mode 100644 test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml create mode 100644 test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py diff --git a/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py b/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py new file mode 100644 index 00000000000000..4d304bef1bc234 --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py @@ -0,0 +1,38 @@ +# Owner(s): ["oncall: r2p"] + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import pathlib +import sys +import unittest + +import torch.distributed.elastic.rendezvous as rdvz + + +BACKEND_NAME = "testbackend" +TEST_PACKAGE_PATH = "/out_of_tree_test_package/src" + + +class OutOfTreeRendezvousTest(unittest.TestCase): + def test_out_of_tree_handler_loading(self): + current_path = str(pathlib.Path(__file__).parent.resolve()) + rdvz._register_out_of_tree_handlers() + registry_dict = rdvz.rendezvous_handler_registry._registry + + # test backend should not be registered as a backend + self.assertFalse(BACKEND_NAME in registry_dict) + + # Including testbackend in python path + sys.path.append(current_path + TEST_PACKAGE_PATH) + + # Registering the out of tree handlers again + rdvz._register_out_of_tree_handlers() + + # test backend should be registered as a backend + self.assertTrue(BACKEND_NAME in registry_dict) + + # Removing testbackend from python path + sys.path.remove(current_path + TEST_PACKAGE_PATH) diff --git a/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml b/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml new file mode 100644 index 00000000000000..32e177bf73fe6b --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml @@ -0,0 +1,6 @@ +[project] +name = "testbackend" +version = "0.0.1" + +[project.entry-points.'torchrun.handlers'] +testbackend = 'testbackend:test_handler' \ No newline at end of file diff --git a/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py b/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py new file mode 100644 index 00000000000000..1fbf5b4c2dfb10 --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py @@ -0,0 +1,2 @@ +def test_handler(): + return "" diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index 62a31adab27b01..22ec0c9a0f6758 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -143,10 +143,11 @@ class that implements the rendezvous mechanism described above. It is a backend- RendezvousStoreInfo, RendezvousTimeoutError, ) -from .registry import _register_default_handlers +from .registry import _register_default_handlers, _register_out_of_tree_handlers _register_default_handlers() +_register_out_of_tree_handlers() __all__ = [ diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index 1a91d0a8ff7946..d038ab95eabae6 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +import sys + from .api import ( rendezvous_handler_registry as handler_registry, RendezvousHandler, @@ -12,6 +15,13 @@ from .dynamic_rendezvous import create_handler +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +log = logging.getLogger(__name__) + __all__ = ["get_rendezvous_handler"] @@ -50,6 +60,21 @@ def _register_default_handlers() -> None: handler_registry.register("static", _create_static_handler) +def _register_out_of_tree_handlers() -> None: + discovered_handler_generators = entry_points(group="torchrun.handlers") + + for handler_generator in discovered_handler_generators: + try: + get_handler = discovered_handler_generators[handler_generator.name].load() + handler_registry.register(handler_generator.name, get_handler()) + except Exception: + log.warning( + "Exception while registering out of tree plugin %s: ", + handler_generator.name, + exc_info=True, + ) + + def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: """ Obtain a reference to a :py:class`RendezvousHandler`. From 9b7feaff6c00a936a4c25ee1eb21498b476cdb6f Mon Sep 17 00:00:00 2001 From: penguin-wwy <940375606@qq.com> Date: Wed, 11 Sep 2024 04:13:22 +0000 Subject: [PATCH 0700/1018] Fix the output of FileCheck when not run and add unit tests (#135345) When FileCheck is destructed without execution, it should output all rules. For example: ``` >>> fc = FileCheck().check("test") >>> del fc You have not run this instance of FileCheck! FileCheck checks: CHECK: test ``` Additionally, unit tests for the Python interface of FileCheck will be added. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135345 Approved by: https://github.com/eellison --- test/test_file_check.py | 53 +++++++++++++++++++++++++++ torch/csrc/jit/testing/file_check.cpp | 1 + 2 files changed, 54 insertions(+) create mode 100644 test/test_file_check.py diff --git a/test/test_file_check.py b/test/test_file_check.py new file mode 100644 index 00000000000000..6aea06536781e5 --- /dev/null +++ b/test/test_file_check.py @@ -0,0 +1,53 @@ +# Owner(s): ["module: unknown"] + +from torch.testing import FileCheck +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestFileCheck(TestCase): + def test_not_run(self): + stdout, stderr = self.run_process_no_exception( + """\ +from torch.testing import FileCheck +file_check = FileCheck().check("not run") +del file_check +""", + ) + FileCheck().check("You have not run this instance of FileCheck!").check_next( + "FileCheck checks:" + ).check_next("\tCHECK: not run").run(stdout) + + def test_all_python_api(self): + test_string = """ +check check_same +check_next +check_count +check_dag +check_source_highlighted +~~~~~~~~~~~~~~~~~~~~~~~~ +check_regex +""" + FileCheck().check("check").check_not("check_not").check_same( + "check_same" + ).check_next("check_next").check_count("check_count", 1).check_dag( + "check_dag" + ).check_source_highlighted("check_source_highlighted").check_regex( + r"check_.+" + ).run(test_string) + + FileCheck().run( + """ +# CHECK: check +# CHECK-NOT: check_not +# CHECK-SAME: check_same +# CHECK-NEXT: check_next +# CHECK-DAG: check_dag +# CHECK-SOURCE-HIGHLIGHTED: check_source_highlighted +# CHECK-REGEX: check_.+ + """, + test_string, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index 027eb2aa0acf6b..e9bf764c31575d 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -228,6 +228,7 @@ struct FileCheckImpl { groups.push_back({check}); } } + checks.push_back(check); has_run = false; } From 8810a8d37dfda1ebe157f6843c4915885b2bcfb9 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 7 Sep 2024 23:50:38 -0700 Subject: [PATCH 0701/1018] [Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732) For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732 Approved by: https://github.com/ydwu4 --- torch/_dynamo/backends/debugging.py | 12 +++++ torch/_higher_order_ops/cond.py | 20 +++++--- torch/_higher_order_ops/while_loop.py | 21 +++++++-- torch/fx/experimental/proxy_tensor.py | 67 ++++++++++++++++----------- torch/nn/attention/flex_attention.py | 42 +++++++++++------ 5 files changed, 111 insertions(+), 51 deletions(-) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 87214f7d59774e..18784498862e39 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs): return gm.forward +def make_eager_backend_with_torch_function_mode(mode): + """Used to trace HOPs (cond and while) for eager exectution, the metadata + TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks + in the HOP, so we need to externally run this mode and not trace it.""" + + def fn(gm, fake_tensor_inputs, **kwargs): + with mode: + return gm.forward + + return fn + + @register_backend def eager_noexcept(gm, fake_tensor_inputs, **kwargs): if kwargs: diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index dee400d76f5964..0467e2899adc28 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -28,6 +28,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, @@ -129,6 +130,10 @@ def false_fn(x: torch.Tensor): if torch.compiler.is_dynamo_compiling(): return cond_op(pred, true_fn, false_fn, operands) + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + if isinstance(pred, (bool, int, float)): log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." @@ -169,12 +174,15 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(): - with torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_pre_dispatch_torch_function_mode(): - return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( - pred, true_fn, false_fn, operands - ) + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( + pred, true_fn, false_fn, operands + ) def create_fw_bw_graph_branches(true_fn, false_fn, *operands): diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index d64f512399d9d5..f14321842f40b9 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -15,7 +15,11 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + ProxyTorchDispatchMode, + track_tensor_tree, +) class WhileLoopOp(HigherOrderOperator): @@ -113,6 +117,9 @@ def body_fn(iter, x): - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. """ + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. @@ -140,9 +147,15 @@ def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( - cond_fn, body_fn, carried_inputs, additional_inputs - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile( + _while_loop_op_wrapper, backend=backend, fullgraph=True + )(cond_fn, body_fn, carried_inputs, additional_inputs) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 5a5771d2ae8f7a..213672e216e6ce 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from contextlib import contextmanager, ExitStack, nullcontext +from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -1084,38 +1084,43 @@ def unwrap_proxy(self, e: T) -> object: return e -@contextmanager -def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: - from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode +def _make_temp_remove_mode_context_manager( + mode_ty: Type[TorchFunctionMode], +) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: + @contextmanager + def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - temp_elements = [] - pre_dispatch_mode = None + temp_elements = [] + removed_mode = None - while _len_torch_function_stack() > 0: - mode = _pop_mode() - if isinstance(mode, PreDispatchTorchFunctionMode): - pre_dispatch_mode = mode - break - else: - temp_elements.append(mode) + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, mode_ty): + removed_mode = mode + break + else: + temp_elements.append(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + for mode in reversed(temp_elements): + _push_mode(mode) - try: - yield + try: + yield removed_mode - finally: - if pre_dispatch_mode is not None: - count = len(temp_elements) - while count > 0: - mode = _pop_mode() - count -= 1 + finally: + if removed_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(removed_mode) - temp_elements.append(pre_dispatch_mode) + for mode in reversed(temp_elements): + _push_mode(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + return context_manager_fn @torch._disable_dynamo @@ -1230,6 +1235,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( + TorchFunctionMetadataMode +) + + # This mode is **only** used for pre_dispatch tracing. # In particular, we need to make sure that autograd/autocast API's # that do not desugar into dispatcher operators stay in the graph. @@ -1258,6 +1268,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( + PreDispatchTorchFunctionMode +) + + class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 7b0a3354642a7f..f71b86222ee5c1 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,6 +19,7 @@ ) from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input @@ -1027,6 +1028,10 @@ def score_mod( if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + # Dynamo is expecting a callable with "__code__" attribute. # We cannot directly pass hop to it. So we wrap it in a dummy function. def _flex_attention_hop_wrapper(*args, **kwargs): @@ -1035,18 +1040,25 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _set_compilation_env(): with torch._dynamo.utils.disable_cache_limit(): with _temp_remove_pre_dispatch_torch_function_mode(): - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend="eager", fullgraph=True - )( - query, - key, - value, - score_mod, - block_mask.as_tuple(), - scale, - kernel_options, - ) - if return_lse: - return out, lse * math.log(2) - else: - return out + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode( + metadata_mode + ) + else: + backend = "eager" + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend="eager", fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out From 5342baa301a199f2b8b220132fb95eab5f8c1928 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 9 Sep 2024 16:02:11 -0700 Subject: [PATCH 0702/1018] [Dynamo] Trace torch function modes entered outside of torch.compile (#133137) This PR adds initial tracing for torch function modes. Details: In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers. Previously landed: https://github.com/pytorch/pytorch/pull/133135 https://github.com/pytorch/pytorch/pull/133136 https://github.com/pytorch/pytorch/pull/133134 https://github.com/pytorch/pytorch/pull/133133 https://github.com/pytorch/pytorch/pull/133132 https://github.com/pytorch/pytorch/pull/133131 https://github.com/pytorch/pytorch/pull/133729 https://github.com/pytorch/pytorch/pull/133130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137 Approved by: https://github.com/jansel, https://github.com/zou3519 ghstack dependencies: #134732 --- test/dynamo/test_modes.py | 135 +++++++++ test/functorch/test_control_flow.py | 3 + .../pt2e/test_metadata_porting.py | 4 +- torch/_dynamo/convert_frame.py | 5 + torch/_dynamo/guards.py | 14 + torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 78 ++---- torch/_dynamo/trace_rules.py | 5 +- torch/_dynamo/utils.py | 10 +- torch/_dynamo/variables/ctx_manager.py | 2 + torch/_dynamo/variables/torch.py | 260 ++++++++++-------- torch/_dynamo/variables/torch_function.py | 129 +++++++-- torch/_dynamo/variables/user_defined.py | 7 - torch/_export/non_strict_utils.py | 6 +- 14 files changed, 456 insertions(+), 204 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index ee8a9b579ff124..bad5a788903440 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -14,6 +14,17 @@ from torch.utils._python_dispatch import TorchDispatchMode +class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -324,6 +335,130 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 2) + def test_nested_torch_function_mode(self): + mode_1_called = False + mode_2_called = False + + def reset_state(): + nonlocal mode_1_called + nonlocal mode_2_called + mode_1_called = False + mode_2_called = False + + ones = torch.ones(2, 2) + zeros = torch.zeros(2, 2) + + class TestMode1(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_1_called + + mode_1_called = True + + if func == torch.add: + return zeros + + return super().__torch_function__(func, types, args, kwargs) + + class TestMode2(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_2_called + + mode_2_called = True + + if func == torch.mul: + return ones + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + def fn_2(x): + return torch.mul(x, 3) + torch.add(x, 3) + + inp = torch.ones(2, 2) + 1 + + for fn_i in [fn, fn_2]: + fn_opt = torch.compile(fn_i, fullgraph=True) + with TestMode1(), TestMode2(): + expected = fn_i(inp), mode_1_called, mode_2_called + reset_state() + actual = fn_opt(inp), mode_1_called, mode_2_called + reset_state() + + self.assertEqual(expected, actual) + + def test_torch_function_mode_disable(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + with torch._C.DisableTorchFunctionSubclass(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + with torch._C.DisableTorchFunction(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_highest_priority(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index e07510828f4584..ad1f6721fc98bf 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -26,6 +26,7 @@ parametrize, requires_cuda, run_tests, + skipIfCrossRef, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, TestCase, @@ -2984,6 +2985,7 @@ def f(x, y): self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -3246,6 +3248,7 @@ def test_while_loop_compile(self, backend, while_loop_test): self._check_compile(fn, inp, backend=backend) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] from torch._dynamo.testing import EagerAndRecordGraphs diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 5f2d1e2d3cf574..e8e00e36e58e10 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef class TestHelperModules: @@ -139,6 +139,8 @@ def _test_metadata_porting( self.assertEqual(v, node_tags[k]) return m + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_simple_metadata_porting(self): """ Model under test diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index d15f5c4aa43dcf..9dabb78211c3b5 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -605,6 +605,10 @@ def _compile( output: Optional[OutputGraph] = None tracer: Optional[InstructionTranslator] = None + tf_mode_stack: List[ + torch.overrides.TorchFunctionMode + ] = torch.overrides._get_current_function_mode_stack() + @preserve_global_state def transform( instructions: List[Instruction], code_options: Dict[str, object] @@ -618,6 +622,7 @@ def transform( locals, globals, builtins, + tf_mode_stack, code_options, compiler_fn, one_graph, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index b0bd273e96bbb5..4437ef4038190d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -98,6 +98,7 @@ ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, + TorchFunctionModeStackSource, TupleIteratorGetItemSource, TypeSource, UnspecializedBuiltinNNModuleSource, @@ -111,6 +112,7 @@ dict_keys_repr, get_custom_getattr, get_torch_function_mode_stack, + get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, @@ -314,6 +316,7 @@ def uninteresting_files(): "___dict_contains": lambda a, b: a in b, "___tuple_iterator_len": tuple_iterator_len, "___tuple_iterator_getitem": tuple_iterator_getitem, + "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -901,6 +904,15 @@ def get_guard_manager_from_source(self, source): ): assert base_guard_manager # to make mypy happy out = base_guard_manager + elif istype(source, TorchFunctionModeStackSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: get_torch_function_mode_stack_at( + source._get_index() + ), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -2214,6 +2226,8 @@ def __init__( self.output_graph = output_graph w_builder = None + # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing + # in case a set default device call was made in the graph. self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 6ed9d97f6a1210..c44f8d9e69abf5 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source): ind: int def name(self): - return "" + return f"___get_torch_function_mode_stack_at({self._get_index()})" def _get_index(self): from .variables.torch_function import TorchFunctionModeStackVariable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index ab92f82aa0f630..415179a5ed32a6 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -19,20 +19,7 @@ import types import typing import weakref -from typing import ( - Any, - Callable, - cast, - Deque, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union from unittest.mock import patch import torch @@ -72,14 +59,12 @@ GlobalWeakRefSource, LocalSource, Source, - TorchFunctionModeStackSource, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( counters, get_fake_value, get_instruction_source_311, - get_torch_function_mode_stack, graph_break_dup_warning_checker, istype, LazyString, @@ -120,11 +105,10 @@ ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable - - -if TYPE_CHECKING: - from .variables.torch_function import TorchFunctionModeVariable - +from .variables.torch_function import ( + SymbolicTorchFunctionState, + TorchFunctionModeVariable, +) from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -284,6 +268,10 @@ def resume_fn(self): return ReenterWith(self.stack_index) def exit(self, tx): + if hasattr(self, "graph_break") and isinstance( + self.with_context, TorchFunctionModeVariable + ): + return assert self.with_context is not None return self.with_context.exit(tx) @@ -652,7 +640,9 @@ def handle_graph_break( # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: assert b.with_context is not None - assert isinstance(b.with_context, ContextWrappingVariable) + assert isinstance( + b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) + ) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -728,7 +718,7 @@ class InstructionTranslatorBase( output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] + symbolic_torch_function_state: SymbolicTorchFunctionState stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -2548,7 +2538,7 @@ def __init__( code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, f_code: types.CodeType, export: bool, inline_depth: int, @@ -2563,7 +2553,7 @@ def __init__( self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals - self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack + self.symbolic_torch_function_state = symbolic_torch_function_state self.stack = [] # stack of variable names for tracking 3.13 closures self.name_stack: list[Any] = [] @@ -2652,6 +2642,7 @@ def __init__( f_locals, f_globals, f_builtins, + torch_function_mode_stack, code_options, compiler_fn, one_graph, @@ -2686,7 +2677,7 @@ def __init__( symbolic_locals={}, # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, - symbolic_torch_function_mode_stack=collections.deque(), + symbolic_torch_function_state=None, # type: ignore[arg-type] # set below f_code=f_code, export=export, inline_depth=0, @@ -2721,7 +2712,9 @@ def __init__( if k in f_locals } - self._init_torch_function_mode_stack() + self.symbolic_torch_function_state = SymbolicTorchFunctionState( + torch_function_mode_stack + ) self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] if export: @@ -2762,29 +2755,6 @@ def _throw_if_in_functorch(self): ) unimplemented(msg) - def _init_torch_function_mode_stack(self): - from .variables.torch_function import TorchFunctionModeStackVariable - - TorchFunctionModeStackVariable.reset() - - self.symbolic_torch_function_mode_stack: Deque[ - TorchFunctionModeVariable - ] = collections.deque() - # We want to retrieve all modes to properly reconstruct the stack if needed - py_stack = get_torch_function_mode_stack(filter_ignored=False) - - if py_stack: - has_device_context = isinstance( - py_stack[0], torch.utils._device.DeviceContext - ) - - for i, val in enumerate(py_stack): - self.symbolic_torch_function_mode_stack.append( - variables.LazyVariableTracker.create( - val, source=TorchFunctionModeStackSource(i) - ) - ) - def get_example_value(self, source: Source): if isinstance(source, LocalSource): return self.f_locals[source.local_name] @@ -3116,7 +3086,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3126,7 +3096,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3179,7 +3149,7 @@ def __init__( code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ) -> None: @@ -3196,7 +3166,7 @@ def __init__( f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, - symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, + symbolic_torch_function_state=symbolic_torch_function_state, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index cf0731911612ba..23a859e426f8df 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3256,6 +3256,7 @@ def _module_dir(m: types.ModuleType): "torch.testing", "torch.utils._content_store", "torch.utils._contextlib", + "torch.utils._device", "torch.utils._foreach_utils", "torch.utils._python_dispatch", "torch.utils._pytree", @@ -3590,7 +3591,9 @@ def lookup_inner( if reasons is not None: reasons.add("func name is patched_init") return SkipFunctionVariable - elif name == "__torch_function__": + elif name == "__torch_function__" or ( + obj and obj.__name__ == "__torch_function__" + ): if reasons is not None: reasons.add("func name is __torch_function__") return UserFunctionVariable diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 0ec548d788f868..ef63ddbea85619 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -63,7 +63,6 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( - _get_function_stack_at, _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, @@ -3065,7 +3064,9 @@ def is_parameter_freezing(): def get_torch_function_mode_stack(filter_ignored=True): from .variables.torch_function import IGNORED_MODES - stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] + stack = [ + get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) + ] if filter_ignored: stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] @@ -3085,6 +3086,11 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) +def clear_torch_function_mode_stack(): + for i in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 301b7f3e819345..838f413a16e306 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -637,6 +637,8 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 + tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] + tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] tx.output.set_torch_function_state(values[0]) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index b32ac9af6ccfd9..288bf14aebf928 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -154,6 +154,15 @@ bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) +@functools.lru_cache(None) +def get_overridable_functions(): + from itertools import chain + + from torch.overrides import get_overridable_functions as get_overridable_functions_ + + return set(chain(*get_overridable_functions_().values())) + + class BaseTorchVariable(VariableTracker): """common base for all torch.* functions, classes, modules and other things""" @@ -787,10 +796,10 @@ def handle_pop_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - if not tx.symbolic_torch_function_mode_stack: + if not tx.symbolic_torch_function_state.mode_stack: raise unimplemented("Popping from an empty torch function mode stack") TorchFunctionModeStackVariable.register_mutation(tx) - return tx.symbolic_torch_function_mode_stack.pop() + return tx.symbolic_torch_function_state.pop_torch_function_mode() @register(torch._C._push_on_torch_function_stack) def handle_push_torch_function( @@ -798,7 +807,7 @@ def handle_push_torch_function( ): assert len(args) == 1 and not kwargs TorchFunctionModeStackVariable.register_mutation(tx) - tx.symbolic_torch_function_mode_stack.append(args[0]) + tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) @@ -806,7 +815,9 @@ def handle_len_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) + return ConstantVariable.create( + len(tx.symbolic_torch_function_state.mode_stack) + ) @register(torch.set_default_device) def handle_set_default_device( @@ -838,6 +849,9 @@ def call_function( from . import ConstantVariable, SymNodeVariable, TensorVariable from .builder import wrap_fx_proxy + if self.torch_function_override_enabled(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -859,147 +873,144 @@ def call_function( if result: return result - if can_dispatch_torch_function(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - else: - any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) - all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args - ) - if ( - getattr(self.value, "__module__", "") == "torch" - and self.value.__name__ in bin_ops - and any_symints_or_symfloats - and all_ints_or_floats - ): - msg = f"""\ + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ - log.warning(msg) - unimplemented(msg) - - # TODO(voz): Replace w/ dynamic shape rewrite table. - # Ideally, we would be able to do this at ctor time, but alas we need a combination - # of value + args to determine this. - fn_ = self.value - if any_symints_or_symfloats: - torch_sym_op = f"_sym_{self.value.__name__}" - if getattr(self.value, "__module__", None) == "math" and hasattr( - torch, torch_sym_op - ): - fn_ = getattr(torch, torch_sym_op) - - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - *proxy_args_kwargs(args, kwargs), - ), - ) - - if ( - isinstance(tensor_variable, TensorVariable) - and "requires_grad" in kwargs - and kwargs["requires_grad"].as_python_constant() + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op ): - unimplemented( - """factory functions that return tensors that require grad are not supported. + fn_ = getattr(torch, torch_sym_op) + + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" - ) + ) - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items - ): - if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor.size != result_tensor.size - ): - # It's hard to get out variants with resizing on graph inputs work - # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape + out_tensor.source + and out_tensor in tx.output.graphargs + and isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor.size != result_tensor.size ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if ( + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out_shape != fake_tensor.shape + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( - "out= op was called where output tensor was non-contiguous" + "out= op was called where some of the output tensors were non-contiguous" ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where some of the output tensors were non-contiguous" - ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") + else: + unimplemented(f"out variant of {type(kwargs['out'])}") - return tensor_variable + return tensor_variable def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" @@ -1127,3 +1138,12 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad source ) return result + + def torch_function_override_enabled(self, tx, args, kwargs): + return ( + self.get_function() in get_overridable_functions() + or isinstance( + self.get_function(), + (torch._ops.OpOverload, torch._ops.OpOverloadPacket), + ) + ) and can_dispatch_torch_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4b3188507fe4b5..6e52cc688e0309 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,8 +1,11 @@ # mypy: ignore-errors +import collections +import contextlib import inspect -from typing import Dict, List, TYPE_CHECKING +from typing import Deque, Dict, List, TYPE_CHECKING +import torch._C import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import _get_overloaded_args, get_default_nowrap_functions @@ -15,6 +18,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import ContextWrappingVariable +from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable @@ -59,6 +63,67 @@ IGNORED_MODES = {DeviceContext} +class SymbolicTorchFunctionState: + def __init__(self, py_stack): + # This is annoyingly complicated because of how the torch function subclass + mode C API was designed + # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass + # These are their definitions: + # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered + # (if either are entered, this will be False) + # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR + # torch._C.DisableTorchFunction has been entered + # To disambiguate these and keep myself sane I added a C API to check whether all torch function + # concepts (modes and subclasses) are enabled. + # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate + # the stack length from the enablement state of torch function modes. + # This is important because now if a mode is pushed while dynamo is tracing, we know whether + # or not torch function modes are enabled and whether we should trace it. + self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() + + # This differs from the C API of the same name + # this will only be false iff we have entered torch._C.DisableTorchFunction + # and does not take into account the mode stack length, while the C API bundles these + # two concepts + self.torch_function_mode_enabled = ( + not torch._C._is_torch_function_all_disabled() + ) + + self.cur_mode = None + + TorchFunctionModeStackVariable.reset() + + self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() + + for i, val in enumerate(py_stack): + self.mode_stack.append( + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) + ) + + def in_torch_function_mode(self): + return len(self.mode_stack) > 0 + + def pop_torch_function_mode(self): + return self.mode_stack.pop() + + def push_torch_function_mode(self, mode_var): + self.mode_stack.append(mode_var) + + def call_torch_function_mode(self, tx, fn, types, args, kwargs): + with self._pop_mode_for_inlining() as cur_mode: + return cur_mode.call_torch_function(tx, fn, types, args, kwargs) + + @contextlib.contextmanager + def _pop_mode_for_inlining(self): + old_mode = self.cur_mode + self.cur_mode = self.pop_torch_function_mode() + try: + yield self.cur_mode + finally: + mode = self.cur_mode + self.cur_mode = old_mode + self.push_torch_function_mode(mode) + + class TorchFunctionModeStackVariable(VariableTracker): """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" @@ -88,19 +153,20 @@ def reset(cls): def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( - source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack + source=Source(), + symbolic_stack=tx.symbolic_torch_function_state.mode_stack, ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 - tx.symbolic_torch_function_mode_stack.insert( + stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) @@ -109,7 +175,7 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @@ -124,23 +190,39 @@ def get_mode_index(cls, ind): class TorchFunctionModeVariable(ContextWrappingVariable): - def __init__(self, value, **kwargs): + def __init__(self, value, source=None, **kwargs): super().__init__(value, **kwargs) self.value = value - - @staticmethod - def get_global_mangled_name(tx, val): - return get_safe_global_name( - tx, f"__torch_function_mode_{val.__class__.__name__}", val - ) + self.cm_obj = value # needed for BC with calling enter from CM code + self.source = source def reconstruct(self, codegen): - # We don't support locally created torch function modes yet + # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) - def _call_func(self, tx, values): - unimplemented("torch function mode context manager is not supported yet") + def module_name(self): + return self.value.__module__ + + def fn_name(self): + return type(self.value).__name__ + + def python_type(self): + return type(self.value) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + return call_torch_function( + tx, + self, + build_torch_function_fn(tx, self.value, self.source), + fn, + types, + args, + kwargs, + ) + + def _call_func(self, tx: "InstructionTranslator", values): + unimplemented("enter/exit for torch function mode NYI") def _get_all_args(args, kwargs): @@ -231,9 +313,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): - return tx.output.torch_function_enabled and any( + has_overridden_args = any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) + tf_state = tx.symbolic_torch_function_state + return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( + tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() + ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): @@ -245,11 +331,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): _get_subclass_type, ) + types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) + + if tx.symbolic_torch_function_state.in_torch_function_mode(): + res = tx.symbolic_torch_function_state.call_torch_function_mode( + tx, fn, types, args, kwargs + ) + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + return res + for arg in overloaded_args: res = arg.call_torch_function( tx, fn, - TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), + types, args, kwargs, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 34e1d0d10c9f72..d069d2e060a108 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,11 +82,6 @@ def is_forbidden_context_manager(ctx): from _pytest.python_api import RaisesContext from _pytest.recwarn import WarningsChecker - # TODO mlazos: Temporary to get this stack to pass - # remove in subsequent PR - from torch.overrides import BaseTorchFunctionMode - - f_ctxs.append(BaseTorchFunctionMode) f_ctxs.append(RaisesContext) f_ctxs.append(WarningsChecker) except ImportError: @@ -413,7 +408,6 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): - # import here to avoid an unfortunate circular dependency. from .ctx_manager import GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( @@ -421,7 +415,6 @@ def call_function( ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj - elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 5c7331659d26a0..ef15e5fea9e973 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: + if ( + not torch.compiler.is_dynamo_compiling() + and log.isEnabledFor(logging.DEBUG) + and config.extended_debug_current_loc + ): frame = _find_user_code_frame() if frame is not None: log.debug( From 813af87ad47fd92e8a38acf6886f98dd4dc0f47e Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 9 Sep 2024 16:02:11 -0700 Subject: [PATCH 0703/1018] [Dynamo] Support thread local setattr (#135443) In preparation for tracing through DeviceContext (https://github.com/pytorch/pytorch/blob/defb515306fc53ec62e92937a5a76fa5cbc05b84/torch/utils/_device.py#L66) This PR adds support for calling the setattr of thread local objects. These objects have a slots impl, and since this doesn't appear to have any side effects, we call this setattr impl when replaying mutations, since calling `object.__setattr__` on these objects results in a type error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135443 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137 --- test/dynamo/test_misc.py | 15 +++++++++++++++ torch/_dynamo/variables/user_defined.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index ab1b343044a26d..0c26eea44ec266 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3380,6 +3380,21 @@ def fn4(x) -> None: self.assertTrue(same(obj41.y, obj42.y)) self.assertEqual(cnts.frame_count, 1) + def test_thread_local_setattr(self): + from threading import local + + loc = local() + + @torch.compile(fullgraph=True) + def fn(x, l): + l.x = x + return x + 1 + + x = torch.ones(2, 2) + fn(x, loc) + + self.assertTrue(loc.x is x) + def test_user_defined_class_name(self): class MyClassFoo: pass diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index d069d2e060a108..ad92277cf70ff4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -9,6 +9,7 @@ import itertools import random import sys +import threading import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING @@ -704,7 +705,7 @@ def call_method( if method is object.__init__: return ConstantVariable.create(None) - if is_standard_setattr(method): + if is_standard_setattr(method) or isinstance(self.value, threading.local): return self.method_setattr_standard(tx, *args, **kwargs) # [NOTE] OrderedDict, dict subtypes must always have source @@ -802,7 +803,7 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def needs_slow_setattr(self): return not is_standard_setattr( inspect.getattr_static(self.value, "__setattr__", None) - ) + ) and not isinstance(self.value, threading.local) def unpack_var_sequence(self, tx): if ( From 4018bfd531b73f48d80514420bdd3a39b0a8f4ff Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 9 Sep 2024 16:02:11 -0700 Subject: [PATCH 0704/1018] [Dynamo] Simplify torch function mode stack guard (#135444) The semantics of ignored modes previously had edge cases, this eliminates these by in essence filtering any ignored modes out of both the ref stack and the current torch function mode stack. This is purely to fix complexity in #135422. The ignored modes handling will be removed in a future PR after https://github.com/pytorch/pytorch/pull/135422 lands, since we will then trace through DeviceContexts vs inserting them into the graph which needed these extra workarounds for correctness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135444 Approved by: https://github.com/anijain2305, https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443 --- test/dynamo/test_modes.py | 2 ++ torch/_dynamo/guards.py | 10 ++++---- torch/csrc/dynamo/guards.cpp | 44 +++++++++++++++++++++++++++++------- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index bad5a788903440..fa4c23fd320ddb 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -68,9 +68,11 @@ def tearDownClass(cls): def setUp(self): torch.set_default_device(None) + torch._dynamo.reset() def tearDown(self): torch.set_default_device(None) + torch._dynamo.reset() def _run_torch_function_mode_guard_test(self): class TestMode1(BaseTorchFunctionMode): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4437ef4038190d..1544bdd6dc697b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2663,12 +2663,14 @@ def make_torch_function_mode_stack_guard(intial_stack): def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - if len(cur_stack) != len(types): + + types_ = [ty for ty in types if ty not in IGNORED_MODES] + cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] + + if len(cur_stack_) != len(types_): return False - for ty, mode in zip(types, cur_stack): - if ty in IGNORED_MODES: - continue + for ty, mode in zip(types_, cur_stack_): if ty != type(mode): return False diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index afc3182c632640..771de44427b669 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2523,7 +2523,8 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref - this->_ref_stack.push_back(Py_TYPE(mode)); + auto type = Py_TYPE(mode); + this->_ref_stack.push_back(type); } len = PyList_Size(ignored_types.ptr()); @@ -2543,29 +2544,56 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface size_t ref_ind = 0; - int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - for (int64_t idx = 0; idx < len; idx++) { + int64_t idx = 0; + while ((idx < len) && (ref_ind < ref_stack_size)) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + bool act_ignored = this->_ignored_types.count(mode_type) > 0; + bool ref_ignored = + this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; // skip ignored types - if (this->_ignored_types.count(mode_type) > 0) { + if (act_ignored && ref_ignored) { + idx++; + ref_ind++; + continue; + } else if (ref_ignored) { + ref_ind++; + continue; + } else if (act_ignored) { + idx++; continue; } // if we already have more non-ignored modes than the ref stack // or if the mode doesn't match at the current index, return false - else if ( - (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || - mode_type != _ref_stack[ref_ind]) { + else if (mode_type != _ref_stack.at(ref_ind)) { return false; } ref_ind++; + idx++; + } + + for (; ref_ind < ref_stack_size; ref_ind++) { + if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { + return false; + } + } + + for (; idx < len; idx++) { + std::shared_ptr mode = + at::impl::PythonTorchFunctionTLS::get_stack_at(idx); + + PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + if (!(this->_ignored_types.count(mode_type) > 0)) { + return false; + } } - return ref_ind == this->_ref_stack.size(); + return ref_ind == ref_stack_size && idx == len; } private: From 21616d1e996ead661542c298cbc21808abaa351b Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 10 Sep 2024 19:18:11 -0700 Subject: [PATCH 0705/1018] [PTD][BE][c10d] Add some code documents for TCPStore code and cosmetic changes to libUVStore code (#130496) While designing something else when TCPStore is needed. I spent some time digging into the codebase of TCPStore and found that the code is a little bit challenging to understand without proper documents. Although people from OSS community must be smarter than me, I still want to document my findings in the code so that devs and users can use them as a reference down the road. Also for libuv, we need to make private variables with a "_", so it's a pure renaming of private variables such as `tcpServer`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130496 Approved by: https://github.com/wconstab --- torch/csrc/distributed/c10d/TCPStore.hpp | 27 +++++++ .../distributed/c10d/TCPStoreLibUvBackend.cpp | 70 +++++++++---------- 2 files changed, 62 insertions(+), 35 deletions(-) diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 5ac48d8205e5da..b2839f62c7a3a1 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -9,6 +9,33 @@ namespace c10d { namespace detail { +// TCPStore is a key-value store used by PyTorch mainly for distributed +// rendezvous, but for other purposes as well. (e.g., a centralized storage for +// synchronization among different processes.) +// +// It is run via a classic client-server architecture, where the server runs +// a separate background thread (alternatively we call it daemon thread). The +// client and server communicate via TCP sockets. +// +// Currently we have two types of server backends: +// 1. TCPStoreBackend: a single thread to handle all incoming request +// synchronously. +// 2. LibUVTCPStoreBackend: an event-driven asynchronous stream processing that +// leverages libuv library (https://github.com/libuv/libuv) for better +// performance. And this backend now is recommended to users. (We set the +// default value of `useLibUV` inside `TCPStoreOptions` to true now, so users +// should get it by default). +// +// Code structure: +// ├── TCPStore client side API and server setup code: +// │ TCPStore.hpp/TCPStore.cpp +// ├── TCPStoreBackend server side API implementation code: +// │ TCPStoreBackend.hpp/TCPStoreBackend.cpp +// | (actual class:`TCPStoreMasterDaemon`) +// ├── LibUVTCPStoreBackend +// │ TCPStoreLibUvBackend.cpp +// | (actual class: `LibUVStoreDaemon`) + class TCPServer; class TCPClient; diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 0cac94bcd13dba..588827a3a422de 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -100,7 +100,6 @@ class UvTcpSocket : public UvHandle { } static void alloc_buffer( - uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) { @@ -185,7 +184,7 @@ class UvTcpServer : public UvTcpSocket { public: typedef std::function OnConnectCallback; explicit UvTcpServer(uv_loop_t* loop) - : UvTcpSocket(loop), onConnectCb(missingOnConnect) {} + : UvTcpSocket(loop), onConnectCb_(missingOnConnect) {} static c10::intrusive_ptr makeWithSocket( uv_loop_t* loop, @@ -216,7 +215,7 @@ class UvTcpServer : public UvTcpSocket { } void setOnConnectCallback(OnConnectCallback&& callback) { - onConnectCb = std::move(callback); + onConnectCb_ = std::move(callback); } static c10::intrusive_ptr makeWithPort( @@ -289,7 +288,7 @@ class UvTcpServer : public UvTcpSocket { } uint16_t port() const { - return portNum; + return portNum_; } void accept(const c10::intrusive_ptr& socket) { @@ -307,8 +306,8 @@ class UvTcpServer : public UvTcpSocket { } private: - OnConnectCallback onConnectCb; - uint16_t portNum{}; + OnConnectCallback onConnectCb_; + uint16_t portNum_{}; c10::intrusive_ptr iptr() { return c10::intrusive_ptr::reclaim_copy(this); @@ -333,9 +332,9 @@ class UvTcpServer : public UvTcpSocket { } if (addr_s.ss_family == AF_INET) { - portNum = ntohs(reinterpret_cast(&addr_s)->sin_port); + portNum_ = ntohs(reinterpret_cast(&addr_s)->sin_port); } else { - portNum = ntohs(reinterpret_cast(&addr_s)->sin6_port); + portNum_ = ntohs(reinterpret_cast(&addr_s)->sin6_port); } } @@ -344,7 +343,7 @@ class UvTcpServer : public UvTcpSocket { } static void on_new_connection(uv_stream_t* server, int status) { - borrow(server)->onConnectCb(status); + borrow(server)->onConnectCb_(status); } }; @@ -628,10 +627,10 @@ class LibUVStoreDaemon : public BackgroundThread { void stop() override; private: - uv_loop_t loop{}; - c10::intrusive_ptr tcpServer; + uv_loop_t loop_{}; + c10::intrusive_ptr tcpServer_; - uv_async_t exit_handle{}; + uv_async_t exit_handle_{}; std::unordered_map> tcpStore_; // From key -> the list of UvClient waiting on the key std::unordered_map>> @@ -1018,10 +1017,10 @@ class UvClient : public UvTcpSocket { }; void LibUVStoreDaemon::onConnect(int status) { - auto client = UvClient::make(&loop, this); + auto client = UvClient::make(&loop_, this); registerClient(client); try { - tcpServer->accept(client); + tcpServer_->accept(client); client->startRead(); } catch (std::exception& e) { C10D_WARNING("Failed to accept client due to {}", e.what()); @@ -1031,27 +1030,28 @@ void LibUVStoreDaemon::onConnect(int status) { void LibUVStoreDaemon::onExitRequest() { C10D_DEBUG("Store exit requested\n"); - uv_close((uv_handle_t*)&exit_handle, nullptr); - uv_stop(&loop); + uv_close((uv_handle_t*)&exit_handle_, nullptr); + uv_stop(&loop_); } void LibUVStoreDaemon::init(const TCPStoreOptions& opts) { if (opts.masterListenFd.has_value()) { - tcpServer = UvTcpServer::makeWithSocket(&loop, *opts.masterListenFd); + tcpServer_ = UvTcpServer::makeWithSocket(&loop_, *opts.masterListenFd); } else { try { - tcpServer = UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/true); + tcpServer_ = + UvTcpServer::makeWithPort(&loop_, opts.port, /*useIpv6=*/true); } catch (std::exception& ex) { C10D_INFO( "Failed to bind to ipv6 address, trying ipv4. Error: {}", ex.what()); - tcpServer = - UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/false); + tcpServer_ = + UvTcpServer::makeWithPort(&loop_, opts.port, /*useIpv6=*/false); } } - tcpServer->setOnConnectCallback( + tcpServer_->setOnConnectCallback( [this](auto status) { this->onConnect(status); }); - port_ = tcpServer->port(); + port_ = tcpServer_->port(); TORCH_CHECK( port_ == opts.port || opts.port == 0, // zero means use any port "listen fd ", @@ -1063,19 +1063,19 @@ void LibUVStoreDaemon::init(const TCPStoreOptions& opts) { } LibUVStoreDaemon::LibUVStoreDaemon(int port) : port_(port) { - TORCH_CHECK(uv_loop_init(&loop) == 0, "Failed to init uv loop"); + TORCH_CHECK(uv_loop_init(&loop_) == 0, "Failed to init uv loop"); TORCH_CHECK( - uv_async_init(&loop, &exit_handle, LibUVStoreDaemon::on_exit_request) == + uv_async_init(&loop_, &exit_handle_, LibUVStoreDaemon::on_exit_request) == 0, "Failed to init uv async event"); - uv_handle_set_data((uv_handle_t*)&exit_handle, this); + uv_handle_set_data((uv_handle_t*)&exit_handle_, this); } LibUVStoreDaemon::~LibUVStoreDaemon() { if (!is_running()) { - uv_close((uv_handle_t*)&exit_handle, nullptr); - uv_run(&loop, UV_RUN_NOWAIT); - TORCH_CHECK(uv_loop_close(&loop) == 0, "loop cleanup didn't work"); + uv_close((uv_handle_t*)&exit_handle_, nullptr); + uv_run(&loop_, UV_RUN_NOWAIT); + TORCH_CHECK(uv_loop_close(&loop_) == 0, "loop cleanup didn't work"); } else { // the daemon thread cleanup libuv dispose(); @@ -1098,7 +1098,7 @@ void LibUVStoreDaemon::run() { c10::setThreadName("pt_tcpstore_uv"); C10D_DEBUG("Uv main loop running"); - int res = uv_run(&loop, UV_RUN_DEFAULT); + int res = uv_run(&loop_, UV_RUN_DEFAULT); if (res) { C10D_DEBUG("UV main loop done: res:{}", res); } @@ -1107,21 +1107,21 @@ void LibUVStoreDaemon::run() { if (debug_enabled) { C10D_DEBUG("Walking live handles prior to closing clients"); - uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr); + uv_walk(&loop_, LibUVStoreDaemon::print_active_handles, nullptr); } for (const auto& client : clients_) { client->close(); } - tcpServer->close(); + tcpServer_->close(); if (debug_enabled) { C10D_DEBUG("Walking live handles after closing clients"); - uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr); + uv_walk(&loop_, LibUVStoreDaemon::print_active_handles, nullptr); } while (true) { - res = uv_loop_close(&loop); + res = uv_loop_close(&loop_); if (res == 0) { break; } @@ -1130,7 +1130,7 @@ void LibUVStoreDaemon::run() { res, uv_err_name(res), uv_strerror(res)); - res = uv_run(&loop, UV_RUN_NOWAIT); + res = uv_run(&loop_, UV_RUN_NOWAIT); if (res != 0) { std::this_thread::sleep_for(std::chrono::milliseconds(500)); } @@ -1139,7 +1139,7 @@ void LibUVStoreDaemon::run() { } void LibUVStoreDaemon::stop() { - int res = uv_async_send(&exit_handle); + int res = uv_async_send(&exit_handle_); if (res) { C10D_WARNING( "uv_async_send failed with:{} errn:{} desc:{}\n", From e6a385498ca9efc74161794d336b4ef9bef991a8 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 10 Sep 2024 13:29:11 -0700 Subject: [PATCH 0706/1018] [dynamo] Bug fix for _torchdynamo_inline source handling (#135612) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135612 Approved by: https://github.com/drisspg ghstack dependencies: #135588 --- test/dynamo/test_repros.py | 16 ++++++++++++++++ torch/_dynamo/variables/builder.py | 12 +++++++----- torch/_dynamo/variables/functions.py | 2 -- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index a3c4226ed6a9f8..1b9bf3c83a8eb1 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5912,6 +5912,22 @@ def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: actual[1].untyped_storage().data_ptr(), ) + def test_torch_compile_in_compile_frame(self): + # TODO(anijain2305/yanboliang) - Dont graph break on torch.compile. + def gn(x, c=None): + if c is None: + c = 2 + return c * x + + def outer_func(x): + return torch.compile(gn)(x) + + compile_outer = torch.compile(outer_func, backend="eager") + x = torch.randn(4) + ref = outer_func(x) + res = compile_outer(x) + self.assertEqual(ref, res) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f7ac4d073b534b..bf1cbed10b8e5c 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -547,11 +547,6 @@ class Autotuner: if id_dispatch is not None: return id_dispatch(self, value) - # Note - There are some nested values where types mismatch! - # We want to get those out and wrap those. - if is_function_or_wrapper(value): - value = inspect.getattr_static(value, "_torchdynamo_inline", value) - # Everything else (NB: order matters!) if is_traceable_wrapper_subclass(value) or istype( value, config.traceable_tensor_subclasses @@ -988,6 +983,13 @@ def build_key_value(i, k, v): elif is_lru_cache_wrapped_function(value): self.install_guards(GuardBuilder.TYPE_MATCH) return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) + elif is_function_or_wrapper(value) and inspect.getattr_static( + value, "_torchdynamo_inline", False + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 90c2d0665aebc2..77662e1f4b595a 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -152,8 +152,6 @@ def __init__(self, fn, is_constant=False, **kwargs) -> None: assert isinstance( fn, (types.FunctionType, torch.jit.ScriptFunction) ), f"expected FunctionType found {typestr(fn)} {fn}" - # unpack @torch._dynamo.optimize()(fn) wrapped function - fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) self.fn: types.FunctionType = fn def as_python_constant(self): From be1b1dc82e9efe195cddf802a71304d41ea5ae8e Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 11 Sep 2024 05:59:25 +0000 Subject: [PATCH 0707/1018] [cuDNN][SDPA] Support `attn_bias` in cuDNN (#130482) CC @drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/130482 Approved by: https://github.com/drisspg, https://github.com/Skylion007, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- aten/src/ATen/native/cudnn/MHA.cpp | 243 ++++++++++-------- aten/src/ATen/native/cudnn/MHA.h | 2 + .../ATen/native/transformers/attention.cpp | 29 ++- .../native/transformers/cuda/attention.cu | 15 +- .../transformers/cuda/attention_backward.cu | 24 +- .../native/transformers/cuda/sdp_utils.cpp | 2 +- .../ATen/native/transformers/sdp_utils_cpp.h | 11 - test/test_transformers.py | 4 +- 8 files changed, 200 insertions(+), 130 deletions(-) diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index d00a8eb30698b8..c70a96f937cb81 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -22,6 +22,7 @@ void run_cudnn_SDP_fprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, Tensor& softmaxstats, Tensor& o, Tensor& dropoutseed, @@ -43,6 +44,7 @@ void run_cudnn_SDP_bprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, const Tensor& o, const Tensor& dO, const Tensor& softmaxstats, @@ -86,9 +88,9 @@ using graph_and_tensors = std::tuple< std::shared_ptr, // Q, std::shared_ptr, // K, std::shared_ptr, // V, + std::optional>, // Bias std::shared_ptr, // Attn_scale, // TODO(eqy): additional options - // std::shared_ptr, // Bias, // std::shared_ptr, // SEQ_LEN_Q, // std::shared_ptr, // SEQ_LEN_KV, std::shared_ptr, // Seed, @@ -104,7 +106,8 @@ using graph_and_tensors_backward = std::tuple< std::shared_ptr, // Q, std::shared_ptr, // K, std::shared_ptr, // V, - std::shared_ptr, // Attn_scale + std::optional>, // Bias, + std::shared_ptr, // Attn_scale, std::shared_ptr, // Seed, std::shared_ptr, // Offset, std::shared_ptr, // O, @@ -126,6 +129,8 @@ struct MHAParams { std::array q_stride; std::array k_stride; std::array v_stride; + std::array bias_dim; + std::array bias_stride; int64_t b; int64_t h; int64_t s_q; @@ -135,6 +140,9 @@ struct MHAParams { double dropout_probability; bool is_causal; bool return_softmaxstats; + // might be redundant if we take 0 dim/stride + // as signaling no-bias + bool has_attn_bias; }; void setMHAParams( @@ -148,6 +156,7 @@ void setMHAParams( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, double dropout_probability, bool is_causal, bool return_softmaxstats) { @@ -166,6 +175,7 @@ void setMHAParams( params.dropout_probability = dropout_probability; params.is_causal = is_causal; params.return_softmaxstats = return_softmaxstats; + params.has_attn_bias = attn_bias.has_value(); TORCH_INTERNAL_ASSERT( q.sizes().size() == MAX_MHA_DIM, "Q tensor has unexpected number of dims, please report a bug to PyTorch."); @@ -190,6 +200,17 @@ void setMHAParams( std::copy(k.strides().begin(), k.strides().end(), params.k_stride.begin()); std::copy(v.sizes().begin(), v.sizes().end(), params.v_dim.begin()); std::copy(v.strides().begin(), v.strides().end(), params.v_stride.begin()); + // uninit is OK as the struct is memset 0'd + if (params.has_attn_bias) { + std::copy( + attn_bias.value().sizes().begin(), + attn_bias.value().sizes().end(), + params.bias_dim.begin()); + std::copy( + attn_bias.value().strides().begin(), + attn_bias.value().strides().end(), + params.bias_stride.begin()); + } } struct MHACacheKeyWrapper : ParamsWrapper { @@ -203,6 +224,7 @@ struct MHACacheKeyWrapper : ParamsWrapper { const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, double dropout_probability, bool is_causal, bool return_softmaxstats) { @@ -217,6 +239,7 @@ struct MHACacheKeyWrapper : ParamsWrapper { q, k, v, + attn_bias, dropout_probability, is_causal, return_softmaxstats); @@ -285,6 +308,7 @@ auto build_graph_and_tensors( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, Tensor& softmaxstats, Tensor& o, Tensor& dropoutseed, @@ -301,36 +325,6 @@ auto build_graph_and_tensors( mha_graph->set_io_data_type(dtype) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - auto Q = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim(std::vector( - q.sizes().data(), q.sizes().data() + q.sizes().size())) - .set_stride(fixSizeOneDimStrideSDPA( - q.sizes(), - std::vector( - q.strides().data(), - q.strides().data() + q.strides().size())))); - auto K = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("K") - .set_dim(std::vector( - k.sizes().data(), k.sizes().data() + k.sizes().size())) - .set_stride(fixSizeOneDimStrideSDPA( - k.sizes(), - std::vector( - k.strides().data(), - k.strides().data() + k.strides().size())))); - auto V = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("V") - .set_dim(std::vector( - v.sizes().data(), v.sizes().data() + v.sizes().size())) - .set_stride(fixSizeOneDimStrideSDPA( - v.sizes(), - std::vector( - v.strides().data(), - v.strides().data() + v.strides().size())))); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Attn_scale") @@ -338,11 +332,6 @@ auto build_graph_and_tensors( .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - // TODO(eqy): support bias in the future in a follow-up PR - // auto bias = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("bias") - // .set_dim({b, 1, s_q, s_kv}) - // .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Seed") .set_dim({1, 1, 1, 1}) @@ -360,11 +349,30 @@ auto build_graph_and_tensors( .set_causal_mask(is_causal) .set_attn_scale(attn_scale) .set_dropout(dropout_probability, seed, offset); - // Optional bias in flash attention is only supported 8.9.3 onwards - if (cudnnGetVersion() >= 8904) { - // scaled_dot_product_flash_attention_options.set_alibi_mask(true); + auto Q = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec()))); + auto K = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec()))); + auto V = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec()))); + std::optional> bias; + if (attn_bias.has_value()) { + bias = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim(attn_bias.value().sizes().vec()) + .set_stride(attn_bias.value().strides().vec())); + scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Seq_q") .set_dim({b, 1, 1, 1}) @@ -376,20 +384,9 @@ auto build_graph_and_tensors( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); - // if (cudnnGetVersion() >= 8903) { - // scaled_dot_product_flash_attention_options.set_bias(bias) - // .set_padding_mask(true) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } - auto [O, Stats] = mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); - O->set_output(true) - .set_dim(std::vector( - o.sizes().data(), o.sizes().data() + o.sizes().size())) - .set_stride(std::vector( - o.strides().data(), o.strides().data() + o.strides().size())); + O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); if (Stats) { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); @@ -407,6 +404,7 @@ auto build_graph_and_tensors( std::move(Q), std::move(K), std::move(V), + std::move(bias), std::move(attn_scale), std::move(seed), std::move(offset), @@ -427,6 +425,7 @@ auto build_graph_and_tensors_backward( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, const Tensor& o, const Tensor& dO, const Tensor& softmaxstats, @@ -447,24 +446,6 @@ auto build_graph_and_tensors_backward( mha_graph->set_io_data_type(dtype) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - auto Q = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim(std::vector(q.sizes().begin(), q.sizes().end())) - .set_stride( - std::vector(q.strides().begin(), q.strides().end()))); - auto K = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("K") - .set_dim(std::vector(k.sizes().begin(), k.sizes().end())) - .set_stride( - std::vector(k.strides().begin(), k.strides().end()))); - auto V = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("V") - .set_dim(std::vector(v.sizes().begin(), v.sizes().end())) - .set_stride( - std::vector(v.strides().begin(), v.strides().end()))); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Attn_scale") @@ -472,6 +453,31 @@ auto build_graph_and_tensors_backward( .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); + auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() + .set_name("CUDNN_SDPA_BACKWARD") + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale); + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(q.strides().vec())); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(k.strides().vec())); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(v.strides().vec())); + std::optional> bias; + if (attn_bias.has_value()) { + bias = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim(attn_bias.value().sizes().vec()) + .set_stride(attn_bias.value().strides().vec())); + sdpa_backward_options.set_bias(bias.value()); + } auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Seed") .set_dim({1, 1, 1, 1}) @@ -482,47 +488,27 @@ auto build_graph_and_tensors_backward( .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); - auto O = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("O") - .set_dim(std::vector(o.sizes().begin(), o.sizes().end())) - .set_stride( - std::vector(o.strides().begin(), o.strides().end()))); - auto STATS = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("Stats") - .set_dim(std::vector( - softmaxstats.sizes().begin(), softmaxstats.sizes().end())) - .set_stride(std::vector( - softmaxstats.strides().begin(), softmaxstats.strides().end())) - .set_data_type(fe::DataType_t::FLOAT)); - auto DO = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("DO") - .set_dim(std::vector(dO.sizes().begin(), dO.sizes().end())) - .set_stride( - std::vector(dO.strides().begin(), dO.strides().end()))); - auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() - .set_name("CUDNN_SDPA_BACKWARD") - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale); + auto O = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim(o.sizes().vec()) + .set_stride(o.strides().vec())); + auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_dim(softmaxstats.sizes().vec()) + .set_stride(softmaxstats.strides().vec()) + .set_data_type(fe::DataType_t::FLOAT)); + auto DO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("DO") + .set_dim(dO.sizes().vec()) + .set_stride(dO.strides().vec())); if (dropout_probability != 0.0f) { sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset); } auto [DQ, DK, DV] = mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options); - DQ->set_output(true) - .set_dim(std::vector(dQ.sizes().begin(), dQ.sizes().end())) - .set_stride( - std::vector(dQ.strides().begin(), dQ.strides().end())); - DK->set_output(true) - .set_dim(std::vector(dK.sizes().begin(), dK.sizes().end())) - .set_stride( - std::vector(dK.strides().begin(), dK.strides().end())); - DV->set_output(true) - .set_dim(std::vector(dV.sizes().begin(), dV.sizes().end())) - .set_stride( - std::vector(dV.strides().begin(), dV.strides().end())); + DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); + DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); + DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); AT_CUDNN_FRONTEND_CHECK( @@ -534,6 +520,7 @@ auto build_graph_and_tensors_backward( std::move(Q), std::move(K), std::move(V), + std::move(bias), std::move(attn_scale), std::move(Seed), std::move(Offset), @@ -559,6 +546,7 @@ void run_cudnn_SDP_fprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, Tensor& softmaxstats, Tensor& o, Tensor& dropoutseed, @@ -573,6 +561,11 @@ void run_cudnn_SDP_fprop( softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat)); } + // do nothing if we got 0-element tensors + if (!q.numel() || !k.numel() || !v.numel()) { + return; + } + auto key = MHACacheKeyWrapper( b, h, @@ -583,6 +576,7 @@ void run_cudnn_SDP_fprop( q, k, v, + attn_bias, dropout_probability, is_causal, return_softmaxstats); @@ -605,13 +599,14 @@ void run_cudnn_SDP_fprop( q, k, v, + attn_bias, softmaxstats, o, dropoutseed, dropoutoffset, handle); } - auto [mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats] = + auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] = graph_and_tensors_values; std::unordered_map, void*> variant_pack = { @@ -619,13 +614,15 @@ void run_cudnn_SDP_fprop( {K, k.data_ptr()}, {V, v.data_ptr()}, {attn_scale, &scaling_factor}, - //{bias, bias.data_ptr()}, {seed, dropoutseed.data_ptr()}, {offset, dropoutoffset.data_ptr()}, {O, o.data_ptr()}}; if (return_softmaxstats) { variant_pack[Stats] = softmaxstats.data_ptr(); } + if (attn_bias.has_value()) { + variant_pack[bias.value()] = attn_bias.value().data_ptr(); + } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); @@ -647,6 +644,7 @@ void run_cudnn_SDP_bprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, const Tensor& o, const Tensor& dO, const Tensor& softmaxstats, @@ -655,6 +653,12 @@ void run_cudnn_SDP_bprop( Tensor& dV, const Tensor& dropoutseed, const Tensor& dropoutoffset) { + // do nothing if we got 0-element tensors + if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() || + !softmaxstats.numel()) { + return; + } + Tensor dO_ = dO; if (!dO.strides()[dO.strides().size() - 1]) { TORCH_WARN( @@ -694,6 +698,7 @@ void run_cudnn_SDP_bprop( q, k, v, + attn_bias, dropout_probability, is_causal, true); @@ -715,6 +720,7 @@ void run_cudnn_SDP_bprop( q, k, v, + attn_bias, o, dO_, softmaxstats, @@ -726,8 +732,20 @@ void run_cudnn_SDP_bprop( handle); } auto - [mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] = - graph_and_tensors_backward_values; + [mha_graph, + Q, + K, + V, + bias, + attn_scale, + Seed, + Offset, + O, + Do, + Stats, + Dq, + Dk, + Dv] = graph_and_tensors_backward_values; std::unordered_map, void*> variant_pack = {// inputs {Q, q.data_ptr()}, @@ -746,6 +764,9 @@ void run_cudnn_SDP_bprop( variant_pack[Seed] = dropoutseed.data_ptr(); variant_pack[Offset] = dropoutoffset.data_ptr(); } + if (attn_bias.has_value()) { + variant_pack[bias.value()] = attn_bias.value().data_ptr(); + } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 8b9315a5a3d85c..3ae1a03b2a7741 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -18,6 +18,7 @@ void run_cudnn_SDP_fprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, Tensor& softmaxstats, Tensor& o, Tensor& dropoutseed, @@ -36,6 +37,7 @@ void run_cudnn_SDP_bprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, const Tensor& o, const Tensor& dO, const Tensor& softmaxstats, diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index cab9dbd520d61b..c1054ccce07e48 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -536,6 +536,24 @@ std::optional convert_boolean_attn_mask(const std::optional& att // Otherwise, attn_mask represents an additive attention tensor return attn_mask; } + +// alternate version to workaround -inf issue with cuDNN +// TODO(eqy): delete this when cuDNN -inf issue is resolved +std::optional convert_boolean_attn_mask_cudnn(const std::optional& attn_mask, caffe2::TypeMeta dtype) { + // Pass through + if(!attn_mask.has_value()){ + return std::nullopt; + } + // Convert boolean mask to additive mask; need to invert mask to indicate what + // to mask *out*. + if (attn_mask->dtype() == at::kBool) { + // TODO Use the max type of the input and output + return at::where(attn_mask->logical_not(), -65504.0, at::scalar_tensor(0.0, at::TensorOptions().dtype(dtype))); + } + // Otherwise, attn_mask represents an additive attention tensor + return attn_mask; +} + // Memory Efficient Attention requires a padded attn mask bias // This function pads the attn_mask bias to be a multiple of 16 // Then slices the padded bias to the original size @@ -698,15 +716,16 @@ Tensor scaled_dot_product_attention( query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa); } sdp::SDPBackend backend = static_cast(choice_int); - std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); switch (backend) { case sdp::SDPBackend::cudnn_attention: { + std::optional attn_mask = convert_boolean_attn_mask_cudnn(attn_mask_, query_.dtype()); bool compute_logsumexp = should_compute_logsumexp(query_, key, value); auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention( - query_, key, value, attn_mask_, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale); + query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale); return std::get<0>(out_lse_softmax); } case sdp::SDPBackend::flash_attention: { + std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); if(query_.device().type() == DeviceType::CUDA){ c10::SymInt og_size = query_.sym_size(-1); Tensor query_padded = pad_last_dim<8, false>(query_); @@ -723,6 +742,7 @@ Tensor scaled_dot_product_attention( query_, key, value, dropout_p, is_causal, attn_mask, scale)); } case sdp::SDPBackend::efficient_attention: { + std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); bool compute_logsumexp = should_compute_logsumexp(query_, key, value); if (attn_mask.has_value()) { attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);; @@ -732,11 +752,13 @@ Tensor scaled_dot_product_attention( return std::get<0>(out_and_lse); } case sdp::SDPBackend::overrideable: { + std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable( query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale); return std::get<0>(out_lse_softmax); } - case sdp::SDPBackend::math: + case sdp::SDPBackend::math: { + std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); if ((!GradMode::is_enabled() || (!query_.requires_grad() && !key.requires_grad() && !value.requires_grad())) && query_.device().type() == DeviceType::MPS && dropout_p == 0.0 && query_.is_contiguous() && key.is_contiguous() && value.is_contiguous() @@ -761,6 +783,7 @@ Tensor scaled_dot_product_attention( std::nullopt, /*dropout_mask*/ scale, enable_gqa)); + } default: TORCH_CHECK( false, diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 78566555865c6d..cf648151a5ba8d 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -774,6 +774,18 @@ std::tuple */, log_sumexp/*Tensor softmaxstats*/, attention/*Tensor o*/, cudnn_seed/*Tensor dropoutseed*/, cudnn_offset/*Tensor dropoutoffset*/); // TODO(eqy): support debug_attn_mask - return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor()); + return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } std::tuple _scaled_dot_product_efficient_attention_cuda( diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 33b95945988b49..75ab28d21dbc3f 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -195,6 +195,27 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ const int64_t num_heads = query.size(1); const int64_t head_dim_qk = query.size(3); const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_batch_q = query.size(2); + const int64_t max_seqlen_batch_k = key.size(2); + + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + } + } + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); auto dq = at::empty_like(query); auto dk = at::empty_like(key); @@ -211,6 +232,7 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ query /*const Tensor& q*/, key /*const Tensor& k*/, value /*const Tensor& v*/, + attn_bias_ /*const std::optional& attn_bias*/, out /*const Tensor& o*/, grad_out/*const Tensor& dO*/, logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, @@ -219,7 +241,7 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ dv/*Tensor& dV*/, philox_seed/*Tensor& dropoutseed*/, philox_offset/*Tensor& dropoutoffset*/); - return std::make_tuple(dq, dk, dv); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); } std::tuple diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 66178a8d8791bd..88bcf29291cd73 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -561,7 +561,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { check_cudnn_deterministic, // check_is_causal, check_dtypes_low_precision, - check_for_attn_mask_cudnn, + check_attn_mask_shape, check_cudnn_hardware_support ); for (auto& constraint : general_constraints) { diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index b22e75301e8772..272700392e1d71 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -275,17 +275,6 @@ inline bool check_for_attn_mask(sdp_params const& params, bool debug) { return true; } -// TODO(eqy): remove this once support is added -inline bool check_for_attn_mask_cudnn(sdp_params const& params, bool debug) { - if (params.attn_mask.has_value()) { - if (debug) { - TORCH_WARN("cuDNN Attention does not support non-null attn_mask."); - } - return false; - } - return true; -} - inline bool check_attn_mask_shape(sdp_params const& params, bool debug) { auto attn_mask = params.attn_mask; if (!attn_mask.has_value()) { diff --git a/test/test_transformers.py b/test/test_transformers.py index ae70b7dc09e9d8..30810997cfd47e 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3319,10 +3319,10 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), fudge_factors={ - 'out': 2.0, + 'out': 3.0, 'grad_query': 100.0, 'grad_key': 8.0, - 'grad_value': 2.0, + 'grad_value': 3.0, } ) From 96f0093c9961a71895e1f2609393760a719914c5 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Wed, 11 Sep 2024 06:16:26 +0000 Subject: [PATCH 0708/1018] Back out "[pytorch][PR] [export] fix re-export custom metadata" (#135634) Summary: Broke some tests. Revert this diff Test Plan: CI Differential Revision: D62474337 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135634 Approved by: https://github.com/tugsbayasgalan --- test/quantization/pt2e/test_numeric_debugger.py | 9 ++++++--- torch/_export/utils.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 85b5e9137bdb5c..e6cc4dc3252f44 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -20,7 +20,6 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.export import export_for_training from torch.testing._internal.common_quantization import TestHelperModules from torch.testing._internal.common_utils import IS_WINDOWS, TestCase @@ -117,18 +116,22 @@ def test_deepcopy_preserve_handle(self): self.assertEqual(debug_handle_map, debug_handle_map_ref) + @unittest.skip("All nodes' meta are preserved but get_attr nodes' meta are wrong.") def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - m = export_for_training(m, example_inputs).module() + m = capture_pre_autograd_graph(m, example_inputs) generate_numeric_debug_handle(m) debug_handle_map_ref = _extract_debug_handles(m) - m_export = export_for_training(m, example_inputs).module() + m_export = capture_pre_autograd_graph(m, example_inputs) debug_handle_map = _extract_debug_handles(m_export) self.assertEqual(debug_handle_map, debug_handle_map_ref) + @unittest.skip( + "All nodes' meta are preserved but the first arg for the first node seems to be dropped" + ) def test_run_decompositions_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 7755099dfdb22a..e085e18b68a207 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -130,6 +130,18 @@ def _getattr(model: torch.fx.GraphModule, attr_name: str): if not isinstance(submodule, torch.fx.GraphModule): params_buffers_to_node_meta[target] = meta + # If the call_function uses param as input, we also need to update params' meta + # with this call_function node's meta. + # This is basically the same flow as torch.fx.traceback.preserve_meta() + if node.op == "call_function" and not isinstance( + node.target, torch._ops.HigherOrderOperator + ): + for arg in node._input_nodes: + if arg.op == "get_attr": + for entry in torch.fx.proxy._COPY_META_FIELDS: + if entry in meta: + params_buffers_to_node_meta[arg.target][entry] = meta[entry] + return params_buffers_to_node_meta From 7cb9204c36bce791f83cfa26efabc75c2e84ef08 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 11 Sep 2024 07:53:40 +0000 Subject: [PATCH 0709/1018] [Distributed] Improve efficiency of NaN checker (#135414) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some customers would like to run the NaN checks on the fly, so we are improving its efficiency. ## Benchmarking Allreduce 2G floats. `TORCH_NCCL_NAN_CHECK=1` Red kernel: ncclAllreduce Blue kernel: Nan check Screenshot 2024-09-06 at 10 00 05 PM ## Comparison with torch ops: Let's say a user manually check for NaNs with the following torch ops before all-reduce: ``` torch.any(torch.isnan(x)) ``` Screenshot 2024-09-06 at 10 14 53 PM So our perf is on-par with torch ops. ## Changes - Load from vidmem using "big packs" of 16 bytes - Bump `blockDim.x` from 256 to 512 - Separate loads and checks into two loops, each of 8 iterations - Unroll the loops - Templated functions for checking NaN in a "big pack" based on dtype Special thanks to @jbachan from NCCL! Pull Request resolved: https://github.com/pytorch/pytorch/pull/135414 Approved by: https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 17 ++- torch/csrc/distributed/c10d/NanCheck.cu | 139 +++++++++++++++++++++++- 2 files changed, 146 insertions(+), 10 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index c12a2e733c1339..5dabf3adc2c552 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -355,12 +355,19 @@ def test_nan_assert(self, type): store = c10d.FileStore(self.file_name, self.world_size) pg = self._create_process_group_nccl(store, self.opts()) device = self.rank_to_GPU[self.rank][0] - size = (10, 10) - nan_tensor = torch.full(size, self.rank, dtype=type, device=device) + # Cover different buffer sizes + if type == torch.float64: + size = (1024,) # 1K elements + elif type == torch.float32: + size = (1024, 1024) # 1M elements + elif type == torch.float16: + size = (1024, 1024, 1024) # 1G elements + else: + size = (1,) # 1 element + nan_tensor = torch.zeros(*size, dtype=type, device=device) # randomly pick an nan element - i = random.randint(0, nan_tensor.size(0) - 1) - j = random.randint(0, nan_tensor.size(1) - 1) - nan_tensor[i, j] = float("nan") + index = tuple([random.randrange(size[i]) for i in range(len(size))]) + nan_tensor[index] = float("nan") with self.assertRaises(RuntimeError): pg.allreduce(nan_tensor) dist.destroy_process_group() diff --git a/torch/csrc/distributed/c10d/NanCheck.cu b/torch/csrc/distributed/c10d/NanCheck.cu index a8d22b9a44e21e..85abd733affc95 100644 --- a/torch/csrc/distributed/c10d/NanCheck.cu +++ b/torch/csrc/distributed/c10d/NanCheck.cu @@ -11,13 +11,142 @@ namespace c10d { // CUDA kernel to check if data has NAN, device side assert // is raised if NAN is found + +// Using ulong2 as a "byte pack", with 16 bytes, for efficient data load +typedef ulong2 BytePack; + +// AMD HIP doesn't define `__trap()`, using `assert` instead +#ifdef USE_ROCM +#define __trap() assert(0) +#endif + +//// Start of templated functions for checking NaNs inside a BytePack + +// (i) General implementation (aka fallback) +// We use a for loop to iterate over the elements in a BytePack. +// EltPerPack would be greater than 8 if falling in this case. + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + T* data = (T*)tmp; + #pragma unroll 8 + for (int i = 0; i < EltPerPack; i++) { + if (isnan(data[i])) __trap(); + } + } +}; + +// (ii) Template Specialization for 8-byte data types, e.g. double +// EltPerPack = 16 / 8 = 2 + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1])) __trap(); + } +}; + +// (iii) Template specialization for 4-byte data types, e.g. float32 +// EltPerPack = 16 / 4 = 4 + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3])) __trap(); + } +}; + +// (iv) Template specialization for 2-byte data types, e.g. float16, bfloat16, half. +// EltPerPack = 16 / 2 = 8 + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3]) || + isnan(data[4]) || isnan(data[5]) || isnan(data[6]) || isnan(data[7])) { + __trap(); + } + } +}; + +//// End of templated functions for checking NaNs inside a BytePack + + +// Fast-path check routine: +// each thread will load and check 8 BytePacks in this routine + +// Create a tmp buffer of size 8, also unroll for loop by 8 +#define UNROLL 8 + +template +__device__ __forceinline__ void checkChunk(BytePack* ptr) { + BytePack tmp[UNROLL]; + int nWorkers = blockDim.x * gridDim.x; + // First load values from global memory into tmp buffer + #pragma unroll 8 + for (int j = 0; j < UNROLL; j++) { + tmp[j] = ptr[nWorkers * j]; + } + // Then check each BytePack in the tmp buffer + #pragma unroll 8 + for (int j = 0; j < UNROLL; j++) { + CheckBytePack::check(tmp + j); + } + // Note: we separate the check from the load for efficient loading +} + +// Align address of `ptr` up, to the alignment of `T` +#define ALIGN_UP(ptr, T) (((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T)) + +// This is the host-facing kernel + template __global__ void checkForNaN(T* data, size_t size) { - size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - size_t stride = blockDim.x * gridDim.x; + constexpr int EltPerPack = sizeof(BytePack) / sizeof(T); + // Offset of current thread + size_t offset = blockIdx.x * blockDim.x + threadIdx.x; + + // Align input address up to BytePack in case it is not + T* ptrAlign = (T*)ALIGN_UP(data, BytePack); + // Pre-process the data before alignment + size_t preProcElts = min(ptrAlign - data, size); + // Read memory by T (slow). One iter is enough bc the number of threads would + // be bigger than `preProcElts` + if (offset < preProcElts) { + if (isnan(data[offset])) __trap(); + } + // We have processes this amount of data + size -= preProcElts; + + // Start BytePack processing + BytePack* ptr = (BytePack*)ptrAlign; + // Size of input data in unit of BytePack + size_t sizeInBP = size * sizeof(T) / sizeof(BytePack); + // Number of BytePacks processed in one fast-path iteration + size_t loopSize = blockDim.x * gridDim.x * UNROLL; + + // Fast path + // The condition below makes sure there is enough data to process (`loopSize`) + for (; offset + loopSize <= sizeInBP; offset += loopSize) { + checkChunk(ptr + offset); + } + + // The rest data goes on slow path + // We just do regular load and check + for (; offset < sizeInBP; offset += blockDim.x * gridDim.x) { + BytePack tmp = ptr[offset]; + CheckBytePack::check(&tmp); + } - for (size_t i = tid; i < size; i += stride) { - CUDA_KERNEL_ASSERT(!isnan(data[i])); + // We can still have a tail smaller than 1 BytePack + // TODO: merge this tail check with head check to make them concurrent + if (threadIdx.x < size % EltPerPack) { + T* tailPtr = (T*)(ptr + sizeInBP); + if (isnan(tailPtr[threadIdx.x])) __trap(); } } @@ -27,7 +156,7 @@ void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { if (!torch::is_floating_point(tensor)) { return; } - const size_t maxNumThreadsPerBlock = 256; + const size_t maxNumThreadsPerBlock = 512; const size_t maxNumBlocks = 24; const size_t numThreadsPerBlock = std::min(maxNumThreadsPerBlock, tensor.numel()); From 6387f201f8824f9192a30f48a23c21b097c0047c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 10 Sep 2024 21:57:09 -0700 Subject: [PATCH 0710/1018] [Inductor][FlexAttention] Supports dynamic shapes with block mask (#135629) Fixes #134560 and #135206 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135629 Approved by: https://github.com/drisspg --- test/inductor/test_flex_attention.py | 27 +++++++++++++++----- torch/_inductor/kernel/flex_attention.py | 32 ++++++++++++++++-------- torch/_inductor/kernel/flex_decoding.py | 2 ++ 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index cc821d1467c329..e0b975c2139e36 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -411,7 +411,11 @@ def run_dynamic_test( S: int = S, D: int = D, ): - sdpa_partial = create_attention(score_mod) + # If the seqlen becomes smaller than the seqlen of the previous batch, + # we can still reuse the block_mask created from a larger seqlen. + MAX_S = S + block_mask = create_block_mask(noop_mask, 1, 1, MAX_S, MAX_S) + sdpa_partial = create_attention(score_mod, block_mask=block_mask) # The first eager batch, shape (B, H, S, D) q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) @@ -485,6 +489,17 @@ def run_dynamic_test( ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) + # The third iteration, shape (B * 2, H, S * 2, D) + # Since seqlen is larger than the seqlen in block_mask, throw errors. + S = int(S * 4) + q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + with self.assertRaisesRegex( + torch._dynamo.exc.BackendCompilerFailed, "Q seqlen must be smaller than" + ): + compiled_sdpa(q3, k3, v3) + def run_automatic_dynamic_test( self, score_mod: Callable, @@ -494,7 +509,9 @@ def run_automatic_dynamic_test( S: int = S, D: int = D, ): - sdpa_partial = create_attention(score_mod) + MAX_S = S + block_mask = create_block_mask(noop_mask, 1, 1, MAX_S, MAX_S) + sdpa_partial = create_attention(score_mod, block_mask=block_mask) # The first eager batch, shape (B, H, S, D) q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") @@ -594,16 +611,14 @@ def causal_mask(b, h, q, kv): ) self.run_test_with_call(attention, dtype, B, H, 64, D, B, H, 64, D) - @expectedFailure # TODO: supports block sparsity with dynamic shapes @supported_platform - @common_utils.parametrize("dtype", test_dtypes) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable): self.run_dynamic_test(score_mod, dtype) - @expectedFailure # TODO: supports block sparsity with dynamic shapes @supported_platform - @common_utils.parametrize("dtype", test_dtypes) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_automatic_dynamic( self, dtype: torch.dtype, score_mod: Callable diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index ec79b8097cb846..33113434e732e7 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -214,8 +214,9 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE) - SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -226,8 +227,8 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK # KV_IDX and KV_NUM_BLKS are always contiguous. sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq - sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE - sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950 + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 Q_block_ptr = tl.make_block_ptr( base=Q, @@ -746,10 +747,9 @@ def flex_attention( Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - - if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): - raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") - + assert V.graph.sizevars.evaluate_expr( + sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) + ), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" B = Bq if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: @@ -804,6 +804,16 @@ def flex_attention( (64, 64, 4, 3), ] + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) + ), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) + ), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function # We do need to explicitly pass it in for autotuning though. @@ -1634,8 +1644,10 @@ def flex_attention_backward(*args, **kwargs): Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): - raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + assert V.graph.sizevars.evaluate_expr( + sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) + ), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + B = Bq kernel_options = dict(kernel_options) kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index ef57f208a93665..c758a3bcfbc961 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -474,6 +474,8 @@ def create_flex_decoding_kernel(*args, **kwargs): ) # TODO: This feels sketchy kernel_options.setdefault("SAFE_N_BOUNDARY", True) + # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function From 61351a9ba989914c4875eb0e3258278971248087 Mon Sep 17 00:00:00 2001 From: Amadeusz Skrzypczak Date: Wed, 11 Sep 2024 11:05:38 +0000 Subject: [PATCH 0711/1018] Disable cuda specific restrictions in _scaled_mm for other devices (#135579) Fixes #135576 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135579 Approved by: https://github.com/drisspg --- torch/_meta_registrations.py | 115 ++++++++++++++++++----------------- 1 file changed, 59 insertions(+), 56 deletions(-) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 61e01fe3b1d11e..d4d2a52f2888c4 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5600,12 +5600,6 @@ def meta_scaled_mm( out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): - def is_row_major(stride): - return stride[0] > stride[1] and stride[1] == 1 - - def is_col_major(stride): - return stride[0] == 1 and stride[1] > 1 - def is_fp8_type(dtype): return dtype in ( torch.float8_e4m3fn, @@ -5618,68 +5612,77 @@ def is_fp8_type(dtype): self.dim() == 2 and mat2.dim() == 2, lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}", ) - torch._check( - is_row_major(self.stride()), - lambda: "self must be row_major", - ) - torch._check( - is_col_major(mat2.stride()), - lambda: "mat2 must be col_major", - ) - torch._check( - self.size(1) % 16 == 0, - lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", - ) - torch._check( - mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, - lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", - ) torch._check( is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype), lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", ) - # determine scaling type and check input dimensions (refer to Blas.cpp op) - torch._check( - scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, - lambda: "Both scale_a and scale_b must be float (fp32) tensors.", - ) - m, k = self.shape - n = mat2.size(1) - if scale_a.numel() == 1 and scale_b.numel() == 1: - # tensorwise scaling - pass - else: - # for non-tensorwise scaling, enforce 2D input tensors + if device_hint(self) == "cuda": + + def is_row_major(stride): + return stride[0] > stride[1] and stride[1] == 1 + + def is_col_major(stride): + return stride[0] == 1 and stride[1] > 1 + + torch._check( + is_row_major(self.stride()), + lambda: "self must be row_major", + ) + torch._check( + is_col_major(mat2.stride()), + lambda: "mat2 must be col_major", + ) + torch._check( + self.size(1) % 16 == 0, + lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", + ) torch._check( - scale_a.dim() == 2 and scale_b.dim() == 2, - lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}", + mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, + lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", ) - if ( - scale_a.size(0) == m - and scale_a.size(1) == 1 - and scale_b.size(0) == 1 - and scale_b.size(1) == n - ): - # rowwise scaling - torch._check( - scale_a.is_contiguous() and scale_b.is_contiguous(), - lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.", - ) + # determine scaling type and check input dimensions (refer to Blas.cpp op) + torch._check( + scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, + lambda: "Both scale_a and scale_b must be float (fp32) tensors.", + ) + m, k = self.shape + n = mat2.size(1) + if scale_a.numel() == 1 and scale_b.numel() == 1: + # tensorwise scaling + pass else: - # does not match any valid scaling type + # for non-tensorwise scaling, enforce 2D input tensors torch._check( - False, - lambda: ( - "Invalid scaling configuration. " - "For tensorwise scaling, both scales should be scalar. " - f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). " - f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) " - f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})" - ), + scale_a.dim() == 2 and scale_b.dim() == 2, + lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}", ) + if ( + scale_a.size(0) == m + and scale_a.size(1) == 1 + and scale_b.size(0) == 1 + and scale_b.size(1) == n + ): + # rowwise scaling + torch._check( + scale_a.is_contiguous() and scale_b.is_contiguous(), + lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.", + ) + else: + # does not match any valid scaling type + torch._check( + False, + lambda: ( + "Invalid scaling configuration. " + "For tensorwise scaling, both scales should be scalar. " + f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). " + f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) " + f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})" + ), + ) + _out_dtype = out_dtype if out_dtype is not None else self.dtype return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device) From 3ce994c4328741c2871fe728dcf3470d73a3f7d9 Mon Sep 17 00:00:00 2001 From: Nikita Lutsenko Date: Wed, 11 Sep 2024 11:12:23 +0000 Subject: [PATCH 0712/1018] ATen | Fix MPSCNNNeuron creation on Mac Catalyst. (#135595) Summary: These are still utilized directly when using relu/sigmoid/tanh tensors directly from here: https://fburl.com/code/k6n7ofzd However, on Mac Catalyst we always were returning `nil`, as such in most cases yielding the entire graph completely useless and most often just stray `MPSTemporaryImage` references that were never written into. This fixes the issue completely by making sure that we always return the valid kernels back, so they can be executed. Test Plan: Test with segmentation net that uses a combination of relu and other tensors together - run this via Mac Catalyst build - it works! {F1858576745} Reviewed By: MichaelTay Differential Revision: D62430010 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135595 Approved by: https://github.com/MichaelTay --- .../ATen/native/metal/mpscnn/MPSCNNNeuronOp.h | 8 +- .../native/metal/mpscnn/MPSCNNNeuronOp.mm | 103 +++++++++--------- 2 files changed, 54 insertions(+), 57 deletions(-) diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h index e1a9b2617bd3e2..639e4b6b746b7c 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h @@ -2,10 +2,10 @@ @interface MPSCNNNeuronOp : NSObject -+ (MPSCNNNeuronHardSigmoid*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)); -+ (MPSCNNNeuronReLU*)relu; -+ (MPSCNNNeuronSigmoid*)sigmoid; -+ (MPSCNNNeuronTanH*)tanh; ++ (MPSCNNNeuron*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)); ++ (MPSCNNNeuron*)relu; ++ (MPSCNNNeuron*)sigmoid; ++ (MPSCNNNeuron*)tanh; @end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm index e722f2765c0893..62eb51661a9a62 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm @@ -8,69 +8,67 @@ @implementation MPSCNNNeuronOp -+ (MPSCNNNeuronHardSigmoid*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)) { -// Remove this once we support iOS 11.3 -#if TARGET_OS_MACCATALYST - return nil; -#else ++ (MPSCNNNeuron*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)) { + static MPSCNNNeuron* neuron = nil; static dispatch_once_t onceToken; - static MPSCNNNeuronHardSigmoid* neuron = nil; dispatch_once(&onceToken, ^{ +#if TARGET_OS_MACCATALYST + neuron = [[MPSCNNNeuron alloc] initWithDevice:[MetalContext sharedInstance].device neuronDescriptor:[MPSCNNNeuronOpDescriptor hardSigmoidDescriptor]]; +#else neuron = [[MPSCNNNeuronHardSigmoid alloc] - initWithDevice:[MetalContext sharedInstance].device - a:1.0 / 6.0 - b:0.5]; + initWithDevice:[MetalContext sharedInstance].device + a:1.0 / 6.0 + b:0.5]; +#endif }); return neuron; -#endif } -+ (MPSCNNNeuronReLU*)relu { -// Remove this once we support iOS 11.3 -#if TARGET_OS_MACCATALYST - return nil; -#else - static MPSCNNNeuronReLU* relu = nil; ++ (MPSCNNNeuron*)relu { + static MPSCNNNeuron* neuron = nil; static dispatch_once_t onceToken; dispatch_once(&onceToken, ^{ - relu = [[MPSCNNNeuronReLU alloc] - initWithDevice:[MetalContext sharedInstance].device - a:0]; - }); - return relu; +#if TARGET_OS_MACCATALYST + neuron = [[MPSCNNNeuron alloc] + initWithDevice:[MetalContext sharedInstance].device + neuronDescriptor:[MPSCNNNeuronOpDescriptor reluDescriptor]]; +#else + neuron = [[MPSCNNNeuronReLU alloc] + initWithDevice:[MetalContext sharedInstance].device + a:0]; #endif + }); + return neuron; } -+ (MPSCNNNeuronSigmoid*)sigmoid { -// Remove this once we support iOS 11.3 -#if TARGET_OS_MACCATALYST - return nil; -#else ++ (MPSCNNNeuron*)sigmoid { + static MPSCNNNeuron* neuron = nil; static dispatch_once_t onceToken; - static MPSCNNNeuronSigmoid* sigmoid = nil; dispatch_once(&onceToken, ^{ - sigmoid = [[MPSCNNNeuronSigmoid alloc] - initWithDevice:[MetalContext sharedInstance].device]; - }); - return sigmoid; +#if TARGET_OS_MACCATALYST + neuron = [[MPSCNNNeuron alloc] initWithDevice:[MetalContext sharedInstance].device neuronDescriptor:[MPSCNNNeuronOpDescriptor sigmoidDescriptor]]; +#else + neuron = [[MPSCNNNeuronSigmoid alloc] + initWithDevice:[MetalContext sharedInstance].device]; #endif + }); + return neuron; } -+ (MPSCNNNeuronTanH*)tanh { -// Remove this once we support iOS 11.3 -#if TARGET_OS_MACCATALYST - return nil; -#else ++ (MPSCNNNeuron*)tanh { + static MPSCNNNeuron* neuron = nil; static dispatch_once_t onceToken; - static MPSCNNNeuronTanH* tanh = nil; dispatch_once(&onceToken, ^{ - tanh = [[MPSCNNNeuronTanH alloc] - initWithDevice:[MetalContext sharedInstance].device - a:1 - b:1]; - }); - return tanh; +#if TARGET_OS_MACCATALYST + neuron = [[MPSCNNNeuron alloc] initWithDevice:[MetalContext sharedInstance].device neuronDescriptor:[MPSCNNNeuronOpDescriptor tanhDescriptor]]; +#else + neuron = [[MPSCNNNeuronTanH alloc] + initWithDevice:[MetalContext sharedInstance].device + a:1 + b:1]; #endif + }); + return neuron; } @end @@ -85,9 +83,9 @@ + (MPSNNNeuronDescriptor*)hardSigmoidDescriptor { static MPSNNNeuronDescriptor* neuronDesc = nil; dispatch_once(&onceToken, ^{ neuronDesc = [MPSNNNeuronDescriptor - cnnNeuronDescriptorWithType:MPSCNNNeuronTypeHardSigmoid - a:1.0 / 6.0 - b:0.5]; + cnnNeuronDescriptorWithType:MPSCNNNeuronTypeHardSigmoid + a:1.0 / 6.0 + b:0.5]; }); return neuronDesc; } @@ -97,8 +95,8 @@ + (MPSNNNeuronDescriptor*)reluDescriptor { static MPSNNNeuronDescriptor* neuronDesc = nil; dispatch_once(&onceToken, ^{ neuronDesc = - [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeReLU - a:0]; + [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeReLU + a:0]; }); return neuronDesc; } @@ -108,7 +106,7 @@ + (MPSNNNeuronDescriptor*)sigmoidDescriptor { static MPSNNNeuronDescriptor* neuronDesc = nil; dispatch_once(&onceToken, ^{ neuronDesc = [MPSNNNeuronDescriptor - cnnNeuronDescriptorWithType:MPSCNNNeuronTypeSigmoid]; + cnnNeuronDescriptorWithType:MPSCNNNeuronTypeSigmoid]; }); return neuronDesc; } @@ -117,10 +115,9 @@ + (MPSNNNeuronDescriptor*)tanhDescriptor { static dispatch_once_t onceToken; static MPSNNNeuronDescriptor* neuronDesc = nil; dispatch_once(&onceToken, ^{ - neuronDesc = - [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeTanH - a:1.0 - b:1.0]; + neuronDesc = [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeTanH + a:1.0 + b:1.0]; }); return neuronDesc; } From b0cb7e9ecb5b72b7265cb852b32cbd8446e131bc Mon Sep 17 00:00:00 2001 From: Pushpak Raj Gautam Date: Wed, 11 Sep 2024 14:08:40 +0000 Subject: [PATCH 0713/1018] For S444023: Back out "deprecate `search_autotune_cache` (#133628)" (#135186) Summary: For S444023 Test Plan: Revert prevented the NaN errors - f639391901 Training job ran for 7767 iterations. NaN errors show up within the first 1k. Reviewed By: nmacchioni Differential Revision: D62224747 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135186 Approved by: https://github.com/kit1980 --- test/inductor/test_cpu_cpp_wrapper.py | 2 +- test/inductor/test_cuda_cpp_wrapper.py | 2 +- test/inductor/test_max_autotune.py | 16 ++++++++++++++++ torch/_inductor/config.py | 4 ++-- torch/_inductor/select_algorithm.py | 5 +++++ torch/_inductor/utils.py | 4 +++- 6 files changed, 28 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index d02ca75cfa5c44..8063d0545ac032 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -130,7 +130,7 @@ def make_test_case( assert callable(func), "not a callable" func = slowTest(func) if slow else func - @config.patch(cpp_wrapper=True) + @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): tests.setUpClass() tests.setUp() diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 136483d27f9ece..9b8b802065590d 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -100,7 +100,7 @@ def make_test_case( assert callable(func), "not a callable" func = slowTest(func) if slow else func - @config.patch(cpp_wrapper=True) + @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): tests.setUpClass() tests.setUp() diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 5a74bb3048490c..8f5eacc0c14ae6 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -377,6 +377,22 @@ def fn(a, b, c): fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + @skipIfRocm + @fresh_inductor_cache() + @config.patch(search_autotune_cache=True) + def test_search_autotune_cache(self): + def fn(a, b, c): + a = (a @ b) @ c + a, b, c = (t.to(torch.float16) for t in [a, b, c]) + return (a @ b) @ c + + fn_c = torch.compile()(fn) + inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)] + from torch._dynamo.utils import counters + + self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) + self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + @skipIfRocm @fresh_inductor_cache() @config.patch(max_autotune=True, max_fusion_size=2) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 32144d4ef97560..4dc7f9b7964fcf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -330,8 +330,8 @@ def autotune_remote_cache_default() -> Optional[bool]: # that can appear in the input shapes (e.g., in autotuning) unbacked_symint_fallback = 8192 -# DEPRECATED, DO NOT USE -search_autotune_cache = False +# enable searching global and local cache regardless of `max_autotune` +search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1" save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 67d9651a1e39fa..7b90690cf50b19 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1258,6 +1258,11 @@ def no_op(*args, **kwargs): if timings: return no_op + if config.search_autotune_cache and not ( + config.max_autotune or config.max_autotune_gemm + ): + return no_op + precompile_key = create_precompile_key(name, inputs_key, choices) if precompile_func := self.precompile_cache.get(precompile_key): return precompile_func diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index b17f3a68559a60..28b8549408f3e4 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1054,7 +1054,9 @@ def is_big_gpu(index) -> bool: def use_max_autotune() -> bool: - return config.max_autotune or config.max_autotune_gemm + return ( + config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache + ) def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: From e96e203084f328d6d44f938d88e3a03df133d3c0 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Tue, 10 Sep 2024 09:33:06 -0700 Subject: [PATCH 0714/1018] Fix test_triton_kernel_float64_constant (#135583) Summary: Landed https://github.com/pytorch/pytorch/pull/135260 too soon and the test in that PR doesn't do exactly what I tested (actually test different dtypes). Test Plan: `python test/inductor/test_triton_kernels.py -k float64_constant` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135583 Approved by: https://github.com/isuruf, https://github.com/eellison, https://github.com/Skylion007 --- test/inductor/test_triton_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 365f43cdd83bee..9d426926d7c019 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -1616,7 +1616,7 @@ def test_triton_kernel_float64_constant(self, dtype): def f(x): return x * (0.12 * x.shape[0]) - x = torch.ones(200, device=GPU_TYPE, dtype=torch.float64) + x = torch.ones(200, device=GPU_TYPE, dtype=dtype) eager_out = f(x) compiled_out = torch.compile(f, dynamic=True)(x) From 159ebbf775ffb69ac0b67f5af6e4133de24d7229 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 11 Sep 2024 15:39:21 +0000 Subject: [PATCH 0715/1018] Revert "[Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732)" This reverts commit 6a3edfcc1e474e6ebd0c06624000a6d6bf1a0dee. Reverted https://github.com/pytorch/pytorch/pull/134732 on behalf of https://github.com/clee2000 due to broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/444b52ff40cf4afce7bc3fdcf021a88eab3b954c), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2344016694)) --- torch/_dynamo/backends/debugging.py | 12 ----- torch/_higher_order_ops/cond.py | 20 +++----- torch/_higher_order_ops/while_loop.py | 21 ++------- torch/fx/experimental/proxy_tensor.py | 67 +++++++++++---------------- torch/nn/attention/flex_attention.py | 42 ++++++----------- 5 files changed, 51 insertions(+), 111 deletions(-) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 18784498862e39..87214f7d59774e 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -31,18 +31,6 @@ def eager(gm, fake_tensor_inputs, **kwargs): return gm.forward -def make_eager_backend_with_torch_function_mode(mode): - """Used to trace HOPs (cond and while) for eager exectution, the metadata - TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks - in the HOP, so we need to externally run this mode and not trace it.""" - - def fn(gm, fake_tensor_inputs, **kwargs): - with mode: - return gm.forward - - return fn - - @register_backend def eager_noexcept(gm, fake_tensor_inputs, **kwargs): if kwargs: diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 0467e2899adc28..dee400d76f5964 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -28,7 +28,6 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, @@ -130,10 +129,6 @@ def false_fn(x: torch.Tensor): if torch.compiler.is_dynamo_compiling(): return cond_op(pred, true_fn, false_fn, operands) - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) - if isinstance(pred, (bool, int, float)): log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." @@ -174,15 +169,12 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode(metadata_mode) - else: - backend = "eager" - return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( - pred, true_fn, false_fn, operands - ) + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_pre_dispatch_torch_function_mode(): + return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( + pred, true_fn, false_fn, operands + ) def create_fw_bw_graph_branches(true_fn, false_fn, *operands): diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index f14321842f40b9..d64f512399d9d5 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -15,11 +15,7 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, - ProxyTorchDispatchMode, - track_tensor_tree, -) +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree class WhileLoopOp(HigherOrderOperator): @@ -117,9 +113,6 @@ def body_fn(iter, x): - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. """ - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. @@ -147,15 +140,9 @@ def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode(metadata_mode) - else: - backend = "eager" - return torch.compile( - _while_loop_op_wrapper, backend=backend, fullgraph=True - )(cond_fn, body_fn, carried_inputs, additional_inputs) + return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( + cond_fn, body_fn, carried_inputs, additional_inputs + ) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 213672e216e6ce..5a5771d2ae8f7a 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext +from contextlib import contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -1084,43 +1084,38 @@ def unwrap_proxy(self, e: T) -> object: return e -def _make_temp_remove_mode_context_manager( - mode_ty: Type[TorchFunctionMode], -) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: - @contextmanager - def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: - from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - - temp_elements = [] - removed_mode = None +@contextmanager +def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - while _len_torch_function_stack() > 0: - mode = _pop_mode() - if isinstance(mode, mode_ty): - removed_mode = mode - break - else: - temp_elements.append(mode) + temp_elements = [] + pre_dispatch_mode = None - for mode in reversed(temp_elements): - _push_mode(mode) + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, PreDispatchTorchFunctionMode): + pre_dispatch_mode = mode + break + else: + temp_elements.append(mode) - try: - yield removed_mode + for mode in reversed(temp_elements): + _push_mode(mode) - finally: - if removed_mode is not None: - count = len(temp_elements) - while count > 0: - mode = _pop_mode() - count -= 1 + try: + yield - temp_elements.append(removed_mode) + finally: + if pre_dispatch_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 - for mode in reversed(temp_elements): - _push_mode(mode) + temp_elements.append(pre_dispatch_mode) - return context_manager_fn + for mode in reversed(temp_elements): + _push_mode(mode) @torch._disable_dynamo @@ -1235,11 +1230,6 @@ def __torch_function__( return func(*args, **kwargs) -_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( - TorchFunctionMetadataMode -) - - # This mode is **only** used for pre_dispatch tracing. # In particular, we need to make sure that autograd/autocast API's # that do not desugar into dispatcher operators stay in the graph. @@ -1268,11 +1258,6 @@ def __torch_function__( return func(*args, **kwargs) -_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( - PreDispatchTorchFunctionMode -) - - class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index f71b86222ee5c1..7b0a3354642a7f 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,7 +19,6 @@ ) from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input @@ -1028,10 +1027,6 @@ def score_mod( if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) - # Dynamo is expecting a callable with "__code__" attribute. # We cannot directly pass hop to it. So we wrap it in a dummy function. def _flex_attention_hop_wrapper(*args, **kwargs): @@ -1040,25 +1035,18 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _set_compilation_env(): with torch._dynamo.utils.disable_cache_limit(): with _temp_remove_pre_dispatch_torch_function_mode(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode( - metadata_mode - ) - else: - backend = "eager" - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend="eager", fullgraph=True - )( - query, - key, - value, - score_mod, - block_mask.as_tuple(), - scale, - kernel_options, - ) - if return_lse: - return out, lse * math.log(2) - else: - return out + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend="eager", fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out From 123ead2559d8768b5e5dbacab76f40e1133e5e50 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 11 Sep 2024 15:48:27 +0000 Subject: [PATCH 0716/1018] Revert "[Dynamo] Simplify torch function mode stack guard (#135444)" This reverts commit 444b52ff40cf4afce7bc3fdcf021a88eab3b954c. Reverted https://github.com/pytorch/pytorch/pull/135444 on behalf of https://github.com/clee2000 due to something in this stack broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/444b52ff40cf4afce7bc3fdcf021a88eab3b954c), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/135444#issuecomment-2344036843)) --- test/dynamo/test_modes.py | 2 -- torch/_dynamo/guards.py | 10 ++++---- torch/csrc/dynamo/guards.cpp | 44 +++++++----------------------------- 3 files changed, 12 insertions(+), 44 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index fa4c23fd320ddb..bad5a788903440 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -68,11 +68,9 @@ def tearDownClass(cls): def setUp(self): torch.set_default_device(None) - torch._dynamo.reset() def tearDown(self): torch.set_default_device(None) - torch._dynamo.reset() def _run_torch_function_mode_guard_test(self): class TestMode1(BaseTorchFunctionMode): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1544bdd6dc697b..4437ef4038190d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2663,14 +2663,12 @@ def make_torch_function_mode_stack_guard(intial_stack): def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - - types_ = [ty for ty in types if ty not in IGNORED_MODES] - cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] - - if len(cur_stack_) != len(types_): + if len(cur_stack) != len(types): return False - for ty, mode in zip(types_, cur_stack_): + for ty, mode in zip(types, cur_stack): + if ty in IGNORED_MODES: + continue if ty != type(mode): return False diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 771de44427b669..afc3182c632640 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2523,8 +2523,7 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref - auto type = Py_TYPE(mode); - this->_ref_stack.push_back(type); + this->_ref_stack.push_back(Py_TYPE(mode)); } len = PyList_Size(ignored_types.ptr()); @@ -2544,56 +2543,29 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface size_t ref_ind = 0; - const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - int64_t idx = 0; - while ((idx < len) && (ref_ind < ref_stack_size)) { + for (int64_t idx = 0; idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - bool act_ignored = this->_ignored_types.count(mode_type) > 0; - bool ref_ignored = - this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; // skip ignored types - if (act_ignored && ref_ignored) { - idx++; - ref_ind++; - continue; - } else if (ref_ignored) { - ref_ind++; - continue; - } else if (act_ignored) { - idx++; + if (this->_ignored_types.count(mode_type) > 0) { continue; } // if we already have more non-ignored modes than the ref stack // or if the mode doesn't match at the current index, return false - else if (mode_type != _ref_stack.at(ref_ind)) { + else if ( + (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || + mode_type != _ref_stack[ref_ind]) { return false; } ref_ind++; - idx++; - } - - for (; ref_ind < ref_stack_size; ref_ind++) { - if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { - return false; - } - } - - for (; idx < len; idx++) { - std::shared_ptr mode = - at::impl::PythonTorchFunctionTLS::get_stack_at(idx); - - PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (!(this->_ignored_types.count(mode_type) > 0)) { - return false; - } } - return ref_ind == ref_stack_size && idx == len; + return ref_ind == this->_ref_stack.size(); } private: From ae8ed52dd8bc42132db06246c54895aa734a5998 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 11 Sep 2024 15:51:10 +0000 Subject: [PATCH 0717/1018] Revert "[dynamo] Bug fix for _torchdynamo_inline source handling (#135612)" This reverts commit 5c3d0a2dedbc0e85f3b256ce56ac674078a5fae1. Reverted https://github.com/pytorch/pytorch/pull/135612 on behalf of https://github.com/clee2000 due to broke inductor/test_cpu_select_algorithm.py::TestSelectAlgorithmCPU::test_linear_input_transpose_bias_True_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10805518363/job/29982386304) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/5c3d0a2dedbc0e85f3b256ce56ac674078a5fae1), bad TD ([comment](https://github.com/pytorch/pytorch/pull/135612#issuecomment-2344039370)) --- test/dynamo/test_repros.py | 16 ---------------- torch/_dynamo/variables/builder.py | 12 +++++------- torch/_dynamo/variables/functions.py | 2 ++ 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 1b9bf3c83a8eb1..a3c4226ed6a9f8 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5912,22 +5912,6 @@ def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: actual[1].untyped_storage().data_ptr(), ) - def test_torch_compile_in_compile_frame(self): - # TODO(anijain2305/yanboliang) - Dont graph break on torch.compile. - def gn(x, c=None): - if c is None: - c = 2 - return c * x - - def outer_func(x): - return torch.compile(gn)(x) - - compile_outer = torch.compile(outer_func, backend="eager") - x = torch.randn(4) - ref = outer_func(x) - res = compile_outer(x) - self.assertEqual(ref, res) - instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index bf1cbed10b8e5c..f7ac4d073b534b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -547,6 +547,11 @@ class Autotuner: if id_dispatch is not None: return id_dispatch(self, value) + # Note - There are some nested values where types mismatch! + # We want to get those out and wrap those. + if is_function_or_wrapper(value): + value = inspect.getattr_static(value, "_torchdynamo_inline", value) + # Everything else (NB: order matters!) if is_traceable_wrapper_subclass(value) or istype( value, config.traceable_tensor_subclasses @@ -983,13 +988,6 @@ def build_key_value(i, k, v): elif is_lru_cache_wrapped_function(value): self.install_guards(GuardBuilder.TYPE_MATCH) return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) - elif is_function_or_wrapper(value) and inspect.getattr_static( - value, "_torchdynamo_inline", False - ): - self.install_guards(GuardBuilder.TYPE_MATCH) - return WrapperUserFunctionVariable( - value, "_torchdynamo_inline", source=self.source - ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 77662e1f4b595a..90c2d0665aebc2 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -152,6 +152,8 @@ def __init__(self, fn, is_constant=False, **kwargs) -> None: assert isinstance( fn, (types.FunctionType, torch.jit.ScriptFunction) ), f"expected FunctionType found {typestr(fn)} {fn}" + # unpack @torch._dynamo.optimize()(fn) wrapped function + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) self.fn: types.FunctionType = fn def as_python_constant(self): From 476c4040ae0c2c036100f32026b10c6500c986e4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 11 Sep 2024 15:53:53 +0000 Subject: [PATCH 0718/1018] Revert "[Dynamo] Support thread local setattr (#135443)" This reverts commit 160c228a4bd60ceffa62b045a6b0a6f9413835c5. Reverted https://github.com/pytorch/pytorch/pull/135443 on behalf of https://github.com/clee2000 due to something in this stack broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/444b52ff40cf4afce7bc3fdcf021a88eab3b954c), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/135443#issuecomment-2344042800)) --- test/dynamo/test_misc.py | 15 --------------- torch/_dynamo/variables/user_defined.py | 5 ++--- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 0c26eea44ec266..ab1b343044a26d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3380,21 +3380,6 @@ def fn4(x) -> None: self.assertTrue(same(obj41.y, obj42.y)) self.assertEqual(cnts.frame_count, 1) - def test_thread_local_setattr(self): - from threading import local - - loc = local() - - @torch.compile(fullgraph=True) - def fn(x, l): - l.x = x - return x + 1 - - x = torch.ones(2, 2) - fn(x, loc) - - self.assertTrue(loc.x is x) - def test_user_defined_class_name(self): class MyClassFoo: pass diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ad92277cf70ff4..d069d2e060a108 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -9,7 +9,6 @@ import itertools import random import sys -import threading import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING @@ -705,7 +704,7 @@ def call_method( if method is object.__init__: return ConstantVariable.create(None) - if is_standard_setattr(method) or isinstance(self.value, threading.local): + if is_standard_setattr(method): return self.method_setattr_standard(tx, *args, **kwargs) # [NOTE] OrderedDict, dict subtypes must always have source @@ -803,7 +802,7 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def needs_slow_setattr(self): return not is_standard_setattr( inspect.getattr_static(self.value, "__setattr__", None) - ) and not isinstance(self.value, threading.local) + ) def unpack_var_sequence(self, tx): if ( From 4c28aafd3542a85752749bdecaa7fb7b029af2ed Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 11 Sep 2024 15:57:00 +0000 Subject: [PATCH 0719/1018] Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)" This reverts commit 0d15122092c27fec1143b800bab7c996d126b547. Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/clee2000 due to something in this stack broke functorch/test_control_flow.py::TestControlFlow::test_scan_simple_graph [GH job link](https://github.com/pytorch/pytorch/actions/runs/10804912306/job/29980571390) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/444b52ff40cf4afce7bc3fdcf021a88eab3b954c), newly added test yesterday ([comment](https://github.com/pytorch/pytorch/pull/133137#issuecomment-2344054339)) --- test/dynamo/test_modes.py | 135 --------- test/functorch/test_control_flow.py | 3 - .../pt2e/test_metadata_porting.py | 4 +- torch/_dynamo/convert_frame.py | 5 - torch/_dynamo/guards.py | 14 - torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 78 ++++-- torch/_dynamo/trace_rules.py | 5 +- torch/_dynamo/utils.py | 10 +- torch/_dynamo/variables/ctx_manager.py | 2 - torch/_dynamo/variables/torch.py | 260 ++++++++---------- torch/_dynamo/variables/torch_function.py | 129 ++------- torch/_dynamo/variables/user_defined.py | 7 + torch/_export/non_strict_utils.py | 6 +- 14 files changed, 204 insertions(+), 456 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index bad5a788903440..ee8a9b579ff124 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -14,17 +14,6 @@ from torch.utils._python_dispatch import TorchDispatchMode -class TestMode(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - if func == torch.add: - return torch.zeros(2, 2) - - return super().__torch_function__(func, types, args, kwargs) - - class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -335,130 +324,6 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 2) - def test_nested_torch_function_mode(self): - mode_1_called = False - mode_2_called = False - - def reset_state(): - nonlocal mode_1_called - nonlocal mode_2_called - mode_1_called = False - mode_2_called = False - - ones = torch.ones(2, 2) - zeros = torch.zeros(2, 2) - - class TestMode1(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - nonlocal mode_1_called - - mode_1_called = True - - if func == torch.add: - return zeros - - return super().__torch_function__(func, types, args, kwargs) - - class TestMode2(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - nonlocal mode_2_called - - mode_2_called = True - - if func == torch.mul: - return ones - - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - def fn_2(x): - return torch.mul(x, 3) + torch.add(x, 3) - - inp = torch.ones(2, 2) + 1 - - for fn_i in [fn, fn_2]: - fn_opt = torch.compile(fn_i, fullgraph=True) - with TestMode1(), TestMode2(): - expected = fn_i(inp), mode_1_called, mode_2_called - reset_state() - actual = fn_opt(inp), mode_1_called, mode_2_called - reset_state() - - self.assertEqual(expected, actual) - - def test_torch_function_mode_disable(self): - class TestSubclass(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - if func == torch.add: - return torch.ones(2, 2) - return super().__torch_function__(func, types, args, kwargs) - - class TestMode(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - if func == torch.add: - return torch.zeros(2, 2) - - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) - - fn_opt = torch.compile(fn, fullgraph=True) - with TestMode(), torch._dynamo.config.patch( - "traceable_tensor_subclasses", {TestSubclass} - ): - with torch._C.DisableTorchFunctionSubclass(): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - - with torch._C.DisableTorchFunction(): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_highest_priority(self): - class TestSubclass(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - if func == torch.add: - return torch.ones(2, 2) - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) - - fn_opt = torch.compile(fn, fullgraph=True) - with TestMode(), torch._dynamo.config.patch( - "traceable_tensor_subclasses", {TestSubclass} - ): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index ad1f6721fc98bf..e07510828f4584 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -26,7 +26,6 @@ parametrize, requires_cuda, run_tests, - skipIfCrossRef, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, TestCase, @@ -2985,7 +2984,6 @@ def f(x, y): self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") - @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -3248,7 +3246,6 @@ def test_while_loop_compile(self, backend, while_loop_test): self._check_compile(fn, inp, backend=backend) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") - @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] from torch._dynamo.testing import EagerAndRecordGraphs diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index e8e00e36e58e10..5f2d1e2d3cf574 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef +from torch.testing._internal.common_utils import IS_WINDOWS class TestHelperModules: @@ -139,8 +139,6 @@ def _test_metadata_porting( self.assertEqual(v, node_tags[k]) return m - @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack - # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_simple_metadata_porting(self): """ Model under test diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 9dabb78211c3b5..d15f5c4aa43dcf 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -605,10 +605,6 @@ def _compile( output: Optional[OutputGraph] = None tracer: Optional[InstructionTranslator] = None - tf_mode_stack: List[ - torch.overrides.TorchFunctionMode - ] = torch.overrides._get_current_function_mode_stack() - @preserve_global_state def transform( instructions: List[Instruction], code_options: Dict[str, object] @@ -622,7 +618,6 @@ def transform( locals, globals, builtins, - tf_mode_stack, code_options, compiler_fn, one_graph, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4437ef4038190d..b0bd273e96bbb5 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -98,7 +98,6 @@ ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, - TorchFunctionModeStackSource, TupleIteratorGetItemSource, TypeSource, UnspecializedBuiltinNNModuleSource, @@ -112,7 +111,6 @@ dict_keys_repr, get_custom_getattr, get_torch_function_mode_stack, - get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, @@ -316,7 +314,6 @@ def uninteresting_files(): "___dict_contains": lambda a, b: a in b, "___tuple_iterator_len": tuple_iterator_len, "___tuple_iterator_getitem": tuple_iterator_getitem, - "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -904,15 +901,6 @@ def get_guard_manager_from_source(self, source): ): assert base_guard_manager # to make mypy happy out = base_guard_manager - elif istype(source, TorchFunctionModeStackSource): - out = root_guard_manager.lambda_manager( - python_lambda=lambda _: get_torch_function_mode_stack_at( - source._get_index() - ), - source=source_name, - example_value=example_value, - guard_manager_enum=guard_manager_enum, - ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -2226,8 +2214,6 @@ def __init__( self.output_graph = output_graph w_builder = None - # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing - # in case a set default device call was made in the graph. self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index c44f8d9e69abf5..6ed9d97f6a1210 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source): ind: int def name(self): - return f"___get_torch_function_mode_stack_at({self._get_index()})" + return "" def _get_index(self): from .variables.torch_function import TorchFunctionModeStackVariable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 415179a5ed32a6..ab92f82aa0f630 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -19,7 +19,20 @@ import types import typing import weakref -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union +from typing import ( + Any, + Callable, + cast, + Deque, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) from unittest.mock import patch import torch @@ -59,12 +72,14 @@ GlobalWeakRefSource, LocalSource, Source, + TorchFunctionModeStackSource, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( counters, get_fake_value, get_instruction_source_311, + get_torch_function_mode_stack, graph_break_dup_warning_checker, istype, LazyString, @@ -105,10 +120,11 @@ ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable -from .variables.torch_function import ( - SymbolicTorchFunctionState, - TorchFunctionModeVariable, -) + + +if TYPE_CHECKING: + from .variables.torch_function import TorchFunctionModeVariable + from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -268,10 +284,6 @@ def resume_fn(self): return ReenterWith(self.stack_index) def exit(self, tx): - if hasattr(self, "graph_break") and isinstance( - self.with_context, TorchFunctionModeVariable - ): - return assert self.with_context is not None return self.with_context.exit(tx) @@ -640,9 +652,7 @@ def handle_graph_break( # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: assert b.with_context is not None - assert isinstance( - b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) - ) + assert isinstance(b.with_context, ContextWrappingVariable) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -718,7 +728,7 @@ class InstructionTranslatorBase( output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] - symbolic_torch_function_state: SymbolicTorchFunctionState + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -2538,7 +2548,7 @@ def __init__( code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_state: SymbolicTorchFunctionState, + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], f_code: types.CodeType, export: bool, inline_depth: int, @@ -2553,7 +2563,7 @@ def __init__( self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals - self.symbolic_torch_function_state = symbolic_torch_function_state + self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack self.stack = [] # stack of variable names for tracking 3.13 closures self.name_stack: list[Any] = [] @@ -2642,7 +2652,6 @@ def __init__( f_locals, f_globals, f_builtins, - torch_function_mode_stack, code_options, compiler_fn, one_graph, @@ -2677,7 +2686,7 @@ def __init__( symbolic_locals={}, # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, - symbolic_torch_function_state=None, # type: ignore[arg-type] # set below + symbolic_torch_function_mode_stack=collections.deque(), f_code=f_code, export=export, inline_depth=0, @@ -2712,9 +2721,7 @@ def __init__( if k in f_locals } - self.symbolic_torch_function_state = SymbolicTorchFunctionState( - torch_function_mode_stack - ) + self._init_torch_function_mode_stack() self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] if export: @@ -2755,6 +2762,29 @@ def _throw_if_in_functorch(self): ) unimplemented(msg) + def _init_torch_function_mode_stack(self): + from .variables.torch_function import TorchFunctionModeStackVariable + + TorchFunctionModeStackVariable.reset() + + self.symbolic_torch_function_mode_stack: Deque[ + TorchFunctionModeVariable + ] = collections.deque() + # We want to retrieve all modes to properly reconstruct the stack if needed + py_stack = get_torch_function_mode_stack(filter_ignored=False) + + if py_stack: + has_device_context = isinstance( + py_stack[0], torch.utils._device.DeviceContext + ) + + for i, val in enumerate(py_stack): + self.symbolic_torch_function_mode_stack.append( + variables.LazyVariableTracker.create( + val, source=TorchFunctionModeStackSource(i) + ) + ) + def get_example_value(self, source: Source): if isinstance(source, LocalSource): return self.f_locals[source.local_name] @@ -3086,7 +3116,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_state, + parent.symbolic_torch_function_mode_stack, closure_cells, func, ) @@ -3096,7 +3126,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_state, + parent.symbolic_torch_function_mode_stack, closure_cells, func, ) @@ -3149,7 +3179,7 @@ def __init__( code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_state: SymbolicTorchFunctionState, + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ) -> None: @@ -3166,7 +3196,7 @@ def __init__( f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, - symbolic_torch_function_state=symbolic_torch_function_state, + symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 23a859e426f8df..cf0731911612ba 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3256,7 +3256,6 @@ def _module_dir(m: types.ModuleType): "torch.testing", "torch.utils._content_store", "torch.utils._contextlib", - "torch.utils._device", "torch.utils._foreach_utils", "torch.utils._python_dispatch", "torch.utils._pytree", @@ -3591,9 +3590,7 @@ def lookup_inner( if reasons is not None: reasons.add("func name is patched_init") return SkipFunctionVariable - elif name == "__torch_function__" or ( - obj and obj.__name__ == "__torch_function__" - ): + elif name == "__torch_function__": if reasons is not None: reasons.add("func name is __torch_function__") return UserFunctionVariable diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ef63ddbea85619..0ec548d788f868 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -63,6 +63,7 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( + _get_function_stack_at, _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, @@ -3064,9 +3065,7 @@ def is_parameter_freezing(): def get_torch_function_mode_stack(filter_ignored=True): from .variables.torch_function import IGNORED_MODES - stack = [ - get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) - ] + stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] if filter_ignored: stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] @@ -3086,11 +3085,6 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) -def clear_torch_function_mode_stack(): - for i in range(_len_torch_function_stack()): - _pop_torch_function_stack() - - def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 838f413a16e306..301b7f3e819345 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -637,8 +637,6 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 - tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] - tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] tx.output.set_torch_function_state(values[0]) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 288bf14aebf928..b32ac9af6ccfd9 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -154,15 +154,6 @@ bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) -@functools.lru_cache(None) -def get_overridable_functions(): - from itertools import chain - - from torch.overrides import get_overridable_functions as get_overridable_functions_ - - return set(chain(*get_overridable_functions_().values())) - - class BaseTorchVariable(VariableTracker): """common base for all torch.* functions, classes, modules and other things""" @@ -796,10 +787,10 @@ def handle_pop_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - if not tx.symbolic_torch_function_state.mode_stack: + if not tx.symbolic_torch_function_mode_stack: raise unimplemented("Popping from an empty torch function mode stack") TorchFunctionModeStackVariable.register_mutation(tx) - return tx.symbolic_torch_function_state.pop_torch_function_mode() + return tx.symbolic_torch_function_mode_stack.pop() @register(torch._C._push_on_torch_function_stack) def handle_push_torch_function( @@ -807,7 +798,7 @@ def handle_push_torch_function( ): assert len(args) == 1 and not kwargs TorchFunctionModeStackVariable.register_mutation(tx) - tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) + tx.symbolic_torch_function_mode_stack.append(args[0]) return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) @@ -815,9 +806,7 @@ def handle_len_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - return ConstantVariable.create( - len(tx.symbolic_torch_function_state.mode_stack) - ) + return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) @register(torch.set_default_device) def handle_set_default_device( @@ -849,9 +838,6 @@ def call_function( from . import ConstantVariable, SymNodeVariable, TensorVariable from .builder import wrap_fx_proxy - if self.torch_function_override_enabled(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -873,144 +859,147 @@ def call_function( if result: return result - any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + if can_dispatch_torch_function(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + else: + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) - all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args - ) - if ( - getattr(self.value, "__module__", "") == "torch" - and self.value.__name__ in bin_ops - and any_symints_or_symfloats - and all_ints_or_floats - ): - msg = f"""\ + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ - log.warning(msg) - unimplemented(msg) - - # TODO(voz): Replace w/ dynamic shape rewrite table. - # Ideally, we would be able to do this at ctor time, but alas we need a combination - # of value + args to determine this. - fn_ = self.value - if any_symints_or_symfloats: - torch_sym_op = f"_sym_{self.value.__name__}" - if getattr(self.value, "__module__", None) == "math" and hasattr( - torch, torch_sym_op - ): - fn_ = getattr(torch, torch_sym_op) - - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - *proxy_args_kwargs(args, kwargs), - ), - ) + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op + ): + fn_ = getattr(torch, torch_sym_op) + + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) - if ( - isinstance(tensor_variable, TensorVariable) - and "requires_grad" in kwargs - and kwargs["requires_grad"].as_python_constant() - ): - unimplemented( - """factory functions that return tensors that require grad are not supported. + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" - ) + ) - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items - ): + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): + if ( + out_tensor.source + and out_tensor in tx.output.graphargs + and isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor.size != result_tensor.size + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor.size != result_tensor.size + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out_shape != fake_tensor.shape ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] - if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape - ): - # It's hard to get out variants with resizing on graph inputs work - # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where output tensor was non-contiguous" - ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( - "out= op was called where some of the output tensors were non-contiguous" + "out= op was called where output tensor was non-contiguous" ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where some of the output tensors were non-contiguous" + ) + else: + unimplemented(f"out variant of {type(kwargs['out'])}") - return tensor_variable + return tensor_variable def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" @@ -1138,12 +1127,3 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad source ) return result - - def torch_function_override_enabled(self, tx, args, kwargs): - return ( - self.get_function() in get_overridable_functions() - or isinstance( - self.get_function(), - (torch._ops.OpOverload, torch._ops.OpOverloadPacket), - ) - ) and can_dispatch_torch_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 6e52cc688e0309..4b3188507fe4b5 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,11 +1,8 @@ # mypy: ignore-errors -import collections -import contextlib import inspect -from typing import Deque, Dict, List, TYPE_CHECKING +from typing import Dict, List, TYPE_CHECKING -import torch._C import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import _get_overloaded_args, get_default_nowrap_functions @@ -18,7 +15,6 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import ContextWrappingVariable -from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable @@ -63,67 +59,6 @@ IGNORED_MODES = {DeviceContext} -class SymbolicTorchFunctionState: - def __init__(self, py_stack): - # This is annoyingly complicated because of how the torch function subclass + mode C API was designed - # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass - # These are their definitions: - # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered - # (if either are entered, this will be False) - # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR - # torch._C.DisableTorchFunction has been entered - # To disambiguate these and keep myself sane I added a C API to check whether all torch function - # concepts (modes and subclasses) are enabled. - # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate - # the stack length from the enablement state of torch function modes. - # This is important because now if a mode is pushed while dynamo is tracing, we know whether - # or not torch function modes are enabled and whether we should trace it. - self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() - - # This differs from the C API of the same name - # this will only be false iff we have entered torch._C.DisableTorchFunction - # and does not take into account the mode stack length, while the C API bundles these - # two concepts - self.torch_function_mode_enabled = ( - not torch._C._is_torch_function_all_disabled() - ) - - self.cur_mode = None - - TorchFunctionModeStackVariable.reset() - - self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() - - for i, val in enumerate(py_stack): - self.mode_stack.append( - LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) - ) - - def in_torch_function_mode(self): - return len(self.mode_stack) > 0 - - def pop_torch_function_mode(self): - return self.mode_stack.pop() - - def push_torch_function_mode(self, mode_var): - self.mode_stack.append(mode_var) - - def call_torch_function_mode(self, tx, fn, types, args, kwargs): - with self._pop_mode_for_inlining() as cur_mode: - return cur_mode.call_torch_function(tx, fn, types, args, kwargs) - - @contextlib.contextmanager - def _pop_mode_for_inlining(self): - old_mode = self.cur_mode - self.cur_mode = self.pop_torch_function_mode() - try: - yield self.cur_mode - finally: - mode = self.cur_mode - self.cur_mode = old_mode - self.push_torch_function_mode(mode) - - class TorchFunctionModeStackVariable(VariableTracker): """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" @@ -153,20 +88,19 @@ def reset(cls): def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( - source=Source(), - symbolic_stack=tx.symbolic_torch_function_state.mode_stack, + source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_state.mode_stack + stack = tx.symbolic_torch_function_mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 - stack.insert( + tx.symbolic_torch_function_mode_stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) @@ -175,7 +109,7 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_state.mode_stack + stack = tx.symbolic_torch_function_mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @@ -190,39 +124,23 @@ def get_mode_index(cls, ind): class TorchFunctionModeVariable(ContextWrappingVariable): - def __init__(self, value, source=None, **kwargs): + def __init__(self, value, **kwargs): super().__init__(value, **kwargs) self.value = value - self.cm_obj = value # needed for BC with calling enter from CM code - self.source = source + + @staticmethod + def get_global_mangled_name(tx, val): + return get_safe_global_name( + tx, f"__torch_function_mode_{val.__class__.__name__}", val + ) def reconstruct(self, codegen): - # This shouldn't be called unless we have a source + # We don't support locally created torch function modes yet assert self.source self.source.reconstruct(codegen) - def module_name(self): - return self.value.__module__ - - def fn_name(self): - return type(self.value).__name__ - - def python_type(self): - return type(self.value) - - def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): - return call_torch_function( - tx, - self, - build_torch_function_fn(tx, self.value, self.source), - fn, - types, - args, - kwargs, - ) - - def _call_func(self, tx: "InstructionTranslator", values): - unimplemented("enter/exit for torch function mode NYI") + def _call_func(self, tx, values): + unimplemented("torch function mode context manager is not supported yet") def _get_all_args(args, kwargs): @@ -313,13 +231,9 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): - has_overridden_args = any( + return tx.output.torch_function_enabled and any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) - tf_state = tx.symbolic_torch_function_state - return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( - tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() - ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): @@ -331,20 +245,11 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): _get_subclass_type, ) - types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) - - if tx.symbolic_torch_function_state.in_torch_function_mode(): - res = tx.symbolic_torch_function_state.call_torch_function_mode( - tx, fn, types, args, kwargs - ) - if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): - return res - for arg in overloaded_args: res = arg.call_torch_function( tx, fn, - types, + TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), args, kwargs, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index d069d2e060a108..34e1d0d10c9f72 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,6 +82,11 @@ def is_forbidden_context_manager(ctx): from _pytest.python_api import RaisesContext from _pytest.recwarn import WarningsChecker + # TODO mlazos: Temporary to get this stack to pass + # remove in subsequent PR + from torch.overrides import BaseTorchFunctionMode + + f_ctxs.append(BaseTorchFunctionMode) f_ctxs.append(RaisesContext) f_ctxs.append(WarningsChecker) except ImportError: @@ -408,6 +413,7 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): + # import here to avoid an unfortunate circular dependency. from .ctx_manager import GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( @@ -415,6 +421,7 @@ def call_function( ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj + elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index ef15e5fea9e973..5c7331659d26a0 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -506,11 +506,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if ( - not torch.compiler.is_dynamo_compiling() - and log.isEnabledFor(logging.DEBUG) - and config.extended_debug_current_loc - ): + if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: frame = _find_user_code_frame() if frame is not None: log.debug( From 1239af290652131278c02e5f36313fea52fc677d Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 11 Sep 2024 13:20:58 +0000 Subject: [PATCH 0720/1018] Improve performance of canonicalize_bool_expr (#135621) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135621 Approved by: https://github.com/ezyang --- torch/fx/experimental/symbolic_shapes.py | 37 +++++++++++++++++------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 21e72542d34cda..78f0dbe833e14c 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -404,7 +404,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: t = type(expr) def is_neg(t): - return t.is_negative or (isinstance(t, sympy.Mul) and t.args[0].is_negative) + return (t.is_Number and t.is_negative) or (isinstance(t, sympy.Mul) and t.args[0].is_Number and t.args[0].is_negative) lhs = 0 rhs = _reduce_to_lowest_terms(rhs) @@ -416,12 +416,14 @@ def is_neg(t): neg.append(-term) else: pos.append(term) - lhs = sympy.Add(*neg) - rhs = sympy.Add(*pos) + lhs = sympy.Add._from_args(neg) + rhs = sympy.Add._from_args(pos) elif is_neg(rhs): # lhs == 0 lhs, rhs = -rhs, 0 - return t(lhs, rhs) + # We don't have to evaluate here because lhs, rhs came from a Boolean + # and it was already simplified + return t(lhs, rhs, evaluate=False) def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: @@ -432,20 +434,33 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: Useful when an expression is == or != to 0. """ def integer_coefficient(x): - if isinstance(x, sympy.Integer): + if x.is_Integer: return abs(int(x)) - elif isinstance(x, sympy.Mul): - return math.prod([abs(int(arg)) for arg in x.args if isinstance(arg, sympy.Integer)]) + elif x.is_Mul: + # If one of the args of a Mul is an Integer, it is the + # first arg. eg: args(2*x*3*y) == (6, x, y) + return abs(int(x.args[0])) if x.args[0].is_Integer else 1 else: return 1 - if isinstance(expr, sympy.Add): + def div_by_factor(x, factor): + if x.is_Integer: + return x / factor + elif x.is_Mul: + return sympy.Mul._from_args((x.args[0] / factor, *x.args[1:]), is_commutative=x.is_commutative) + + if expr.is_Add: atoms = expr.args factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) + if factor == 1: + return expr atoms = [x / factor for x in atoms] - return sympy.Add(*atoms) - else: - return expr / integer_coefficient(expr) + return sympy.Add._from_args(atoms) + elif expr.is_Integer: + return sympy.One + elif expr.is_Mul: + return div_by_factor(expr, integer_coefficient(expr)) + return expr def is_concrete_bool(a: Union[bool, SymBool]) -> bool: From 96099cc3dc73b061bdb9abb5cb6f899da34f3df9 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 11 Sep 2024 13:20:58 +0000 Subject: [PATCH 0721/1018] Improve performance of sympy_generic_le (#135622) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135622 Approved by: https://github.com/ezyang ghstack dependencies: #135621 --- torch/utils/_sympy/value_ranges.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 9080800979e67b..2dbad5241d039d 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -87,7 +87,9 @@ def simple_sympify(e): def sympy_generic_le(lower, upper): if isinstance(lower, sympy.Expr): assert isinstance(upper, sympy.Expr) - return lower <= upper + # instead of lower <= upper, we do upper >= lower since upper is mostly int_oo + # and we have better code paths there. + return upper >= lower else: # only negative condition is True > False assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), ( From a10836c53c43ebcff8b1dcdc04c95c4b9f5fd984 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 11 Sep 2024 13:20:58 +0000 Subject: [PATCH 0722/1018] Optimize ShapeEnv.replace (#135652) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135652 Approved by: https://github.com/ezyang ghstack dependencies: #135621, #135622 --- torch/fx/experimental/symbolic_shapes.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 78f0dbe833e14c..f8ca28ab0d7b9f 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4584,7 +4584,14 @@ def _maybe_evaluate_static( def replace(self, expr: "sympy.Expr") -> "sympy.Expr": """Apply symbol replacements to any symbols in the given expression """ - replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols} + replacements = {} + for s in expr.free_symbols: + r = self._find(cast(sympy.Symbol, s)) + # Micro-optimization: only do replacements if r and s are different + # Otherwise, xreplace is not a no-op and will trigger expensive + # assumption queries if expr has a relational node. + if not r.is_Symbol or r != s: + replacements[s] = r return safe_expand(expr.xreplace(replacements)) @_lru_cache From 33cf962075ac2e614fd50661ab4bce503cb8ee0d Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Wed, 11 Sep 2024 17:02:17 +0000 Subject: [PATCH 0723/1018] [CI] Fix update slow tests (#135390) * Add pytorchbot to list of approvers for file * Add labels to the auto created PR The auto generated PR is currently not merging due to some failing tests on slow workflow that were supposed to be moved back to normal idk if this has much value, clearly we've been managing without the update Pull Request resolved: https://github.com/pytorch/pytorch/pull/135390 Approved by: https://github.com/ZainRizvi --- .github/merge_rules.yaml | 12 ++++++++++++ tools/testing/update_slow_tests.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index bf80880e604144..350010708be630 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -86,6 +86,18 @@ - pull - inductor +- name: OSS CI / pytorchbot / slow tests + patterns: + - test/slow_tests.json + approved_by: + - pytorchbot + ignore_flaky_failures: false + mandatory_checks_name: + - EasyCLA + - Lint + - pull + - slow + - name: OSS CI /pytorchbot / Executorch patterns: - .ci/docker/ci_commit_pins/executorch.txt diff --git a/tools/testing/update_slow_tests.py b/tools/testing/update_slow_tests.py index e88e85f5c5a090..d10daf6a8386b6 100644 --- a/tools/testing/update_slow_tests.py +++ b/tools/testing/update_slow_tests.py @@ -3,7 +3,7 @@ import subprocess import time from pathlib import Path -from typing import Any, cast, Dict, Optional, Tuple +from typing import Any, cast, Dict, List, Optional, Tuple import requests import rockset # type: ignore[import] @@ -93,7 +93,7 @@ def git_api( - url: str, params: Dict[str, str], type: str = "get", token: str = UPDATEBOT_TOKEN + url: str, params: Dict[str, Any], type: str = "get", token: str = UPDATEBOT_TOKEN ) -> Any: headers = { "Accept": "application/vnd.github.v3+json", @@ -147,6 +147,15 @@ def make_comment(source_repo: str, pr_number: int, msg: str) -> None: ) +def add_labels(source_repo: str, pr_number: int, labels: List[str]) -> None: + params = {"labels": labels} + git_api( + f"/repos/{source_repo}/issues/{pr_number}/labels", + params, + type="post", + ) + + def search_for_open_pr( source_repo: str, search_string: str ) -> Optional[Tuple[int, str]]: @@ -204,5 +213,6 @@ def search_for_open_pr( # no existing pr, so make a new one and approve it pr_num = make_pr("pytorch/pytorch", params) time.sleep(5) + add_labels("pytorch/pytorch", pr_num, ["ciflow/slow", "ci-no-td"]) approve_pr("pytorch/pytorch", pr_num) make_comment("pytorch/pytorch", pr_num, "@pytorchbot merge") From d9345aea7981373b71c9d0d520878371fc1aa926 Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Wed, 11 Sep 2024 17:21:05 +0000 Subject: [PATCH 0724/1018] [CI] [ROCm] Run rocm workflow on every push to main branch (#135644) Dial the frequency back up from https://github.com/pytorch/pytorch/pull/131637 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135644 Approved by: https://github.com/huydhn --- .github/workflows/rocm.yml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm.yml index 7486162e620a38..781f2878b6eae6 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm.yml @@ -3,18 +3,12 @@ name: rocm on: push: branches: -# - main + - main - release/* tags: - ciflow/rocm/* workflow_dispatch: schedule: - # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs. - # Also run less frequently on weekends. - - cron: 45 0,8,16 * * 1-5 - - cron: 45 4 * * 0,6 - - cron: 45 4,12,20 * * 1-5 - - cron: 45 12 * * 0,6 - cron: 29 8 * * * # about 1:29am PDT concurrency: From e25ea54d81f4e3b218d1308a5d1da6ca52afae70 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Wed, 11 Sep 2024 17:21:39 +0000 Subject: [PATCH 0725/1018] [ROCm] [BUGFIX] Re-enable rocm-specific tuning parameters v2 (#133852) Small bug fix - https://github.com/pytorch/pytorch/pull/124592 replaced the torch.version.hip with device_props but made a mistake in porting the original logic. The original code was: `if torch.version.hip is not None:` Which was incorrectly replaced by: `if self.device_props.type != "hip":` Another occurence of https://github.com/pytorch/pytorch/pull/130617 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133852 Approved by: https://github.com/masnesral, https://github.com/malfet --- torch/_inductor/runtime/triton_heuristics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 0c499956ff7406..bdc6a1e10bcaeb 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -408,7 +408,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): "num_stages": compile_meta["num_stages"], "debug": compile_meta["debug"], } - if self.device_props.type != "hip": + if self.device_props.type == "hip": if "waves_per_eu" in compile_meta: options["waves_per_eu"] = compile_meta["waves_per_eu"] if "matrix_instr_nonkdim" in compile_meta: From b6f08fe8c1494010cbc4332946e5fd0d82fdff7d Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Tue, 10 Sep 2024 10:49:54 -0700 Subject: [PATCH 0726/1018] Don't use exception chaining for BackendCompilerFailed (#135545) Commandeered from https://github.com/pytorch/pytorch/pull/135496 as I'm now helping @ezyang ship dynamic float arguments in PT2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135545 Approved by: https://github.com/ezyang --- test/dynamo/test_logging.py | 9 --------- torch/_dynamo/output_graph.py | 4 +++- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index b71f8b7fce95bf..fe64ac745545f1 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -189,15 +189,6 @@ def throw(x): Traceback (most recent call last): File "test_logging.py", line N, in throw raise AssertionError -torch._inductor.exc.LoweringException: AssertionError: - target: aten.round.default - args[0]: TensorBox(StorageBox( - InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1])) - )) - -The above exception was the direct cause of the following exception: - -Traceback (most recent call last): torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: LoweringException: AssertionError: target: aten.round.default diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 5684de69601565..ba8cfa5a03dfba 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1462,7 +1462,9 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: # aborting execution. raise e except Exception as e: - raise BackendCompilerFailed(self.compiler_fn, e) from e + raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( + e.__traceback__ + ) from None signpost_event( "dynamo", From ee36dc911daa71dc82df8232a2bf485a0d9765fb Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Wed, 11 Sep 2024 12:31:55 -0500 Subject: [PATCH 0727/1018] Support rolling over a percentage of workflows (#134816) In order to support adding a rollover percentage, this ended up being a complete rewrite of runner_determinator.py. Details of the new format are in the comments up top. On the plus side, this now includes some unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134816 Approved by: https://github.com/PaliC, https://github.com/zxiiro --- .github/scripts/runner_determinator.py | 340 +++++++++++++------- .github/scripts/test_runner_determinator.py | 237 ++++++++++++++ .github/workflows/_runner-determinator.yml | 340 +++++++++++++------- 3 files changed, 701 insertions(+), 216 deletions(-) create mode 100644 .github/scripts/test_runner_determinator.py diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index e52b19c64b4e3c..641e438b784510 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -3,49 +3,94 @@ """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default -https://github.com/pytorch/test-infra/issues/5132) as a user list to determine -which users will get their jobs to run on experimental runners. This user list -is also a comma separated list of additional features or experiments which the -user could be opted in to. +https://github.com/pytorch/test-infra/issues/5132) to define the configuration +of which runners should be used to run which job. + +The configuration has two parts, the settings and a list of opted-in users, +separated by a line containing "---". If the line is not present, the +settings are considered to be empty with only the second part, the user +list, defined. + +The first part is a YAML block that defines the rollout settings. This can be +used to define any settings that are needed to determine which runners to use. +It's fields are defined by the RolloutSettings class below. + +The second part is a list of users who are explicitly opted in to the LF fleet. +The user list is also a comma separated list of additional features or +experiments which the user could be opted in to. The user list has the following rules: -- Users are GitHub usernames with the @ prefix -- If the first line is a "*" then all users will use the new runners -- If the first line is a "!" then all users will use the old runners +- Users are GitHub usernames, which must start with the @ prefix - Each user is also a comma-separated list of features/experiments to enable -- A "#" prefix indicates the user is opted out of the new runners but is opting - into features/experiments. +- A "#" prefix opts the user out of all experiments + +Example config: + # A list of experiments that can be opted into. + # This defines the behavior they'll induce when opted into. + # Expected syntax is: + # [experiment_name]: # Name of the experiment. Also used for the label prefix. + # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. + + experiments: + lf: + rollout_percent: 25 -Example user list: + --- - @User1 - @User2,amz2023 - #@UserOptOutOfNewRunner,amz2023 + # Opt-ins: + # Users can opt into the LF fleet by adding their GitHub username to this list + # and specifying experiments to enable in a comma-separated list. + # Experiments should be from the above list. + + @User1,lf,split_build + @User2,lf + @User3,split_build """ import logging import os +import random from argparse import ArgumentParser from logging import LogRecord -from typing import Any, Iterable +from typing import Any, Dict, Iterable, List, NamedTuple, Tuple +import yaml from github import Auth, Github from github.Issue import Issue -WORKFLOW_LABEL_META = "" # use meta runners +DEFAULT_LABEL_PREFIX = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation -RUNNER_AMI_LEGACY = "" -RUNNER_AMI_AMZ2023 = "amz2023" - GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" +SETTING_EXPERIMENTS = "experiments" + +LF_FLEET_EXPERIMENT = "lf" +CANARY_FLEET_SUFFIX = ".c" + + +class Experiment(NamedTuple): + rollout_perc: float = ( + 0 # Percentage of workflows to experiment on when user is not opted-in. + ) + + # Add more fields as needed + + +class Settings(NamedTuple): + """ + Settings for the experiments that can be opted into. + """ + + experiments: Dict[str, Experiment] = {} + + class ColorFormatter(logging.Formatter): """Color codes the log messages based on the log level""" @@ -172,85 +217,180 @@ def is_exception_branch(branch: str) -> bool: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} -def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str: +def load_yaml(yaml_text: str) -> Any: + try: + data = yaml.safe_load(yaml_text) + return data + except yaml.YAMLError as exc: + log.exception("Error loading YAML") + raise + + +def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: """ - Determines if the job should run on the LF fleet or the Meta fleet + Extracts the text with settings, if any, and the opted in users from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. - Returns: - The appropriate label prefix for the runner, corresponding to the fleet to use. - This gets prefixed to the very start of the runner label. + If it doesn't contain "---" then the settings are empty and the rest is the users. """ + rollout_state_parts = rollout_state.split("---") + if len(rollout_state_parts) >= 2: + return rollout_state_parts[0], rollout_state_parts[1] + else: + return "", rollout_state - try: - if rollout_state[0] == "!": - log.info("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META - elif rollout_state[0] == "*": - log.info("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF - else: - all_opted_in_users = { - usr_raw.strip("\n\t@ ").split(",")[0] - for usr_raw in rollout_state.split() - } - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - if opted_in_requestors: - log.info( - f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners." - ) - return WORKFLOW_LABEL_LF - else: - log.info( - f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners." - ) - return WORKFLOW_LABEL_META - except Exception as e: - log.error( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META +class UserOptins(Dict[str, List[str]]): + """ + Dictionary of users with a list of features they have opted into + """ -def get_optin_feature( - rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str -) -> str: +def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: """ - Used to dynamically opt in jobs to specific runner-type variants. + Parse the user opt-in text into a key value pair of username and the list of features they have opted into + + Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. + - Example line: "@User1,lf,split_build" + - A "#" prefix indicates the user is opted out of all experiments + - Returns: - The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string. - This variant name is prefixed to the runner-type in the label. + """ + optins = UserOptins() + for user in user_optin_text.split("\n"): + user = user.strip("\r\n\t -") + if not user or not user.startswith("@"): + # Not a valid user. Skip + continue + + if user: + usr_name = user.split(",")[0].strip("@") + optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] + + return optins + + +def parse_settings_from_text(settings_text: str) -> Settings: + """ + Parse the experiments from the issue body into a list of ExperimentSettings """ try: - userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()} - all_opted_in_users = set() - for user in userlist: - for i in user.split(","): - if i == feature: - all_opted_in_users.add(user.split(",")[0]) - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - - if opted_in_requestors: - log.info( - f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}." - ) - return feature - else: + if settings_text: + # Escape the backtick as well so that we can have the settings in a code block on the GH issue + # for easy reading + # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on + # the backtick character in shell commands. + backtick = chr(96) # backtick character + settings_text = settings_text.strip(f"\r\n\t{backtick} ") + settings = load_yaml(settings_text) + + # For now we just load experiments. We can expand this if/when we add more settings + experiments = {} + + for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): + valid_settings = {} + for setting in exp_settings: + if setting not in Experiment._fields: + log.warning( + f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" + ) + else: + valid_settings[setting] = exp_settings[setting] + + experiments[exp_name] = Experiment(**valid_settings) + return Settings(experiments) + + except Exception: + log.exception("Failed to parse settings") + + return Settings() + + +def parse_settings(rollout_state: str) -> Settings: + """ + Parse settings, if any, from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. + + If it doesn't contain "---" then the settings are empty and the default values are used. + """ + settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) + return parse_settings_from_text(settings_text) + + +def parse_users(rollout_state: str) -> UserOptins: + """ + Parse users from the rollout state. + + """ + _, users_text = extract_settings_user_opt_in_from_text(rollout_state) + return parse_user_opt_in_from_text(users_text) + + +def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: + """ + Check if a user is opted into an experiment + """ + return experiment_name in user_optins.get(user, []) + + +def get_runner_prefix( + rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False +) -> str: + settings = parse_settings(rollout_state) + user_optins = parse_users(rollout_state) + + fleet_prefix = "" + prefixes = [] + for experiment_name, experiment_settings in settings.experiments.items(): + enabled = False + + # Is any workflow_requestor opted in to this experiment? + opted_in_users = [ + requestor + for requestor in workflow_requestors + if is_user_opted_in(requestor, user_optins, experiment_name) + ] + + if opted_in_users: log.info( - f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"." + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." ) - return fallback + enabled = True + elif experiment_settings.rollout_perc: + # If no user is opted in, then we randomly enable the experiment based on the rollout percentage + if random.uniform(0, 100) <= experiment_settings.rollout_perc: + log.info( + f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." + ) + enabled = True + + if enabled: + label = experiment_name + if experiment_name == LF_FLEET_EXPERIMENT: + # We give some special treatment to the "lf" experiment since determines the fleet we use + # - If it's enabled, then we always list it's prefix first + # - If we're in the canary branch, then we append ".c" to the lf prefix + if is_canary: + label += CANARY_FLEET_SUFFIX + fleet_prefix = label + else: + prefixes.append(label) - except Exception as e: + if len(prefixes) > 1: log.error( - f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}' + f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" ) - return fallback + prefixes = prefixes[:1] + + # Fleet always comes first + if fleet_prefix: + prefixes.insert(0, fleet_prefix) + + return ".".join(prefixes) + "." if prefixes else "" def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -268,9 +408,10 @@ def main() -> None: args = parse_args() if args.github_ref_type == "branch" and is_exception_branch(args.github_branch): - log.info(f"Exception branch: '{args.github_branch}', using meta runners") - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY + log.info( + f"Exception branch: '{args.github_branch}', using Meta runners and no experiments." + ) + runner_label_prefix = DEFAULT_LABEL_PREFIX else: try: rollout_state = get_rollout_state_from_issue( @@ -285,35 +426,18 @@ def main() -> None: args.github_branch, ) - label_type = get_fleet( - rollout_state, - ( - args.github_issue_owner, - username, - ), - ) - runner_ami = get_optin_feature( - rollout_state=rollout_state, - workflow_requestors=( - args.github_issue_owner, - username, - ), - feature=RUNNER_AMI_AMZ2023, - fallback=RUNNER_AMI_LEGACY, + is_canary = args.github_repo == "pytorch/pytorch-canary" + + runner_label_prefix = get_runner_prefix( + rollout_state, (args.github_issue_owner, username), is_canary ) + except Exception as e: log.error( - f"Failed to get issue. Falling back to meta runners. Exception: {e}" + f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" ) - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY - - # For Canary builds use canary runners - if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF: - label_type = WORKFLOW_LABEL_LF_CANARY - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type) - set_github_output(GH_OUTPUT_KEY_AMI, runner_ami) + set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) if __name__ == "__main__": diff --git a/.github/scripts/test_runner_determinator.py b/.github/scripts/test_runner_determinator.py new file mode 100644 index 00000000000000..b20b8a68cbe084 --- /dev/null +++ b/.github/scripts/test_runner_determinator.py @@ -0,0 +1,237 @@ +from unittest import main, TestCase +from unittest.mock import Mock, patch + +import runner_determinator as rd + + +class TestRunnerDeterminatorIssueParser(TestCase): + def test_parse_settings(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + settings = rd.parse_settings(settings_text) + + self.assertTupleEqual( + rd.Experiment(rollout_perc=25), + settings.experiments["lf"], + "lf settings not parsed correctly", + ) + self.assertTupleEqual( + rd.Experiment(rollout_perc=0), + settings.experiments["otherExp"], + "otherExp settings not parsed correctly", + ) + + def test_parse_settings_in_code_block(self) -> None: + settings_text = """ + + ``` + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 0 + + ``` + + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + settings = rd.parse_settings(settings_text) + + self.assertTupleEqual( + rd.Experiment(rollout_perc=25), + settings.experiments["lf"], + "lf settings not parsed correctly", + ) + self.assertTupleEqual( + rd.Experiment(rollout_perc=0), + settings.experiments["otherExp"], + "otherExp settings not parsed correctly", + ) + + def test_parse_users(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + users = rd.parse_users(settings_text) + self.assertDictEqual( + {"User1": ["lf"], "User2": ["lf", "otherExp"]}, + users, + "Users not parsed correctly", + ) + + def test_parse_users_without_settings(self) -> None: + settings_text = """ + + @User1,lf + @User2,lf,otherExp + + """ + + users = rd.parse_users(settings_text) + self.assertDictEqual( + {"User1": ["lf"], "User2": ["lf", "otherExp"]}, + users, + "Users not parsed correctly", + ) + + +class TestRunnerDeterminatorGetRunnerPrefix(TestCase): + def test_opted_in_user(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("lf.", prefix, "Runner prefix not correct for User1") + + def test_opted_in_user_two_experiments(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User2"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2") + + @patch("random.uniform", return_value=50) + def test_opted_out_user(self, mock_uniform: Mock) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User3"]) + self.assertEqual("", prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=10) + def test_opted_out_user_was_pulled_in_by_rollout(self, mock_uniform: Mock) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + # User3 is opted out, but is pulled into both experiments by the 10% rollout + prefix = rd.get_runner_prefix(settings_text, ["User3"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + def test_lf_prefix_always_comes_first(self) -> None: + settings_text = """ + experiments: + otherExp: + rollout_perc: 0 + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,otherExp,lf + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User2"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + def test_ignores_commented_users(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + #@User1,lf + @User2,lf,otherExp + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("", prefix, "Runner prefix not correct for user") + + def test_ignores_extra_experiments(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + foo: + rollout_perc: 0 + --- + + Users: + @User1,lf,otherExp,foo + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 8ba709358c5ea1..862ceceec181fb 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -62,49 +62,94 @@ jobs: """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default - https://github.com/pytorch/test-infra/issues/5132) as a user list to determine - which users will get their jobs to run on experimental runners. This user list - is also a comma separated list of additional features or experiments which the - user could be opted in to. + https://github.com/pytorch/test-infra/issues/5132) to define the configuration + of which runners should be used to run which job. + + The configuration has two parts, the settings and a list of opted-in users, + separated by a line containing "---". If the line is not present, the + settings are considered to be empty with only the second part, the user + list, defined. + + The first part is a YAML block that defines the rollout settings. This can be + used to define any settings that are needed to determine which runners to use. + It's fields are defined by the RolloutSettings class below. + + The second part is a list of users who are explicitly opted in to the LF fleet. + The user list is also a comma separated list of additional features or + experiments which the user could be opted in to. The user list has the following rules: - - Users are GitHub usernames with the @ prefix - - If the first line is a "*" then all users will use the new runners - - If the first line is a "!" then all users will use the old runners + - Users are GitHub usernames, which must start with the @ prefix - Each user is also a comma-separated list of features/experiments to enable - - A "#" prefix indicates the user is opted out of the new runners but is opting - into features/experiments. + - A "#" prefix opts the user out of all experiments + + Example config: + # A list of experiments that can be opted into. + # This defines the behavior they'll induce when opted into. + # Expected syntax is: + # [experiment_name]: # Name of the experiment. Also used for the label prefix. + # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. + + experiments: + lf: + rollout_percent: 25 - Example user list: + --- - @User1 - @User2,amz2023 - #@UserOptOutOfNewRunner,amz2023 + # Opt-ins: + # Users can opt into the LF fleet by adding their GitHub username to this list + # and specifying experiments to enable in a comma-separated list. + # Experiments should be from the above list. + + @User1,lf,split_build + @User2,lf + @User3,split_build """ import logging import os + import random from argparse import ArgumentParser from logging import LogRecord - from typing import Any, Iterable + from typing import Any, Dict, Iterable, List, NamedTuple, Tuple + import yaml from github import Auth, Github from github.Issue import Issue - WORKFLOW_LABEL_META = "" # use meta runners + DEFAULT_LABEL_PREFIX = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation - RUNNER_AMI_LEGACY = "" - RUNNER_AMI_AMZ2023 = "amz2023" - GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" + SETTING_EXPERIMENTS = "experiments" + + LF_FLEET_EXPERIMENT = "lf" + CANARY_FLEET_SUFFIX = ".c" + + + class Experiment(NamedTuple): + rollout_perc: float = ( + 0 # Percentage of workflows to experiment on when user is not opted-in. + ) + + # Add more fields as needed + + + class Settings(NamedTuple): + """ + Settings for the experiments that can be opted into. + """ + + experiments: Dict[str, Experiment] = {} + + class ColorFormatter(logging.Formatter): """Color codes the log messages based on the log level""" @@ -231,85 +276,180 @@ jobs: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} - def get_fleet(rollout_state: str, workflow_requestors: Iterable[str]) -> str: + def load_yaml(yaml_text: str) -> Any: + try: + data = yaml.safe_load(yaml_text) + return data + except yaml.YAMLError as exc: + log.exception("Error loading YAML") + raise + + + def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: """ - Determines if the job should run on the LF fleet or the Meta fleet + Extracts the text with settings, if any, and the opted in users from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. - Returns: - The appropriate label prefix for the runner, corresponding to the fleet to use. - This gets prefixed to the very start of the runner label. + If it doesn't contain "---" then the settings are empty and the rest is the users. """ + rollout_state_parts = rollout_state.split("---") + if len(rollout_state_parts) >= 2: + return rollout_state_parts[0], rollout_state_parts[1] + else: + return "", rollout_state - try: - if rollout_state[0] == "!": - log.info("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META - elif rollout_state[0] == "*": - log.info("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF - else: - all_opted_in_users = { - usr_raw.strip("\n\t@ ").split(",")[0] - for usr_raw in rollout_state.split() - } - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - if opted_in_requestors: - log.info( - f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners." - ) - return WORKFLOW_LABEL_LF - else: - log.info( - f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners." - ) - return WORKFLOW_LABEL_META - except Exception as e: - log.error( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META + class UserOptins(Dict[str, List[str]]): + """ + Dictionary of users with a list of features they have opted into + """ - def get_optin_feature( - rollout_state: str, workflow_requestors: Iterable[str], feature: str, fallback: str - ) -> str: + def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: """ - Used to dynamically opt in jobs to specific runner-type variants. + Parse the user opt-in text into a key value pair of username and the list of features they have opted into + + Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. + - Example line: "@User1,lf,split_build" + - A "#" prefix indicates the user is opted out of all experiments + - Returns: - The runner-type's variant name if the user has opted in to the feature, otherwise returns an empty string. - This variant name is prefixed to the runner-type in the label. + """ + optins = UserOptins() + for user in user_optin_text.split("\n"): + user = user.strip("\r\n\t -") + if not user or not user.startswith("@"): + # Not a valid user. Skip + continue + + if user: + usr_name = user.split(",")[0].strip("@") + optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] + + return optins + + + def parse_settings_from_text(settings_text: str) -> Settings: + """ + Parse the experiments from the issue body into a list of ExperimentSettings """ try: - userlist = {u.lstrip("#").strip("\n\t@ ") for u in rollout_state.split()} - all_opted_in_users = set() - for user in userlist: - for i in user.split(","): - if i == feature: - all_opted_in_users.add(user.split(",")[0]) - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - - if opted_in_requestors: - log.info( - f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}." - ) - return feature - else: + if settings_text: + # Escape the backtick as well so that we can have the settings in a code block on the GH issue + # for easy reading + # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on + # the backtick character in shell commands. + backtick = chr(96) # backtick character + settings_text = settings_text.strip(f"\r\n\t{backtick} ") + settings = load_yaml(settings_text) + + # For now we just load experiments. We can expand this if/when we add more settings + experiments = {} + + for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): + valid_settings = {} + for setting in exp_settings: + if setting not in Experiment._fields: + log.warning( + f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" + ) + else: + valid_settings[setting] = exp_settings[setting] + + experiments[exp_name] = Experiment(**valid_settings) + return Settings(experiments) + + except Exception: + log.exception("Failed to parse settings") + + return Settings() + + + def parse_settings(rollout_state: str) -> Settings: + """ + Parse settings, if any, from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. + + If it doesn't contain "---" then the settings are empty and the default values are used. + """ + settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) + return parse_settings_from_text(settings_text) + + + def parse_users(rollout_state: str) -> UserOptins: + """ + Parse users from the rollout state. + + """ + _, users_text = extract_settings_user_opt_in_from_text(rollout_state) + return parse_user_opt_in_from_text(users_text) + + + def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: + """ + Check if a user is opted into an experiment + """ + return experiment_name in user_optins.get(user, []) + + + def get_runner_prefix( + rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False + ) -> str: + settings = parse_settings(rollout_state) + user_optins = parse_users(rollout_state) + + fleet_prefix = "" + prefixes = [] + for experiment_name, experiment_settings in settings.experiments.items(): + enabled = False + + # Is any workflow_requestor opted in to this experiment? + opted_in_users = [ + requestor + for requestor in workflow_requestors + if is_user_opted_in(requestor, user_optins, experiment_name) + ] + + if opted_in_users: log.info( - f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"." + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." ) - return fallback + enabled = True + elif experiment_settings.rollout_perc: + # If no user is opted in, then we randomly enable the experiment based on the rollout percentage + if random.uniform(0, 100) <= experiment_settings.rollout_perc: + log.info( + f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." + ) + enabled = True + + if enabled: + label = experiment_name + if experiment_name == LF_FLEET_EXPERIMENT: + # We give some special treatment to the "lf" experiment since determines the fleet we use + # - If it's enabled, then we always list it's prefix first + # - If we're in the canary branch, then we append ".c" to the lf prefix + if is_canary: + label += CANARY_FLEET_SUFFIX + fleet_prefix = label + else: + prefixes.append(label) - except Exception as e: + if len(prefixes) > 1: log.error( - f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}' + f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" ) - return fallback + prefixes = prefixes[:1] + + # Fleet always comes first + if fleet_prefix: + prefixes.insert(0, fleet_prefix) + + return ".".join(prefixes) + "." if prefixes else "" def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -327,9 +467,10 @@ jobs: args = parse_args() if args.github_ref_type == "branch" and is_exception_branch(args.github_branch): - log.info(f"Exception branch: '{args.github_branch}', using meta runners") - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY + log.info( + f"Exception branch: '{args.github_branch}', using Meta runners and no experiments." + ) + runner_label_prefix = DEFAULT_LABEL_PREFIX else: try: rollout_state = get_rollout_state_from_issue( @@ -344,35 +485,18 @@ jobs: args.github_branch, ) - label_type = get_fleet( - rollout_state, - ( - args.github_issue_owner, - username, - ), - ) - runner_ami = get_optin_feature( - rollout_state=rollout_state, - workflow_requestors=( - args.github_issue_owner, - username, - ), - feature=RUNNER_AMI_AMZ2023, - fallback=RUNNER_AMI_LEGACY, + is_canary = args.github_repo == "pytorch/pytorch-canary" + + runner_label_prefix = get_runner_prefix( + rollout_state, (args.github_issue_owner, username), is_canary ) + except Exception as e: log.error( - f"Failed to get issue. Falling back to meta runners. Exception: {e}" + f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" ) - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY - - # For Canary builds use canary runners - if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF: - label_type = WORKFLOW_LABEL_LF_CANARY - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type) - set_github_output(GH_OUTPUT_KEY_AMI, runner_ami) + set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) if __name__ == "__main__": From 306987d62f719be08f57c06fd7cc128e4c2196cf Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 11 Sep 2024 07:54:34 -0700 Subject: [PATCH 0728/1018] Log full exception trace when error raised in Dynamo (#135697) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135697 Approved by: https://github.com/Skylion007 --- test/dynamo/test_structured_trace.py | 2 ++ torch/_dynamo/convert_frame.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 2ec2310acf3274..cdb7bba77fe91c 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -298,6 +298,7 @@ def test_dynamo_error(self): {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 ) @@ -337,6 +338,7 @@ def throw(x): {"aot_backward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 ) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index d15f5c4aa43dcf..3ab2937a9496d4 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -929,6 +929,15 @@ def format_guard_failures() -> str: # NB: e's msg is mutated here to add user stack, but we DON'T want # that stack in the Scuba logged fail_reason exception_handler(e, code, frame, export=export) + # NB: this is the post-mutation exception + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_error", + "encoding": "string", + }, + payload_fn=lambda: traceback.format_exc(), + ) fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( e, compile_id ) From cdf110286ba3fbd62622cab39e547a1295030c86 Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 10 Sep 2024 08:00:03 -0700 Subject: [PATCH 0729/1018] Flip triton kernel default layout constraint to "needs_fixed_stride_order" (#135581) This is to match the default layout constraint for custom operators. By default, Inductor should match the stride order of inputs to a triton kernel. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/135581 Approved by: https://github.com/eellison ghstack dependencies: #135530 --- torch/_inductor/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 4dc7f9b7964fcf..d3d58c551ec374 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -74,7 +74,7 @@ def autotune_remote_cache_default() -> Optional[bool]: # The default layout constraint for user-defined triton kernels. # See "The default layout constraint for custom operators" for options. -triton_kernel_default_layout_constraint = "flexible_layout" +triton_kernel_default_layout_constraint = "needs_fixed_stride_order" # use cpp wrapper instead of python wrapper cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" From a199b13e885a5c5678e9c058c35ec719674ef05a Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 11 Sep 2024 19:23:08 +0000 Subject: [PATCH 0730/1018] Replace capture_pre_autograd_graph with export_for_training in torch tests (#135623) Summary: as title Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_conv_dynamic buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:fx -- -r matcher buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r x86 ``` CI Differential Revision: D62448302 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135623 Approved by: https://github.com/tugsbayasgalan --- test/export/test_export.py | 6 ++--- test/fx/test_matcher_utils.py | 23 +++++++++---------- test/quantization/pt2e/test_duplicate_dq.py | 6 ++--- .../pt2e/test_metadata_porting.py | 4 ++-- .../pt2e/test_x86inductor_quantizer.py | 18 +++++++-------- test/test_model_exports_to_core_aten.py | 2 +- .../utils/matcher_with_name_node_map_utils.py | 4 ++-- 7 files changed, 30 insertions(+), 33 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 7ee5e5c982b0e4..674aa0060dced3 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -509,11 +509,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) self.assertEqual(exported_program.module()(*args), m(*args)) - from torch._export import capture_pre_autograd_graph - - gm: torch.fx.GraphModule = capture_pre_autograd_graph( + gm: torch.fx.GraphModule = torch.export.export_for_training( m, args=example_args, dynamic_shapes=dynamic_shapes - ) + ).module() args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256)) self.assertEqual(gm(*args), m(*args)) diff --git a/test/fx/test_matcher_utils.py b/test/fx/test_matcher_utils.py index 83db868e4d5a97..f1cb6105b94ff2 100644 --- a/test/fx/test_matcher_utils.py +++ b/test/fx/test_matcher_utils.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F +from torch.export import export_for_training from torch.fx import symbolic_trace from torch.fx.experimental.proxy_tensor import make_fx @@ -167,13 +168,13 @@ def pattern(x, weight): relu_mul_by_two = relu * 2 return relu, relu_mul_by_two, {"conv": conv, "relu": relu} - from torch._export import capture_pre_autograd_graph - example_inputs = ( torch.randn(1, 3, 3, 3) * 10, torch.randn(3, 3, 3, 3), ) - pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs) + pattern_gm = export_for_training( + WrapperModule(pattern), example_inputs + ).module() before_split_res = pattern_gm(*example_inputs) pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) after_split_res = pattern_gm(*example_inputs) @@ -198,17 +199,17 @@ def pattern(x, weight): relu_mul_by_two = relu * 2 return relu, relu_mul_by_two, {"conv": conv, "relu": relu} - from torch._export import capture_pre_autograd_graph - example_inputs = ( torch.randn(1, 3, 3, 3) * 10, torch.randn(3, 3, 3, 3), ) - pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs) + pattern_gm = export_for_training( + WrapperModule(pattern), example_inputs + ).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) - target_gm = capture_pre_autograd_graph( + target_gm = export_for_training( WrapperModule(target_graph), example_inputs - ) + ).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map @@ -246,12 +247,10 @@ def forward(self, x): # nn.Parameter is not an allowed output type in dynamo return linear, {"linear": linear, "x": x} - from torch._export import capture_pre_autograd_graph - example_inputs = (torch.randn(3, 5),) - pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs) + pattern_gm = export_for_training(Pattern(), example_inputs).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) - target_gm = capture_pre_autograd_graph(M(), example_inputs) + target_gm = export_for_training(M(), example_inputs).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index 905098c3e6aca0..e2b7236d2ef093 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -4,7 +4,6 @@ from typing import Any, Dict import torch -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.observer import ( HistogramObserver, MinMaxObserver, @@ -24,6 +23,7 @@ OP_TO_ANNOTATOR, QuantizationConfig, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import QuantizationTestCase from torch.testing._internal.common_utils import IS_WINDOWS @@ -100,10 +100,10 @@ def _test_duplicate_dq( # program capture m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 5f2d1e2d3cf574..27df58e159f3b6 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -99,10 +99,10 @@ def _test_metadata_porting( # program capture m = copy.deepcopy(m_eager) - m = torch._export.capture_pre_autograd_graph( + m = torch.export.export_for_training( m, example_inputs, - ) + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index aa254762db114f..8b357bd9b311ae 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -7,7 +7,6 @@ import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.nn as nn -from torch._export import capture_pre_autograd_graph from torch.ao.quantization import ObserverBase from torch.ao.quantization.quantize_pt2e import ( convert_pt2e, @@ -18,6 +17,7 @@ QUANT_ANNOTATION_KEY, X86InductorQuantizer, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, @@ -550,10 +550,10 @@ def _test_quantizer( # program capture m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -1222,7 +1222,7 @@ def _test_linear_unary_helper( Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. """ use_bias_list = [True, False] - # TODO test for inplace add after refactoring of capture_pre_autograd_graph + # TODO test for inplace add after refactoring of export_for_training inplace_list = [False] if post_op_algo_list is None: post_op_algo_list = [None] @@ -1362,7 +1362,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op is supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of capture_pre_autograd_graph + # TODO test for inplace add after refactoring of export_for_training inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( @@ -1466,7 +1466,7 @@ def test_linear_binary2(self): Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 """ example_inputs = (torch.randn(2, 16),) - # TODO test for inplace add after refactoring of capture_pre_autograd_graph + # TODO test for inplace add after refactoring of export_for_training inplace_add_list = [False] is_qat_list = [False, True] is_dynamic_list = [False, True] @@ -1535,9 +1535,9 @@ def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op and relu as unary post op are supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of capture_pre_autograd_graph + # TODO test for inplace add after refactoring of export_for_training inplace_add_list = [False] - # TODO test for inplace relu after refactoring of capture_pre_autograd_graph + # TODO test for inplace relu after refactoring of export_for_training inplace_relu_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( @@ -2110,7 +2110,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = capture_pre_autograd_graph(m, example_inputs) + m = export_for_training(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. diff --git a/test/test_model_exports_to_core_aten.py b/test/test_model_exports_to_core_aten.py index bbf0a8ba3c04ac..aae14c28b8d6e0 100644 --- a/test/test_model_exports_to_core_aten.py +++ b/test/test_model_exports_to_core_aten.py @@ -27,7 +27,7 @@ def test_vit_aten_export(self): m = m.eval() input_shape = (1, 3, 224, 224) example_inputs = (torch.randn(input_shape),) - m = export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs)) + m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module() m(*example_inputs) m = export.export(m, copy.deepcopy(example_inputs)) ops = _get_ops_list(m.graph_module) diff --git a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py index 8482dca74b1002..78b063ff8a7aaa 100644 --- a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py +++ b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -61,8 +61,8 @@ def target_graph(x, weight): relu *= 2 return relu - pattern_gm = capture_pre_autograd_graph(pattern, example_inputs) - target_gm = capture_pre_autograd_graph(target_graph, example_inputs) + pattern_gm = export_for_training(pattern, example_inputs).module() + target_gm = export_for_training(target_graph, example_inputs).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) matches = matcher.match(target_gm) for match in matches: From 9e43d5903c9ca9acfafbdbc60b7e5884ddf79c81 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Tue, 10 Sep 2024 13:33:01 +0000 Subject: [PATCH 0731/1018] Add decomposition for transpose_copy (#130943) * Extracted from #128416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130943 Approved by: https://github.com/amjames, https://github.com/eellison --- ...asDecompTest.test_has_decomposition.expect | 2 -- test/functorch/test_ops.py | 2 ++ test/functorch/test_vmap.py | 1 + test/test_mps.py | 1 + tools/autograd/gen_variable_type.py | 1 + torch/_decomp/__init__.py | 1 + torch/_refs/__init__.py | 2 ++ .../_internal/common_methods_invocations.py | 27 +++++++++++++++++++ 8 files changed, 35 insertions(+), 2 deletions(-) diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 58d52f34345e6f..5b491d1bc73f80 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1301,8 +1301,6 @@ aten::to_padded_tensor.out aten::topk aten::topk.values aten::transpose_ -aten::transpose_copy.int -aten::transpose_copy.int_out aten::triangular_solve aten::triangular_solve.X aten::unbind_copy.int diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 03744b7a8ef8fe..ad48b7581ea053 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1429,6 +1429,7 @@ def test_vmapjvpall(self, device, dtype, op): xfail("masked.cumprod", ""), xfail("renorm"), # hit vmap fallback, which is disabled xfail("t_copy"), + xfail("transpose_copy"), xfail("unsqueeze_copy"), } ), @@ -1567,6 +1568,7 @@ def test(): "index_fill" ), # aten::_unique hit the vmap fallback which is currently disabled xfail("t_copy"), + xfail("transpose_copy"), xfail("unsqueeze_copy"), } ), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 7b7ae117f95c63..31bae28eaca7a2 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4446,6 +4446,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("resize_as_"), xfail("take"), xfail("tensor_split"), + xfail("transpose_copy"), xfail("to_sparse"), # TypeError: expected Tensor as element 0 in argument 0, but got float xfail("item"), diff --git a/test/test_mps.py b/test/test_mps.py index eb99a6afdc6490..2c3039a1602a77 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -344,6 +344,7 @@ def mps_ops_modifier(ops): 'tanh', 'tensor_split', 'transpose', + 'transpose_copy', 'T', 'unbind', 'unflatten', diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 4bec1871ae483e..fa6c578dea04a2 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -197,6 +197,7 @@ "nanmean", "nansum", "transpose", + "transpose_copy", "permute", "squeeze", "unsqueeze", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 93bbec04a425be..46ffa579aec564 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -452,6 +452,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.threshold_backward, aten.trace, aten.transpose.int, + aten.transpose_copy, aten.tril, aten.tril_, aten.triu, diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 15b18fefacb777..9b4d10cd5a6ad4 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -290,6 +290,7 @@ "take_along_dim", "tensor_split", "transpose", + "transpose_copy", "unfold", "unfold_copy", "unsqueeze", @@ -6335,6 +6336,7 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) # no sparse support. See narrow_copy_sparse in core. narrow_copy = _make_copy_from_view(aten.narrow) t_copy = _make_copy_from_view(aten.t) +transpose_copy = _make_copy_from_view(aten.transpose) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) view_copy = _make_copy_from_view(aten.view) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 15102a2fba28eb..d0a0a09d7d214c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -19621,6 +19621,24 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # vmap does not support inplace views check_inplace_batched_forward_grad=False, sample_inputs_func=sample_inputs_transpose_swapdims), + OpInfo('transpose_copy', + assert_jit_shape_analysis=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_transpose_swapdims, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,) + ), + )), OpInfo('T', op=lambda x: x.T, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), @@ -23835,6 +23853,15 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.transpose", torch_opinfo_name="transpose", ), + PythonRefInfo( + "_refs.transpose_copy", + torch_opinfo_name="transpose_copy", + skips=( + # RuntimeError: no _refs support for torch.Tensor.is_conj + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), + supports_out=True, + ), PythonRefInfo( "_refs.t", torch_opinfo_name="t", From 4b88f1ac3f2d61dc37902499a2af4b5dc8041a68 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Wed, 11 Sep 2024 09:28:29 -0700 Subject: [PATCH 0732/1018] Simplify expr before getting implications in _maybe_evaluate_static (#135499) Fixes #134268 Previously we weren't simplifying these expressions before calling get_implications, resulting in inconsistent application of FloorDiv/CleanDiv. See #134268 for more details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135499 Approved by: https://github.com/ezyang --- test/dynamo/test_misc.py | 15 +++++++++++++++ torch/fx/experimental/symbolic_shapes.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index ab1b343044a26d..9d2bb374621134 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10430,6 +10430,21 @@ def fn(x): c2 = _debug_get_cache_entry_list(fn.__code__) self.assertEqual(len(c2), 0) + def test_guard_size_oblivious_simplification(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + u0, u1 = x.tolist() + torch._check_is_size(u0) + torch._check_is_size(u1) + torch._check((2 * u0) % (u0 + u1) == 0) + torch._check((2 * u0) // (u0 + u1) != 0) + if guard_size_oblivious((2 * u0) // (u0 + u1) == 0): + return torch.tensor(True) + else: + return torch.tensor(False) + + fn(torch.tensor([3, 3])) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_guard_size_oblivious(self): # This code, in fact, does NOT work in eager diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index f8ca28ab0d7b9f..c892fe846662cb 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4477,7 +4477,7 @@ def _maybe_evaluate_static( subst = {} for e in axioms: if e.free_symbols.issubset(expr.free_symbols): - subst.update(dict(self.get_implications(e))) + subst.update(dict(self.get_implications(self.simplify(e)))) expr = expr.xreplace(subst) From 2d6d998a254dafd86bef2050de4b7f2559902b49 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 11 Sep 2024 20:07:11 +0000 Subject: [PATCH 0733/1018] [Inductor] add _dynamo.reset to test_cat_slice_cat_cuda (#135694) Summary: test_cat_slice_cat_cuda runs inductor multiple times and check counters["inductor"] in between, and thus we need to reset properly. Differential Revision: D62500331 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135694 Approved by: https://github.com/masnesral --- test/inductor/test_pattern_matcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 3ad1b5a6ab1fef..4c7ad746c6cc3f 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -748,6 +748,7 @@ def fn(a, b): torch.randn(2, 8, device="cuda"), torch.randn(2, 16, device="cuda"), ] + torch._dynamo.reset() counters.clear() expected = fn(*args) actual = torch.compile(fn)(*args) From f260e4d2296cd6df36f7a0b209548f11d4b765db Mon Sep 17 00:00:00 2001 From: Sidney Tsang Date: Wed, 11 Sep 2024 20:21:59 +0000 Subject: [PATCH 0734/1018] [Export] Fix SDPA decomposition (#135297) Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary. Test Plan: CI Differential Revision: D62278378 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297 Approved by: https://github.com/drisspg --- test/test_decomp.py | 29 +++++++---------------------- torch/_decomp/decompositions.py | 20 +++++++------------- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/test/test_decomp.py b/test/test_decomp.py index b3ccf298516939..d458892851ef9f 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -1179,24 +1179,6 @@ def test_weight_norm_interface(self, device): def test_sdpa(self, device, dtype, op): # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we # add support for float16 over there we should update this test as well. - - class ScaledDotProductAttention(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward( - self, query_layer, key_layer, value_layer, mask=None, is_causal=True - ): - attn_output = op( - query_layer, - key_layer, - value_layer, - attn_mask=mask, - dropout_p=0.0, - is_causal=is_causal, - ) - return attn_output - query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) @@ -1206,12 +1188,17 @@ def forward( for mask in masks: is_causal = mask is None - attention = ScaledDotProductAttention() decomposed_res = ( torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu( query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask ) ) + actual_res = decomposed_res[0] + # Output has form (N, H, L, E), but should be continuous on (L, N, H, E) + # in order for subsequent view(L * N, H * E) to be valid. + # So permute(2, 0, 1, 3) before checking that tensor is contiguous + self.assertTrue(actual_res.permute(2, 0, 1, 3).is_contiguous()) + eager_res = op( query_layer, key_layer, @@ -1221,9 +1208,7 @@ def forward( is_causal=is_causal, ) - self.assertTrue( - torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol) - ) + self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol)) instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index c35d7a72774f16..1aa13da9cacd7e 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4905,9 +4905,7 @@ def scaled_dot_product_flash_attention_for_cpu( # Why this change? # In pre-dispatch export scaled_dot_product_attention is executed via # * flash_attention. - # flash_attention allocates output tensor as (N, L, H, E) - # it then transposes that to get (N, H, L, E) which is supposed to be the return - # tensor dim for scaled_dot_product_attention + # flash_attention allocates output tensor as (N, H, L, E) (see PR #134656) # assume x: [N, H, L, E] is the output sdpa # In MHA code, this output is then permuted via (2, 0, 1, 3) to get # (L, N, H, E) dim tensor @@ -4923,20 +4921,16 @@ def scaled_dot_product_flash_attention_for_cpu( # subsequent view is not valid and the export fails # solution is to maintain the return tensor view from the decomp to be # exactly same as *flash* variant. - # flash variants output is contiguous as [N, L, H, E] - # _match variant out is contiguous as [N, H, L, E] - # out = out.transpose(1, 2).contiguous gets output as contiguous - # in [N, L, H, E]. - # Subsrequent transpose(1, 2) then returns a view on which - # aforementioned code snippet, as showm below, is valid - # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via - # x = x.view(L * N, H * E) # Really the invariant you want to maintain is: # pre-dispatch op-output and its decomposed representation must # return tensor with same view and dims - output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) - return (output.transpose(1, 2), attn) + output = ( + output.permute(2, 0, 1, 3) + .contiguous(memory_format=torch.contiguous_format) + .permute(1, 2, 0, 3) + ) + return output, attn def register_inplace(aten_op, outplace_op): From c5c98dc8789668749402f48f8242918a031971a9 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 11 Sep 2024 20:34:01 +0000 Subject: [PATCH 0735/1018] [ROCm] Update to AOTriton 0.7b (#134498) Notable changes: 1. Enable CudaGraph related tests 2. Fix UT problems 3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` Know Problem: 1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest` + Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest` Note: AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it. Fixes #133540 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498 Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet --- .ci/docker/aotriton_version.txt | 6 +- .ci/docker/common/install_aotriton.sh | 4 +- .../native/transformers/cuda/attention.cu | 29 ++++-- .../transformers/cuda/attention_backward.cu | 9 +- .../native/transformers/cuda/sdp_utils.cpp | 29 +++++- .../transformers/hip/aotriton_adapter.h | 12 +++ .../transformers/hip/flash_attn/flash_api.hip | 60 ++++++------ test/functorch/test_control_flow.py | 4 + test/inductor/test_flex_attention.py | 4 + test/inductor/test_flex_decoding.py | 4 + test/nn/test_multihead_attention.py | 2 + test/run_test.py | 1 - test/test_native_mha.py | 7 +- test/test_nn.py | 2 + test/test_transformers.py | 92 ++++++++++++++----- .../_internal/common_methods_invocations.py | 22 +++-- 16 files changed, 203 insertions(+), 84 deletions(-) diff --git a/.ci/docker/aotriton_version.txt b/.ci/docker/aotriton_version.txt index 69b02c700e32ce..602b77d3b853a5 100644 --- a/.ci/docker/aotriton_version.txt +++ b/.ci/docker/aotriton_version.txt @@ -1,5 +1,5 @@ -0.6b +0.7b manylinux_2_17 rocm6.2 -7f07e8a1cb1f99627eb6d77f5c0e9295c775f3c7 -e4ab195d2bd19e939c675a13280c29714c6ef9f2cf420690da150fa0cac043b1 +9be04068c3c0857a4cfd17d7e39e71d0423ebac2 +3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291 diff --git a/.ci/docker/common/install_aotriton.sh b/.ci/docker/common/install_aotriton.sh index 8b340ee219de2a..2aee95c48d4794 100755 --- a/.ci/docker/common/install_aotriton.sh +++ b/.ci/docker/common/install_aotriton.sh @@ -4,12 +4,12 @@ set -ex source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" -TARBALL='aotriton.tar.bz2' +TARBALL='aotriton.tar.gz' # This read command alwasy returns with exit code 1 read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true ARCH=$(uname -m) AOTRITON_INSTALL_PREFIX="$1" -AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.bz2" +AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.gz" cd "${AOTRITON_INSTALL_PREFIX}" # Must use -L to follow redirects diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index cf648151a5ba8d..914915da01a03d 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1115,10 +1115,13 @@ std::tuple _efficient_ offset_t = at::empty({}, at::dtype(at::kLong).device(device)); } else { auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor( - at::Scalar(static_cast(seed)), at::dtype(at::kLong)); - offset_t = at::scalar_tensor( - at::Scalar(static_cast(offset)), at::dtype(at::kLong)); +#ifdef USE_ROCM + const auto options = at::dtype(at::kLong).device(at::kCUDA); +#else + const auto options = at::dtype(at::kLong); +#endif + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), options); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), options); } } else { // Not using dropout @@ -1131,7 +1134,8 @@ std::tuple _efficient_ auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") } // AOTriton may accept aligned on logsumexp tensor in the future for better @@ -1160,8 +1164,16 @@ std::tuple _efficient_ using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16); at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options()); + const bool use_philox_state = in_capture_stream; + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); + auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); hipError_t err; // TODO: Error handling err = attn_fwd(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), @@ -1171,8 +1183,11 @@ std::tuple _efficient_ mk_aotensor<2>(softmax_lse, "M"), mk_aotensor(output_t, "Out"), dropout_p, - use_dropout ? *seed_t.data_ptr() : 0, - use_dropout ? *offset_t.data_ptr() : 0, + seed, + offset1, + offset2, + seed_output, + offset_output, mk_aotensor(softmax_fa_t, "encoded_softmax"), is_causal, stream); diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 75ab28d21dbc3f..07a398573665ac 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -416,7 +416,8 @@ _efficient_attention_backward( auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") } const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); bool is_causal; @@ -441,6 +442,7 @@ _efficient_attention_backward( hipError_t err; using aotriton::v2::flash::attn_bwd; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); err = attn_bwd(mk_aotensor(q_t, "q"), @@ -457,8 +459,9 @@ _efficient_attention_backward( mk_aotensor<2>(softmax_lse, "L"), mk_aotensor<2>(delta, "delta"), float(dropout_p), - rng_engine_inputs.seed_.val, - rng_engine_inputs.offset_.val, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, is_causal, stream); #else diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 88bcf29291cd73..d84d9417692166 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -210,6 +210,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug // Check that the gpu is capable of running flash attention using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; + auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM #if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -221,11 +222,19 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug } return false; } + c10::string_view arch(dprops->gcnArchName); + if (arch == "gfx1100") { + static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; + if (!enable_navi3x) { + TORCH_WARN_ONCE("Flash attention support on Navi31 GPU is still experimental." + " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); + return false; + } + } #else return false; #endif #else - auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( @@ -245,6 +254,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) // Mem Efficient attention supports hardware in the range [sm_50, sm_90] using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; + auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM #if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -256,11 +266,19 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } + c10::string_view arch(dprops->gcnArchName); + if (arch == "gfx1100") { + static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; + if (!enable_navi3x) { + TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental." + " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); + return false; + } + } #else return false; #endif #else - auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( @@ -616,9 +634,14 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { } } } +#if USE_ROCM + constexpr bool backend_supports_grouped_query_attention = false; +#else + constexpr bool backend_supports_grouped_query_attention = true; +#endif if (has_only_dense_inputs(params)) { constexpr auto dense_constraints = array_of( - check_batch_size_and_num_heads_dense, + check_batch_size_and_num_heads_dense, check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense); for (auto& constraint : dense_constraints) { diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index 1c238c751a05c9..57d5c34444390d 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -115,6 +115,18 @@ aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view ten cast_dtype(q.dtype())); } +inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q) +{ + return aotriton::TensorView<0>(reinterpret_cast(q.data_ptr()), + cast_dtype(q.dtype())); +} + +inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr) +{ + return aotriton::TensorView<0>(reinterpret_cast(ptr), + aotriton::DType::kUInt64); // AOTriton excepts unsigned int64 +} + } // namespace aotriton_adapter } // namespace sdp diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 6d363d5efa0752..92e51f85d8e540 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -72,7 +72,8 @@ void check_gpu_arch(hipStream_t stream) { auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") } } @@ -164,6 +165,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); at::Tensor seed_t, offset_t; + at::PhiloxCudaState philox_state; + bool use_philox_state = false; if (p_dropout > 0.0) { // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -171,12 +174,14 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head int64_t counter_offset = batch_size * num_heads * 32; // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset); + philox_state = gen->philox_cuda_state(counter_offset); if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong)); + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); } else { + // See Note [CUDA Graph-safe RNG states] about the design + use_philox_state = true; seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); } @@ -185,19 +190,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); } else { - seed_t = at::empty({}, at::dtype(at::kLong)); - offset_t = at::empty({}, at::dtype(at::kLong)); - } - } - - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*seed_t.data_ptr(), *offset_t.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(seed_t.data_ptr(), offset_t.data_ptr(), 0); + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); } } @@ -219,9 +213,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; + using aotriton::TensorView; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); + auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); err = attn_fwd(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), mk_aotensor(v_t, "v"), @@ -230,8 +232,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head mk_aotensor<2>(M, "M"), mk_aotensor(output_t, "Out"), p_dropout, - philox_args.seed_.val, - philox_args.offset_.val, + seed, + offset1, + offset2, + seed_output, + offset_output, mk_aotensor(softmax_fa_t, "encoded_softmax"), is_causal, stream); @@ -392,17 +397,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv_expanded = dv; } - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(philox_seed.data_ptr(), philox_offset.data_ptr(), 0); - } - } - at::Tensor q_t = q.permute({0,2,1,3}); at::Tensor k_t = k.permute({0,2,1,3}); at::Tensor v_t = v.permute({0,2,1,3}); @@ -420,6 +414,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si { using aotriton::v2::flash::attn_bwd; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd(mk_aotensor(q_t, "q"), @@ -436,8 +431,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si mk_aotensor<2>(softmax_lse_cont, "L"), mk_aotensor<2>(delta, "delta"), p_dropout, - philox_args.seed_.val, - philox_args.offset_.val, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, is_causal, stream); } diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index e07510828f4584..16740fdeddf3a6 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -26,6 +26,7 @@ parametrize, requires_cuda, run_tests, + skipIfRocm, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, TestCase, @@ -1617,6 +1618,7 @@ def test_scan_dim(self, reverse, device): result_exp_PT = op_pt(x, rnd_scan_dim) self.assertEqual(result[1], result_exp_PT) + @skipIfRocm(msg="Unsupported on ROCM yet") @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("combine_mode", ["pointwise", "generic"]) @@ -1692,6 +1694,7 @@ def test_scan_binary_operator(self, reverse, device): ) self.assertEqual(result, expected_result) + @skipIfRocm(msg="Unsupported on ROCM yet") @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("combine_mode", ["pointwise", "generic"]) @@ -1723,6 +1726,7 @@ def test_associative_scan_tuple(self, combine_mode, reverse, device): ) self.assertEqual(result1, expected_result) + @skipIfRocm(msg="Unsupported on ROCM yet") @requires_cuda @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index e0b975c2139e36..cf37f89a5175b2 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -29,6 +29,7 @@ from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 +from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM from torch.utils._triton import has_triton @@ -312,6 +313,8 @@ def run_test( V_D: int = D, block_mask: Optional[BlockMask] = None, ): + if TEST_WITH_ROCM and Q_H != KV_H: + self.skipTest("enable_gqa=True is unsupported on ROCM, for now") q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True ) @@ -1360,6 +1363,7 @@ def mask_mod(b, h, q, kv): self.run_test_with_call(attention) + @skipIfRocm @supported_platform def test_GQA_causal_mask(self): def mask_mod(b, h, q, kv): diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 27d0f3083be4d7..04ca1429bebf19 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -21,6 +21,7 @@ from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 +from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM from torch.utils._triton import has_triton @@ -277,6 +278,8 @@ def run_test( score_mod is not None or block_mask is not None ), "Must provide score_mod or block_mask" assert Q_H % KV_H == 0 + if TEST_WITH_ROCM and Q_H != KV_H: + self.skipTest("enable_gqa=True is unsupported on ROCM, for now") q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, @@ -822,6 +825,7 @@ def bias_mod(score, batch, head, token_q, token_kv): self.run_test(bias_mod) + @skipIfRocm @supported_platform def test_fully_masked_out_rows_0_check_gqa(self): # Ensure fully masked out rows won't cause NaNs. diff --git a/test/nn/test_multihead_attention.py b/test/nn/test_multihead_attention.py index 40dca90b164887..c0419664d0098e 100644 --- a/test/nn/test_multihead_attention.py +++ b/test/nn/test_multihead_attention.py @@ -17,6 +17,7 @@ instantiate_parametrized_tests, parametrize as parametrize_test, run_tests, + skipIfRocm, TEST_NUMPY, TEST_WITH_CROSSREF, ) @@ -745,6 +746,7 @@ def test_multihead_attn_nested_tensor_outside_fast_path(self): class TestMultiheadAttentionNNDeviceType(NNTestCase): + @skipIfRocm(msg="To investigate: yields NaN") def test_multihead_self_attn_two_masks_fast_path(self, device): """ Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path diff --git a/test/run_test.py b/test/run_test.py index b459b577947693..231a1b2b7ca012 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -183,7 +183,6 @@ def __contains__(self, item): "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", "distributed/_tensor/test_attention", - "test_transformers", ] XPU_BLOCKLIST = [ diff --git a/test/test_native_mha.py b/test/test_native_mha.py index 9a07485cb2e946..307115147852ff 100644 --- a/test/test_native_mha.py +++ b/test/test_native_mha.py @@ -276,8 +276,11 @@ def do_pad_all(tensors): @torch.no_grad() def test_native_multihead_self_attention(self, device, dtype, use_nt, need_weights, average_attn_weights, use_padding, pad_all, fused): - if TEST_WITH_ROCM and use_nt: - self.skipTest("ROCM does not support nested tensors for Flash Attention for now.") + if TEST_WITH_ROCM: + if use_nt: + self.skipTest("ROCM does not support nested tensors for Flash Attention for now.") + if use_padding and not pad_all and fused: + self.skipTest("Large numerical errors on ROCM to investigate.") for need_weights in (False, not pad_all): with self.subTest(use_padding=use_padding, pad_all=pad_all, use_nt=use_nt, need_weights=need_weights, diff --git a/test/test_nn.py b/test/test_nn.py index 533f889975b99b..7d0241f97dc210 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3078,6 +3078,7 @@ def perm_fn(x): [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) + @skipIfRocm(msg='Large numerical errors') def test_transformerdecoder(self): def get_a_test_layer(use_cuda, activation, batch_first=False): d_model = 4 @@ -12409,6 +12410,7 @@ def test_skip_init(self, device): self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device) self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight)) + @skipIfRocm(msg='Not our bug: TransformerEncoderLayer._sa_block still uses FA/ME and effectively takes fastpath') @skipIfMps # TODO(hvaara): Investigate as possible bug. macOS 13 passes, while 14 and 15 fails. @dtypes(torch.float) @dtypesIfCUDA(torch.double, torch.float, torch.half) diff --git a/test/test_transformers.py b/test/test_transformers.py index 30810997cfd47e..168e4f903b0fda 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -347,6 +347,9 @@ def test_train_with_pad_and_catch_error(self, device): @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype): + if TEST_WITH_ROCM: + if attn_mask_dim is not None and mask_dtype == torch.bool: + self.skipTest("boolean mask is not fully supported on ROCm yet.") # MHA converts all with torch.no_grad(): B = 2 @@ -429,6 +432,7 @@ def hook(module, inputs, output): # remove hook handle.remove() + @skipIfRocm @tf32_on_and_off(0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @@ -1585,7 +1589,7 @@ def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): q, k, v, None, 0.0, False)) @onlyCUDA - @skipIfRocm # Nested Tensor + @skipIfRocm(msg='enable_gqa=True unsupported') @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION]) def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel): @@ -1601,7 +1605,7 @@ def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_ke is_causal=False, enable_gqa=True) @onlyCPU - @skipIfRocm # Nested Tensor + @skipIfRocm(msg='enable_gqa=True unsupported') def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device): rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True) rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) @@ -2910,6 +2914,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, return if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: torch.cuda.empty_cache() # Prevent memory fragmentation + if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k: + self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k") seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -2957,15 +2963,27 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value), upstream_grad) grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref), upstream_grad) + fudge_factors = { + 'out': 3.0 , + 'grad_query': 150.0 , + 'grad_key': 25.0, + 'grad_value': 8.5, + } + if TEST_WITH_ROCM: + fudge_factors['grad_key'] = 45.0 + fudge_factors['grad_query'] = 360.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 160.0 + fudge_factors['grad_query'] = 650.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 + check_out_and_grad( (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), - fudge_factors={ - 'out': 3.0 , - 'grad_query': 150.0 , - 'grad_key': 25.0, - 'grad_value': 8.5, - } + fudge_factors=fudge_factors, ) @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") @@ -3054,16 +3072,28 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value, attn_mask), upstream_grad) grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref, attn_mask_ref), upstream_grad) + fudge_factors = { + "out": 4, + "grad_query": 150.0, + "grad_key": 25.0, + "grad_value": 8.0, + "grad_attn_mask": 45.0, + } + if TEST_WITH_ROCM: + fudge_factors['grad_key'] = 45.0 + fudge_factors['grad_query'] = 360.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 160.0 + fudge_factors['grad_query'] = 650.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 + check_out_and_grad( (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), - fudge_factors={ - "out": 4, - "grad_query": 160.0, - "grad_key": 25.0, - "grad_value": 8.0, - "grad_attn_mask": 45.0, - }, + fudge_factors=fudge_factors, ) @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @@ -3076,7 +3106,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, @parametrize("dropout_p", [0.0, 0.22, 0.48]) @parametrize("dtype", [torch.float16, torch.bfloat16]) @parametrize("scale", [None, "l1"]) - @parametrize("enable_gqa", [True, False]) + @parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False]) @parametrize("n_heads", [[16, 8], [10, 2]]) def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, @@ -3164,18 +3194,31 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value), upstream_grad) grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref), upstream_grad) + fudge_factors = { + 'out': 4, + 'grad_query': 160.0, + 'grad_key': 16, + 'grad_value': 4, + } + if TEST_WITH_ROCM: + fudge_factors['grad_key'] = 45.0 + fudge_factors['grad_query'] = 360.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 190.0 + fudge_factors['grad_query'] = 650.0 + if seq_len_q >= 2048: + fudge_factors['grad_query'] = 1100.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 + check_out_and_grad( (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), - fudge_factors={ - 'out': 4, - 'grad_query': 160.0, - 'grad_key': 16, - 'grad_value': 4, - } + fudge_factors=fudge_factors, ) - @skipIfRocm # FIXME: "capturing stream has unjoined work" @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("batch_size", [1, 8]) @parametrize("seq_len_q", [256, 1024]) @@ -3223,6 +3266,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k: self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") + if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k: + self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k") seed = 42 n_heads = 4 @@ -3722,6 +3767,7 @@ def test_causal_variants_compile(self, device, causal_variant: CausalVariant, sh self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") + @skipIfRocm @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]): make_tensor = partial( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index d0a0a09d7d214c..a6f67d75659689 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8776,8 +8776,13 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] samples = [] + gqa_options = [False] if TEST_WITH_ROCM else [True, False] # TODO: GQA support + if TEST_WITH_ROCM and dtype == torch.float32: + causal_options = [False] # FIXME: Large errors with causal+fp32 + else: + causal_options = [True, False] for qkv_shape, is_causal, dropout_p, enable_gqa in product( - qkv_shapes, [True, False], [0.0, 0.5], [True, False]): + qkv_shapes, causal_options, [0.0, 0.5], gqa_options): shape_q, shape_kv = qkv_shape samples.append(SampleInput( make(shape_q), @@ -8807,14 +8812,15 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ dropout_p=0.0) ) - samples.append( - SampleInput( - make((batch, num_heads_q_gqa, seq_q, head_dim)), - make((batch, num_heads_kv_gqa, seq_kv, head_dim)), - make((batch, num_heads_kv_gqa, seq_kv, head_dim)), - enable_gqa=True + if not TEST_WITH_ROCM: + samples.append( + SampleInput( + make((batch, num_heads_q_gqa, seq_q, head_dim)), + make((batch, num_heads_kv_gqa, seq_kv, head_dim)), + make((batch, num_heads_kv_gqa, seq_kv, head_dim)), + enable_gqa=True + ) ) - ) yield from samples From 46a9c676cb94b22dd28168676455bbc5b0a29ba4 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 11 Sep 2024 20:48:52 +0000 Subject: [PATCH 0736/1018] Expand bitwise ops to unsigned types (#135525) Fixes https://github.com/pytorch/pytorch/issues/135436 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135525 Approved by: https://github.com/ezyang --- aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 51417d0e6851eb..838b7fbd097fed 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -81,6 +81,12 @@ void atan2_kernel(TensorIteratorBase& iter) { } #if !defined(C10_MOBILE) +#define _AT_DISPATCH_INTEGRAL_TYPES_V2(TYPE, NAME, ...) \ + AT_DISPATCH_V2( \ + TYPE, \ + NAME, \ + AT_WRAP(__VA_ARGS__), \ + AT_EXPAND(AT_INTEGRAL_TYPES_V2)) #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \ AT_DISPATCH_V2( \ TYPE, \ @@ -104,6 +110,8 @@ void atan2_kernel(TensorIteratorBase& iter) { AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #else +#define _AT_DISPATCH_INTEGRAL_TYPES_V2(TYPE, NAME, ...) \ + AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, __VA_ARGS__) #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ kComplexHalf, kHalf, kBool, kBFloat16, TYPE, NAME, __VA_ARGS__) @@ -382,7 +390,7 @@ void bitwise_and_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel(iter, [](bool a, bool b) { return a && b; }); } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() { + _AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_and_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return a & b; }, @@ -395,7 +403,7 @@ void bitwise_or_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel(iter, [](bool a, bool b) { return a || b; }); } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_or_cpu", [&]() { + _AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_or_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return a | b; }, @@ -410,7 +418,7 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) { // this operation for both Boolean and integral types. cpu_kernel(iter, [](bool a, bool b) { return a != b; }); } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_xor_cpu", [&]() { + _AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_xor_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return a ^ b; }, From 7ed3a6e5f1d7c32bab337dce856b17b644989678 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 10 Sep 2024 16:31:26 +0800 Subject: [PATCH 0737/1018] Refactoring byte_order (#135558) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135558 Approved by: https://github.com/mikaylagawarecki --- torch/csrc/StorageMethods.cpp | 92 ++++----- torch/csrc/distributed/rpc/utils.cpp | 4 +- torch/csrc/serialization.cpp | 16 +- torch/csrc/utils/byte_order.cpp | 267 +++++++++------------------ torch/csrc/utils/byte_order.h | 145 +-------------- 5 files changed, 131 insertions(+), 393 deletions(-) diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 4bd16066143fe4..19f02b464530d0 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -190,6 +190,16 @@ static PyObject* THPStorage_fill_(PyObject* self, PyObject* number_arg) { END_HANDLE_TH_ERRORS } +template +static void decodeWrapper( + void* data, + const uint8_t* src, + bool do_byte_swap, + size_t count) { + torch::utils::THP_decodeBuffer( + static_cast(data), src, do_byte_swap, count); +} + static PyObject* THPStorage_fromBuffer( PyObject* _unused, PyObject* args, @@ -308,70 +318,30 @@ static PyObject* THPStorage_fromBuffer( c10::GetDefaultCPUAllocator(), /*resizable=*/true); + static const std::unordered_map< + at::ScalarType, + std::function> + decode_map = { + {at::kBool, decodeWrapper}, + {at::kShort, decodeWrapper}, + {at::kInt, decodeWrapper}, + {at::kLong, decodeWrapper}, + {at::kHalf, decodeWrapper}, + {at::kBFloat16, decodeWrapper}, + {at::kFloat, decodeWrapper}, + {at::kDouble, decodeWrapper}, + {at::kComplexFloat, decodeWrapper>}, + {at::kComplexDouble, decodeWrapper>}}; + if (is_endian_independent) { memcpy(storage->mutable_data(), src + offset, count); - } else if (scalar_type == at::kBool) { - // Because of ASAN checks, that are failing whenever - // we are trying to get a value which is not 0 or 1, we have to manually - // convert original values to boolean ones. - torch::utils::THP_decodeBoolBuffer( - static_cast(storage->mutable_data()), src + offset, count); - } else if (scalar_type == at::kShort) { - torch::utils::THP_decodeInt16Buffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kInt) { - torch::utils::THP_decodeInt32Buffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kLong) { - torch::utils::THP_decodeInt64Buffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kHalf) { - torch::utils::THP_decodeHalfBuffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kBFloat16) { - torch::utils::THP_decodeBFloat16Buffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kFloat) { - torch::utils::THP_decodeFloatBuffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kDouble) { - torch::utils::THP_decodeDoubleBuffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kComplexFloat) { - torch::utils::THP_decodeComplexFloatBuffer( - static_cast*>(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kComplexDouble) { - torch::utils::THP_decodeComplexDoubleBuffer( - static_cast*>(storage->mutable_data()), - src + offset, - do_byte_swap, - count); } else { - TORCH_CHECK(false, "Unknown type: ", scalar_type); + auto it = decode_map.find(scalar_type); + if (it != decode_map.end()) { + it->second(storage->mutable_data(), src + offset, do_byte_swap, count); + } else { + TORCH_CHECK(false, "Unknown type: ", scalar_type); + } } PyBuffer_Release(&buffer); diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index e65b503e574f16..e93be25483051c 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -477,7 +477,7 @@ void writeWrappedPayload( int64_t indexToWrite = originalPayload.size(); originalPayload.resize(originalPayload.size() + sizeof(int64_t)); const int64_t additionalPayloadSize = additionalPayload.size(); - torch::utils::THP_encodeInt64Buffer( + torch::utils::THP_encodeBuffer( reinterpret_cast(originalPayload.data()) + indexToWrite, &additionalPayloadSize, torch::utils::THPByteOrder::THP_BIG_ENDIAN, @@ -492,7 +492,7 @@ std::vector readWrappedPayload( int64_t additionalPayloadSize; TORCH_INTERNAL_ASSERT(payload.size() >= sizeof(int64_t)); size_t indexToRead = payload.size() - sizeof(int64_t); - torch::utils::THP_decodeInt64Buffer( + torch::utils::THP_decodeBuffer( &additionalPayloadSize, reinterpret_cast(payload.data()) + indexToRead, torch::utils::THPByteOrder::THP_BIG_ENDIAN, diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp index 10ba23a656f08c..c922da900613d9 100644 --- a/torch/csrc/serialization.cpp +++ b/torch/csrc/serialization.cpp @@ -254,7 +254,7 @@ void THPStorage_writeFileRaw( doWrite(fd, &numel, sizeof(int64_t)); else { int64_t nsize{}; // convert big endian cpu to little endian storage - torch::utils::THP_encodeInt64Buffer( + torch::utils::THP_encodeBuffer( (uint8_t*)&nsize, (const int64_t*)&numel, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, @@ -274,19 +274,19 @@ void THPStorage_writeFileRaw( for (size_t i = 0; i < numel; i += buffer_size) { size_t to_convert = std::min(numel - i, buffer_size); if (element_size == 2) { - torch::utils::THP_encodeInt16Buffer( + torch::utils::THP_encodeBuffer( le_buffer.data(), (const int16_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } else if (element_size == 4) { - torch::utils::THP_encodeInt32Buffer( + torch::utils::THP_encodeBuffer( le_buffer.data(), (const int32_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } else if (element_size == 8) { - torch::utils::THP_encodeInt64Buffer( + torch::utils::THP_encodeBuffer( le_buffer.data(), (const int64_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, @@ -322,7 +322,7 @@ c10::intrusive_ptr THPStorage_readFileRaw( if (torch::utils::THP_nativeByteOrder() == torch::utils::THPByteOrder::THP_BIG_ENDIAN) { int64_t tsize = size; // convert little endian storage to big endian cpu - torch::utils::THP_decodeInt64Buffer(&size, (const uint8_t*)&tsize, true, 1); + torch::utils::THP_decodeBuffer(&size, (const uint8_t*)&tsize, true, 1); } size_t nbytes = element_size * size; if (!storage.defined()) { @@ -369,13 +369,13 @@ c10::intrusive_ptr THPStorage_readFileRaw( // NOLINTNEXTLINE(bugprone-branch-clone) if (element_size == 2) { - torch::utils::THP_decodeInt16Buffer( + torch::utils::THP_decodeBuffer( (int16_t*)data + i, le_buffer.get(), true, to_convert); } else if (element_size == 4) { - torch::utils::THP_decodeInt32Buffer( + torch::utils::THP_decodeBuffer( (int32_t*)data + i, le_buffer.get(), true, to_convert); } else if (element_size == 8) { - torch::utils::THP_decodeInt64Buffer( + torch::utils::THP_decodeBuffer( (int64_t*)data + i, le_buffer.get(), true, to_convert); } } diff --git a/torch/csrc/utils/byte_order.cpp b/torch/csrc/utils/byte_order.cpp index 6c41253fbb8664..4432e74bd06c10 100644 --- a/torch/csrc/utils/byte_order.cpp +++ b/torch/csrc/utils/byte_order.cpp @@ -117,49 +117,38 @@ THPByteOrder THP_nativeByteOrder() { return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN; } -void THP_decodeInt16Buffer( - int16_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len) { - for (const auto i : c10::irange(len)) { - dst[i] = (int16_t)(do_byte_swap ? decodeUInt16ByteSwapped(src) - : decodeUInt16(src)); - src += sizeof(int16_t); - } -} - -void THP_decodeInt32Buffer( - int32_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len) { - for (const auto i : c10::irange(len)) { - dst[i] = (int32_t)(do_byte_swap ? decodeUInt32ByteSwapped(src) - : decodeUInt32(src)); - src += sizeof(int32_t); - } -} +template +void THP_decodeBuffer(T* dst, const uint8_t* src, U type, size_t len) { + if constexpr (std::is_same_v) + THP_decodeBuffer(dst, src, type != THP_nativeByteOrder(), len); + else { + auto func = [&](const uint8_t* src_data) { + if constexpr (std::is_same_v) { + return type ? decodeUInt16ByteSwapped(src_data) + : decodeUInt16(src_data); + } else if constexpr (std::is_same_v) { + return type ? decodeUInt32ByteSwapped(src_data) + : decodeUInt32(src_data); + } else if constexpr (std::is_same_v) { + return type ? decodeUInt64ByteSwapped(src_data) + : decodeUInt64(src_data); + } + }; -void THP_decodeInt64Buffer( - int64_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len) { - for (const auto i : c10::irange(len)) { - dst[i] = (int64_t)(do_byte_swap ? decodeUInt64ByteSwapped(src) - : decodeUInt64(src)); - src += sizeof(int64_t); + for (const auto i : c10::irange(len)) { + dst[i] = static_cast(func(src)); + src += sizeof(T); + } } } -void THP_decodeHalfBuffer( +template <> +TORCH_API void THP_decodeBuffer( c10::Half* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint16_t x; c10::Half f; @@ -170,7 +159,8 @@ void THP_decodeHalfBuffer( } } -void THP_decodeBFloat16Buffer( +template <> +TORCH_API void THP_decodeBuffer( at::BFloat16* dst, const uint8_t* src, bool do_byte_swap, @@ -183,19 +173,24 @@ void THP_decodeBFloat16Buffer( } } -void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, size_t len) { +template <> +TORCH_API void THP_decodeBuffer( + bool* dst, + const uint8_t* src, + bool, + size_t len) { for (const auto i : c10::irange(len)) { dst[i] = (int)src[i] != 0 ? true : false; } } -void THP_decodeFloatBuffer( +template <> +TORCH_API void THP_decodeBuffer( float* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t x; float f; @@ -206,13 +201,13 @@ void THP_decodeFloatBuffer( } } -void THP_decodeDoubleBuffer( +template <> +TORCH_API void THP_decodeBuffer( double* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t x; double d; @@ -223,18 +218,17 @@ void THP_decodeDoubleBuffer( } } -void THP_decodeComplexFloatBuffer( +template <> +TORCH_API void THP_decodeBuffer, bool>( c10::complex* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t x; float re; }; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t y; float im; @@ -249,18 +243,17 @@ void THP_decodeComplexFloatBuffer( } } -void THP_decodeComplexDoubleBuffer( +template <> +TORCH_API void THP_decodeBuffer, bool>( c10::complex* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t x; double re; }; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t y; double im; @@ -276,150 +269,46 @@ void THP_decodeComplexDoubleBuffer( } } -void THP_decodeInt16Buffer( - int16_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeInt16Buffer(dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_decodeInt32Buffer( - int32_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeInt32Buffer(dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_decodeInt64Buffer( - int64_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeInt64Buffer(dst, src, (order != THP_nativeByteOrder()), len); -} +#define DEFINE_DECODE(TYPE, ORDER) \ + template TORCH_API void THP_decodeBuffer( \ + TYPE * dst, const uint8_t* src, ORDER type, size_t len); -void THP_decodeHalfBuffer( - c10::Half* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeHalfBuffer(dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_decodeBFloat16Buffer( - at::BFloat16* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeBFloat16Buffer(dst, src, (order != THP_nativeByteOrder()), len); -} +DEFINE_DECODE(int16_t, THPByteOrder) +DEFINE_DECODE(int32_t, THPByteOrder) +DEFINE_DECODE(int64_t, THPByteOrder) +DEFINE_DECODE(c10::Half, THPByteOrder) +DEFINE_DECODE(float, THPByteOrder) +DEFINE_DECODE(double, THPByteOrder) +DEFINE_DECODE(c10::BFloat16, THPByteOrder) +DEFINE_DECODE(c10::complex, THPByteOrder) +DEFINE_DECODE(c10::complex, THPByteOrder) -void THP_decodeFloatBuffer( - float* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len); -} +DEFINE_DECODE(int16_t, bool) +DEFINE_DECODE(int32_t, bool) +DEFINE_DECODE(int64_t, bool) -void THP_decodeDoubleBuffer( - double* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeDoubleBuffer(dst, src, (order != THP_nativeByteOrder()), len); -} +#undef DEFINE_DECODE -void THP_decodeComplexFloatBuffer( - c10::complex* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeComplexFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_decodeComplexDoubleBuffer( - c10::complex* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeComplexDoubleBuffer( - dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_encodeInt16Buffer( - uint8_t* dst, - const int16_t* src, - THPByteOrder order, - size_t len) { - memcpy(dst, src, sizeof(int16_t) * len); - if (order != THP_nativeByteOrder()) { - for (const auto i : c10::irange(len)) { - (void)i; - swapBytes16(dst); - dst += sizeof(int16_t); - } - } -} - -void THP_encodeInt32Buffer( - uint8_t* dst, - const int32_t* src, - THPByteOrder order, - size_t len) { - memcpy(dst, src, sizeof(int32_t) * len); - if (order != THP_nativeByteOrder()) { - for (const auto i : c10::irange(len)) { - (void)i; - swapBytes32(dst); - dst += sizeof(int32_t); - } - } -} - -void THP_encodeInt64Buffer( - uint8_t* dst, - const int64_t* src, - THPByteOrder order, - size_t len) { - memcpy(dst, src, sizeof(int64_t) * len); - if (order != THP_nativeByteOrder()) { - for (const auto i : c10::irange(len)) { - (void)i; - swapBytes64(dst); - dst += sizeof(int64_t); - } - } -} - -void THP_encodeFloatBuffer( - uint8_t* dst, - const float* src, - THPByteOrder order, - size_t len) { - memcpy(dst, src, sizeof(float) * len); - if (order != THP_nativeByteOrder()) { - for (const auto i : c10::irange(len)) { - (void)i; - swapBytes32(dst); - dst += sizeof(float); - } - } -} - -void THP_encodeDoubleBuffer( +template +void THP_encodeBuffer( uint8_t* dst, - const double* src, + const T* src, THPByteOrder order, size_t len) { - memcpy(dst, src, sizeof(double) * len); + memcpy(dst, src, sizeof(T) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(len)) { (void)i; - swapBytes64(dst); - dst += sizeof(double); + if constexpr (std::is_same_v) { + swapBytes16(dst); + } else if constexpr ( + std::is_same_v || std::is_same_v) { + swapBytes32(dst); + } else if constexpr ( + std::is_same_v || std::is_same_v) { + swapBytes64(dst); + } + dst += sizeof(T); } } } @@ -436,7 +325,8 @@ std::vector complex_to_float(const c10::complex* src, size_t len) { return new_src; } -void THP_encodeComplexFloatBuffer( +template <> +TORCH_API void THP_encodeBuffer>( uint8_t* dst, const c10::complex* src, THPByteOrder order, @@ -452,7 +342,8 @@ void THP_encodeComplexFloatBuffer( } } -void THP_encodeComplexDoubleBuffer( +template <> +TORCH_API void THP_encodeBuffer>( uint8_t* dst, const c10::complex* src, THPByteOrder order, @@ -468,4 +359,16 @@ void THP_encodeComplexDoubleBuffer( } } +#define DEFINE_ENCODE(TYPE) \ + template TORCH_API void THP_encodeBuffer( \ + uint8_t * dst, const TYPE* src, THPByteOrder order, size_t len); + +DEFINE_ENCODE(int16_t) +DEFINE_ENCODE(int32_t) +DEFINE_ENCODE(int64_t) +DEFINE_ENCODE(float) +DEFINE_ENCODE(double) + +#undef DEFINE_ENCODE + } // namespace torch::utils diff --git a/torch/csrc/utils/byte_order.h b/torch/csrc/utils/byte_order.h index 7fdbe4da8fe44f..d586270fbfc7c2 100644 --- a/torch/csrc/utils/byte_order.h +++ b/torch/csrc/utils/byte_order.h @@ -68,148 +68,13 @@ enum THPByteOrder { THP_LITTLE_ENDIAN = 0, THP_BIG_ENDIAN = 1 }; TORCH_API THPByteOrder THP_nativeByteOrder(); -TORCH_API void THP_decodeInt16Buffer( - int16_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeInt32Buffer( - int32_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeInt64Buffer( - int64_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeHalfBuffer( - c10::Half* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeFloatBuffer( - float* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeDoubleBuffer( - double* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, size_t len); -TORCH_API void THP_decodeBFloat16Buffer( - at::BFloat16* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeComplexFloatBuffer( - c10::complex* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeComplexDoubleBuffer( - c10::complex* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); - -TORCH_API void THP_decodeInt16Buffer( - int16_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeInt32Buffer( - int32_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeInt64Buffer( - int64_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeHalfBuffer( - c10::Half* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeFloatBuffer( - float* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeDoubleBuffer( - double* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeBFloat16Buffer( - at::BFloat16* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeFloat8_e5m2Buffer( - at::Float8_e5m2* dst, - const uint8_t* src, - size_t len); -TORCH_API void THP_decodeFloat8_e4m3fnBuffer( - at::Float8_e4m3fn* dst, - const uint8_t* src, - size_t len); -TORCH_API void THP_decodeFloat8_e5m2fnuzBuffer( - at::Float8_e5m2fnuz* dst, - const uint8_t* src, - size_t len); -TORCH_API void THP_decodeFloat8_e4m3fnuzBuffer( - at::Float8_e4m3fnuz* dst, - const uint8_t* src, - size_t len); -TORCH_API void THP_decodeComplexFloatBuffer( - c10::complex* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeComplexDoubleBuffer( - c10::complex* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); +template +TORCH_API void THP_decodeBuffer(T* dst, const uint8_t* src, U type, size_t len); -TORCH_API void THP_encodeInt16Buffer( - uint8_t* dst, - const int16_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeInt32Buffer( - uint8_t* dst, - const int32_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeInt64Buffer( - uint8_t* dst, - const int64_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeFloatBuffer( - uint8_t* dst, - const float* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeDoubleBuffer( - uint8_t* dst, - const double* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeComplexFloatBuffer( - uint8_t* dst, - const c10::complex* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeComplexDoubleBuffer( +template +TORCH_API void THP_encodeBuffer( uint8_t* dst, - const c10::complex* src, + const T* src, THPByteOrder order, size_t len); From 9e491037b8e66fa1a8375e289bc855f0755fa655 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 11 Sep 2024 21:27:53 +0000 Subject: [PATCH 0738/1018] Revert "[Partitioner] Reuse partition to check whether nodes exist (#135317)" This reverts commit e004d539da3335d97a8134c9081245628f18eb67. Reverted https://github.com/pytorch/pytorch/pull/135317 on behalf of https://github.com/izaitsevfb due to BC-breaking, breaks executorch and internal meta builds ([comment](https://github.com/pytorch/pytorch/pull/135317#issuecomment-2344730294)) --- test/test_fx_passes.py | 4 ++-- torch/fx/passes/infra/partitioner.py | 2 +- torch/fx/passes/utils/fuser_utils.py | 13 ++++++------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py index ac78aef325b4ce..e5ed6e078c9eab 100644 --- a/test/test_fx_passes.py +++ b/test/test_fx_passes.py @@ -360,7 +360,7 @@ def test_fuser_util(self, partition): partitions = [] for node_names in partition: - partitions.append(dict.fromkeys([nodes_by_name[name] for name in node_names])) + partitions.append([nodes_by_name[name] for name in node_names]) fused_graph = fuse_by_partitions(gm, partitions) @@ -385,7 +385,7 @@ def test_fuser_util_xfail(self, partition): partitions = [] for node_names in partition: - partitions.append(dict.fromkeys([nodes_by_name[name] for name in node_names])) + partitions.append([nodes_by_name[name] for name in node_names]) with self.assertRaises(Exception): fuse_by_partitions(gm, partitions) diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 0df74ceba71e8b..271f90a7b75e8b 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -268,7 +268,7 @@ def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") - # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] return fuse_by_partitions( self.graph_module, - [partition.nodes for partition in partitions], + [list(partition.nodes) for partition in partitions], prefix=prefix, ) diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 3375472528b6a6..11a9cfa34898ab 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -92,7 +92,6 @@ def bfs_find_cycle(root_nodes: NodeList) -> bool: @compatibility(is_backward_compatible=False) def fuse_as_graphmodule(gm: GraphModule, nodes: NodeList, - partition: Dict[Node, None], module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: """ @@ -136,7 +135,7 @@ def remap_inputs(x): # do something here pass - if x in partition: + if x in nodes: # x is inside subgraph, return the copied node # the node should have been copied aleady, as we are copying graph in the topological order return node_map[x] @@ -160,7 +159,7 @@ def remap_inputs(x): for node in nodes: for user_node in node.users: - if user_node not in partition: + if user_node not in nodes: # external user node, need to expose as an output output_mapping[node] = node_map[node] @@ -220,12 +219,12 @@ def erase_nodes(gm: GraphModule, nodes: NodeList): @compatibility(is_backward_compatible=False) -def fuse_by_partitions(gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_") -> GraphModule: - for partition_id, partition in enumerate(partitions): - sorted_nodes = topo_sort(list(partition)) +def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList], prefix: str = "fused_") -> GraphModule: + for partition_id, nodes in enumerate(partitions): + sorted_nodes = topo_sort(nodes) submodule_name = prefix + str(partition_id) - sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, partition, submodule_name) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) From adddc6dd42e9dc2730b4b2a333d7ce1e49979aa8 Mon Sep 17 00:00:00 2001 From: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Date: Wed, 11 Sep 2024 21:29:02 +0000 Subject: [PATCH 0739/1018] [ONNX] Update fake mode usage in onnx docs (#135512) Update fake mode usage in onnx docs Pull Request resolved: https://github.com/pytorch/pytorch/pull/135512 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu --- torch/onnx/_internal/_exporter_legacy.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 71b828b9bceb7a..0222da61cfef48 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -395,8 +395,7 @@ def enable_fake_mode(): are too large to fit into memory. Returns: - A :class:`ONNXFakeContext` object that must be passed to :func:`dynamo_export` - through the :attr:`ExportOptions.fake_context` argument. + A :class:`ONNXFakeContext` object. Example:: @@ -414,15 +413,16 @@ def enable_fake_mode(): ... my_nn_module = MyModel() ... arg1 = torch.randn(2, 2, 2) # positional input 1 >>> export_options = torch.onnx.ExportOptions(fake_context=fake_context) - >>> onnx_program = torch.onnx.dynamo_export( - ... my_nn_module, - ... arg1, - ... export_options=export_options - ... ) + >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True) + >>> onnx_program.apply_weights(MyModel().state_dict()) >>> # Saving model WITHOUT initializers - >>> onnx_program.save("my_model_without_initializers.onnx") + >>> onnx_program.save( + ... "my_model_without_initializers.onnx", + ... include_initializers=False, + ... keep_initializers_as_inputs=True, + ... ) >>> # Saving model WITH initializers - >>> onnx_program.save("my_model_with_initializers.onnx", model_state=MyModel().state_dict()) + >>> onnx_program.save("my_model_with_initializers.onnx") .. warning:: This API is experimental and is *NOT* backward-compatible. From 782ca2651fceb67d3d5a66a23b4a7a59e9ee248c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 11 Sep 2024 22:02:04 +0000 Subject: [PATCH 0740/1018] fix for fp16 (#134106) This PR is a replacement for https://github.com/pytorch/pytorch/pull/133085 for pushing a quick fix for RMSNorm. The original author is @kkontny Previous PR summary: Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation. I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor. Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability. ``` class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ``` Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134106 Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy --- aten/src/ATen/native/layer_norm.cpp | 8 +++++++- torch/testing/_internal/common_methods_invocations.py | 2 +- torch/testing/_internal/common_modules.py | 4 +++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index b11bcaba38e689..6e0abb14f8be91 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -295,7 +296,12 @@ Tensor rms_norm( eps_val = eps.value(); } - auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val))); + // upcast is needed for fp16 and bf16 + c10::ScalarType opmath_t = toOpMathType(input.scalar_type()); + Tensor upcasted_input = input.to(opmath_t); + + Tensor rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)); + Tensor result = upcasted_input.mul(rqrst_input).type_as(input); if (weight_opt.has_value()) { result = result.mul(weight_opt.value()); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a6f67d75659689..a174154349f8e4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -4541,7 +4541,7 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar ) def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000) # Ordered as input shape, normalized_shape and a kwarg dict for eps cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 00251ca264f840..63963bab1b050a 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1937,7 +1937,9 @@ def rms_norm_reference_fn(m, p, i): normalized_shape = m.normalized_shape weight = m.weight dims = [ndim - i - 1 for i in range(len(normalized_shape))] - result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps) + upcasted_i = i.float() + result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps) + result = result.type_as(i) if weight is not None: result *= weight return result From 8618a6ea3c507b85711d5387bce577007bcabd1e Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Wed, 11 Sep 2024 22:32:52 +0000 Subject: [PATCH 0741/1018] Fix/torch cat doc attr (#135698) The `torch.cat` attr name for tensors in the docs differs from the method signature, unlike other methods. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135698 Approved by: https://github.com/albanD Co-authored-by: Alexander Jipa --- torch/_torch_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 70f7192ee2d73c..8d83872f863802 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2276,7 +2276,7 @@ def merge_dicts(*dicts): r""" cat(tensors, dim=0, *, out=None) -> Tensor -Concatenates the given sequence of :attr:`seq` tensors in the given dimension. +Concatenates the given sequence of tensors in :attr:`tensors` in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size ``(0,)``. From 0f8a70829061d867a0a182b04dcee8d9c7812195 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 11 Sep 2024 12:39:07 -0700 Subject: [PATCH 0742/1018] [Traceable FSDP2] Use .copy_ instead of .set_ for unsharded_param inplace update; Replace unsharded_param graph input usage with graph intermediate; Support FSDP2+LoRA (#133730) Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good). This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern. ------ Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_` - `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries` - `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients` - `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching` - `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager` - `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32` - `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730 Approved by: https://github.com/bdhirsh --- .../fsdp/test_fully_shard_compile.py | 287 ++++++++++++------ test/dynamo/test_repros.py | 4 +- test/functorch/test_aotdispatch.py | 53 ---- test/inductor/test_distributed_patterns.py | 47 +-- test/inductor/test_pattern_matcher.py | 2 +- .../_aot_autograd/functional_utils.py | 4 +- torch/_functorch/partitioners.py | 1 + torch/_inductor/comms.py | 204 ++++++++++++- torch/_inductor/fx_passes/post_grad.py | 4 + torch/_inductor/lowering.py | 16 +- .../_composable/fsdp/_fsdp_param.py | 96 +++--- 11 files changed, 482 insertions(+), 236 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index b187c867694171..873fc4ba8bfaa2 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -4,7 +4,10 @@ import contextlib import copy import functools +import itertools +import logging import unittest +from collections import defaultdict from unittest import mock import torch @@ -30,8 +33,11 @@ from torch.utils._triton import has_triton -def _is_op_in_graph(graph, op): - return any(node.target is op for node in graph.nodes) +log = logging.getLogger(__name__) + + +def _count_op_in_graph(graph, op): + return sum(1 for node in graph.nodes if node.target is op) def _is_fallback_op_in_snodes(snodes, op): @@ -130,7 +136,7 @@ def f(x): self.assertEqual(cnt.op_count, 1) self.assertEqual(len(cnt.graphs), 1) - def test_trace_fsdp_set_(self): + def test_trace_fsdp_copy_(self): @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"}) def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None: torch.add(x, 1, out=out) @@ -140,7 +146,7 @@ def f(x): buf_view = buf.view(-1) torch.ops.mylib.add_one_out(x, out=buf_view) buf_view2 = buf.view(-1) - torch.ops.fsdp.set_(x, buf_view2) + torch.ops.fsdp.copy_(x, buf_view2) ref_x = torch.zeros(2) x = copy.deepcopy(ref_x) @@ -148,26 +154,80 @@ def f(x): torch.compile(f, backend="aot_eager")(x) self.assertEqual(x, ref_x) + def _assert_no_aliased_graph_inputs(self, graph: torch.fx.Graph) -> None: + storage_id_to_graph_inputs = defaultdict(list) + for node in graph.nodes: + if node.op == "placeholder" and isinstance( + node.meta.get("val", None), torch.Tensor + ): + storage_id_to_graph_inputs[ + id(node.meta["val"].untyped_storage()) + ].append(node) + no_aliased_graph_inputs = True + err_msg = "" + for aliased_graph_inputs in storage_id_to_graph_inputs.values(): + if len(aliased_graph_inputs) > 1: + no_aliased_graph_inputs = False + err_msg += f"""\n +Found aliased graph inputs: {aliased_graph_inputs}, +val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, +""" + self.assertTrue(no_aliased_graph_inputs, err_msg) + + def _check_fsdp_copy_and_resize_ops_count_in_graph( + self, + graph, + *, + fwd_copy_count, + fwd_resize_count, + bwd_copy_count, + bwd_resize_count, + ): + def _check_count(copy_count, resize_count): + actual_copy_count = _count_op_in_graph(graph, torch.ops.fsdp.copy_.default) + self.assertEqual( + actual_copy_count, + copy_count, + f"Unexpected number of `fsdp.copy_` ops (expected {copy_count}, got {actual_copy_count}) in graph: {graph}", + ) + + actual_resize_count = _count_op_in_graph( + graph, torch.ops.inductor.resize_storage_bytes_.default + ) + self.assertEqual( + actual_resize_count, + resize_count, + f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950 + ) + + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: + _check_count(fwd_copy_count, fwd_resize_count) # fwd graph + else: + _check_count(bwd_copy_count, bwd_resize_count) # bwd graph + def _reinplace_all_gather_with_optional_checks(self, fullgraph): def _run_with_checks(graph, orig_fn): - self.assertTrue( - _is_op_in_graph( - graph, - torch.ops._c10d_functional.all_gather_into_tensor.default, - ) + self.assertGreater( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, ) + orig_fn(graph) - self.assertFalse( - _is_op_in_graph( - graph, - torch.ops._c10d_functional.all_gather_into_tensor.default, - ) + + self.assertEqual( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, ) - self.assertTrue( - _is_op_in_graph( - graph, - torch.ops._c10d_functional.all_gather_into_tensor_out.default, - ) + + self.assertGreater( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default + ), + 0, ) if fullgraph: @@ -266,8 +326,6 @@ def inductor_code_check_fsdp_all_gather( self, file_check, overlapped_compute_op_str, - num_resize, - num_set, last_all_gather=False, ): file_check = file_check.check("torch.ops.fsdp.all_gather_copy_in.") @@ -278,16 +336,9 @@ def inductor_code_check_fsdp_all_gather( # Checks that AGWait is delayed, making the AG overlap with some compute op. if overlapped_compute_op_str is not None: file_check = file_check.check(f"{overlapped_compute_op_str}") - file_check = file_check.check_count( - "inductor_ops.resize_storage_bytes_(", num_resize, exactly=True - ) file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.") file_check = self.inductor_code_check_no_compute_op(file_check) file_check = file_check.check("torch.ops.fsdp.split_with_sizes_copy.") - file_check = self.inductor_code_check_no_compute_op(file_check) - file_check = file_check.check_count( - "torch.ops.aten.set_.", num_set, exactly=True - ) if not last_all_gather: # Checks that there is no compute op between this AGWait and next AG. file_check = self.inductor_code_check_no_compute_op(file_check) @@ -307,20 +358,6 @@ def inductor_code_check_fsdp_reduce_scatter( file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.") return file_check - @torch._dynamo.config.patch( - inline_inbuilt_nn_modules=True, - skip_fsdp_hooks=False, - ) - @torch._functorch.config.patch(recompute_views=True) - @torch._functorch.config.patch(cse=False) - @torch._inductor.config.patch( - reorder_for_compute_comm_overlap=True, - reorder_for_compute_comm_overlap_passes=[ - "sink_waits", - "raise_comms", - "reorder_compute_for_overlap", - ], - ) def _test_traceable_fsdp( self, model_init_fn, input_creation_fn, backend, fullgraph ): @@ -334,7 +371,12 @@ def _fn(gm): return _fn - def run_iters(model, optim, n_iter=10, compiled_autograd_backend=None): + def run_iters( + model, + optim, + n_iter=10, + compiled_autograd_backend=None, + ): torch.manual_seed(42) losses = [] for i in range(n_iter): @@ -360,7 +402,11 @@ def test_compiled(): run_iters(model, optim, n_iter=1) model_compiled = torch.compile(model, backend=backend, fullgraph=fullgraph) - res = run_iters(model_compiled, optim, compiled_autograd_backend=backend) + res = run_iters( + model_compiled, + optim, + compiled_autograd_backend=backend, + ) return res def test_eager(): @@ -371,7 +417,23 @@ def test_eager(): res = run_iters(model, optim) return res - losses_compiled = test_compiled() + with torch._dynamo.config.patch( + inline_inbuilt_nn_modules=True, + skip_fsdp_hooks=False, + ), torch._functorch.config.patch( + recompute_views=True, cse=False + ), torch._inductor.config.patch( + reorder_for_compute_comm_overlap=True, + reorder_for_compute_comm_overlap_passes=[ + "sink_waits", + "raise_comms", + "reorder_compute_for_overlap", + ], + post_grad_custom_pre_pass=self._assert_no_aliased_graph_inputs + if fullgraph + else None, + ): + losses_compiled = test_compiled() losses_eager = test_eager() if not self.fake_pg: for loss_compiled, loss_eager in zip(losses_compiled, losses_eager): @@ -448,9 +510,9 @@ def __init__(self, hidden_dim): ) def forward(self, x): + ret = torch.matmul(x, self.param1) if not fullgraph: torch._dynamo.graph_break() - ret = torch.matmul(x, self.param1) ret = ret * self.param2 ret = torch.relu(ret) return ret @@ -519,7 +581,19 @@ def test_nested_fully_shard_backend_inductor(self): for fullgraph in [True, False]: with self._reinplace_all_gather_with_optional_checks( fullgraph - ), self._maybe_run_decide_global_ordering_of_comms_with_checks(fullgraph): + ), self._maybe_run_decide_global_ordering_of_comms_with_checks( + fullgraph + ), torch._inductor.config.patch( + post_grad_custom_post_pass=functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + fwd_copy_count=0, + fwd_resize_count=0, + bwd_copy_count=0, + bwd_resize_count=0, + ) + if fullgraph + else None + ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( *self._create_nested_fully_shard_factory_fns( @@ -537,46 +611,30 @@ def test_nested_fully_shard_backend_inductor(self): fwd_code = triton_codes[0] file_check = FileCheck().check("def call(args):") for fwd_ag_block_info in [ - dict(overlapped_compute_op_str=None, num_resize=0, num_set=2), + dict(overlapped_compute_op_str=None), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, last_all_gather=True, ), ]: @@ -588,16 +646,12 @@ def test_nested_fully_shard_backend_inductor(self): bwd_code = triton_codes[1] file_check = FileCheck().check("def call(args):") for bwd_ag_block_info in [ - dict(overlapped_compute_op_str=None, num_resize=0, num_set=2), + dict(overlapped_compute_op_str=None), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=0, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=0, - num_set=2, last_all_gather=True, ), ]: @@ -605,7 +659,7 @@ def test_nested_fully_shard_backend_inductor(self): file_check, **bwd_ag_block_info ) for bwd_rs_block_info in [ - dict(overlapped_compute_op_str="extern_kernels.mm("), + dict(overlapped_compute_op_str="extern_kernels.addmm("), dict( overlapped_compute_op_str=None ), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None @@ -623,9 +677,10 @@ def test_nested_fully_shard_backend_inductor(self): "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", ) - def _create_transformer_factory_fns(self): + def _create_transformer_factory_fns(self, all_requires_grad): seq_len = 16 vocab_size = 8 + n_layers = 3 def model_init_fn(): torch.manual_seed(self.rank) @@ -633,9 +688,20 @@ def model_init_fn(): mesh = init_device_mesh("cuda", (self.world_size,)) model_args = ModelArgs( vocab_size=vocab_size, - n_layers=3, + n_layers=n_layers, ) model = Transformer(model_args) + if not all_requires_grad: + requires_grad_params = ["attention.wq", "attention.wv"] + requires_grad_param_count = 0 + for k, v in model.named_parameters(): + for substring in requires_grad_params: + if substring in k: + v.requires_grad_(True) + requires_grad_param_count += 1 + else: + v.requires_grad_(False) + assert requires_grad_param_count == n_layers * len(requires_grad_params) for layer_id, mod in enumerate(model.layers): fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config) model = fully_shard( @@ -672,12 +738,16 @@ def _sdpa_with_graph_break(orig_fn, fullgraph, *args, **kwargs): @skipIfRocm @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_transformer_backend_aot_eager(self): - for fullgraph in [True, False]: + for fullgraph, all_requires_grad in itertools.product( + [True, False], [True, False] + ): with self._maybe_add_graph_break_to_sdpa( fullgraph ), self._reinplace_all_gather_with_optional_checks(fullgraph): self._test_traceable_fsdp( - *self._create_transformer_factory_fns(), + *self._create_transformer_factory_fns( + all_requires_grad=all_requires_grad + ), "aot_eager", fullgraph=fullgraph, ) @@ -687,10 +757,14 @@ def test_transformer_backend_aot_eager(self): # TODO: native_dropout has worse accuracy after decomp, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_aot_eager_decomp_partition(self): - for fullgraph in [True, False]: + for fullgraph, all_requires_grad in itertools.product( + [True, False], [True, False] + ): with self._maybe_add_graph_break_to_sdpa(fullgraph): self._test_traceable_fsdp( - *self._create_transformer_factory_fns(), + *self._create_transformer_factory_fns( + all_requires_grad=all_requires_grad + ), "aot_eager_decomp_partition", fullgraph=fullgraph, ) @@ -700,17 +774,36 @@ def test_transformer_backend_aot_eager_decomp_partition(self): # TODO: native_dropout causes CUDA IMA error, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_inductor(self): - for fullgraph in [True, False]: + # TODO: enable fullgraph=False case + for fullgraph, all_requires_grad in itertools.product([True], [True, False]): + log.warning( + f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}" # noqa: G004, G001 + ) with self._maybe_add_graph_break_to_sdpa( fullgraph ), self._reinplace_all_gather_with_optional_checks( fullgraph ), self._maybe_run_decide_global_ordering_of_comms_with_checks( fullgraph + ), torch._inductor.config.patch( + post_grad_custom_post_pass=functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + # NOTE: For the root unsharded params, we don't reshard after forward since for training, + # the parameters would be freed and all-gathered immediately. Hence we still have + # their resize and copy ops in the graph. + fwd_copy_count=4, + fwd_resize_count=4, + bwd_copy_count=0, + bwd_resize_count=4, + ) + if fullgraph + else None ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( - *self._create_transformer_factory_fns(), + *self._create_transformer_factory_fns( + all_requires_grad=all_requires_grad + ), "inductor", fullgraph=fullgraph, ) @@ -723,21 +816,19 @@ def test_transformer_backend_inductor(self): fwd_code = triton_codes[0] file_check = FileCheck().check("def call(args):") for fwd_ag_block_info in [ - dict(overlapped_compute_op_str="triton_", num_resize=0, num_set=4), + dict( + overlapped_compute_op_str="triton_" + if all_requires_grad + else None, + ), dict( overlapped_compute_op_str="aten.native_dropout.", - num_resize=0, - num_set=12, ), dict( overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.", - num_resize=12, - num_set=12, ), dict( overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.", - num_resize=12, - num_set=12, last_all_gather=True, ), ]: @@ -751,35 +842,33 @@ def test_transformer_backend_inductor(self): for bwd_ag_block_info in [ dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=0, - num_set=12, ), dict( overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.", - num_resize=0, - num_set=12, ), dict( overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.", - num_resize=0, - num_set=12, last_all_gather=True, ), ]: - file_check = self.inductor_code_check_fsdp_all_gather( - file_check, **bwd_ag_block_info - ) + if bwd_ag_block_info is not None: + file_check = self.inductor_code_check_fsdp_all_gather( + file_check, **bwd_ag_block_info + ) for bwd_rs_block_info in [ - dict(overlapped_compute_op_str="extern_kernels.mm("), + dict(overlapped_compute_op_str="extern_kernels.mm(") + if all_requires_grad + else None, dict( overlapped_compute_op_str=None ), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None dict(overlapped_compute_op_str=None), - dict(overlapped_compute_op_str=None), + dict(overlapped_compute_op_str=None) if all_requires_grad else None, ]: - file_check = self.inductor_code_check_fsdp_reduce_scatter( - file_check, **bwd_rs_block_info - ) + if bwd_rs_block_info is not None: + file_check = self.inductor_code_check_fsdp_reduce_scatter( + file_check, **bwd_rs_block_info + ) file_check.run(bwd_code) else: # TODO: when fullgraph=False and there is graph break in FWD graph, diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index a3c4226ed6a9f8..66ca040dd4170a 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5557,7 +5557,7 @@ def f(x, l): z0 = x.sin() z1 = x.sin() y = x + 1 - torch.ops.fsdp.set_.default(x, y) + torch.ops.fsdp.copy_.default(x, y) # z3 and z3 can be CSEd with each other, # but *not* with z0/z1 (they cross a mutation boundary) z2 = x.sin() @@ -5589,7 +5589,7 @@ def f(x, l): z = x.sin() y = x + 1 # graph input has its storage mutated - torch.ops.fsdp.set_.default(x, y) + torch.ops.fsdp.copy_.default(x, y) z2 = x.sin() return z2, l**2 diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index f62d46df470bac..45462674fe8eec 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -725,59 +725,6 @@ def forward(self, primals_1): return (add,)""", ) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix the test case") - @unittest.skipIf(IS_MACOS, "TODO: need to fix the test case") - def test_input_mutation_fsdp_set__into_same_input(self): - import torch.distributed._composable.fsdp._fsdp_param - - def f(a): - b = torch.arange(9, dtype=a.dtype).view(3, 3) - c = torch.arange(9, dtype=a.dtype).view(3, 3) - d = torch.arange(9, dtype=a.dtype).view(3, 3) - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): - torch.ops.fsdp.set_.default(a, b) - x = a * a - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): - torch.ops.fsdp.set_.default(a, c) - y = a * a - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): - torch.ops.fsdp.set_.default(a, c) - z = a * a - return x + y + z - - inp = [torch.ones(3, 3, requires_grad=True)] - fw_graph = self.verify_aot_autograd( - f, inp, test_mutation=True, keep_inp_mutations=True - ) - inp = [torch.ones(3, 3, requires_grad=False)] - self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) - """ - Expected behavior: - (1) When there are multiple set_() calls on the same graph input primal_X, - we want those set_() calls to all show up with primal_X as the first arg in the graph. - (2) Behavior (1) is not the case today with normal aten.set_ (blocked on #129892), - but using a custom fsdp.set_ op with no returns is a simple workaround to achieve that behavior. - """ - self.assertExpectedInline( - fw_graph.code.strip(), - """\ -def forward(self, primals_1): - arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - view = torch.ops.aten.view.default(arange, [3, 3]); arange = None - arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - view_1 = torch.ops.aten.view.default(arange_1, [3, 3]); arange_1 = None - set_ = torch.ops.fsdp.set_.default(primals_1, view); view = set_ = None - mul = torch.ops.aten.mul.Tensor(primals_1, primals_1) - set__1 = torch.ops.fsdp.set_.default(primals_1, view_1); set__1 = None - mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1) - set__2 = torch.ops.fsdp.set_.default(primals_1, view_1); view_1 = set__2 = None - mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1) - add = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None - add_1 = torch.ops.aten.add.Tensor(add, mul_2); add = mul_2 = None - return (add_1, primals_1)""", - ) - self.assertEqual(torch.compile(f, backend="inductor")(*inp), f(*inp)) - def test_input_mutation_simple_with_none_and_nontensor(self): # Tensor, None, int def f(a, b, c): diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index 9aa070e979a0ae..fd446370434e94 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -26,22 +26,13 @@ def reduce_scatter(t): return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone() def fw_pre_hook(mod, inp): - if not compiled_autograd.compiled_autograd_enabled: - # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead. - mod.unsharded_weight.untyped_storage().resize_( - mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() - ) - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight - ): - mod.unsharded_weight.copy_(all_gather(mod.sharded_weight)) - else: - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight - ): - torch.ops.fsdp.set_( - mod.unsharded_weight, all_gather(mod.sharded_weight) - ) + mod.unsharded_weight.untyped_storage().resize_( + mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() + ) + with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( + mod.unsharded_weight + ): + torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight)) mod._parameters["weight"] = mod.unsharded_weight # Forward: @@ -58,22 +49,13 @@ def fw_post_hook(mod, inp, out): mod.unsharded_weight.untyped_storage().resize_(0) def bw_pre_hook(mod, gO): - if not compiled_autograd.compiled_autograd_enabled: - # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead. - mod.unsharded_weight.untyped_storage().resize_( - mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() - ) - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight - ): - mod.unsharded_weight.copy_(all_gather(mod.sharded_weight)) - else: - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight - ): - torch.ops.fsdp.set_( - mod.unsharded_weight, all_gather(mod.sharded_weight) - ) + mod.unsharded_weight.untyped_storage().resize_( + mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() + ) + with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( + mod.unsharded_weight + ): + torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight)) mod._parameters["weight"] = mod.unsharded_weight # Backward: @@ -465,7 +447,6 @@ def test_fake_distributed_aot_eager(self): @requires_gpu() @torch._functorch.config.patch(recompute_views=True) def test_fake_distributed_inductor(self): - # TODO: fix .set_ lowering in CPU inductor, and enable the CPU test. m1, inp1 = init_fake_distributed(GPU_TYPE) out1 = steps(m1, inp1) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 4c7ad746c6cc3f..5f34b9ceb69a25 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -1447,7 +1447,7 @@ def check(type, func_name, args, kwargs, expect=True): (t, [64, 128, 8, 8]), {"dim": 1, "out": [t, t, t, t]}, ) - check("call_function", torch.ops.fsdp.set_, (t, t), {}) + check("call_function", torch.ops.fsdp.copy_, (t, t), {}) check( "call_function", torch.ops.aten.__rshift__.Scalar, (t, 2), {}, expect=False ) diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 48e4ba84cf3c02..71862997ae0717 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -405,8 +405,8 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: torch.ops.aten.copy_.default, torch.ops.aten.set_.source_Tensor, ] - if hasattr(torch.ops.fsdp, "set_"): - allowed_mutation_ops.append(torch.ops.fsdp.set_.default) + if hasattr(torch.ops.fsdp, "copy_"): + allowed_mutation_ops.append(torch.ops.fsdp.copy_.default) placeholders = set() mutation_count = 0 diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index b0f609dc75615c..4263a2f2b10a33 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1272,6 +1272,7 @@ def get_default_op_list() -> OpTypes: aten.expand, aten.as_strided, aten.permute, + aten.select, ] view_ops = recomputable_view_ops default_recomputable_ops += [ diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index dcad1e1bf67c98..7851e154d7386c 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -3,12 +3,14 @@ from __future__ import annotations import heapq +import logging import operator import sys from collections import defaultdict from typing import Dict, List, Set, TYPE_CHECKING import torch +from torch.multiprocessing.reductions import StorageWeakRef from . import config, ir from .dependencies import WeakDep @@ -23,6 +25,7 @@ ) +log = logging.getLogger(__name__) overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") if TYPE_CHECKING: @@ -342,6 +345,202 @@ def reorder_compute_and_comm_for_overlap( return order +def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph): + """ + This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding + graph intermediates that were fsdp.copy_ into the unsharded params in the original graph. + + NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern + (or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case + where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't + remove these resize and copy ops and thus we will have worse performance there. + + In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param" + is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern + (in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed. + """ + node_list = list(graph.nodes) + + # Find all graph inputs and their resize counts + graph_input_to_resized_to_full_node_idxes = defaultdict(list) + graph_input_to_resized_to_0_node_idxes = defaultdict(list) + for idx, node in enumerate(node_list): + if ( + node.op == "call_function" + and node.target == torch.ops.inductor.resize_storage_bytes_.default + ): + assert ( + node.args[0].op == "placeholder" + ), f"""\ +Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]} +""" + graph_input = node.args[0] + new_size = node.args[1] + if new_size > 0: + graph_input_to_resized_to_full_node_idxes[graph_input].append(idx) + else: + graph_input_to_resized_to_0_node_idxes[graph_input].append(idx) + + def check_resize_pattern(graph_input): + # Check the number of resize-to-full and resize-to-0 nodes are equal, + # and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node + # always happens before the resize-to-0 node. + # This is the precondition for being able to remove all the resize and copy nodes + # for this specific unsharded param. + resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get( + graph_input, [] + ) + resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, []) + + if not len(resized_to_full_idxes) == len(resized_to_0_idxes): + log.warning( + f""" +Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}: +{len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}. +Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass. +""" # noqa: G004 + ) + return False + + # Check the sequence: (resize_to_full -> resize_to_0)+ + for resize_to_full_idx, resize_to_0_idx in zip( + resized_to_full_idxes, resized_to_0_idxes + ): + if resize_to_full_idx >= resize_to_0_idx: + log.warning( + f""" +For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx} +happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}. +Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param. +""" # noqa: G004 + ) + return False + return True + + # Find all eligible unsharded params and their corresponding graph intermediates. + unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list) + for idx, node in enumerate(node_list): + if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default: + fsdp_copy_node = node + unsharded_param = node.args[0] + assert ( + unsharded_param.op == "placeholder" + ), f""" +Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true! +Offending node: {unsharded_param}. Graph: {graph} +""" + if check_resize_pattern(unsharded_param): + unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx) + + def is_allowed_mutation(node): + return ( + node.target == torch.ops.fsdp.copy_.default + or node.target == torch.ops.inductor.resize_storage_bytes_.default + ) + + def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params): + # Check whether the node is mutating any of the unsharded params or their aliases. + mutated_arg_idxes = ( + [ + i + for i, x in enumerate(node.target._schema.arguments) + if x.alias_info is not None and x.alias_info.is_write + ] + if isinstance(node.target, torch._ops.OpOverload) + else [] + ) + mutated_node_arg_storages = { + StorageWeakRef(node.args[i].meta["val"].untyped_storage()) + for i in mutated_arg_idxes + } + storages_of_unsharded_params = { + StorageWeakRef(unsharded_param.meta["val"].untyped_storage()) + for unsharded_param in unsharded_params + } + return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0 + + # Check no user mutation on any unsharded_param + for node in node_list: + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and node.target._schema.is_mutable + and not is_allowed_mutation(node) + ): + assert not is_node_mutating_unsharded_param_or_its_alias( + node, unsharded_param_to_fsdp_copy_node_idxes.keys() + ), f"""\ +User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node} +""" + + # For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`. + # + # NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input. + # e.g. + # ``` + # fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1) + # ... (use of unsharded_param_1) -> Subgraph 1 + # fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2) + # ... (use of unsharded_param_1) -> Subgraph 2 + # fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3) + # ... (use of unsharded_param_1) -> Subgraph 3 + # ``` + # We must do the replacement only within each subgraph. + for ( + unsharded_param, + fsdp_copy_node_idxes, + ) in unsharded_param_to_fsdp_copy_node_idxes.items(): + for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes): + fsdp_copy_node = node_list[fsdp_copy_node_idx] + assert fsdp_copy_node.args[0] is unsharded_param + _, replacement = fsdp_copy_node.args + # subgraph_start_idx is exclusive + subgraph_start_idx = fsdp_copy_node_idx + 1 + # subgraph_end_idx is exclusive (also intentionally don't replace args in return op) + subgraph_end_idx = ( + fsdp_copy_node_idxes[i + 1] + if i < len(fsdp_copy_node_idxes) - 1 + else len(node_list) - 1 + ) + subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx] + assert not any( + is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param]) + for node in subgraph_nodes + ), f"""\ +Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true! +Graph: {graph} +""" + for node in subgraph_nodes: + if ( + node.op == "call_function" + and unsharded_param in node.args + and node.target != torch.ops.inductor.resize_storage_bytes_.default + ): # TODO(yf225): implement replacement in kwargs + new_args = tuple( + replacement if arg is unsharded_param else arg + for arg in node.args + ) + node.args = new_args + + # Delete `fsdp.copy_(unsharded_param, Y)` nodes + for ( + unsharded_param, + fsdp_copy_node_idxes, + ) in unsharded_param_to_fsdp_copy_node_idxes.items(): + for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes): + fsdp_copy_node = node_list[fsdp_copy_node_idx] + graph.erase_node(fsdp_copy_node) + + # Delete `resize_(unsharded_param, ...)` nodes + for node in node_list: + if ( + node.op == "call_function" + and node.target == torch.ops.inductor.resize_storage_bytes_.default + and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes + ): + graph.erase_node(node) + + def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: try: import torch.distributed._composable.fsdp._fsdp_collectives @@ -509,12 +708,11 @@ def _create_group_node(snodes_to_group): name_to_fused_node, ) - # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block + # Find the "all_gather + all_gather_wait_tensor + copy_out" code block allowed_ops = { torch.ops._c10d_functional.all_gather_into_tensor_out.default, torch.ops._c10d_functional.wait_tensor.default, torch.ops.fsdp.split_with_sizes_copy.default, - torch.ops.aten.set_.source_Tensor, } find_recursive_users_of_node( ag_snode, @@ -560,7 +758,7 @@ def _create_group_node(snodes_to_group): assert wait_node_idx is not None ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) - # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode + # Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 4e73931fcef23b..640c0a8ffa18bd 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -22,6 +22,7 @@ from .. import config, ir, pattern_matcher from ..codegen.common import BackendFeature, has_backend_feature +from ..comms import remove_fsdp2_unsharded_param_graph_input_usage from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage from ..lowering import lowerings as L from ..pattern_matcher import ( @@ -76,6 +77,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): The IR here has been normalized and functionalized. """ + if not torch._dynamo.config.skip_fsdp_hooks: + remove_fsdp2_unsharded_param_graph_input_usage(gm.graph) + if config.dce: # has some issues with mutation in inference mode gm.graph.eliminate_dead_code() diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 988453a86735de..2bdeac572ddd9c 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -6102,13 +6102,17 @@ def set__source_tensor(self, source_tensor): return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor)) -if hasattr(torch.ops.fsdp, "set_"): +if hasattr(torch.ops.fsdp, "copy_"): - @register_lowering(torch.ops.fsdp.set_.default) - def fsdp_set_(self, source_tensor): - self.realize() - source_tensor.realize() - ir.SetSourceTensorKernel(self, source_tensor) + @register_lowering(torch.ops.fsdp.copy_.default) + def fsdp_copy_(dst, src): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) @register_lowering(torch.ops.aten.resize) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index f6fc631ec8ac46..59c2ff7589d3a3 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -65,73 +65,77 @@ lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 -lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()") +lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()") -@torch.library.impl(lib, "set_", "Meta") -@torch.library.impl(lib, "set_", "CUDA") -@torch.library.impl(lib, "set_", "CPU") -def set_(tensor, data): - tensor.set_(data) +@torch.library.impl(lib, "copy_", "Meta") +@torch.library.impl(lib, "copy_", "CUDA") +@torch.library.impl(lib, "copy_", "CPU") +def copy_(tensor, data): + tensor.copy_(data) """ -[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)] +[Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_] -Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op +Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op (i.e. they show up as a mutation op in the middle of the AOT joint graph). Reason: Traceable FSDP2 compiled autograd BWD graph have the following traits: (1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors). -(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param). +(2) One of them is mutated (copy_ and resize_ to handle the all-gathered param). (3) They are both subclasses. The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing). So this doesn't work at all for Traceable FSDP2. -The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops. +The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops. This avoids the problem above, because from AOTAutograd point-of-view there are no mutations that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.) We can avoid this functionalization because: -(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created), -so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream. +(1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created), +so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream. (2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops. So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay (since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore). -Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places? -A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process. -Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner -(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to -make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph. -This requires a custom FX pass but we believe it's not hard to write and maintain. - -Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors? -A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use, -so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input. -This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original -nn.Parameter in order to see the result of .set_. +Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance? +A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops +for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0). + +TODO: +Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward, +it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are: + +Cons (of functionalizing those ops): +(1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time +in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them +in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse +peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_: +https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089) + +Pros (of functionalizing): +(1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations +mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more). +(2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it. +But maybe there are few enough mutations induced by FSDP for this to matter. """ -@torch.library.impl(lib, "set_", "Functionalize") -def set__functionalize(tensor, data): +@torch.library.impl(lib, "copy_", "Functionalize") +def copy__functionalize(tensor, data): torch._sync(tensor) torch._sync(data) - # AOTDispatcher needs to know if any inputs had their storages mutated. - # (Why? It sometimes detaches inputs before sending them into the graph, - # when it sees that they do not need to have any gradients computed) - torch._functionalize_set_storage_changed(tensor) tensor_inner = torch._from_functional_tensor(tensor) data_inner = torch._from_functional_tensor(data) with torch._C._ExcludeDispatchKeyGuard( torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) ): - torch.ops.fsdp.set_.default(tensor_inner, data_inner) + torch.ops.fsdp.copy_.default(tensor_inner, data_inner) -torch.fx.node.has_side_effect(torch.ops.fsdp.set_.default) +torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) class ShardedState(Enum): @@ -475,7 +479,12 @@ def init_unsharded_param(self): with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( self._unsharded_param ): - torch.ops.fsdp.set_.default(self._unsharded_param, unsharded_param) + # NOTE: Under compile, if an unsharded param goes through + # resize_(full) -> copy_ -> resize_(0) pattern, we will remove those + # resize_ and copy_ ops in a compiler graph pass + # `remove_fsdp2_unsharded_param_graph_input_usage` to recover performance. + alloc_storage(self._unsharded_param) + torch.ops.fsdp.copy_(self._unsharded_param, unsharded_param) else: self._unsharded_param = nn.Parameter( unsharded_param, requires_grad=self.sharded_param.requires_grad @@ -605,13 +614,26 @@ def alloc_all_gather_outputs(self) -> None: alloc_storage(tensor) def free_unsharded_param(self) -> None: - for tensor in itertools.chain( - self.all_gather_outputs, self._unsharded_inner_tensors - ): - free_storage(tensor) if ca.compiled_autograd_enabled: + """ + Assumptions under compile: + - `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`. + Instead, we resize `self._unsharded_param` storage size to full and then + explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param` + in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove + the resize_ and copy_ ops in a compiler graph pass to recover performance.) + - `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT + graph inputs. They are created within the graph and is guaranteed to be freed + by the end of the graph. They don't leak outside of the graph. + """ + self._unsharded_param.untyped_storage().resize_(0) self.all_gather_outputs = [] self._unsharded_inner_tensors = [] + else: + for tensor in itertools.chain( + self.all_gather_outputs, self._unsharded_inner_tensors + ): + free_storage(tensor) @property def all_gather_inputs(self) -> List[torch.Tensor]: # 1D From 3277ba420161c026253d484ab62c0729fd2de35c Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Tue, 10 Sep 2024 03:13:14 -0700 Subject: [PATCH 0743/1018] [Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR. (#135312) [Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135312 Approved by: https://github.com/jansel, https://github.com/desertfire, https://github.com/eellison --- test/inductor/test_cpp_wrapper_hipify.py | 5 +- .../codegen/codegen_device_driver.py | 91 ---------------- torch/_inductor/codegen/common.py | 25 ++++- torch/_inductor/codegen/cpp.py | 4 +- .../_inductor/codegen/cpp_template_kernel.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 10 +- torch/_inductor/codegen/cpp_wrapper_cuda.py | 85 ++++++++------- torch/_inductor/codegen/cuda/cuda_kernel.py | 2 +- .../codegen/cuda/device_op_overrides.py | 102 ++++++++++++++++++ torch/_inductor/codegen/halide.py | 2 +- torch/_inductor/codegen/rocm/rocm_kernel.py | 2 +- torch/_inductor/codegen/triton.py | 2 +- .../_inductor/codegen/triton_combo_kernel.py | 2 +- torch/_inductor/codegen/wrapper.py | 12 +-- 14 files changed, 195 insertions(+), 151 deletions(-) delete mode 100644 torch/_inductor/codegen/codegen_device_driver.py diff --git a/test/inductor/test_cpp_wrapper_hipify.py b/test/inductor/test_cpp_wrapper_hipify.py index 43b1909e927df8..62f23ad3abc768 100644 --- a/test/inductor/test_cpp_wrapper_hipify.py +++ b/test/inductor/test_cpp_wrapper_hipify.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] import torch from torch._inductor.codegen.aoti_hipify_utils import maybe_hipify_code_wrapper -from torch._inductor.codegen.codegen_device_driver import cuda_kernel_driver +from torch._inductor.codegen.common import get_device_op_overrides from torch._inductor.test_case import run_tests, TestCase @@ -34,7 +34,8 @@ def test_hipify_basic_declaration(self) -> None: self.assertEqual(result, expected) def test_hipify_aoti_driver_header(self) -> None: - header = cuda_kernel_driver() + cuda_codegen = get_device_op_overrides("cuda") + header = cuda_codegen.kernel_driver() expected = """ #define CUDA_DRIVER_CHECK(EXPR) \\ do { \\ diff --git a/torch/_inductor/codegen/codegen_device_driver.py b/torch/_inductor/codegen/codegen_device_driver.py deleted file mode 100644 index c31017fe6471cb..00000000000000 --- a/torch/_inductor/codegen/codegen_device_driver.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch - - -# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose - - -def cuda_kernel_driver() -> str: - source_codes = """ - #define CUDA_DRIVER_CHECK(EXPR) \\ - do { \\ - CUresult code = EXPR; \\ - const char *msg; \\ - cuGetErrorString(code, &msg); \\ - if (code != CUDA_SUCCESS) { \\ - throw std::runtime_error( \\ - std::string("CUDA driver error: ") + \\ - std::string(msg)); \\ - } \\ - } while (0); - - namespace { - - struct Grid { - Grid(uint32_t x, uint32_t y, uint32_t z) - : grid_x(x), grid_y(y), grid_z(z) {} - uint32_t grid_x; - uint32_t grid_y; - uint32_t grid_z; - - bool is_non_zero() { - return grid_x > 0 && grid_y > 0 && grid_z > 0; - } - }; - - } // anonymous namespace - - static inline CUfunction loadKernel( - std::string filePath, - const std::string &funcName, - uint32_t sharedMemBytes, - const std::optional &cubinDir = std::nullopt) { - if (cubinDir) { - std::filesystem::path p1{*cubinDir}; - std::filesystem::path p2{filePath}; - filePath = (p1 / p2.filename()).string(); - } - - CUmodule mod; - CUfunction func; - CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); - CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); - if (sharedMemBytes > 0) { - CUDA_DRIVER_CHECK(cuFuncSetAttribute( - func, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - sharedMemBytes - )) - } - return func; - } - - static inline void launchKernel( - CUfunction func, - uint32_t gridX, - uint32_t gridY, - uint32_t gridZ, - uint32_t numWarps, - uint32_t sharedMemBytes, - void* args[], - cudaStream_t stream) { - CUDA_DRIVER_CHECK(cuLaunchKernel( - func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr - )); - } - """ - if torch.version.hip is not None: - # Adjusting the warp size to GPU supported wavefront size on AMD GPU - prop = torch.cuda.get_device_properties(torch.cuda.current_device()) - source_codes = source_codes.replace( - "32*numWarps", str(prop.warp_size) + "*numWarps" - ) - return source_codes - - -def cuda_kernel_header() -> str: - source_codes = """ - #include - #include - #include - """ - return source_codes diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 82fb318f2ca814..efbef7f461c8c2 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -124,6 +124,27 @@ def cpp_aoti_stream_guard(self): def cpp_getStreamFromExternal(self): raise NotImplementedError + def kernel_header(self): + raise NotImplementedError + + def kernel_driver(self): + raise NotImplementedError + + def abi_compatible_header(self): + raise NotImplementedError + + def cpp_stream_type(self): + raise NotImplementedError + + def aoti_get_stream(self): + raise NotImplementedError + + def cpp_kernel_type(self): + raise NotImplementedError + + def cpp_device_ptr(self): + raise NotImplementedError + device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} @@ -211,7 +232,7 @@ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): def init_backend_registration(): from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu - from .cpp_wrapper_cuda import CppWrapperCuda + from .cpp_wrapper_cuda import CppWrapperGpu from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .triton import TritonScheduling @@ -233,7 +254,7 @@ def init_backend_registration(): "cuda", lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs), WrapperCodeGen, - CppWrapperCuda, + CppWrapperGpu, ) if get_scheduling_for_device("xpu") is None: diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e1d54762d9e8fd..cef37d18ac20db 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4564,7 +4564,7 @@ def define_kernel(self, src_code, nodes, kernel_args=None): compile_wrapper.splice(src_code, strip=True) if not V.graph.cpp_wrapper: compile_wrapper.writeline("''')") - wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), cuda=False) + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), gpu=False) return kernel_name def flush(self): @@ -4645,7 +4645,7 @@ def codegen_group(self, name=None) -> str: def call_kernel(self, wrapper, kernel_name): _, call_args, arg_types = self.args.cpp_argdefs() wrapper.generate_kernel_call( - kernel_name, call_args, cuda=False, arg_types=arg_types + kernel_name, call_args, gpu=False, arg_types=arg_types ) diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 0333720bbdc689..e72a895dc44d59 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -107,7 +107,7 @@ def hook(): def call_kernel(self, name: str, node: ir.CppTemplateBuffer): wrapper = V.graph.wrapper_code _, call_args, arg_types = self.args.cpp_argdefs() - wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) + wrapper.generate_kernel_call(name, call_args, gpu=False, arg_types=arg_types) def dtype(self, node: ir.Buffer) -> str: return DTYPE_TO_CPP[node.get_dtype()] diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index fe178978d53be4..3824d71d894823 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -75,7 +75,7 @@ def generate_kernel_call( call_args, grid=None, device_index=None, - cuda=True, + gpu=True, triton=True, arg_types=None, raw_args=None, @@ -87,19 +87,19 @@ def generate_kernel_call( """ Generates kernel call code. - cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + gpu: Defines whether the backend is GPU. Otherwise the backend is CPU. triton: Defines whether the GPU backend uses Triton for codegen. Otherwise it uses the CUDA language for codegen. Only valid when cuda == True. """ - if cuda: + if gpu: return super().generate_kernel_call( kernel_name, call_args, grid, device_index, - cuda, + gpu, triton, arg_types, raw_args, @@ -920,7 +920,7 @@ def finalize_prefix(self): self.prefix = cached_dtypes_buffer def define_kernel( - self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False + self, name: str, kernel: str, metadata: Optional[str] = None, gpu=False ): self.header.splice(f"\n{kernel}\n") diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index 89c28bab380949..5719d3eba589f8 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -12,10 +12,10 @@ from .. import config from ..codecache import CudaKernelParamCache -from ..utils import DeferredLineBase +from ..utils import DeferredLineBase, get_gpu_type from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper -from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header +from .common import get_device_op_overrides from .cpp_utils import cexpr, DTYPE_TO_CPP from .cpp_wrapper_cpu import CppWrapperCpu from .wrapper import SymbolicCallArg @@ -25,9 +25,9 @@ from ..graph import GraphLowering -class DeferredCudaKernelLine(DeferredLineBase): +class DeferredGpuKernelLine(DeferredLineBase): """ - When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels to be tuned and stored as cubin files, so use a deferred line to backfill those information """ @@ -58,10 +58,10 @@ def __call__(self): return self.line_template % tuple(params[key] for key in self.keys) def _new_line(self, line): - return DeferredCudaKernelLine(self.kernel_name, line, self.keys) + return DeferredGpuKernelLine(self.kernel_name, line, self.keys) -class DeferredCudaDefaultGrid: +class DeferredGpuDefaultGrid: """ A container for the default grid, which may be used by DeferredCudaGridLine """ @@ -106,9 +106,9 @@ def __call__(self): return grid_fn(block_cfg) -class DeferredCudaGridLine(DeferredLineBase): +class DeferredGpuGridLine(DeferredLineBase): """ - When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels to be tuned and stored as cubin files, so use a deferred line to backfill those information """ @@ -142,7 +142,7 @@ def __call__(self): grid = self.grid[i] break assert grid is not None - elif isinstance(self.grid, DeferredCudaDefaultGrid): + elif isinstance(self.grid, DeferredGpuDefaultGrid): grid = self.grid() else: grid = self.grid @@ -154,18 +154,19 @@ def __call__(self): return f" Grid {self.grid_var} = Grid({grid_args_str});" def _new_line(self, line): - return DeferredCudaGridLine( + return DeferredGpuGridLine( self.kernel_name, self.grid_var, self.grid, self.autotune_configs ) -class CppWrapperCuda(CppWrapperCpu): +class CppWrapperGpu(CppWrapperCpu): """ Generates cpp wrapper for running on GPU and calls CUDA kernels """ def __init__(self) -> None: - self.device = "cuda" + self.device = get_gpu_type() + self.device_codegen = get_device_op_overrides(self.device) super().__init__() self.grid_id = count() @@ -178,26 +179,32 @@ def write_header(self): self.header.splice("#include ") if config.abi_compatible: + self.header.splice(self.device_codegen.abi_compatible_header()) + else: self.header.splice( - "#include " + maybe_hipify_code_wrapper(self.device_codegen.kernel_header()) ) - else: - self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header())) - self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver())) + self.header.splice( + maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) + ) def write_get_raw_stream(self, index, graph=None): name = f"stream{index}" - self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};")) self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));" + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_stream_type()} {name};" + ) + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({index}, (void**)&{name}));" ) return name def define_kernel( - self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + self, name: str, kernel: str, metadata: Optional[str] = None, gpu=True ): - if not cuda: - return super().define_kernel(name, kernel, metadata, cuda) + if not gpu: + return super().define_kernel(name, kernel, metadata, gpu) def generate(self, is_inference): self.prefix.writeline("\n") @@ -207,7 +214,9 @@ def generate(self, is_inference): sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]), ): self.prefix.writeline( - maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;") + maybe_hipify_code_wrapper( + f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;" + ) ) self.prefix.writeline("\n") return super().generate(is_inference) @@ -238,7 +247,7 @@ def generate_user_defined_triton_kernel( arg_types=arg_types, raw_args=raw_args, grid=grid, - cuda=True, + gpu=True, triton=True, triton_meta=triton_meta, autotune_configs=configs, @@ -254,7 +263,7 @@ def generate_load_kernel_once( kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name self.writeline(f"if ({kernel_var_name} == nullptr) {{") self.writeline( - DeferredCudaKernelLine( + DeferredGpuKernelLine( kernel_name, """ """ + kernel_var_name @@ -300,7 +309,9 @@ def generate_args_decl(self, call_args, arg_types): else: if config.abi_compatible: self.writeline( - maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};") + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_device_ptr()} {var_name};" + ) ) self.writeline( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" @@ -308,7 +319,8 @@ def generate_args_decl(self, call_args, arg_types): else: self.writeline( maybe_hipify_code_wrapper( - f"CUdeviceptr {var_name} = reinterpret_cast({arg}.data_ptr());" + f"{self.device_codegen.cpp_device_ptr()} {var_name} = \ + reinterpret_cast<{self.device_codegen.cpp_device_ptr()}>({arg}.data_ptr());" ) ) elif arg_type in (sympy.Integer, int): @@ -325,7 +337,7 @@ def generate_default_grid( self, kernel_name: str, grid: List[Any], - cuda: bool = True, + gpu: bool = True, grid_callable: Optional[Callable[..., Any]] = None, **grid_extra_kwargs, ): @@ -333,11 +345,11 @@ def generate_default_grid( Generate grid configs for launching a CUDA kernel using the grid function from triton_heuristics. Because its computation needs to read kernel config after autotune, it is done in a deferred way - using DeferredCudaDefaultGrid. + using DeferredGpuDefaultGrid. """ - if not cuda: + if not gpu: return grid - return DeferredCudaDefaultGrid( + return DeferredGpuDefaultGrid( kernel_name, grid, grid_callable, **grid_extra_kwargs ) @@ -347,7 +359,7 @@ def generate_kernel_call( call_args, grid=None, device_index=None, - cuda=True, + gpu=True, triton=True, arg_types=None, raw_args=None, @@ -360,14 +372,14 @@ def generate_kernel_call( arg_types ), "call_args and arg_types do not match" - if not cuda: - # Even in CppWrapperCuda, we may see cpp kernels + if not gpu: + # Even in CppWrapperGpu, we may see cpp kernels return super().generate_kernel_call( kernel_name, call_args, grid, device_index, - cuda, + gpu, triton, arg_types, raw_args, @@ -402,10 +414,9 @@ def generate_kernel_call( if V.graph.aot_mode else self.write_get_raw_stream(device_index, V.graph) ) - grid_var = f"{kernel_name}_grid_{next(self.grid_id)}" self.writeline( - DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs) + DeferredGpuGridLine(kernel_name, grid_var, grid, autotune_configs) ) kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name @@ -415,7 +426,7 @@ def generate_kernel_call( with debug_printer_manager: self.writeline(f"if ({grid_var}.is_non_zero()) {{") self.writeline( - DeferredCudaKernelLine( + DeferredGpuKernelLine( kernel_name, r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( kernel_var_name, diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index d6472a48f1e08c..dfb0b159e2f7b8 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -180,7 +180,7 @@ def call_kernel( wrapper.generate_kernel_call( name, call_args, - cuda=True, + gpu=True, triton=False, arg_types=arg_types, ) diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 03e760ddc3da86..011e503a7b889d 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +import torch + from ..common import DeviceOpOverrides, register_device_op_overrides @@ -30,5 +32,105 @@ def cpp_aoti_stream_guard(self): def cpp_getStreamFromExternal(self): return "at::cuda::getStreamFromExternal" + def kernel_header(self): + source_codes = """ + #include + #include + #include + """ + return source_codes + + def kernel_driver(self): + source_codes = """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + cuGetErrorString(code, &msg); \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + if torch.version.hip is not None: + # Adjusting the warp size to GPU supported wavefront size on AMD GPU + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + source_codes = source_codes.replace( + "32*numWarps", str(prop.warp_size) + "*numWarps" + ) + return source_codes + + def abi_compatible_header(self): + return "#include " + + def cpp_stream_type(self): + return "cudaStream_t" + + def aoti_get_stream(self): + return "aoti_torch_get_current_cuda_stream" + + def cpp_kernel_type(self): + return "CUfunction" + + def cpp_device_ptr(self): + return "CUdeviceptr" + register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 20968a57a4444f..337aa544b0d10a 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -1638,7 +1638,7 @@ def call_kernel(self, name: str, node=None): wrapper.generate_kernel_call( name, call_args, - cuda=False, # grid/stream is handled internally in halide + gpu=False, # grid/stream is handled internally in halide ) def generate_assert(self, check): diff --git a/torch/_inductor/codegen/rocm/rocm_kernel.py b/torch/_inductor/codegen/rocm/rocm_kernel.py index 9029dbe644a5ba..ace9910685ca2b 100644 --- a/torch/_inductor/codegen/rocm/rocm_kernel.py +++ b/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -157,7 +157,7 @@ def call_kernel( name, kernel_args, device_index=current_device.index, - cuda=True, + gpu=True, triton=False, arg_types=arg_types, ) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 8fe368048c9c77..629b6286b2242c 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2834,7 +2834,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): call_args, grid, current_device.index, - cuda=True, + gpu=True, triton=True, arg_types=arg_types, grid_fn=self._get_grid_fn(), diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index abfd1dc76c698e..946d318fac11c5 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -1087,7 +1087,7 @@ def call_kernel(self, code: IndentedBuffer, name: str) -> None: call_args, grid, V.graph.scheduler.get_current_device_or_throw().index, - cuda=True, + gpu=True, triton=True, arg_types=arg_types, grid_fn="grid_combo_kernels", diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0586076c948fb6..e04d3dfa1430e2 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1252,7 +1252,7 @@ def add_benchmark_harness(self, output): ) def define_kernel( - self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + self, name: str, kernel: str, metadata: Optional[str] = None, gpu=True ): metadata_comment = f"{metadata}\n" if metadata else "" body = f"\n\n{metadata_comment}{name} = {kernel}" @@ -1554,7 +1554,7 @@ def generate_default_grid( self, kernel_name: str, grid: List[Any], - cuda: bool = True, + gpu: bool = True, grid_callable: Optional[Callable[..., Any]] = None, **grid_extra_kwags, ): @@ -1646,7 +1646,7 @@ def generate_kernel_call( call_args, grid=None, device_index=None, - cuda=True, + gpu=True, triton=True, arg_types=None, raw_args=None, @@ -1658,13 +1658,13 @@ def generate_kernel_call( """ Generates kernel call code. - cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + gpu: Defines whether the backend is GPU. Otherwise the backend is CPU. triton: Defines whether the GPU backend uses Triton for codegen. Otherwise it uses the CUDA language for codegen. - Only valid when cuda == True. + Only valid when gpu == True. """ - if cuda: + if gpu: device_index, call_args_str = self.prepare_triton_kernel_call( device_index, call_args ) From a684766b21bcfc30913c78ebba9ac4ffeb6fac1d Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Tue, 10 Sep 2024 03:13:14 -0700 Subject: [PATCH 0744/1018] [Inductor] Rename `cpp_wrapper_cuda.py` as `cpp_wrapper_gpu.py` (#135313) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135313 Approved by: https://github.com/jansel, https://github.com/desertfire ghstack dependencies: #135312 --- tools/amd_build/build_amd.py | 2 +- torch/_inductor/codegen/common.py | 2 +- .../codegen/{cpp_wrapper_cuda.py => cpp_wrapper_gpu.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename torch/_inductor/codegen/{cpp_wrapper_cuda.py => cpp_wrapper_gpu.py} (100%) diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 967f63ae2c8f15..60a1be73fbb2d1 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -207,7 +207,7 @@ def remove_hcc(line: str) -> str: ignores=ignores, extra_files=[ "torch/_inductor/codegen/cpp_wrapper_cpu.py", - "torch/_inductor/codegen/cpp_wrapper_cuda.py", + "torch/_inductor/codegen/cpp_wrapper_gpu.py", "torch/_inductor/codegen/wrapper.py", ], out_of_place_only=args.out_of_place_only, diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index efbef7f461c8c2..95183c1484fe00 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -232,7 +232,7 @@ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): def init_backend_registration(): from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu - from .cpp_wrapper_cuda import CppWrapperGpu + from .cpp_wrapper_gpu import CppWrapperGpu from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .triton import TritonScheduling diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py similarity index 100% rename from torch/_inductor/codegen/cpp_wrapper_cuda.py rename to torch/_inductor/codegen/cpp_wrapper_gpu.py From 36e4d543c75be38d19ccb3b98b832795643c2f7b Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 11 Sep 2024 09:28:27 -0700 Subject: [PATCH 0745/1018] [FSDP2] better error msg for cpu offloading (#135156) when cpu offloading is enabled, if user load a gpu state dict, FSDP2 will throw a less obvious error at backward ``` RuntimeError: attempting to assign a gradient with device type 'cpu' to a tensor with device type 'cuda'. Please ensure that the gradient and the tensor are on the same device ``` this PR throws error more explicitly by specifying which parameters should be moved because of cpu offloading ``` FSDP parameters should be materialized on cpu when enabling cpu offloading. For example, load cpu state dict or call module.to_empty(device="cpu"). Found following parameters on non-cpu device: ['0.weight'] ``` `pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_dp_state_dict_cpu_offload` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135156 Approved by: https://github.com/awgu --- .../fsdp/test_fully_shard_state_dict.py | 35 ++++++++++++-- .../_composable/fsdp/_fsdp_param_group.py | 47 ++++++++++++++----- 2 files changed, 64 insertions(+), 18 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py index 46f22a12d9d703..3ed5853908298d 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py @@ -3,6 +3,7 @@ import copy import functools import unittest +from contextlib import nullcontext from typing import Dict import torch @@ -77,8 +78,21 @@ def _test_dp_state_dict_save_load(self, mlp_dim: int, mesh: DeviceMesh): @skip_if_lt_x_gpu(2) def test_dp_state_dict_cpu_offload(self): + self.run_subtests( + { + "offload_policy": [ + CPUOffloadPolicy(pin_memory=True), + CPUOffloadPolicy(pin_memory=False), + ], + "cpu_state_dict": [True, False], + }, + self._test_dp_state_dict_cpu_offload, + ) + + def _test_dp_state_dict_cpu_offload( + self, offload_policy: CPUOffloadPolicy, cpu_state_dict: bool + ): mlp_dim = 4 - offload_policy = CPUOffloadPolicy(pin_memory=True) torch.manual_seed(42) with torch.device("meta"): model = nn.Sequential( @@ -97,6 +111,8 @@ def test_dp_state_dict_cpu_offload(self): sharded_tensor = distribute_tensor( full_tensor, dtensor.device_mesh, dtensor.placements ) + if cpu_state_dict: + sharded_tensor = sharded_tensor.cpu() state_dicts.append({name: sharded_tensor}) # check that we can load with some parameters still on meta device @@ -105,11 +121,20 @@ def test_dp_state_dict_cpu_offload(self): # lazy init without error inp = torch.rand((mlp_dim, mlp_dim), device="cuda") - model(inp) - state_dict = model.state_dict() - for name, dtensor in state_dict.items(): - self.assertEqual(dtensor.device.type, "cpu") + context = ( + self.assertRaisesRegex( + RuntimeError, + r"Found following parameters on non-CPU device: \[\('0.weight', device\(type='cuda'", + ) + if not cpu_state_dict + else nullcontext() + ) + with context: + model(inp).sum() + state_dict = model.state_dict() + for name, dtensor in state_dict.items(): + self.assertEqual(dtensor.device.type, "cpu") def test_2d_state_dict_correctness(self): dp_size = 2 diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index ae8a513f05af05..5f22e2d34b3229 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -12,7 +12,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle -from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_collectives import ( AllGatherResult, foreach_all_gather, @@ -130,6 +130,7 @@ def __init__( self.post_forward_mesh_info = post_forward_mesh_info self.device = device self.mp_policy = mp_policy + self.offload_policy = offload_policy self._training_state = TrainingState.IDLE # Group's sharded state always matches its parameters' sharded states self._sharded_state = ShardedState.SHARDED @@ -208,18 +209,8 @@ def lazy_init(self): for fsdp_param in self.fsdp_params: fsdp_param.reset_sharded_param() self._reset_sharded_params = True - param_names_on_meta = [ - fsdp_param._param_fqn - for fsdp_param in self.fsdp_params - if fsdp_param.sharded_param.device.type == "meta" - ] - if param_names_on_meta: - raise RuntimeError( - "FSDP parameters should be materialized from meta device before training, " - f"but the following were still on meta device: {param_names_on_meta}\n" - "For example, call module.to_empty(device) to materialize to device and " - "call module.reset_parameters() on each module to initialize values." - ) + self._validate_no_meta_params() + self._validate_cpu_offload_params() # Initialize mixed precision attributes lazily in case the user changes # the parameter dtypes after construction time but before forward self._init_mp_dtypes() @@ -580,6 +571,36 @@ def _with_fqn(self, label: str) -> str: def __repr__(self): return f"FSDPParamGroup(fqn={self._module_fqn})" + def _validate_no_meta_params(self): + param_names_on_meta = [ + fsdp_param._param_fqn + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type == "meta" + ] + if param_names_on_meta: + raise RuntimeError( + "FSDP parameters should be materialized from meta device before training, " + f"but the following were still on meta device: {param_names_on_meta}\n" + "For example, call module.to_empty(device) to materialize to device and " + "call module.reset_parameters() on each module to initialize values." + ) + + def _validate_cpu_offload_params(self): + if not isinstance(self.offload_policy, CPUOffloadPolicy): + return + fsdp_params_not_on_cpu = [ + fsdp_param + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type != "cpu" + ] + if fsdp_params_not_on_cpu: + raise RuntimeError( + "FSDP parameters should be materialized on CPU when enabling CPU offloading. " + 'For example, load a CPU state dict or call module.to_empty(device="cpu"). ' + "Found following parameters on non-CPU device: " + f"{[(fsdp_param._param_fqn, fsdp_param.sharded_param.device) for fsdp_param in fsdp_params_not_on_cpu]}\n" + ) + def _get_param_module_infos( params: List[nn.Parameter], modules: Tuple[nn.Module, ...] From ed75c13f97063c14386bc07d2c49c42172701fb0 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Thu, 12 Sep 2024 00:51:34 +0000 Subject: [PATCH 0746/1018] [PT2][inductor][Optimus] Add pad_aten_mm_pass pattern to resolve long computation kernel in LCE (#135167) Summary: We observed another long computation issue for OBA_AFOC pyper model, thus adding a pattern to avoid the perf regression - Only happens in A100 - Do not want to use force_shape_pad since it will pad all GEMMs, which may not be optimal. Optimus pass has more flexisibility to customized GEMM shape and do corresponding padding - To enable, we pass the pass to config, where "k_threshold_to_pad" can be customized inductor_config.patch(post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad" : 8388608}}) Test Plan: # unit test ``` buck2 test mode/opt //caffe2/test/inductor:pad_mm ``` Buck UI: https://www.internalfb.com/buck2/58b0f272-f405-45be-bc8d-aec2dc4d5841 Test UI: https://www.internalfb.com/intern/testinfra/testrun/10133099209954651 Network: Up: 9.0KiB Down: 142B (reSessionID-8eb71a37-a5ca-4aff-a4f1-93ade3e47e4e) Jobs completed: 9. Time elapsed: 3:18.0s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 17. Fail 0. Fatal 0. Skip 0. Build failure 0 # e2e test see [D62388582](https://www.internalfb.com/diff/D62388582) Differential Revision: D62220158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135167 Approved by: https://github.com/jackiexu1992 --- test/inductor/test_pad_mm.py | 34 ++++++++++++++++++++++++++ torch/_inductor/fx_passes/pad_mm.py | 23 +++++++++++++++++ torch/_inductor/fx_passes/split_cat.py | 1 + 3 files changed, 58 insertions(+) diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index dbe1170994fdbe..957ec59916de6c 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -9,6 +9,7 @@ get_pad_cache, get_padded_length, should_pad_common, + should_pad_mm_bf16, ) from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code @@ -449,6 +450,39 @@ def mm(inps, b): repr(get_pad_cache().get_local_cache()) ) + @fresh_inductor_cache() + @inductor_config.patch( + post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}} + ) + def test_pad_mm_bf16(self): + m = 2 + n = 13 + k = 15691904 + mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16) + mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16) + expected_alignment = get_alignment_size(mat1) + + assert expected_alignment == 8, "Alignment for bfloat16 should be 8" + assert should_pad_common( + mat1, mat2 + ), "This should pass the common padding criteria" + assert should_pad_mm_bf16( + mat1.dtype, m, n, k + ), "This should pass the should_pad_mm_bf16 padding criteria" + + @torch.compile() + def mm(mat1, mat2): + return torch.mm(mat1, mat2) + + res2, (code,) = run_and_get_code(mm, mat1, mat2) + mm_expected_result = torch.mm(mat1, mat2) + # in call code, expect to see a single pad per input, and then we should see padded allocation for output + FileCheck().check("del async_compile").check_count( + ".run(", 2, exactly=True + ).check("empty_strided_cuda((8, 16)").run(code) + + assert torch.allclose(res2, mm_expected_result), "MM results are not identical" + if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 87450b34e7ee27..a00bad39747918 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -364,6 +364,23 @@ def should_pad(key: str, ori_time, pad_time) -> bool: return should_pad +def should_pad_mm_bf16(dtype, M, N, K): + # always force pad for mm with bf16 when the following are satisfied to avoid perf regression + large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ + "pad_aten_mm_pass" + ].get("k_threshold_to_pad", 8388608) + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and N % 2 == 1 + and K >= large_k_threshold_to_pad + and torch.cuda.get_device_capability() < (9, 0) + ): # doesnt repro on h100s: + return True + return False + + def should_pad_bench( match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None ) -> bool: @@ -410,6 +427,12 @@ def realize_symbols(ds): if torch._inductor.config.force_shape_pad: return True + if ( + "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options + and should_pad_mm_bf16(mat1.dtype, m, n, k) + ): + return True + if not has_triton(): return False diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index d0098a535ee295..b3cef2ea079814 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -65,6 +65,7 @@ "decompose_mm_pass", "unbind_stack_aten_pass", "shape_padding_multiplier", + "pad_aten_mm_pass", ] for pass_name in pre_grad_pass_names: From ada7f0d2a14dd9ef416ab9aa4a910f59dc64b98b Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 12 Sep 2024 01:16:38 +0000 Subject: [PATCH 0747/1018] Fix the upload of x86 micro benchmark results (#135780) Upload stats workflow currently skips this https://github.com/pytorch/pytorch/actions/runs/10807251335/job/29977650639, this is a miss from https://github.com/pytorch/pytorch/pull/135042. So, the workflow is running but nothing has been uploaded yet. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135780 Approved by: https://github.com/atalman --- .github/workflows/upload-test-stats.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index c09d76559a9355..630e971e77532a 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -96,7 +96,7 @@ jobs: python3 -m tools.stats.check_disabled_tests --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" - name: Upload gpt-fast benchmark results to Rockset - if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' && github.event.workflow_run.name == 'inductor-micro-benchmark' + if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' && contains('inductor-micro-benchmark', github.event.workflow_run.name) env: ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }} WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} From 5ea4788eba417d2e148fb10af7f7dfc23447d83d Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 12 Sep 2024 01:22:06 +0000 Subject: [PATCH 0748/1018] "Remove BLOCK_LIST" (#135729) Summary: Skip test_prepare_qat_conv_bn_fusion_getitem_placeholder when we use training ir, since it's only for bn-getitem pattern, but the pattern doesn't exist in training ir. Remove BLOCK_LIST since it's empty. Now all internal unittests will use training ir. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' caffe2/test/quantization:test_quantization -- -r test_prepare_qat_conv_bn_fusion_getitem_placeholder buck2 run 'fbcode//mode/dev-nosan' caffe2/test:quantization_pt2e_qat -- -r test_prepare_qat_conv_bn_fusion_getitem_placeholder ``` Differential Revision: D62387987 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135729 Approved by: https://github.com/tugsbayasgalan --- test/quantization/pt2e/test_quantize_pt2e_qat.py | 3 +++ torch/_export/__init__.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index c0dc5fb32e82d1..2a9aea449bc71a 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -604,6 +604,9 @@ def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self): is returned as part of the match anyway (as a placeholder). """ + if capture_pre_autograd_graph_using_training_ir(): + self.skipTest("Not applicable to training IR") + class M(torch.nn.Module): def __init__(self, conv_class, bn_class): super().__init__() diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 91d893e05cb89c..88106676c7a7bc 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -67,7 +67,7 @@ def capture_pre_autograd_graph_warning(): log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.") log.warning("Please switch to use torch.export.export_for_training instead.") if config.is_fbcode(): - log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950 + log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950 @compatibility(is_backward_compatible=False) From f9cdef8ab05ecb7688c73565879f9a7ff8e73f75 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Thu, 12 Sep 2024 01:34:09 +0000 Subject: [PATCH 0749/1018] [BE] typing for decorators - masked/_ops (#135108) Differential Revision: D62184735 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135108 Approved by: https://github.com/Skylion007 --- torch/masked/_ops.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 3d8151197a207e..4962d0430992e6 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,7 +1,7 @@ -# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import warnings -from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec import torch from torch import sym_float, Tensor @@ -23,13 +23,16 @@ __all__: List[str] = [] +_T = TypeVar("_T") +_P = ParamSpec("_P") + # All masked reduction/normalization operations have the same # signatures. Here we introduce docstring templates that are applied # to docstrings of reduction/normalization functions via # _apply_docstring_templates decorator. -def _apply_docstring_templates(func): +def _apply_docstring_templates(func: Callable[_P, _T]) -> Callable[_P, _T]: """Decorator that applies docstring templates to function docstring and returns the function instance. """ From 542236c9af2a508624742f7fc87c91c166e0897b Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Thu, 12 Sep 2024 03:16:08 +0000 Subject: [PATCH 0750/1018] Update torch-xpu-ops pin (ATen XPU implementation) (#135647) Release cycle for PyTorch 2.5 1. Fixing runtime error on Windows: Fail to load torch_xpu_ops_unary_binary_kernels.dll as the bin size is large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135647 Approved by: https://github.com/EikanWang --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 5043495ce9680b..69606f14a7af3a 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -2326220aa49235ff87c382a85eb7f7cf115b4338 +7e3d00acea9f0d3728048a5b2743de20d55c64ba From ad1ab151d1090c31fc8c1a6f0138feb946a86ed2 Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 12 Sep 2024 03:27:06 +0000 Subject: [PATCH 0751/1018] Fix clang-tidy warnings in Caffe2 code (#134935) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134935 Approved by: https://github.com/ezyang --- .lintrunner.toml | 2 + BUCK.oss | 1 - BUILD.bazel | 1 - build.bzl | 1 - caffe2/core/timer.h | 2 +- caffe2/perfkernels/common_avx.cc | 2 +- caffe2/perfkernels/common_avx2.cc | 2 +- caffe2/serialize/CMakeLists.txt | 3 +- caffe2/serialize/crc_alt.h | 2 +- caffe2/serialize/file_adapter.cc | 11 +-- caffe2/serialize/file_adapter.h | 10 +-- caffe2/serialize/in_memory_adapter.h | 21 +++-- caffe2/serialize/inline_container.cc | 26 +++--- caffe2/serialize/inline_container.h | 14 ++-- caffe2/serialize/inline_container_test.cc | 24 +++--- caffe2/serialize/istream_adapter.cc | 9 +-- caffe2/serialize/istream_adapter.h | 6 +- caffe2/serialize/read_adapter_interface.cc | 10 --- caffe2/serialize/read_adapter_interface.h | 12 ++- caffe2/serialize/versions.h | 6 +- caffe2/utils/string_utils.cc | 2 +- caffe2/utils/string_utils.h | 2 +- caffe2/utils/threadpool/ThreadPool.h | 6 +- caffe2/utils/threadpool/pthreadpool.cc | 80 +++++++++---------- caffe2/utils/threadpool/pthreadpool_impl.cc | 4 +- test/cpp/jit/test_custom_class.cpp | 1 + .../mobile/compatibility/backport_manager.cpp | 1 + .../compatibility/model_compatibility.cpp | 1 + 28 files changed, 112 insertions(+), 150 deletions(-) delete mode 100644 caffe2/serialize/read_adapter_interface.cc diff --git a/.lintrunner.toml b/.lintrunner.toml index 4789991736be17..5fa37ef3b6b019 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -210,6 +210,8 @@ include_patterns = [ 'aten/src/ATen/native/nested/*.h', 'c10/**/*.cpp', 'c10/**/*.h', + 'caffe2/**/*.cc', + 'caffe2/**/*.h', 'torch/*.h', 'torch/csrc/*.h', 'torch/csrc/*.cpp', diff --git a/BUCK.oss b/BUCK.oss index 13d5518801a249..448cce08cd854a 100644 --- a/BUCK.oss +++ b/BUCK.oss @@ -65,7 +65,6 @@ cxx_library( "caffe2/serialize/file_adapter.cc", "caffe2/serialize/inline_container.cc", "caffe2/serialize/istream_adapter.cc", - "caffe2/serialize/read_adapter_interface.cc", ], visibility = ["PUBLIC"], deps = [ diff --git a/BUILD.bazel b/BUILD.bazel index 1018f7907adecd..e5ab0d0e29b856 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -473,7 +473,6 @@ filegroup( "caffe2/serialize/file_adapter.cc", "caffe2/serialize/inline_container.cc", "caffe2/serialize/istream_adapter.cc", - "caffe2/serialize/read_adapter_interface.cc", ], ) diff --git a/build.bzl b/build.bzl index dbb1866ac54823..3b72864f047ae1 100644 --- a/build.bzl +++ b/build.bzl @@ -34,7 +34,6 @@ def define_targets(rules): "caffe2/serialize/file_adapter.cc", "caffe2/serialize/inline_container.cc", "caffe2/serialize/istream_adapter.cc", - "caffe2/serialize/read_adapter_interface.cc", ], copts = ["-fexceptions"], tags = [ diff --git a/caffe2/core/timer.h b/caffe2/core/timer.h index a0384b0dbdbd02..28d373f468d0df 100644 --- a/caffe2/core/timer.h +++ b/caffe2/core/timer.h @@ -3,7 +3,7 @@ #include -#include "caffe2/core/common.h" +#include namespace caffe2 { diff --git a/caffe2/perfkernels/common_avx.cc b/caffe2/perfkernels/common_avx.cc index 08c70d8966bc0f..cdb8b760a0d9e2 100644 --- a/caffe2/perfkernels/common_avx.cc +++ b/caffe2/perfkernels/common_avx.cc @@ -2,7 +2,7 @@ // example, if your compiler did not specify -mavx, you should not provide // the CAFFE2_PERF_WITH_AVX macro. -#include "caffe2/core/common.h" +#include "caffe2/core/macros.h" #ifdef CAFFE2_PERF_WITH_AVX #ifndef __AVX__ diff --git a/caffe2/perfkernels/common_avx2.cc b/caffe2/perfkernels/common_avx2.cc index 3db34aa549d883..a45d887b2b49ad 100644 --- a/caffe2/perfkernels/common_avx2.cc +++ b/caffe2/perfkernels/common_avx2.cc @@ -2,7 +2,7 @@ // example, if your compiler did not specify -mavx2, you should not provide // the CAFFE2_PERF_WITH_AVX2 macro. -#include "caffe2/core/common.h" +#include "caffe2/core/macros.h" #ifdef CAFFE2_PERF_WITH_AVX2 #ifndef __AVX2__ diff --git a/caffe2/serialize/CMakeLists.txt b/caffe2/serialize/CMakeLists.txt index 1552b59d0d4419..30303e8b744cb5 100644 --- a/caffe2/serialize/CMakeLists.txt +++ b/caffe2/serialize/CMakeLists.txt @@ -6,8 +6,7 @@ list(APPEND Caffe2_CPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc ${CMAKE_CURRENT_SOURCE_DIR}/istream_adapter.cc ${CMAKE_CURRENT_SOURCE_DIR}/file_adapter.cc - ${CMAKE_CURRENT_SOURCE_DIR}/crc.cc - ${CMAKE_CURRENT_SOURCE_DIR}/read_adapter_interface.cc) + ${CMAKE_CURRENT_SOURCE_DIR}/crc.cc) list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.1.0) set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/caffe2/serialize/crc_alt.h b/caffe2/serialize/crc_alt.h index 9d1c4f1dc7ddc8..10f9c6bdb380ef 100644 --- a/caffe2/serialize/crc_alt.h +++ b/caffe2/serialize/crc_alt.h @@ -25,7 +25,7 @@ // using the aforementioned #defines the table is automatically fitted to your needs // uint8_t, uint32_t, int32_t -#include +#include // size_t #include diff --git a/caffe2/serialize/file_adapter.cc b/caffe2/serialize/file_adapter.cc index 67634d7f7fd278..f112dd61e5dbf4 100644 --- a/caffe2/serialize/file_adapter.cc +++ b/caffe2/serialize/file_adapter.cc @@ -3,13 +3,11 @@ #include #include #include -#include "caffe2/core/common.h" -namespace caffe2 { -namespace serialize { +namespace caffe2::serialize { -FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) { - fp_ = fopen(file_name.c_str(), "rb"); +FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) + : fp_(fopen(file_name.c_str(), "rb")) { if (fp_ == nullptr) { auto old_errno = errno; #if defined(_WIN32) && (defined(__MINGW32__) || defined(_MSC_VER)) @@ -77,5 +75,4 @@ size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what) FileAdapter::~FileAdapter() = default; -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/serialize/file_adapter.h b/caffe2/serialize/file_adapter.h index 36bc5de2710034..36663e3ea8dee2 100644 --- a/caffe2/serialize/file_adapter.h +++ b/caffe2/serialize/file_adapter.h @@ -1,14 +1,11 @@ #pragma once -#include -#include #include +#include -#include "caffe2/serialize/istream_adapter.h" #include "caffe2/serialize/read_adapter_interface.h" -namespace caffe2 { -namespace serialize { +namespace caffe2::serialize { class TORCH_API FileAdapter final : public ReadAdapterInterface { public: @@ -32,5 +29,4 @@ class TORCH_API FileAdapter final : public ReadAdapterInterface { uint64_t size_; }; -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/serialize/in_memory_adapter.h b/caffe2/serialize/in_memory_adapter.h index 817f3f5d677493..869ead3b24563a 100644 --- a/caffe2/serialize/in_memory_adapter.h +++ b/caffe2/serialize/in_memory_adapter.h @@ -1,10 +1,9 @@ #pragma once -#include #include +#include +#include - -namespace caffe2 { -namespace serialize { +namespace caffe2::serialize { class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { public: @@ -15,18 +14,18 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { return size_; } - size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") - const override { - (void) what; + size_t read( + uint64_t pos, + void* buf, + size_t n, + const char* what [[maybe_unused]] = "") const override { memcpy(buf, (int8_t*)(data_) + pos, n); return n; } private: const void* data_; - off_t size_; + off_t size_{}; }; - -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 2761147cf333da..c232838dec0c42 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -10,15 +10,11 @@ #include #include -#include -#include #include -#include #include #include #include -#include "caffe2/core/common.h" #include "caffe2/serialize/file_adapter.h" #include "caffe2/serialize/inline_container.h" #include "caffe2/serialize/istream_adapter.h" @@ -27,8 +23,8 @@ #include "caffe2/serialize/versions.h" #include "miniz.h" -namespace caffe2 { -namespace serialize { + +namespace caffe2::serialize { constexpr c10::string_view kDebugPklSuffix(".debug_pkl"); struct MzZipReaderIterWrapper { @@ -194,8 +190,7 @@ void PyTorchStreamReader::init() { // version check at::DataPtr version_ptr; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t version_size; + size_t version_size = 0; if (hasRecord(".data/version")) { std::tie(version_ptr, version_size) = getRecord(".data/version"); } else { @@ -204,7 +199,7 @@ void PyTorchStreamReader::init() { } std::string version(static_cast(version_ptr.get()), version_size); try { - version_ = std::stoull(version); + version_ = std::stoll(version); } catch (const std::invalid_argument& e) { CAFFE_THROW("Couldn't parse the version ", version, @@ -627,7 +622,7 @@ PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name) } PyTorchStreamWriter::PyTorchStreamWriter( - const std::function writer_func) + const std::function& writer_func) : archive_name_("archive"), writer_func_(writer_func) { setup(archive_name_); @@ -638,7 +633,7 @@ void PyTorchStreamWriter::setup(const string& file_name) { memset(ar_.get(), 0, sizeof(mz_zip_archive)); archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord(). - if (archive_name_.size() == 0) { + if (archive_name_.empty()) { CAFFE_THROW("invalid file name: ", file_name); } if (!writer_func_) { @@ -649,7 +644,7 @@ void PyTorchStreamWriter::setup(const string& file_name) { const std::string dir_name = parentdir(file_name); if(!dir_name.empty()) { - struct stat st; + struct stat st{}; bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR)); TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist."); } @@ -706,8 +701,8 @@ void PyTorchStreamWriter::writeRecord( /*uncomp_size=*/0, /*uncomp_crc32=*/0, /*last_modified=*/nullptr, - /*user_extra_data=*/padding_.c_str(), - /*user_extra_data_len=*/padding_size, + /*user_extra_data_local=*/padding_.c_str(), + /*user_extra_data_local_len=*/padding_size, /*user_extra_data_central=*/nullptr, /*user_extra_data_central_len=*/0); valid("writing file ", name.c_str()); @@ -820,5 +815,4 @@ PyTorchStreamWriter::~PyTorchStreamWriter() { } } -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 6a13d414feb9e6..fe319e9011bc32 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -6,7 +6,6 @@ #include #include #include -#include #include #include @@ -91,8 +90,8 @@ typedef struct mz_zip_archive mz_zip_archive; // model.json as the last file when writing after we have accumulated all // other information. -namespace caffe2 { -namespace serialize { + +namespace caffe2::serialize { static constexpr const char* kSerializationIdRecordName = ".data/serialization_id"; @@ -196,18 +195,18 @@ class TORCH_API PyTorchStreamReader final { std::string archive_name_; std::string archive_name_plus_slash_; std::shared_ptr in_; - int64_t version_; + int64_t version_{}; std::mutex reader_lock_; bool load_debug_symbol_ = true; std::string serialization_id_; - size_t additional_reader_size_threshold_; + size_t additional_reader_size_threshold_{}; }; class TORCH_API PyTorchStreamWriter final { public: explicit PyTorchStreamWriter(const std::string& archive_name); explicit PyTorchStreamWriter( - const std::function writer_func); + const std::function& writer_func); void setMinVersion(const uint64_t version); @@ -274,5 +273,4 @@ size_t getPadding( std::string& padding_buf); } // namespace detail -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 4e027f681961d0..4e58c0d609932b 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -5,12 +5,14 @@ #include -#include "caffe2/serialize/inline_container.h" #include -#include "c10/util/irange.h" +#include +#include "caffe2/serialize/inline_container.h" +#include "caffe2/serialize/istream_adapter.h" + -namespace caffe2 { -namespace serialize { +// NOLINTBEGIN(*-narrowing-conversions) +namespace caffe2::serialize { namespace { TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { @@ -19,7 +21,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { std::ostringstream oss; // write records through writers PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { - oss.write(static_cast(b), n); + oss.write(static_cast(b), static_cast(n)); return oss ? n : 0; }); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) @@ -28,14 +30,14 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { std::vector buf(data1.size()); for (auto i : c10::irange(data1.size())) { - data1[i] = data1.size() - i; + data1[i] = static_cast(data1.size() - i); } writer.writeRecord("key1", data1.data(), data1.size()); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) std::array data2; for (auto i : c10::irange(data2.size())) { - data2[i] = data2.size() - i; + data2[i] = static_cast(data2.size() - i); } writer.writeRecord("key2", data2.data(), data2.size()); @@ -149,7 +151,7 @@ TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) { PyTorchStreamReader reader(&iss); reader.setAdditionalReaderSizeThreshold(0); // before testing, sanity check - int64_t size1, size2, ret; + int64_t size1 = 0, size2 = 0, ret = 0; at::DataPtr data_ptr; std::tie(data_ptr, size1) = reader.getRecord("key1"); std::tie(data_ptr, size2) = reader.getRecord("key2"); @@ -296,7 +298,7 @@ TEST(PytorchStreamWriterAndReader, SkipDebugRecords) { reader.setShouldLoadDebugSymbol(false); EXPECT_FALSE(reader.hasRecord("key1.debug_pkl")); at::DataPtr ptr; - size_t size; + size_t size = 0; std::tie(ptr, size) = reader.getRecord("key1.debug_pkl"); EXPECT_EQ(size, 0); std::vector dst(data1.size()); @@ -479,5 +481,5 @@ TEST_P(ChunkRecordIteratorTest, ChunkRead) { } } // namespace -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize +// NOLINTEND(*-narrowing-conversions) diff --git a/caffe2/serialize/istream_adapter.cc b/caffe2/serialize/istream_adapter.cc index 9509a088736ef8..baad0e9b2ee69b 100644 --- a/caffe2/serialize/istream_adapter.cc +++ b/caffe2/serialize/istream_adapter.cc @@ -1,8 +1,7 @@ #include "caffe2/serialize/istream_adapter.h" #include -namespace caffe2 { -namespace serialize { +namespace caffe2::serialize { IStreamAdapter::IStreamAdapter(std::istream* istream) : istream_(istream) {} @@ -33,8 +32,6 @@ void IStreamAdapter::validate(const char* what) const { } } -// NOLINTNEXTLINE(modernize-use-equals-default) -IStreamAdapter::~IStreamAdapter() {} +IStreamAdapter::~IStreamAdapter() = default; -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/serialize/istream_adapter.h b/caffe2/serialize/istream_adapter.h index 680c288a15f2e8..69f91faf3a79f3 100644 --- a/caffe2/serialize/istream_adapter.h +++ b/caffe2/serialize/istream_adapter.h @@ -5,8 +5,7 @@ #include "c10/macros/Macros.h" #include "caffe2/serialize/read_adapter_interface.h" -namespace caffe2 { -namespace serialize { +namespace caffe2::serialize { // this is a reader implemented by std::istream class TORCH_API IStreamAdapter final : public ReadAdapterInterface { @@ -23,5 +22,4 @@ class TORCH_API IStreamAdapter final : public ReadAdapterInterface { void validate(const char* what) const; }; -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/serialize/read_adapter_interface.cc b/caffe2/serialize/read_adapter_interface.cc deleted file mode 100644 index e9b84660efd72f..00000000000000 --- a/caffe2/serialize/read_adapter_interface.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "caffe2/serialize/read_adapter_interface.h" - -namespace caffe2 { -namespace serialize { - -// NOLINTNEXTLINE(modernize-use-equals-default) -ReadAdapterInterface::~ReadAdapterInterface() {} - -} // namespace serialize -} // namespace caffe2 diff --git a/caffe2/serialize/read_adapter_interface.h b/caffe2/serialize/read_adapter_interface.h index 0a6b5b74a762e7..450db5c9b739ca 100644 --- a/caffe2/serialize/read_adapter_interface.h +++ b/caffe2/serialize/read_adapter_interface.h @@ -3,21 +3,19 @@ #include #include -#include "c10/macros/Macros.h" +#include -namespace caffe2 { -namespace serialize { +namespace caffe2::serialize { // this is the interface for the (file/stream/memory) reader in -// PyTorchStreamReader. with this interface, we can extend the support +// PyTorchStreamReader. With this interface, we can extend the support // besides standard istream class TORCH_API ReadAdapterInterface { public: virtual size_t size() const = 0; virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") const = 0; - virtual ~ReadAdapterInterface(); + virtual ~ReadAdapterInterface() = default; }; -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 6e2c27adc8fae8..aa3a6b1753aa67 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -1,8 +1,7 @@ #pragma once #include -namespace caffe2 { -namespace serialize { +namespace caffe2::serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; @@ -129,5 +128,4 @@ constexpr uint64_t kProducedBytecodeVersion = 0x8L; constexpr uint64_t kMinSupportedBytecodeVersion = 0x4L; constexpr uint64_t kMaxSupportedBytecodeVersion = 0x9L; -} // namespace serialize -} // namespace caffe2 +} // namespace caffe2::serialize diff --git a/caffe2/utils/string_utils.cc b/caffe2/utils/string_utils.cc index ba763738f7bb04..8e31ada6e12ec7 100644 --- a/caffe2/utils/string_utils.cc +++ b/caffe2/utils/string_utils.cc @@ -51,7 +51,7 @@ size_t editDistance( (c)=(uint8_t)(s)[(i)++]; \ } -int32_t editDistanceHelper(const char* s1, +size_t editDistanceHelper(const char* s1, size_t s1_len, const char* s2, size_t s2_len, diff --git a/caffe2/utils/string_utils.h b/caffe2/utils/string_utils.h index 6bf3b867b5434e..677e25deed76e2 100644 --- a/caffe2/utils/string_utils.h +++ b/caffe2/utils/string_utils.h @@ -39,7 +39,7 @@ TORCH_API inline bool EndsWith( } } -TORCH_API int32_t editDistanceHelper( +TORCH_API size_t editDistanceHelper( const char* s1, size_t s1_len, const char* s2, diff --git a/caffe2/utils/threadpool/ThreadPool.h b/caffe2/utils/threadpool/ThreadPool.h index dbaefc51389ff9..b5bc9dd44fddaa 100644 --- a/caffe2/utils/threadpool/ThreadPool.h +++ b/caffe2/utils/threadpool/ThreadPool.h @@ -1,15 +1,11 @@ #ifndef CAFFE2_UTILS_THREADPOOL_H_ #define CAFFE2_UTILS_THREADPOOL_H_ -#include "ThreadPoolCommon.h" - -#include #include #include #include -#include -#include "caffe2/core/common.h" +#include // // A work-stealing threadpool loosely based off of pthreadpool diff --git a/caffe2/utils/threadpool/pthreadpool.cc b/caffe2/utils/threadpool/pthreadpool.cc index b8c6c7cebb8e8a..2f2866a77d4a9d 100644 --- a/caffe2/utils/threadpool/pthreadpool.cc +++ b/caffe2/utils/threadpool/pthreadpool.cc @@ -1,9 +1,9 @@ /* Standard C headers */ -#include -#include -#include -#include -#include +#include + +#include +#include +#include #include #ifdef _MSC_VER @@ -71,8 +71,8 @@ void legacy_pthreadpool_compute_1d_tiled( } struct compute_2d_context { - legacy_pthreadpool_function_2d_t function; - void* argument; + legacy_pthreadpool_function_2d_t function{}; + void* argument{}; caffe2::FixedDivisor range_j; }; @@ -80,8 +80,8 @@ static void compute_2d(void* context_, size_t linear_index) { TORCH_DCHECK_LE(linear_index, std::numeric_limits::max()); const struct compute_2d_context* context = static_cast(context_); - int32_t q; - int32_t r; + int32_t q = 0; + int32_t r = 0; context->range_j.DivMod(static_cast(linear_index), &q, &r); context->function(context->argument, q, r); } @@ -112,18 +112,18 @@ void legacy_pthreadpool_compute_2d( } struct compute_2d_tiled_context { - legacy_pthreadpool_function_2d_tiled_t function; - void* argument; + legacy_pthreadpool_function_2d_tiled_t function{}; + void* argument{}; caffe2::FixedDivisor tile_range_j; - size_t range_i; - size_t range_j; - size_t tile_i; - size_t tile_j; + size_t range_i{}; + size_t range_j{}; + size_t tile_i{}; + size_t tile_j{}; }; static void compute_2d_tiled(void* context_, size_t linear_index) { - int32_t q; - int32_t r; + int32_t q = 0; + int32_t r = 0; const struct compute_2d_tiled_context* context = static_cast(context_); context->tile_range_j.DivMod(linear_index, &q, &r); @@ -172,26 +172,26 @@ void legacy_pthreadpool_compute_2d_tiled( } struct compute_3d_tiled_context { - legacy_pthreadpool_function_3d_tiled_t function; - void* argument; + legacy_pthreadpool_function_3d_tiled_t function{}; + void* argument{}; caffe2::FixedDivisor tile_range_j; caffe2::FixedDivisor tile_range_k; - size_t range_i; - size_t range_j; - size_t range_k; - size_t tile_i; - size_t tile_j; - size_t tile_k; + size_t range_i{}; + size_t range_j{}; + size_t range_k{}; + size_t tile_i{}; + size_t tile_j{}; + size_t tile_k{}; }; static void compute_3d_tiled( void* context_, size_t linear_index) { - int32_t tile_index_ij, tile_index_k; + int32_t tile_index_ij = 0, tile_index_k = 0; const struct compute_3d_tiled_context* context = static_cast(context_); context->tile_range_k.DivMod( static_cast(linear_index), &tile_index_ij, &tile_index_k); - int32_t tile_index_i, tile_index_j; + int32_t tile_index_i = 0, tile_index_j = 0; context->tile_range_j.DivMod(tile_index_ij, &tile_index_i, &tile_index_j); const size_t max_tile_i = context->tile_i; const size_t max_tile_j = context->tile_j; @@ -261,31 +261,31 @@ void legacy_pthreadpool_compute_3d_tiled( } struct compute_4d_tiled_context { - legacy_pthreadpool_function_4d_tiled_t function; - void* argument; + legacy_pthreadpool_function_4d_tiled_t function{}; + void* argument{}; caffe2::FixedDivisor tile_range_kl; caffe2::FixedDivisor tile_range_j; caffe2::FixedDivisor tile_range_l; - size_t range_i; - size_t range_j; - size_t range_k; - size_t range_l; - size_t tile_i; - size_t tile_j; - size_t tile_k; - size_t tile_l; + size_t range_i{}; + size_t range_j{}; + size_t range_k{}; + size_t range_l{}; + size_t tile_i{}; + size_t tile_j{}; + size_t tile_k{}; + size_t tile_l{}; }; static void compute_4d_tiled( void* context_, size_t linear_index) { - int32_t tile_index_ij, tile_index_kl; + int32_t tile_index_ij = 0, tile_index_kl = 0; const struct compute_4d_tiled_context* context = static_cast(context_); context->tile_range_kl.DivMod( static_cast(linear_index), &tile_index_ij, &tile_index_kl); - int32_t tile_index_i, tile_index_j; + int32_t tile_index_i = 0, tile_index_j = 0; context->tile_range_j.DivMod(tile_index_ij, &tile_index_i, &tile_index_j); - int32_t tile_index_k, tile_index_l; + int32_t tile_index_k = 0, tile_index_l = 0; context->tile_range_l.DivMod(tile_index_kl, &tile_index_k, &tile_index_l); const size_t max_tile_i = context->tile_i; const size_t max_tile_j = context->tile_j; diff --git a/caffe2/utils/threadpool/pthreadpool_impl.cc b/caffe2/utils/threadpool/pthreadpool_impl.cc index ae031ca2ae7e60..79597b5a2457e7 100644 --- a/caffe2/utils/threadpool/pthreadpool_impl.cc +++ b/caffe2/utils/threadpool/pthreadpool_impl.cc @@ -7,9 +7,7 @@ namespace caffe2 { namespace { static thread_local bool using_new_threadpool{false}; } -WithCastToNewThreadPool::WithCastToNewThreadPool(bool use_new_threadpool) { - use_new_threadpool_ = using_new_threadpool; - using_new_threadpool = use_new_threadpool; +WithCastToNewThreadPool::WithCastToNewThreadPool(bool use_new_threadpool) : use_new_threadpool_(using_new_threadpool) { } WithCastToNewThreadPool::~WithCastToNewThreadPool() { using_new_threadpool = use_new_threadpool_; diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index f7fe339b561435..42234cddf4e5d3 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 3d71329e61240d..17145fee78a6b5 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp index b8b1ca6adc0dc4..d87660238232a0 100644 --- a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include // removed after using simple type_resolver/obj_loader #include #include From 23ef29ca5cf9098fe0ec70313e43cbd7d04cd19c Mon Sep 17 00:00:00 2001 From: He Kai Date: Thu, 12 Sep 2024 03:36:03 +0000 Subject: [PATCH 0752/1018] Avoid inserting extra transpose when the input to group norm is NHWC (#135575) When the input format for group norm is NHWC and the device is privateuseone, it introduces an additional transpose operation. To avoid this issue, a check for the privateuseone device needs to be added here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135575 Approved by: https://github.com/ezyang --- aten/src/ATen/native/group_norm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 655b2c1e72173f..627fa71382e209 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -197,7 +197,8 @@ Tensor group_norm( const Tensor kEmpty; auto memory_format = input.suggest_memory_format(); - const auto& X = input.device().is_cpu() ? input.contiguous(memory_format) : input.contiguous(); + const auto& X = input.device().is_cpu() || input.is_privateuseone() ? + input.contiguous(memory_format) : input.contiguous(); const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty; const auto& beta = bias.defined() ? bias.contiguous() : kEmpty; TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C); From 1815d13274c33ae1cea50202185dc06b0e0476a9 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 12 Sep 2024 03:53:27 +0000 Subject: [PATCH 0753/1018] torch.hub: add get_dir/set_dir type hints (#134906) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134906 Approved by: https://github.com/Skylion007 --- torch/hub.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/hub.py b/torch/hub.py index 096e463fcba10f..c037c6e9dc139a 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -12,7 +12,7 @@ import warnings import zipfile from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from typing_extensions import deprecated from urllib.error import HTTPError, URLError from urllib.parse import urlparse # noqa: F401 @@ -389,7 +389,7 @@ def _load_entry_from_hubconf(m, model): return func -def get_dir(): +def get_dir() -> str: r""" Get the Torch Hub cache directory used for storing downloaded models & weights. @@ -408,7 +408,7 @@ def get_dir(): return os.path.join(_get_torch_home(), "hub") -def set_dir(d): +def set_dir(d: Union[str, os.PathLike]) -> None: r""" Optionally set the Torch Hub directory used to save downloaded models & weights. From 2c99d6c82561b28af1dd185b58da91bcf55a2df3 Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 12 Sep 2024 03:59:32 +0000 Subject: [PATCH 0754/1018] Fix build warnings for torch_python (#134981) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134981 Approved by: https://github.com/ezyang --- .../quantized/cpu/qlinear_deserialize.cpp | 24 +++++++++---------- .../native/quantized/cpu/conv_serialization.h | 6 ++--- c10/util/StringUtil.cpp | 5 ++-- .../csrc/distributed/c10d/PyProcessGroup.hpp | 1 + torch/csrc/dynamo/cache_entry.cpp | 7 +++++- torch/csrc/dynamo/cache_entry.h | 5 ++++ 6 files changed, 29 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp index aa2bab7c6b9455..bdc3a554b85b7e 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp @@ -11,18 +11,18 @@ namespace ao { namespace sparse { namespace { -constexpr int64_t serialization_version_index = 0; -constexpr int64_t bias_index = 1; -constexpr int64_t out_features_block_size_index = 2; -constexpr int64_t in_features_block_size_index = 3; -constexpr int64_t weight_scales_index = 4; -constexpr int64_t weight_zero_point_index = 5; -constexpr int64_t quantization_scheme_index = 6; -constexpr int64_t row_block_indices_index = 7; -constexpr int64_t col_block_indices_index = 8; -constexpr int64_t weight_values_index = 9; -constexpr int64_t num_output_channels_index = 10; -constexpr int64_t num_input_channels_index = 11; +constexpr int64_t serialization_version_index [[maybe_unused]] = 0; +constexpr int64_t bias_index [[maybe_unused]] = 1; +constexpr int64_t out_features_block_size_index [[maybe_unused]] = 2; +constexpr int64_t in_features_block_size_index [[maybe_unused]] = 3; +constexpr int64_t weight_scales_index [[maybe_unused]] = 4; +constexpr int64_t weight_zero_point_index [[maybe_unused]] = 5; +constexpr int64_t quantization_scheme_index [[maybe_unused]] = 6; +constexpr int64_t row_block_indices_index [[maybe_unused]] = 7; +constexpr int64_t col_block_indices_index [[maybe_unused]] = 8; +constexpr int64_t weight_values_index [[maybe_unused]] = 9; +constexpr int64_t num_output_channels_index [[maybe_unused]] = 10; +constexpr int64_t num_input_channels_index [[maybe_unused]] = 11; template std::vector unwrap_vector(at::Tensor tensor) { diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h index 85451fb57482a3..9f2dfd26118acf 100644 --- a/aten/src/ATen/native/quantized/cpu/conv_serialization.h +++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h @@ -313,9 +313,9 @@ c10::intrusive_ptr> deserialize_conv( output_padding.emplace_back(config_vals.at(idx)); idx++; } - int64_t groups = config_vals.at(idx); + int64_t groups [[maybe_unused]] = config_vals.at(idx); idx++; - int64_t flags = config_vals.at(idx); + int64_t flags [[maybe_unused]] = config_vals.at(idx); idx++; TORCH_INTERNAL_ASSERT(idx == static_cast(config_vals.size()), "Unexpected length of config_vals, expected ", @@ -323,7 +323,7 @@ c10::intrusive_ptr> deserialize_conv( " got ", config_vals.size()); - bool transpose = flags & (1 << 0); + bool transpose [[maybe_unused]] = flags & (1 << 0); int64_t other_flags = flags & ~(1 << 0); TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, "."); diff --git a/c10/util/StringUtil.cpp b/c10/util/StringUtil.cpp index 1f5254a3dedaa2..b92802d956c806 100644 --- a/c10/util/StringUtil.cpp +++ b/c10/util/StringUtil.cpp @@ -41,15 +41,14 @@ std::ostream& _strFromWide(std::ostream& ss, const std::wstring& wString); #ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-declarations") // TODO (huydhn) https://en.cppreference.com/w/cpp/header/codecvt has been // deprecated in C++17 but there is no alternative yet, so I just ack it std::ostream& _strFromWide(std::ostream& ss, const std::wstring& wString) { std::wstring_convert> converter; return _str(ss, converter.to_bytes(wString)); } -#pragma GCC diagnostic pop +C10_DIAGNOSTIC_POP() #else // #ifndef _WIN32 // The WIN32 implementation of wstring_convert leaks memory; see diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index 265c78f1b78cf9..3655984d452a92 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -210,6 +210,7 @@ class TORCH_PYTHON_API PythonOnCompletionHook { // Wraps a py::object hook and acquires Python GIL in dtor before // destructing the hook object. PythonOnCompletionHook(py::object hook) : hook_(std::move(hook)) {} + PythonOnCompletionHook(const PythonOnCompletionHook&) = default; ~PythonOnCompletionHook() { py::gil_scoped_acquire ag; diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp index 1142187ffc21d1..ea61825d614e86 100644 --- a/torch/csrc/dynamo/cache_entry.cpp +++ b/torch/csrc/dynamo/cache_entry.cpp @@ -5,7 +5,7 @@ #include CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) - : backend(backend) { + : backend{backend} { this->check_fn = guarded_code.attr("check_fn"); this->code = guarded_code.attr("code"); this->compile_id = guarded_code.attr("compile_id"); @@ -16,12 +16,17 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) } } +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( + "-Wdeprecated-copy-with-user-provided-dtor") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") // NOLINTNEXTLINE(bugprone-exception-escape) CacheEntry::~CacheEntry() { // prevent check_fn from use-after-free when invalidating this->check_fn.attr("cache_entry") = py::none(); this->check_fn.attr("extra_state") = py::none(); } +C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() py::object CacheEntry::next() { NULL_CHECK(this->_owner); diff --git a/torch/csrc/dynamo/cache_entry.h b/torch/csrc/dynamo/cache_entry.h index af909dc749b444..3d2391d23f8475 100644 --- a/torch/csrc/dynamo/cache_entry.h +++ b/torch/csrc/dynamo/cache_entry.h @@ -36,6 +36,9 @@ typedef struct ExtraState ExtraState; #ifdef __cplusplus +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( + "-Wdeprecated-copy-with-user-provided-dtor") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") typedef struct VISIBILITY_HIDDEN CacheEntry { // check the guards: lambda: : bool py::object check_fn; @@ -58,6 +61,8 @@ typedef struct VISIBILITY_HIDDEN CacheEntry { // Warning: returns a reference whose lifetime is controlled by C++ py::object next(); } CacheEntry; +C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() #endif From 86a223fa66221e41f2d68a079624ab98cf91c8f2 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 11 Sep 2024 14:18:34 -0700 Subject: [PATCH 0755/1018] [dynamo] Bug fix for _torchdynamo_inline source handling (#135612) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135612 Approved by: https://github.com/drisspg --- test/dynamo/test_repros.py | 16 ++++++++++++++++ torch/_dynamo/variables/builder.py | 12 +++++++----- torch/_dynamo/variables/functions.py | 2 ++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 66ca040dd4170a..74f5e3749c6c75 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5912,6 +5912,22 @@ def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: actual[1].untyped_storage().data_ptr(), ) + def test_torch_compile_in_compile_frame(self): + # TODO(anijain2305/yanboliang) - Dont graph break on torch.compile. + def gn(x, c=None): + if c is None: + c = 2 + return c * x + + def outer_func(x): + return torch.compile(gn)(x) + + compile_outer = torch.compile(outer_func, backend="eager") + x = torch.randn(4) + ref = outer_func(x) + res = compile_outer(x) + self.assertEqual(ref, res) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f7ac4d073b534b..bf1cbed10b8e5c 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -547,11 +547,6 @@ class Autotuner: if id_dispatch is not None: return id_dispatch(self, value) - # Note - There are some nested values where types mismatch! - # We want to get those out and wrap those. - if is_function_or_wrapper(value): - value = inspect.getattr_static(value, "_torchdynamo_inline", value) - # Everything else (NB: order matters!) if is_traceable_wrapper_subclass(value) or istype( value, config.traceable_tensor_subclasses @@ -988,6 +983,13 @@ def build_key_value(i, k, v): elif is_lru_cache_wrapped_function(value): self.install_guards(GuardBuilder.TYPE_MATCH) return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) + elif is_function_or_wrapper(value) and inspect.getattr_static( + value, "_torchdynamo_inline", False + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 90c2d0665aebc2..c7324960fb989d 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -152,6 +152,8 @@ def __init__(self, fn, is_constant=False, **kwargs) -> None: assert isinstance( fn, (types.FunctionType, torch.jit.ScriptFunction) ), f"expected FunctionType found {typestr(fn)} {fn}" + # TODO(anijain2305) - Replace directly calling UserFunctionVariable with + # VariableBuilder, which handles the wrapping of _torchdynamo_inline. # unpack @torch._dynamo.optimize()(fn) wrapped function fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) self.fn: types.FunctionType = fn From d131da35528f138763974e759c1fa1e7f27546c5 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 12 Sep 2024 04:29:37 +0000 Subject: [PATCH 0756/1018] [reland 1/3][fx] Bypass custom __setattr__ in Node.__init__ (#135733) Relands #135079 whcih was reverted by #135562 I broke this up into three parts to test internally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135733 Approved by: https://github.com/oulgen --- torch/fx/node.py | 44 ++++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/torch/fx/node.py b/torch/fx/node.py index f84b23e29ddf85..8c3461cbe23c7c 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -168,6 +168,16 @@ class Node(_NodeBase): """ _args: Tuple['Argument', ...] _kwargs: Dict[str, 'Argument'] + graph: 'Graph' + name: str + op: str + target: 'Target' + _input_nodes: Dict['Node', None] + users: Dict['Node', None] + type: Optional[Any] + _sort_key: Any + _repr_fn: Optional[Callable[['Node'], str]] + meta: Dict[str, Any] @compatibility(is_backward_compatible=True) def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', @@ -199,11 +209,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', annotation of values in the generated code or for other types of analyses. """ - super().__init__() - self.graph = graph - self.name = name # unique name of value being created assert op in _legal_ops - self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr if op == 'call_function': if not callable(target): raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' @@ -212,13 +218,22 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', if not isinstance(target, str): raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' 'but a str is expected') - self.target = target # for method/module/function, the name of the method/module/function/attr + super().__init__() + + # bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects + assign = object.__setattr__ + + assign(self, "graph", graph) + assign(self, "name", name) # unique name of value being created + assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + + assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add # All `Node`-valued inputs. Key is the Node, value is don't-care. # The public API for this is `all_input_nodes`, this private attribute # should not be accessed directly. - self._input_nodes : Dict[Node, None] = {} + assign(self, "_input_nodes", {}) self.__update_args_kwargs(args, kwargs) # All of the nodes that use the value produced by this Node @@ -226,7 +241,8 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # would appear once here, but represents two uses. # # Is a dict to act as an "ordered set". Keys are significant, value dont-care - self.users : Dict[Node, None] = {} + assign(self, "users", {}) + # Type expression representing the output value of this node. # This should contain the same class of Type objects that would appear # as type annotations for function inputs/outputs. @@ -237,15 +253,15 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # generated function return type. (Note this is a special case. ``return`` # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. - self.type : Optional[Any] = return_type - self._sort_key: Any = () + assign(self, "type", return_type) + assign(self, "_sort_key", ()) # If set, use this fn to print this node - self._repr_fn : Optional[Callable[[Node], str]] = None + assign(self, "_repr_fn", None) # Dictionary to store metadata passes need to do their # transformations. This metadata is preserved across node copies - self.meta : Dict[str, Any] = {} + assign(self, "meta", {}) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -490,14 +506,14 @@ def update_users_and_input_nodes(n: Any) -> Any: # Clear prior users and input_nodes for old_use in self._input_nodes.keys(): old_use.users.pop(self) - self._input_nodes = {} + object.__setattr__(self, "_input_nodes", {}) # bypass Node.__setattr__ # We do three things in a single pass of the args # - Normalize list->immutable_list, dict->immutable_dict, etc # - Populate self._input_nodes # - Populate arg.users[self] for each arg - self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment] - self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment] + object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes)) + object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)) def __repr__(self) -> str: if self._repr_fn: From 8adc14f79087b57378906c8bd0ef10e53516e7bb Mon Sep 17 00:00:00 2001 From: Zhenbin Lin Date: Thu, 12 Sep 2024 05:03:43 +0000 Subject: [PATCH 0757/1018] OpenReg: Split the daemon into drvier/executor (#135646) Split the daemon into a proper user-process driver vs device-process executor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135646 Approved by: https://github.com/albanD --- .../pytorch_openreg/_aten_impl.py | 14 +- .../pytorch_openreg/_device_daemon.py | 148 +++++++++++------- 2 files changed, 96 insertions(+), 66 deletions(-) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py index 534b810f9999c5..894e6ea10d129d 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -6,7 +6,7 @@ log = logging.getLogger(__name__) -from ._device_daemon import daemon +from ._device_daemon import driver from ._meta_parser import prepare_for_sending, to_device_no_copy @@ -18,7 +18,7 @@ def _register_same_name(name, with_log=False): def _(*args, **kwargs): if with_log: log.info("Calling hook %s", name) - return daemon.exec(name, *args, **kwargs) + return driver.exec(name, *args, **kwargs) _IMPL_REGISTRY[name] = _ @@ -45,11 +45,11 @@ def _openreg_kernel_fallback(op, *args, **kwargs): # handled below as a regular copy elif from_.device.type == "openreg": args, _ = prepare_for_sending((from_,), {}) - host_mem = daemon.exec("send_data", *args) + host_mem = driver.exec("send_data", *args) return to_.copy_(host_mem) elif to_.device.type == "openreg": args, _ = prepare_for_sending((to_,), {}) - daemon.exec("recv_data", from_, *args) + driver.exec("recv_data", from_, *args) return to_ else: raise RuntimeError("Should not happen") @@ -63,7 +63,7 @@ def _openreg_kernel_fallback(op, *args, **kwargs): ) elif op is torch.ops.aten._local_scalar_dense.default: args, _ = prepare_for_sending(args, {}) - host_mem = daemon.exec("send_data", *args) + host_mem = driver.exec("send_data", *args) return host_mem.item() op_name = None @@ -125,7 +125,7 @@ def _post_process(): # Slow version for data-dependent functions: # Run the op on the device just to get the output shape args_, kwargs_ = prepare_for_sending(args, kwargs) - shape = daemon.exec( + shape = driver.exec( "get_op_output_shape", op.overloadpacket._qualified_op_name, args_, @@ -142,7 +142,7 @@ def _post_process(): # 4. Run the compute and populate the output on the device args, kwargs = prepare_for_sending(args, kwargs) - daemon.exec("run_op", op_name, args, kwargs) + driver.exec("run_op", op_name, args, kwargs) if post_process is not None: post_process() diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py index 8cf727cfa2c826..3b6ee8638939bb 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py @@ -13,13 +13,6 @@ log = logging.getLogger(__name__) mp_context = torch.multiprocessing.get_context("spawn") -# Constant properties of our device -NUM_DEVICES = 7 - -# Global state of our driver -CURR_DEVICE_IDX = 0 -CURR_STREAM = 0 - # Our allocator class Allocator: @@ -75,13 +68,15 @@ def tensor_from_meta(self, meta): return view -def run_op(allocator, op_name, args, kwargs): - op, _ = torch._C._jit_get_operation(op_name) - args, kwargs = receive_after_sending(allocator, args, kwargs) - return op(*args, **kwargs) +def register(registry): + def func(fn): + registry[fn.__name__] = fn + return fn + return func -class _Daemon: + +class Driver: def __init__(self): super().__init__() self.is_initialized = False @@ -89,11 +84,20 @@ def __init__(self): def _lazy_init(self): if self.is_initialized: return + + # State of our driver + self.curr_device_idx = 0 + self.curr_stream = 0 + # Constant properties of our device + self.num_devices = 7 + self.req_queue = mp_context.Queue() self.ans_queue = mp_context.Queue() self.runner = mp_context.Process( - target=self.run_forever, args=(self.req_queue, self.ans_queue), daemon=True + target=_Executor().run_forever, + args=(self.req_queue, self.ans_queue), + daemon=True, ) self.runner.start() self.is_initialized = True @@ -101,68 +105,94 @@ def _lazy_init(self): def exec(self, cmd, *args): self._lazy_init() log.info("Main process launched: %s(*%s)", cmd, safe_str(args)) - validate_send_queue_args(cmd, args) - self.req_queue.put((cmd,) + args) - res = self.ans_queue.get() + + if cmd in Driver.registry: + res = Driver.registry[cmd](self, *args) + else: + validate_send_queue_args(cmd, args) + self.req_queue.put((cmd,) + args) + res = self.ans_queue.get() + log.info("Main process result for %s received: %s", cmd, safe_str(res)) if res == "ERROR": raise RuntimeError(f"Error in daemon while executing {cmd}, see logs") else: return res - @staticmethod - def run_forever(req_queue, ans_queue): - # Initialize our device - global CURR_DEVICE_IDX - empty_res = object() - allocator = Allocator() + registry = {} + + @register(registry) + def deviceCount(self, *args): + assert len(args) == 0 + return self.num_devices + + @register(registry) + def getDevice(self): + return self.curr_device_idx + + @register(registry) + def uncheckedSetDevice(self, *args): + assert len(args) == 1 + self.curr_device_idx = int(args[0]) + + @register(registry) + def exchangeDevice(self, *args): + assert len(args) == 1 + res = self.curr_device_idx + self.curr_device_idx = int(args[0]) + return res + +class _Executor: + def __init__(self): + self.allocator = Allocator() + + def run_forever(self, req_queue, ans_queue): # Serve all requests while True: cmd, *args = req_queue.get() log.info("Worker executing: %s", cmd) - res = empty_res - if cmd == "deviceCount": - assert len(args) == 0 - res = NUM_DEVICES - elif cmd == "getDevice": - res = CURR_DEVICE_IDX - elif cmd == "uncheckedSetDevice": - assert len(args) == 1 - CURR_DEVICE_IDX = int(args[0]) - res = None - elif cmd == "exchangeDevice": - assert len(args) == 1 - res = CURR_DEVICE_IDX - CURR_DEVICE_IDX = int(args[0]) - elif cmd == "malloc": - res = allocator.malloc(*args) - elif cmd == "free": - res = allocator.free(*args) - elif cmd == "run_op": - op_name, args, kwargs = args - run_op(allocator, op_name, args, kwargs) - res = None - elif cmd == "send_data": - assert len(args) == 1 - res = OpenRegTensorData.from_meta(allocator, args[0]) - elif cmd == "recv_data": - assert len(args) == 2 - host_tensor, dev_mem = args - dev_tensor = OpenRegTensorData.from_meta(allocator, dev_mem) - dev_tensor.copy_(host_tensor) - res = None - elif cmd == "get_op_output_shape": - op_name, args, kwargs = args - res = run_op(allocator, op_name, args, kwargs).size() + if cmd in _Executor.registry: + res = _Executor.registry[cmd](self, *args) else: log.warning("Bad command in worker") res = "ERROR" - if res == empty_res: - raise RuntimeError("Bad impl didn't return anything") log.info("Worker answering to: %s", cmd) ans_queue.put(res) + registry = {} + + @register(registry) + def malloc(self, size): + return self.allocator.malloc(size) + + @register(registry) + def free(self, ptr): + return self.allocator.free(ptr) + + def _run_op(self, op_name, args, kwargs): + op, _ = torch._C._jit_get_operation(op_name) + args, kwargs = receive_after_sending(self.allocator, args, kwargs) + return op(*args, **kwargs) + + @register(registry) + def run_op(self, op_name, args, kwargs): + self._run_op(op_name, args, kwargs) + + @register(registry) + def get_op_output_shape(self, op_name, args, kwargs): + return self._run_op(op_name, args, kwargs).size() + + @register(registry) + def send_data(self, *args): + assert len(args) == 1 + return OpenRegTensorData.from_meta(self.allocator, args[0]) + + @register(registry) + def recv_data(self, host_tensor, dev_mem): + dev_tensor = OpenRegTensorData.from_meta(self.allocator, dev_mem) + dev_tensor.copy_(host_tensor) + -daemon = _Daemon() +driver = Driver() From 893b04b50dca068421643e8412612ca75dd5f6b3 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 11 Sep 2024 18:35:48 -0700 Subject: [PATCH 0758/1018] [inductor] Optimize cache_on_self (#135445) This is a small compile time win, but also makes profiles more readable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135445 Approved by: https://github.com/oulgen --- torch/_inductor/utils.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 28b8549408f3e4..1004817d6e7dbc 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -464,13 +464,23 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: - key = f"__{fn.__name__}_cache" - - @functools.wraps(fn) - def wrapper(self): - if not hasattr(self, key): - setattr(self, key, fn(self)) - return getattr(self, key) + name = fn.__name__ + key = f"__{name}_cache" + + # wrapper is likely on the hot path, compile a specialized version of it + ctx = {"fn": fn} + exec( + f"""\ + def {name}_cache_on_self(self): + try: + return self.{key} + except AttributeError: + self.{key} = rv = fn(self) + return rv + """.lstrip(), + ctx, + ) + wrapper = functools.wraps(fn)(ctx[f"{name}_cache_on_self"]) def clear_cache(self): if hasattr(self, key): From 14a98ed1f2fb2be427d9614297fb8eab47fe17c1 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 11 Sep 2024 18:35:48 -0700 Subject: [PATCH 0759/1018] [inductor] Cache get_operation_names/get_buffer_names (#135446) Before: ![image](https://github.com/user-attachments/assets/db5b6fce-d849-4512-a21d-7a09efc72311) After: ![image](https://github.com/user-attachments/assets/097e340c-03b2-491e-ad36-132350b37892) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135446 Approved by: https://github.com/oulgen ghstack dependencies: #135445 --- torch/_inductor/scheduler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 564a9b4ccfd833..5c05f1200f6749 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -329,11 +329,13 @@ def get_name(self) -> str: def get_first_name(self) -> str: return self.get_name() + @cache_on_self def get_operation_names(self) -> OrderedSet[str]: - return OrderedSet(node.get_name() for node in self.get_nodes()) + return OrderedSet([node.get_name() for node in self.get_nodes()]) + @cache_on_self def get_buffer_names(self) -> OrderedSet[str]: - return OrderedSet(out.get_name() for out in self.outputs) + return OrderedSet([out.get_name() for out in self.outputs]) def get_nodes(self) -> Sequence[BaseSchedulerNode]: return [self] From f818b410bba96cb7e73eb67786754ac67cbe7065 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 11 Sep 2024 18:35:48 -0700 Subject: [PATCH 0760/1018] [inductor] Skip unused call to get_estimated_runtime() (#135776) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135776 Approved by: https://github.com/oulgen ghstack dependencies: #135445, #135446 --- torch/_inductor/scheduler.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5c05f1200f6749..1fb91a3e938936 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -504,6 +504,7 @@ def codegen_originating_info( buffer.writelines(out_lines) self.written = True + @cache_on_self def get_read_write_buffers_sizes(self) -> int: """ Counting the number of bytes accessed for a kernel is @@ -609,6 +610,7 @@ def get_buf_bytes(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int: return node_bytes + @cache_on_self def get_estimated_runtime(self) -> float: """ Returns estimated op runtime in nanoseconds (ns) @@ -3406,17 +3408,18 @@ def _codegen(self) -> None: seen.add(key) for node in self.nodes: - try: - log.debug( - "Generating code for node %s with estimated runtime %f", - node.get_name(), - node.get_estimated_runtime(), - ) - except Exception as e: - log.debug( - "Generating code for node %s with estimated runtime 0.0", - node.get_name(), - ) + if log.isEnabledFor(logging.DEBUG): + try: + log.debug( + "Generating code for node %s with estimated runtime %f", + node.get_name(), + node.get_estimated_runtime(), + ) + except Exception as e: + log.debug( + "Generating code for node %s with estimated runtime 0.0", + node.get_name(), + ) self.enter_context(node) From 4beb08e8936f5ca4c906ca7ada120a7d3de36c56 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Thu, 12 Sep 2024 05:30:59 +0000 Subject: [PATCH 0761/1018] xpu: fix 3rd party builds on systems with cmake<3.25 (#135767) Cmake LINUX variable is available on starting from cmake 3.25. Better to use CMAKE_SYSTEM_NAME instead to relax cmake version requirement. See: https://cmake.org/cmake/help/v3.25/variable/LINUX.html Fixes: #135766 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135767 Approved by: https://github.com/malfet, https://github.com/guangyey --- cmake/Modules/FindSYCLToolkit.cmake | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cmake/Modules/FindSYCLToolkit.cmake b/cmake/Modules/FindSYCLToolkit.cmake index d96b4c8d45cfef..ec46a111eaed9f 100644 --- a/cmake/Modules/FindSYCLToolkit.cmake +++ b/cmake/Modules/FindSYCLToolkit.cmake @@ -49,7 +49,10 @@ find_file( ) # Find SYCL library fullname. -if(LINUX) +# Don't use if(LINUX) here since this requires cmake>=3.25 and file is installed +# and used by other projects. +# See: https://cmake.org/cmake/help/v3.25/variable/LINUX.html +if(CMAKE_SYSTEM_NAME MATCHES "Linux") find_library( SYCL_LIBRARY NAMES sycl-preview From 5b615475fbeb75c62f6ba25c37f6fc51f336c028 Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 11 Sep 2024 15:05:40 -0700 Subject: [PATCH 0762/1018] [aoti] Remove nlohmann/json.hpp from header (#135765) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135765 Approved by: https://github.com/malfet --- .../aoti_package/model_package_loader.cpp | 46 +++++++++---------- .../aoti_package/model_package_loader.h | 12 ----- 2 files changed, 23 insertions(+), 35 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 1c09d9186d0a15..5f758787e658fa 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -64,8 +64,8 @@ std::string create_temp_dir() { namespace torch::inductor { -const nlohmann::json& AOTIModelPackageLoader::load_json_file( - std::string json_path) { +namespace { +const nlohmann::json& load_json_file(std::string json_path) { if (!file_exists(json_path)) { throw std::runtime_error("File found: " + json_path); } @@ -78,25 +78,11 @@ const nlohmann::json& AOTIModelPackageLoader::load_json_file( return json_obj; } -void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { - // Parse metadata json file (if it exists) into the metadata_ map - size_t lastindex = cpp_filename.find_last_of('.'); - std::string metadata_json_path = - cpp_filename.substr(0, lastindex) + "_metadata.json"; - - const nlohmann::json metadata_json_obj = load_json_file(metadata_json_path); - - for (auto& item : metadata_json_obj.items()) { - metadata_[item.key()] = item.value().get(); - } -} - -std::tuple AOTIModelPackageLoader:: - get_cpp_compile_command( - const std::string& filename, - const std::vector& sources, - const nlohmann::json& compile_options, - const std::string& output_dir = "") { +std::tuple get_cpp_compile_command( + const std::string& filename, + const std::vector& sources, + const nlohmann::json& compile_options, + const std::string& output_dir = "") { // Construct the cpp command std::string compiler = compile_options["compiler"].get(); @@ -164,7 +150,7 @@ std::tuple AOTIModelPackageLoader:: return std::make_tuple(cmd, target_file); } -bool AOTIModelPackageLoader::recursive_mkdir(const std::string& dir) { +bool recursive_mkdir(const std::string& dir) { // Creates directories recursively, copied from jit_utils.cpp // Check if current dir exists const char* p_dir = dir.c_str(); @@ -204,7 +190,7 @@ bool AOTIModelPackageLoader::recursive_mkdir(const std::string& dir) { return ret == 0; } -std::string AOTIModelPackageLoader::compile_so( +std::string compile_so( const std::string& cpp_filename, const std::string& consts_filename) { // Compile the cpp file into a .so @@ -269,6 +255,20 @@ std::string AOTIModelPackageLoader::compile_so( return output_so; } +} // namespace + +void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { + // Parse metadata json file (if it exists) into the metadata_ map + size_t lastindex = cpp_filename.find_last_of('.'); + std::string metadata_json_path = + cpp_filename.substr(0, lastindex) + "_metadata.json"; + + const nlohmann::json metadata_json_obj = load_json_file(metadata_json_path); + + for (auto& item : metadata_json_obj.items()) { + metadata_[item.key()] = item.value().get(); + } +} AOTIModelPackageLoader::AOTIModelPackageLoader( const std::string& model_package_path) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.h b/torch/csrc/inductor/aoti_package/model_package_loader.h index 03dc1c64018ddd..70c9514849da5b 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.h +++ b/torch/csrc/inductor/aoti_package/model_package_loader.h @@ -4,8 +4,6 @@ #include #include -#include - namespace torch::inductor { class TORCH_API AOTIModelPackageLoader { public: @@ -24,16 +22,6 @@ class TORCH_API AOTIModelPackageLoader { std::unordered_map metadata_; void load_metadata(const std::string& cpp_filename); - std::string compile_so( - const std::string& cpp_filename, - const std::string& consts_filename); - const nlohmann::json& load_json_file(std::string json_path); - std::tuple get_cpp_compile_command( - const std::string& filename, - const std::vector& sources, - const nlohmann::json& compile_options, - const std::string& output_dir); - bool recursive_mkdir(const std::string& dir); }; } // namespace torch::inductor From e623dedc94972fc7a7dea9e30488f45f60341a67 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 12 Sep 2024 05:50:39 +0000 Subject: [PATCH 0763/1018] [reland 3/3][fx] Bypass custom __setattr__ in Node.__init__ (#135735) Relands #135079 whcih was reverted by #135562 I broke this up into three parts to test internally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135735 Approved by: https://github.com/oulgen --- torch/fx/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index e4fdd79fcbb280..1601eee07f2717 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -816,10 +816,10 @@ def remove(self, node: Node) -> None: def find_nodes(self, *, op: str, target: Optional['Target'] = None): if op == "call_function": assert target is not None - return dict(self.table[(op, target)]).keys() + return [*self.table[(op, target)].keys()] if target is None: - return dict(self.table[(op, None)]).keys() + return [*self.table[(op, None)].keys()] # op is call_method, get_attr, call_module return [node for node in self.table[(op, None)].keys() if node.target == target] From 9cb243cece50dd511e2e0708e079c431623d855b Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:20:57 -0700 Subject: [PATCH 0764/1018] [dtensor][BE] replace compute_local_shape with compute_local_shape_and_global_offset (#135554) **Summary** 1. This PR removes the public API `compute_local_shape` and replace its use with the more general API `compute_local_shape_and_global_offset`. 2. To keep `compute_local_shape_and_global_offset` consistent with `compute_local_shape` on empty shards, it now returns local tensor shape `(0,)` for empty shards which is more aligned with DTensor's semantics on non-participating ranks. **Test** `pytest test/distributed/_tensor/test_dtensor.py` `pytest test/distributed/_tensor/test_init.py` `pytest test/distributed/_tensor/test_tensor_ops.py` Differential Revision: [D62415591](https://our.internmc.facebook.com/intern/diff/D62415591) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135554 Approved by: https://github.com/tianyu-l, https://github.com/wz337 --- test/distributed/_tensor/test_utils.py | 50 +++---------------- torch/distributed/tensor/_api.py | 7 ++- .../distributed/tensor/_ops/_common_rules.py | 4 +- torch/distributed/tensor/_sharding_prop.py | 4 +- torch/distributed/tensor/_utils.py | 34 +------------ 5 files changed, 18 insertions(+), 81 deletions(-) diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index ea05d2027b5829..e41c990f21a583 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -4,11 +4,8 @@ import torch from torch.distributed._tensor import distribute_tensor, DTensor -from torch.distributed._tensor._utils import ( - compute_local_shape, - compute_local_shape_and_global_offset, -) -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed._tensor._utils import compute_local_shape_and_global_offset +from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard @@ -27,48 +24,17 @@ class UtilTest(DTensorTestBase): def world_size(self): return 8 - @with_comms - def test_compute_local_shape_2d_uneven(self): - # mesh: 4 * 2 - mesh_tensor = torch.arange(self.world_size).reshape(4, 2) - mesh = DeviceMesh(self.device_type, mesh_tensor) - size = torch.Size([7, 7]) - rank_coordinates = mesh.get_coordinate() - - # replicate, shard - placements2 = [Replicate(), Shard(0)] - local_size2 = compute_local_shape(size, mesh, placements2) - if rank_coordinates[1] < 1: - self.assertEqual(local_size2, torch.Size([4, 7])) - else: - self.assertEqual(local_size2, torch.Size([3, 7])) - - # shard, shard - placements3 = [Shard(0), Shard(1)] - local_size3 = compute_local_shape(size, mesh, placements3) - # first dim - if rank_coordinates[0] < 3: - self.assertEqual(local_size3[0], 2) - else: - self.assertEqual(local_size3[0], 1) - # second dim - if rank_coordinates[1] < 1: - self.assertEqual(local_size3[1], 4) - else: - self.assertEqual(local_size3[1], 3) - @with_comms def test_compute_local_shape_and_global_offset_1D(self): one_d_placements = [[Shard(0)], [Replicate()]] + device_mesh = init_device_mesh(self.device_type, (self.world_size,)) for placements in one_d_placements: # When the placements is [Shard(0)], we test for three different scenarios: # 1) sharding resulting in empty shards on all or some of the ranks # 2) sharding resulting in shards of different size across different ranks # 3) sharding resulting in non-empty shards of same size across all ranks for size in range(self.world_size * 2 + 1): - mesh_tensor = torch.arange(self.world_size) - device_mesh = DeviceMesh(self.device_type, mesh_tensor) global_tensor = torch.arange(size) global_shape = global_tensor.size() @@ -96,12 +62,12 @@ def test_compute_local_shape_and_global_offset_2D(self): itertools.combinations_with_replacement(two_d_placements_options, 2) ) + # mesh: 2 * 4 + device_mesh = init_device_mesh(self.device_type, (2, 4)) for placements in two_d_placements: - for dim_0_size in (1, 2, 4, 8): - # mesh: 2 * 4 - mesh_tensor = torch.arange(self.world_size).reshape(2, 4) - device_mesh = DeviceMesh(self.device_type, mesh_tensor) - global_tensor = torch.arange(64).view(dim_0_size, -1) + for dim_0_size in range(1, 9): + nelem = 64 // dim_0_size * dim_0_size + global_tensor = torch.arange(nelem).view(dim_0_size, -1) global_shape = global_tensor.size() dtensor = distribute_tensor(global_tensor, device_mesh, placements) diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index ec2993690e74e2..583aa20c1dc188 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -22,7 +22,7 @@ ) from torch.distributed.tensor._utils import ( compute_global_tensor_info, - compute_local_shape, + compute_local_shape_and_global_offset, normalize_to_torch_size, ) from torch.distributed.tensor.placement_types import ( @@ -931,7 +931,10 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] torch_stride = torch._prims_common.make_contiguous_strides_for(size) # get local tensor shape - local_shape = compute_local_shape(size, device_mesh, placements) + local_shape, _ = compute_local_shape_and_global_offset( + size, device_mesh, placements + ) + # initialize the local tensor if init_op == torch.full: fill_value = kwargs.pop("fill_value", 0) diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 2a41252be40e76..059dd04bd2f4d4 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -10,7 +10,7 @@ OutputSharding, ) from torch.distributed.tensor._ops.utils import prod -from torch.distributed.tensor._utils import compute_local_shape +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: @@ -171,7 +171,7 @@ def merge_sharding(dim: str, a: int, b: int) -> int: ): assert input_spec.tensor_meta is not None global_shape = input_spec.tensor_meta.shape - local_shape = compute_local_shape( + local_shape, _ = compute_local_shape_and_global_offset( global_shape, input_spec.mesh, input_spec.placements ) cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index c277655ff4bf0f..ec538d35ed66d4 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -21,7 +21,7 @@ TupleStrategy, ) from torch.distributed.tensor._utils import ( - compute_local_shape, + compute_local_shape_and_global_offset, compute_local_stride, try_find_mesh_from_args, ) @@ -484,7 +484,7 @@ def _adjust_shape_and_stride_args( expected_input_schema = list(schema.args_schema) # adjust shape to be the same as that of the _local_tensor # of the DTensor input arg at index 0, which is inferred - expected_input_schema[shape_idx] = compute_local_shape( + expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( out_tensor_meta.shape, mesh, spec.placements ) diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 55a892063db926..182d7f07552751 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -14,38 +14,6 @@ ) -# TODO: audit existing code base to see if we can safely remove this API. -def compute_local_shape( - global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] -) -> Tuple[int, ...]: - """ - Compute the shape of a local shard of the given DTensor on its current - coordinate of the mesh. - """ - my_coordinate = mesh.get_coordinate() - - if my_coordinate is None: - # if rank not in the mesh, return empty shape - return (0,) - else: - local_shape = list(global_shape) # start with global shape - ndim = len(global_shape) - for idx, placement in enumerate(placements): - mesh_dim_size = mesh.size(idx) - if isinstance(placement, Shard): - shard_dim = placement.dim - assert ( - shard_dim < ndim - ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" - local_shard_size, _ = placement._local_shard_size_on_dim( - local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] - ) - assert isinstance(local_shard_size, int) - local_shape[shard_dim] = local_shard_size - - return tuple(local_shape) - - def compute_local_shape_and_global_offset( global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: @@ -90,7 +58,7 @@ def compute_local_shape_and_global_offset( if my_coordinate is None: # if rank not in the mesh, return empty offset - return ((), ()) + return ((0,), ()) else: local_shape = list(global_shape) global_offset = [0] * len(global_shape) From c5a07bf622dc22decf9fa80ffe8e19fb56555b6b Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Wed, 11 Sep 2024 01:27:39 -0700 Subject: [PATCH 0765/1018] [Inductor] simplify indexing_exprs in LoopBody._init_with_copy (#135574) This PR uses `var_ranges` information to simplify `indexing_exprs` in `LoopBody._init_with_copy` to to reduce occurrences of `FloorDiv` and `ModularIndexing` in the `indexing_exprs`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135574 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel --- test/inductor/test_cpu_repro.py | 1 - torch/_inductor/loop_body.py | 6 +++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 669d8ad598b1bd..f90b9da1f11a34 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3422,7 +3422,6 @@ def fn(): dtype if dtype else torch.float32, ) - @config.patch("cpp.enable_tiling_heuristics", False) def test_group_norm_vec(self): class M(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index dcf0b61ff5adac..2f0b00b724e71f 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -115,7 +115,11 @@ def _init_with_copy(self, other: LoopBody, args): where we are just reordering/merging/splitting the args of an existing LoopBody. """ - self.indexing_exprs = other.indexing_from_args(args) + indexing_exprs = other.indexing_from_args(args) + self.indexing_exprs = { + name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges) + for name, expr in indexing_exprs.items() + } self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} self.indirect_vars = other.indirect_vars self.indirect_var_ranges = other.indirect_var_ranges From 56a63335fe96fbd595117ef339dcc4352c30ac0d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 11 Sep 2024 20:53:10 -0700 Subject: [PATCH 0766/1018] [GPT-fast] Update compilation time target for Llama & Mixtral (#135817) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135817 Approved by: https://github.com/xmfan, https://github.com/huydhn --- benchmarks/gpt_fast/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index bbfffa75ac9fdc..56e6cff1cf8ced 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -264,7 +264,7 @@ def run_llama2_7b_bf16(device: str = "cuda"): LLaMAWeightOnlyInt8QuantHandler, 94, 1253, - 162, + 133, ) token_per_sec, memory_bandwidth, compilation_time = run_experiment( model, device=device @@ -314,7 +314,7 @@ def run_llama2_7b_int8(device: str = "cuda"): LLaMAWeightOnlyInt8QuantHandler, 144, 957, - 172, + 136, ) token_per_sec, memory_bandwidth, compilation_time = run_experiment( model, device=device @@ -365,7 +365,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): MixtralMoEWeightOnlyInt8QuantHandler, 175, 1130, - 162, + 133, ) token_per_sec, memory_bandwidth, compilation_time = run_experiment( model, device=device From 92a5b4e5ff593ce58fc8b98b246f7a84e4577c0c Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 11 Sep 2024 22:11:49 -0700 Subject: [PATCH 0767/1018] [inductor] Split reduction loops when there is no shared reads (#134307) Fixes #129102 ![image](https://github.com/user-attachments/assets/0d00f75b-2bb9-4ce6-a0d9-2daceaff539c) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134307 Approved by: https://github.com/shunting314 --- torch/_inductor/codegen/simd.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 59c126b993e3d4..148d062b7a7a3e 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1068,6 +1068,8 @@ def generate_node_schedule(self, nodes, numel, rnumel): # Writes with a reduced shape, meaning they are only present once the # reduction loop has ended not_ready_yet_nodes: OrderedSet[str] = OrderedSet() + current_loop_buffer_usage: OrderedSet[str] = OrderedSet() + maybe_split_index: Optional[int] = None def fits_in_main_body(n): _, (node_numel, node_rnumel) = n.group @@ -1079,9 +1081,17 @@ def fits_outside_reduction(n): _, (node_numel, node_rnumel) = n.group return node_numel == numel and node_rnumel == 1 and rnumel != 1 + def expect_improved_memory_usage(n): + for read in n.read_writes.reads: + if read.name in current_loop_buffer_usage: + return True + return False + def schedule_node_in_loop(n): done.add(n) node_schedule.append(n) + current_loop_buffer_usage.update([x.name for x in n.read_writes.reads]) + # A scan is modelled as a reduction in the scheduler but has a # full sized output that can be used inside the loop body if ( @@ -1091,16 +1101,24 @@ def schedule_node_in_loop(n): and not isinstance(n.node.data, ir.Scan) ): not_ready_yet_nodes.add(n.get_name()) + else: # this node is available within the loop + current_loop_buffer_usage.update([x.name for x in n.read_writes.writes]) @contextlib.contextmanager def end_current_reduction_loop(): + nonlocal maybe_split_index if node_schedule and node_schedule[-1] is EnableReduction: node_schedule.pop() else: node_schedule.append(DisableReduction) + if maybe_split_index: + node_schedule.insert(maybe_split_index, DisableReduction) + node_schedule.insert(maybe_split_index + 1, EnableReduction) + maybe_split_index = None yield node_schedule.append(EnableReduction) not_ready_yet_nodes.clear() + current_loop_buffer_usage.clear() def requires_closing_previous_reduction(node, node_schedule): if rnumel == 1: @@ -1122,6 +1140,13 @@ def requires_closing_previous_reduction(node, node_schedule): with end_current_reduction_loop(): pass # need to start a new reduction loop + if current_loop_buffer_usage and not expect_improved_memory_usage(node): + # If we don't improve memory usage, then it is better to split into two loops + maybe_split_index = maybe_split_index or len(node_schedule) + else: + # Memory usage got improved, cancel the loop split + maybe_split_index = None + schedule_node_in_loop(node) elif fits_outside_reduction(node): with end_current_reduction_loop(): From 142bcd35c85bbc840725c9976af5b70ad8e58a59 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 11 Sep 2024 21:23:13 -0700 Subject: [PATCH 0768/1018] [dynamo] Dont graph break on inner torch.compile (#135819) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135819 Approved by: https://github.com/jansel --- test/dynamo/test_repros.py | 5 ++--- torch/_dynamo/variables/torch.py | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 74f5e3749c6c75..e1e1dd1837bef4 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5913,16 +5913,15 @@ def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ) def test_torch_compile_in_compile_frame(self): - # TODO(anijain2305/yanboliang) - Dont graph break on torch.compile. def gn(x, c=None): if c is None: c = 2 return c * x def outer_func(x): - return torch.compile(gn)(x) + return torch.compile(gn, backend="eager")(x) - compile_outer = torch.compile(outer_func, backend="eager") + compile_outer = torch.compile(outer_func, backend="eager", fullgraph=True) x = torch.randn(4) ref = outer_func(x) res = compile_outer(x) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index b32ac9af6ccfd9..eca9d25adfe4ac 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -446,6 +446,14 @@ def handle_numel(self, tx: "InstructionTranslator", input): # Workaround dynamic shapes issue return input.call_method(tx, "numel", [], {}) + @register(torch.compile) + def handle_torch_compile(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) == 1: + # torch.compile is a no-op in dynamo + return args[0] + + unimplemented("torch.compile is used as a decorator in the compiled frame") + @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): assert isinstance(input, TensorVariable) From 0f8c7e4dc76a3a355356df60162b8442a4adac45 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 12 Sep 2024 14:24:43 +0000 Subject: [PATCH 0769/1018] [ONNX] Fix symbolic values and numpy implementation (#135786) 1. Remove `__eq__` to make `SymbolicTensor` hashable and test for that 2. Update the `__array__` method so that it works for tensor on GPU Fixes https://github.com/pytorch/pytorch/issues/135700 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135786 Approved by: https://github.com/titaiwangms --- test/onnx/exporter/test_tensors.py | 22 ++++++++++++++++++++++ torch/onnx/_internal/exporter/_core.py | 7 ++++--- torch/onnx/_internal/exporter/_tensors.py | 3 --- 3 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 test/onnx/exporter/test_tensors.py diff --git a/test/onnx/exporter/test_tensors.py b/test/onnx/exporter/test_tensors.py new file mode 100644 index 00000000000000..5f1fb0e88b6782 --- /dev/null +++ b/test/onnx/exporter/test_tensors.py @@ -0,0 +1,22 @@ +# Owner(s): ["module: onnx"] +"""Unit tests for the _tensors module.""" + +from __future__ import annotations + +import onnxscript + +from torch.onnx._internal.exporter import _tensors +from torch.testing._internal import common_utils + + +class SymbolicTensorTest(common_utils.TestCase): + def test_it_is_hashable(self): + tensor = _tensors.SymbolicTensor( + opset=onnxscript.values.Opset(domain="test", version=1) + ) + self.assertEqual(hash(tensor), hash(tensor)) + self.assertIn(tensor, {tensor}) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index ae216c61f13c93..3f64b34816e187 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -99,8 +99,9 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None): def __array__(self, dtype: Any = None) -> np.ndarray: # numpy() calls __array__ in ir.Tensor + self.raw: torch.Tensor if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).__array__(dtype) + return self.raw.view(torch.uint16).numpy(force=True).__array__(dtype) if self.dtype in { ir.DataType.FLOAT8E4M3FN, ir.DataType.FLOAT8E4M3FNUZ, @@ -108,8 +109,8 @@ def __array__(self, dtype: Any = None) -> np.ndarray: ir.DataType.FLOAT8E5M2FNUZ, }: # TODO: Use ml_dtypes - return self.raw.view(torch.uint8).__array__(dtype) - return self.raw.__array__(dtype) + return self.raw.view(torch.uint8).numpy(force=True).__array__(dtype) + return self.raw.numpy(force=True).__array__(dtype) def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 diff --git a/torch/onnx/_internal/exporter/_tensors.py b/torch/onnx/_internal/exporter/_tensors.py index cfe8f7dc2a661a..2fdafacbe06f4b 100644 --- a/torch/onnx/_internal/exporter/_tensors.py +++ b/torch/onnx/_internal/exporter/_tensors.py @@ -88,9 +88,6 @@ def __lt__(self, other): def __le__(self, other): return self._opset.LessOrEqual(self, other) - def __eq__(self, other): - return self._opset.Equal(self, other) - def __ge__(self, other): return self._opset.GreaterOrEqual(self, other) From f5d575978958f1421895c41d859c272e40e4208a Mon Sep 17 00:00:00 2001 From: whywhy-rtx3090 Date: Thu, 12 Sep 2024 15:22:06 +0000 Subject: [PATCH 0770/1018] Allow optional positional arguments for `torch.func.functional_call` (#134643) This PR resolves #134408. Add an additional test and have passed the local test. Do you think we should add a post-check to ensure `args` and `kwargs` are not both `None`? It seems to be possible to have modules without inputs. This PR does not include any such post-check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134643 Approved by: https://github.com/zou3519 --- test/test_stateless.py | 2 ++ torch/_functorch/functional_call.py | 2 +- torch/nn/utils/stateless.py | 8 +++++--- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_stateless.py b/test/test_stateless.py index a68d8ad9b5f703..a62e88d2caf034 100644 --- a/test/test_stateless.py +++ b/test/test_stateless.py @@ -738,6 +738,8 @@ def forward(self, inp, *, other_inp): self.assertEqual(res, other_inp) res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp}) self.assertEqual(res, res_1) + res_2 = functional_call(mod, a, kwargs={'inp': inp, 'other_inp': other_inp}) + self.assertEqual(res, res_2) def test_functional_call_tuple_dicts(self): mod = MockModule() diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index 86c63be17fc9d0..7ecc4a2d7449ff 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -12,7 +12,7 @@ def functional_call( module: "torch.nn.Module", parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]], - args: Union[Any, Tuple], + args: Optional[Union[Any, Tuple]] = None, kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index 69994f315f77e6..dcb80a94c950d1 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -186,7 +186,7 @@ def _reparametrize_module( def functional_call( module: "torch.nn.Module", parameters_and_buffers: Dict[str, Tensor], - args: Union[Any, Tuple], + args: Optional[Union[Any, Tuple]] = None, kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, @@ -264,7 +264,7 @@ def functional_call( def _functional_call( module: "torch.nn.Module", parameters_and_buffers: Dict[str, Tensor], - args: Union[Any, Tuple], + args: Optional[Union[Any, Tuple]] = None, kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, @@ -290,7 +290,9 @@ def _functional_call( ) if kwargs is None: kwargs = {} - if not isinstance(args, tuple): + if args is None: + args = () + elif not isinstance(args, tuple): args = (args,) with _reparametrize_module( module, parameters_and_buffers, tie_weights=tie_weights, strict=strict From ea51a8b49fad8e38c4dd945854fd59f9b34b28e1 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 11 Sep 2024 19:44:41 +0000 Subject: [PATCH 0771/1018] Use a better decomposition for split_with_sizes (#135728) This decomposition has less checks and improves the performance of torch.compile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135728 Approved by: https://github.com/ezyang --- torch/_decomp/decompositions.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 1aa13da9cacd7e..ec092a5234096d 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1409,14 +1409,17 @@ def split_with_sizes( sum(split_sizes) == self.shape[dim], lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}", ) - num_splits = len(split_sizes) - splits = [] - start_idx = 0 - for i in range(num_splits): - length = split_sizes[i] - splits.append(self.narrow(dim, start_idx, length)) - start_idx += length + splits = [] + offset = self.storage_offset() + + for split_size in split_sizes: + new_shape = list(self.shape) + new_shape[dim] = split_size + # We reimplement narrow here to avoid a lot of checks in the + # decomposition of narrow which calls slice_in_dim and slice + splits.append(self.as_strided(new_shape, self.stride(), offset)) + offset = offset + self.stride()[dim] * split_size return splits From f6bc78acda375a64d7992653233f228e917eccbe Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Thu, 12 Sep 2024 17:31:10 +0000 Subject: [PATCH 0772/1018] [AOTI][Tooling] Support debug printing for inductor level extern kernel call such as externkernel.addmm, bmm, etc. (#135731) Summary: As title. Effect after merging this diff would look something like this: ``` print('inductor: before_launch - triton_poi_fused_0 - buf0', buf0) triton_poi_fused_0.run(buf0, 6, grid=grid(6), stream=stream0) print('inductor: after_launch - triton_poi_fused_0 - buf0', buf0) buf1 = empty_strided_cuda((16, 6), (6, 1), torch.float32) # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.addmm] print('inductor: before_launch - extern_kernels.addmm - buf0', buf0) extern_kernels.addmm(buf0, reinterpret_tensor(arg2_1, (16, 16), (16, 1), 0), reinterpret_tensor(L__self___weight, (16, 6), (1, 16), 0), alpha=1, beta=1, out=buf1) print('inductor: after_launch - extern_kernels.addmm - buf0', buf0) ``` Context: D62272588 only support major triton kernel jit inductor debug printing codegen Test Plan: CI & OSS CI Reviewed By: chenyang78, ColinPeppler Differential Revision: D62397017 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135731 Approved by: https://github.com/ColinPeppler --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 20 +------------------- torch/_inductor/codegen/debug_utils.py | 14 +++++++++++++- torch/_inductor/codegen/wrapper.py | 6 +++++- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 3824d71d894823..f3d4e53dc9110c 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -12,7 +12,6 @@ import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._ops -from torch._inductor.codegen.debug_utils import IntermediateValueDebuggingLevel from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes from .. import config, ir @@ -1214,14 +1213,7 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): self.allow_stack_allocation = False wrapped_args = [] - - args_to_print_or_save = None debug_printer_manager = V.graph.wrapper_code.debug_printer - if ( - debug_printer_manager.debug_printer_level - != IntermediateValueDebuggingLevel.OFF - ): - args_to_print_or_save = [] for x in args: pieces = x.split(", ") @@ -1235,20 +1227,10 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): if isinstance(piece, str) and piece.startswith( ("buf", "arg", "wrap_with_raii_handle_if_needed") ): - # TODO: The current way to find a 'tensor' type arg is hacky also as mentioned above - # Find a more reliable way to detect tensor kernel args for extern kernel calls - if ( - debug_printer_manager.debug_printer_level - != IntermediateValueDebuggingLevel.OFF - ): - if piece.startswith(("buf", "arg")): - args_to_print_or_save.append(piece) piece = f"convert_arrayref_tensor_to_tensor({piece})" wrapped_args.append(piece) - debug_printer_manager.set_printer_args( - args_to_print_or_save, kernel, None, None - ) + debug_printer_manager.set_printer_args(args, kernel, None, None, "extern") with debug_printer_manager: shim_fn = self.get_c_shim_func_name(kernel) self.writeline( diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index f129e3f31590be..56347f77f5cc4e 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -102,6 +102,7 @@ def set_printer_args( kernel_name: str, arg_signatures: Optional[List[type]], kernel, + kernel_type=None, ): # Note: MultiKernel debug printing is not supported for now if isinstance(kernel, MultiKernel): @@ -109,7 +110,17 @@ def set_printer_args( "MultiKernel type is not supported in AOTI debug printer tool yet." ) self.debug_printer_level = IntermediateValueDebuggingLevel.OFF - self.args_to_print_or_save = args_to_print_or_save + + # Note: if the kernel type is an extern kernel, we do a special handling to get the list of args_to_print_or_save + # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls + if kernel_type == "extern": + args_to_print_or_save_extern = [] + for arg in args_to_print_or_save: + if arg.startswith(("buf", "arg")): + args_to_print_or_save_extern.append(arg) + self.args_to_print_or_save = args_to_print_or_save_extern + else: + self.args_to_print_or_save = args_to_print_or_save self.kernel_name = kernel_name self.arg_signatures = arg_signatures self.kernel = kernel @@ -187,6 +198,7 @@ def codegen_intermediate_tensor_value_print( # TODO: add non-abi compatible mode debug printing info pass else: + # TODO: add mean/min/max instead of naively print the tensor value line = ( f"print('inductor: {launch_prefix} - {kernel_name} - {arg}', {arg})" ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e04d3dfa1430e2..19e4bb293c7eff 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -787,8 +787,12 @@ def generate_extern_kernel_alloc(self, extern_kernel, args): def generate_extern_kernel_out( self, kernel: str, out: str, out_view: Optional[str], args: List[str] ): + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(args, kernel, None, None, "extern") args.append(f"out={out_view if out_view else out}") - self.writeline(f"{kernel}({', '.join(args)})") + with debug_printer_manager: + self.writeline(f"{kernel}({', '.join(args)})") def generate_user_defined_triton_kernel( self, From ced4af177925db4dfd67ecaf45aa6444bb09906e Mon Sep 17 00:00:00 2001 From: wz337 Date: Wed, 11 Sep 2024 19:05:22 -0700 Subject: [PATCH 0773/1018] [DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725) Fix https://github.com/pytorch/pytorch/issues/134095 This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135725 Approved by: https://github.com/fegin --- .../test_2d_composability.py | 58 +++++++++++++++++++ torch/distributed/_state_dict_utils.py | 11 +++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 754705b567a324..c4e89fe3421633 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -17,7 +17,9 @@ from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, + set_model_state_dict, set_optimizer_state_dict, + StateDictOptions, ) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -335,6 +337,57 @@ def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): self.assertEqual(loss_no_cp2, loss_cp2) +class TestFullyShard2DStateDict(DTensorTestBase): + @property + def backend(self): + # need to specify gloo backend for testing cpu offload + return "cpu:gloo,cuda:nccl" + + @with_comms + @skip_if_lt_x_gpu(4) + def test_fully_shard_tp_2d_set_full_state_dict(self): + dummy_model = SimpleModel().cuda() + mesh_2d = init_device_mesh( + "cuda", + (2, self.world_size // 2), + mesh_dim_names=("dp", "tp"), + ) + tp_mesh = mesh_2d["tp"] + dp_mesh = mesh_2d["dp"] + parallelize_plan = { + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + "net3": ColwiseParallel(), + } + model = parallelize_module(dummy_model, tp_mesh, parallelize_plan) + fully_shard(model, mesh=dp_mesh) + optim = torch.optim.Adam(model.parameters(), lr=0.01) + model(model.get_input()).sum().backward() + optim.step() + # ref_msd, ref_osd are both the default sharded state dict + ref_msd = copy.deepcopy(get_model_state_dict(model)) + ref_osd = copy.deepcopy(get_optimizer_state_dict(model, optimizers=optim)) + + options = StateDictOptions( + full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True + ) + full_msd = get_model_state_dict(model, options=options) + full_osd = get_optimizer_state_dict(model, optimizers=optim, options=options) + # load full_msd and full_osd into model and optim. + # this loads the slice of full tensor into each rank's local DTensor. + set_model_state_dict(model, full_msd, options=options) + set_optimizer_state_dict( + model, optimizers=optim, optim_state_dict=full_osd, options=options + ) + + # check after setting full state dict, the model and optim default sharded state dict + # are the same as the initial default sharded state dict. + new_msd = get_model_state_dict(model) + new_osd = get_optimizer_state_dict(model, optimizers=optim) + self.assertEqual(ref_msd, new_msd) + self.assertEqual(ref_osd, new_osd) + + class Test2dFSDP1ParallelIntegration(DTensorTestBase): def init_model(self, device_type, model_parallel_size=2): torch.manual_seed(0) @@ -544,6 +597,11 @@ def test_2d_e2e_training_not_use_orig_params(self): # TODO: update all state dict unit tests to use distributed.checkpoint.state_dict, # and consolidate all the state_dict test in test.distributed.checkpoint. class TestNew2dParallelStateDict(DTensorTestBase): + @property + def backend(self): + # need to specify gloo backend for testing cpu offload + return "cpu:gloo,cuda:nccl" + @with_comms @skip_if_lt_x_gpu(4) def test_fsdp_2d_extension(self): diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index f22a13cb6ec0f4..9520b195f9ed5c 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -28,6 +28,7 @@ from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed.tensor import distribute_tensor, DTensor, Replicate + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset def _identity_func( @@ -551,8 +552,14 @@ def _distribute_tensors( local_state = _local_state[0] full_tensor = _local_state[1] - local_state_dict[key] = distribute_tensor( - full_tensor, local_state.device_mesh, local_state.placements + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))] + local_tensor = full_tensor[slices] + local_state_dict[key] = DTensor.from_local( + local_tensor, local_state.device_mesh, local_state.placements ) From 4e4539da464bf3225d865300360670b02e39b5b0 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Thu, 12 Sep 2024 10:31:56 -0400 Subject: [PATCH 0774/1018] NJT <-> padded dense conversions (#125947) This PR: * Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values) * Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics * Note: there is currently no public API for this; design booted to a future PR TODO: * ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~ * ~~Verify that Inductor does computation fusion via test logic~~ Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947 Approved by: https://github.com/soulitzer --- aten/src/ATen/native/native_functions.yaml | 5 + ...asDecompTest.test_has_decomposition.expect | 1 + test/test_nestedtensor.py | 176 ++++++++++++++++++ tools/autograd/derivatives.yaml | 7 +- torch/_dynamo/trace_rules.py | 4 +- torch/_meta_registrations.py | 22 --- torch/_subclasses/fake_impls.py | 32 ++++ torch/nested/_internal/nested_tensor.py | 25 +++ torch/nested/_internal/ops.py | 92 +++++++++ 9 files changed, 340 insertions(+), 24 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c7533e4ef854c9..8a68423495d607 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14706,6 +14706,11 @@ CUDA: _fbgemm_dense_to_jagged_forward_symint CPU: _padded_dense_to_jagged_forward_cpu +- func: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + - func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor dispatch: NestedTensorCPU: NestedTensor_softmax_dropout diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 5b491d1bc73f80..a759903d065591 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -445,6 +445,7 @@ aten::_nested_from_padded aten::_nested_from_padded.out aten::_nested_from_padded_and_nested_example aten::_nested_from_padded_and_nested_example.out +aten::_nested_from_padded_tensor aten::_nested_get_jagged_dummy aten::_nested_get_lengths aten::_nested_get_max_seqlen diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 49cb422f8720cc..82db5d038bee71 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -1,5 +1,6 @@ # Owner(s): ["module: nestedtensor"] +import ast import io import itertools import math @@ -7164,6 +7165,181 @@ def check(nt): check(nt) + @dtypes(torch.float32, torch.double, torch.half) + @parametrize("nt_dim", [2, 3, 4]) + @parametrize("requires_grad", [False, True]) + def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad): + if nt_dim == 2: + post_seq_len_shape = () + elif nt_dim == 3: + post_seq_len_shape = (10,) + elif nt_dim == 4: + post_seq_len_shape = (9, 10) + + nt = torch.nested.nested_tensor( + [ + torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + for n in range(2, 9) + ], + layout=torch.jagged, + requires_grad=requires_grad, + ) + + PADDING_VAL = 4.2 + expected_padded = nt._values.new_full((7, 8, *post_seq_len_shape), PADDING_VAL) + for i, component in enumerate(nt.unbind()): + expected_padded[i, : component.shape[0]].copy_(component) + + padded = nt.to_padded_tensor(PADDING_VAL) + self.assertEqual(expected_padded, padded) + + # convert padded dense -> NJT + from torch.nested._internal.nested_tensor import nested_from_padded + + nt2 = nested_from_padded(padded, nt.offsets()) + self.assertEqual(nt, nt2) + + if requires_grad: + # ensure gradients flow through conversions + nt2.backward(torch.ones_like(nt2)) + self.assertEqual(nt.grad, torch.ones_like(nt)) + + # blows up due to test parametrization otherwise + @torch._dynamo.utils.disable_cache_limit() + @skipIfTorchDynamo("SDPA test compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + @dtypes(torch.float32, torch.double, torch.half) + @parametrize("nt_dim", [2, 3, 4]) + @parametrize("requires_grad", [False, True]) + def test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad): + if nt_dim == 2: + post_seq_len_shape = () + elif nt_dim == 3: + post_seq_len_shape = (10,) + elif nt_dim == 4: + post_seq_len_shape = (9, 10) + + nt = torch.nested.nested_tensor( + [ + torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + for n in range(2, 9) + ], + layout=torch.jagged, + requires_grad=requires_grad, + ) + + def f(x): + return x.sin() + 1 + + from torch.nested._internal.nested_tensor import nested_from_padded + + @torch.compile(fullgraph=True) + def g(nt): + def _g(nt): + PADDING_VAL = 4.2 + padded = nt.to_padded_tensor(PADDING_VAL) + padded = f(padded) + # NB: sum_S must be specified to use the lowering for dense -> jagged + # and get full fusion + return nested_from_padded( + padded, nt.offsets(), sum_S=nt.values().shape[0] + ) + + # NB: use checkpointing to force fusion + return torch.utils.checkpoint.checkpoint(_g, nt, use_reentrant=False) + + expected_output = f(nt) + if requires_grad: + expected_output.backward(torch.ones_like(expected_output)) + expected_grad = nt.grad.clone().detach() + nt.grad = None + + from torch._inductor.utils import run_and_get_code + + compiled_output, generated_code = run_and_get_code(g, nt) + if requires_grad: + compiled_output.backward(torch.ones_like(compiled_output)) + compiled_grad = nt.grad.clone().detach() + self.assertEqual(compiled_grad, expected_grad, rtol=1e-3, atol=1e-3) + + self.assertEqual(compiled_output, expected_output, rtol=1e-3, atol=1e-3) + + # === Verify that computation fusion happens. === + # Fallback op call -> fusion didn't happen. + fallback_op_calls_present = any( + "torch.ops.aten._padded_dense_to_jagged_forward.default(" + in generated_code[i] + or "torch.ops.aten._jagged_to_padded_dense_forward.default(" + in generated_code[i] + for i in range(len(generated_code)) + ) + + # NB: Fusion isn't supported on CPU. + self.assertEqual("cuda" in device, not fallback_op_calls_present) + + for i in range(len(generated_code)): + # Examine buffer construction lines in the generated code to determine + # whether fusion occurred. If fusion happens, a 3D buffer with shape + # (B, max_seqlen, D) should never be materialized. + buffer_constructions = [ + line.strip() + for line in generated_code[i].split("\n") + if "empty_strided_cuda(" in line + ] + + buffer_dims = [ + # buffer dim == number of elements in the tensor size tuple arg + len(ast.parse(t).body[0].value.args[0].elts) + for t in buffer_constructions + ] + + if "cuda" in device: + self.assertFalse(any(d == 3 for d in buffer_dims)) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + def test_compile_padded_dense_conversion_preserves_metadata_cache( + self, device, dtype + ): + # shape (B, *, D) + nt = random_nt_from_dims( + [4, None, 3, 16], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # expect min / max seqlen to be stored here + cache = dict(nt._metadata_cache) + + @torch.compile + def g(nt): + padded = nt.to_padded_tensor(0.3) + intermediate = padded.sin() + 1 + + from torch.nested._internal.nested_tensor import nested_from_padded + + return nested_from_padded( + intermediate, + nt.offsets(), + min_seqlen=nt._min_seqlen, + max_seqlen=nt._max_seqlen, + sum_S=nt.values().shape[0], + ) + + output = g(nt) + output.backward(torch.ones_like(output)) + self.assertEqual(output._metadata_cache, cache) + FORWARD_FAILURES = { # === BEGIN NotImplementedError SECTION === diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index aaa6c3a55a43ef..9df4d965d9f788 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2826,9 +2826,14 @@ cpu_nested_shape_example: non_differentiable - name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor - self: at::_nested_from_padded(grad, self._nested_tensor_size()) + self: "self.layout() == c10::kJagged ? at::_nested_from_padded_tensor_symint(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt, std::optional(at::_nested_get_values(self).sym_size(0))) : at::_nested_from_padded(grad, self._nested_tensor_size())" padding: non_differentiable +- name: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + padded: grad.to_padded_tensor_symint(0.0, at::OptionalArrayRef(padded.sym_sizes())) + offsets: non_differentiable + dummy: non_differentiable + - name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) self: grad.values() nested_size: non_differentiable diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index cf0731911612ba..391beb07cd600d 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -178,8 +178,9 @@ "torch.nn.Parameter": TorchInGraphFunctionVariable, "torch.nn.Buffer": TorchInGraphFunctionVariable, "torch._nested_tensor_from_mask": SkipFunctionVariable, - "torch._nested_from_padded": SkipFunctionVariable, + "torch.nested._internal.nested_tensor.nested_from_padded": TorchInGraphFunctionVariable, "torch.nested.nested_tensor_from_jagged": UserFunctionVariable, + "torch.nested.nested_tensor_from_padded": UserFunctionVariable, # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable, @@ -1525,6 +1526,7 @@ "torch._neg_view_copy", "torch._neg_view", "torch._nested_from_padded_and_nested_example", + "torch._nested_from_padded_tensor", "torch._nested_tensor_from_mask_left_aligned", "torch._nested_tensor_from_tensor_list", "torch._nested_tensor_softmax_with_shape", diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index d4d2a52f2888c4..ce37b0a96fbfb5 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6491,28 +6491,6 @@ def meta__jagged_to_padded_dense_forward( return values.new_empty(output_shape) -@register_meta(aten._padded_dense_to_jagged_forward.default) -def meta__padded_dense_to_jagged_forward( - padded: Tensor, - offsets: List[Tensor], - total_L: Optional[int] = None, -): - # only one jagged dim is supported for now - assert len(offsets) == 1 - - if not total_L: - assert isinstance(padded, torch._subclasses.FakeTensor) - shape_env = padded.fake_mode.shape_env - assert shape_env is not None - total_L = shape_env.create_unbacked_symint() - torch.fx.experimental.symbolic_shapes._constrain_range_for_size( - total_L, min=0, max=None - ) - - output_shape = (total_L, *padded.shape[2:]) - return padded.new_empty(output_shape) - - def _create_unary_float_meta_func(func): @register_meta(func) @out_wrapper() diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 1ba37b6f06bddf..b80b200a3c52ba 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -437,6 +437,38 @@ def nonzero(fake_mode, func, arg): return arg.new_empty((nnz, arg.dim()), dtype=torch.int64) +@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default) +def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=None): + # only one jagged dim is supported for now + assert len(offsets) == 1 + + if not total_L: + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + total_L = fake_mode.shape_env.create_unbacked_symint() + + maxval = sys.maxsize - 1 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + + if not has_free_symbols(padded.numel()): + maxval = int(padded.numel()) + + _constrain_range_for_size(total_L, min=0, max=maxval) + + output_shape = (total_L, *padded.shape[2:]) + return padded.new_empty(output_shape) + + @register_op_impl(torch.ops.aten.masked_select.default) def masked_select(fake_mode, func, self, mask): if ( diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 6a0425a13d43af..76f8d5a29e2d8f 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -562,3 +562,28 @@ def nested_view_from_values_offsets_lengths( min_seqlen_tensor, max_seqlen_tensor, ) # type: ignore[return-value] + + +def nested_from_padded( + padded, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None, sum_S=None +): + if ragged_idx != 1: + raise RuntimeError("nested_from_padded(): only ragged_idx=1 supported for now") + + min_seqlen_tensor = None + if min_seqlen is not None: + min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + + max_seqlen_tensor = None + if max_seqlen is not None: + max_seqlen_tensor = _store_val_in_tensor(max_seqlen) + + return torch._nested_from_padded_tensor( + padded, + offsets, + _nt_view_dummy(), + ragged_idx, + min_seqlen_tensor, + max_seqlen_tensor, + sum_S, + ) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index ed9a54f9dca932..f8053d3bcf7baf 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1553,6 +1553,98 @@ def all_default(func, *args, **kwargs): return func(inp._values) +@register_jagged_func( + torch.ops.aten.to_padded_tensor.default, "self: jt, padding: any, output_size: any?" +) +def to_padded_tensor_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + # TODO: Handle the rest of output_size + output_size = new_kwargs["output_size"] + if output_size is not None: + max_seq_len = output_size[inp._ragged_idx] + else: + max_seq_len = inp._max_seqlen + + # only 2D values is supported by the underlying FBGEMM kernel so do shape + # gymnastics if needed + values = inp.values() + values_shape = values.shape + if values.dim() > 2: + values = values.flatten(start_dim=1) + elif values.dim() == 1: + values = values.unsqueeze(-1) + + padded_out = torch.ops.aten._jagged_to_padded_dense_forward( + values, + [inp._offsets], + [max_seq_len], + new_kwargs["padding"], + ) + + # shape gymnastics part 2 + if len(values_shape) > 2: + padded_out = padded_out.unflatten(-1, values_shape[1:]) + elif len(values_shape) == 1: + padded_out = padded_out.squeeze(-1) + + return padded_out + + +@register_jagged_func( + torch.ops.aten._nested_from_padded_tensor.default, + "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?", +) +def _nested_from_padded_tensor_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + if new_kwargs["ragged_idx"] != 1: + raise RuntimeError( + "_nested_from_padded_tensor(): only ragged_idx=1 supported for jagged layout" + ) + + padded, offsets = new_kwargs["padded"], new_kwargs["offsets"] + + # non-3D padded is not supported by the underlying FBGEMM kernel so do shape gymnastics + padded_shape = padded.shape + if padded.dim() > 3: + padded = padded.flatten(start_dim=2) + elif padded.dim() < 3: + padded = padded.unsqueeze(-1) + + values = torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets], new_kwargs["sum_S"] + ) + + # shape gymnastics part 2 + if len(padded_shape) > 3: + values = values.unflatten(-1, padded_shape[2:]) + elif len(padded_shape) < 3: + values = values.squeeze(-1) + + ragged_idx = new_kwargs["ragged_idx"] + min_seqlen = new_kwargs["min_seqlen"] + max_seqlen = new_kwargs["max_seqlen"] + metadata_cache = {} + if min_seqlen is not None: + metadata_cache["min_seqlen"] = min_seqlen + if max_seqlen is not None: + metadata_cache["max_seqlen"] = max_seqlen + + return NestedTensor( + values, + offsets, + _ragged_idx=ragged_idx, + _metadata_cache=metadata_cache, + ) + + @register_jagged_func( torch.ops.aten._nested_view_from_jagged.default, "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", From c67fe13380ed866b57ae9bd0a334e86504f2deed Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 10 Sep 2024 12:38:32 -0300 Subject: [PATCH 0775/1018] Add batching rule for torch.scatter_reduce (#135547) Fixes #134797 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135547 Approved by: https://github.com/zou3519 --- .../ATen/functorch/BatchRulesScatterOps.cpp | 24 ++++++++++++++++++ test/functorch/test_ops.py | 25 ------------------- test/functorch/test_vmap.py | 9 ------- 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index 8626f4eb9fe4f9..e3e9a980f30b6c 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -779,6 +779,28 @@ std::tuple> scatter_reduce_batch_rule( self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce); } +std::tuple> scatter_reduce_two_batch_rule( + const Tensor& self, std::optional self_bdim, + int64_t dim, + const Tensor& index, std::optional index_bdim, + const Tensor& src, std::optional src_bdim, + const c10::string_view reduce, + bool include_self) { + return scatter_batch_rule(ATEN_FN2(scatter_reduce, two), + self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self); +} + +std::tuple> scatter_reduce__two_batch_rule( + const Tensor& self, std::optional self_bdim, + int64_t dim, + const Tensor& index, std::optional index_bdim, + const Tensor& src, std::optional src_bdim, + const c10::string_view reduce, + bool include_self) { + return scatter_batch_rule(ATEN_FN2(scatter_reduce_, two), + self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self); +} + std::tuple> scatter_value_reduce_batch_rule( const Tensor& self, std::optional self_bdim, int64_t dim, @@ -1250,6 +1272,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(scatter_add, scatter_add_batch_rule); VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule); VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule); + VMAP_SUPPORT2(scatter_reduce, two, scatter_reduce_two_batch_rule); + VMAP_SUPPORT2(scatter_reduce_, two, scatter_reduce__two_batch_rule); // as_strided_scatter does not work with the for-loop fallback today, // because as_strided_scatter will return an output that matches // the strides/storage_offset of its input. diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index ad48b7581ea053..5cd18a072d8daf 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1409,18 +1409,6 @@ def test_vmapjvpall(self, device, dtype, op): xfail("nn.functional.soft_margin_loss", ""), xfail("nn.functional.max_unpool1d", "grad"), xfail("nn.functional.embedding", ""), - xfail( - "scatter_reduce", "sum" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "mean" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "amin" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "amax" - ), # aten::scatter_reduce.two hit the vmap fallback xfail("nn.functional.glu"), xfail("nn.functional.bilinear"), # trilinear doesn't have batching rule xfail("linalg.lu", ""), @@ -1492,18 +1480,6 @@ def test(): xfail("nanquantile"), xfail("ormqr"), xfail("put"), - xfail( - "scatter_reduce", "sum" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "mean" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "amin" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "amax" - ), # aten::scatter_reduce.two hit the vmap fallback xfail("quantile"), xfail("renorm"), xfail("take"), @@ -1530,7 +1506,6 @@ def test(): xfail("nn.functional.multi_margin_loss", ""), xfail("nn.functional.multilabel_margin_loss", ""), xfail("nn.functional.pdist", ""), - xfail("scatter_reduce", "prod"), xfail("nn.functional.max_unpool1d", ""), xfail("nn.functional.max_unpool3d", ""), xfail("nn.functional.max_unpool3d", "grad"), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 31bae28eaca7a2..1bfc31fe521bb5 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4427,10 +4427,6 @@ def test_vmap_exhaustive(self, device, dtype, op): # TODO: implement batching rule xfail("_batch_norm_with_update"), xfail("histogram"), - xfail("scatter_reduce", "sum"), - xfail("scatter_reduce", "mean"), - xfail("scatter_reduce", "amax"), - xfail("scatter_reduce", "amin"), # `index_put` OpInfo in pytorch/pytorch has # masked index as input which is not supported xfail("index_put", ""), @@ -4502,20 +4498,15 @@ def test_vmap_exhaustive(self, device, dtype, op): ), # Batching rule not implemented for aten::narrow.Tensor xfail("nn.functional.triplet_margin_loss", ""), xfail("nn.functional.pdist", ""), - xfail("scatter_reduce", "sum"), - xfail("scatter_reduce", "amax"), xfail("nn.functional.max_unpool1d", "grad"), xfail("nn.functional.multi_margin_loss", ""), - xfail("scatter_reduce", "prod"), xfail("nn.functional.multilabel_margin_loss", ""), - xfail("scatter_reduce", "amin"), xfail("nn.functional.max_unpool3d", "grad"), xfail("nn.functional.max_unpool2d", ""), xfail("nn.functional.max_unpool2d", "grad"), xfail("nn.functional.margin_ranking_loss", ""), xfail("nn.functional.max_unpool1d", ""), xfail("nn.functional.soft_margin_loss", ""), - xfail("scatter_reduce", "mean"), xfail("nn.functional.max_unpool3d", ""), xfail("linalg.ldl_solve", "", device_type="cpu"), xfail("chalf", ""), From 6bfcaa60de916d3a0f1018054eef134547f57536 Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Thu, 12 Sep 2024 18:52:13 +0000 Subject: [PATCH 0776/1018] [torch][fx] Add new replacement_callback to materialize a replacement just in time (#135553) Summary: Sometimes we only want to generate a replacement for a matched pattern once we know some information about the nodes in the pattern. So far, we have found this the most useful to do matches based on specific shapes of tensors flowing into functions. Use a callback function similar to `match_filters`. By default this isn't used. Had to make `replacement` a None-able parameter because Callable was already used to detect a case where a graph needed to be traced. Differential Revision: D62412628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135553 Approved by: https://github.com/SherlockNoMad --- test/fx/test_subgraph_rewriter.py | 35 ++++++++++++++++++++++++++++ torch/fx/subgraph_rewriter.py | 38 +++++++++++++++++++++---------- 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index a4e14fbfab4488..7f23e706216a10 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -980,3 +980,38 @@ def check_replacement_nodes(self, traced, matches): return len(replacement_nodes_in_graph) self.assertEqual(check_replacement_nodes(self, traced, matches), 2) + + def test_replace_pattern_with_callback(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + return torch.add(x, y) + + def pattern(x, y): + return torch.add(x, y) + + def replacement(x, y): + return torch.sub(torch.mul(x, y), y) + + traced = symbolic_trace(M()) + # Return the same replacement graph for all matches, but have it be a unique + # object each time. + matches = subgraph_rewriter.replace_pattern_with_filters( + traced, + pattern, + replacement_callback=lambda *args: symbolic_trace(replacement).graph, + ) + + def check_replacement_nodes(self, traced, matches): + replacement_nodes_in_graph = [ + node + for node in traced.graph.nodes + if node.target in {torch.sub, torch.mul} + ] + replacement_nodes_in_res = [r for m in matches for r in m.replacements] + self.assertEqual( + len(replacement_nodes_in_graph), len(replacement_nodes_in_res) + ) + self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res) + return len(replacement_nodes_in_graph) + + self.assertEqual(check_replacement_nodes(self, traced, matches), 2) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 7f2cb743d2cdd9..c0d88821d7fafc 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -207,9 +207,11 @@ def forward(self, x, w1, w2): def replace_pattern_with_filters( gm: GraphModule, pattern: Union[Callable, Graph, GraphModule], - replacement: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule, None] = None, match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, ignore_literals: bool = False, + # Placed at the end to avoid breaking backward compatibility + replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None, ) -> List[ReplacedPatterns]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -219,17 +221,22 @@ def replace_pattern_with_filters( (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating whether the match satisfies the condition. See matcher_utils.py for definition of InternalMatch. + ``replacement_callback``: A function that takes in a match and returns a + Graph to be used as the replacement. This allows you to construct a + replacement graph based on the match. """ - return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals) + return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals, replacement_callback) def _replace_pattern( gm: GraphModule, pattern: Union[Callable, Graph, GraphModule], - replacement: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule, None] = None, match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, ignore_literals: bool = False, + # Placed at the end to avoid breaking backward compatibility + replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None, ) -> List[ReplacedPatterns]: from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch @@ -247,13 +254,6 @@ def _replace_pattern( else: pattern_graph = symbolic_trace(pattern).graph - if isinstance(replacement, GraphModule): - replacement_graph = replacement.graph - elif isinstance(replacement, Graph): - replacement_graph = replacement - else: - replacement_graph = symbolic_trace(replacement).graph - matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, remove_overlapping_matches=True, ignore_literals=ignore_literals) _matches: List[InternalMatch] = matcher.match(original_graph) @@ -265,13 +265,27 @@ def _replace_pattern( for match_filter in match_filters) ] - replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] + if isinstance(replacement, GraphModule): + common_replacement_graph = replacement.graph + elif isinstance(replacement, Graph): + common_replacement_graph = replacement + elif callable(replacement): + common_replacement_graph = symbolic_trace(replacement).graph + else: + assert replacement_callback is not None, "Must provide either a replacement GraphModule or a replacement callback" + common_replacement_graph = None # As we progressively replace nodes, we'll need to keep track of how the match results should change match_changed_node: Dict[Node, Node] = {} match_and_replacements = [] - for match in _matches: + for i, match in enumerate(_matches): + if replacement_callback is not None: + replacement_graph = replacement_callback(match, original_graph, pattern_graph) + else: + assert common_replacement_graph is not None, "Must provide either a replacement GraphModule or a replacement callback" + replacement_graph = common_replacement_graph + replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] # Build connecting between replacement graph's input and original graph input producer node From 0519f2d8f45b934ca84b651a92b5c7c6c3159a0d Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Thu, 12 Sep 2024 20:30:20 +0000 Subject: [PATCH 0777/1018] Pass ideep:lowp_kind to matmul_forward::compute on cache misses (#135058) Optimized dynamic quantization for aarch64 was enabled by #126687 and #134897 This PR fixes an issue for aarch64 where on a [cache miss](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L592) (e.g. if input dimensions change) [ideep::matmul_forward::compute ](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L160) (wrongly) runs with the [default lowp_kind (u8s8)](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L174) which is not supported by oneDNN+ACL (Arm Compute Library), causing the workload to fall back to a much slower oneDNN gemm:jit kernel Example: ```python import torch DIM = 4096 INPUT_SIZE1 = 32 INPUT_SIZE2 = 16 class LinearNet(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(DIM, DIM, bias=False) def forward(self, x): x = self.fc1(x) return x input1 = torch.randn(size=(INPUT_SIZE1, DIM)) input2 = torch.randn(size=(INPUT_SIZE2, DIM)) with torch.no_grad(): model = LinearNet() model = torch.ao.quantization.quantize_dynamic(model,{torch.nn.Linear}) model(input1) # this goes to ACL lowp_gemm print("="*50) model(input2) # this goes to gemm:jit without this PR, and to ACL with this PR ``` In the code snippet above: - The matmul from `model(input1)` goes to oneDNN+ACL (in both cases, with and without the PR) - The matmul from `model(input2)`: **Without this PR**: there's a cache miss (different input shapes) and matmul_forward::compute is run with the default lowp_kind (u8s8). Hence the matmul falls back to gemm:jit in oneDNN. However, **With this PR** the matmul goes to oneDNN+ACL which is around 10x faster than oneDNN+jit. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135058 Approved by: https://github.com/jondea, https://github.com/malfet --- .../native/quantized/cpu/qlinear_dynamic.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index b78d3cc9a4f566..8a100d516446b8 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -590,10 +590,21 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl( LinearParams& params = get_cache().get_param(); ideep::matmul_forward::compute(params, x, w, b, y, src_scales, src_zero_point); } else { - ideep::matmul_forward::compute(x, w, b, y, - src_scales, weights_scales, ideep::scale_t(), - src_zero_point, ideep::zero_point_t(), - 1.0f, 1.0f, op_attr); + ideep::matmul_forward::compute( + x, + w, + b, + y, + src_scales, + weights_scales, + ideep::scale_t(), + src_zero_point, + ideep::zero_point_t(), + 1.0f, + 1.0f, + op_attr, + ideep::tensor::data_type::undef, + std::is_signed_v ? ideep::s8s8 : ideep::u8s8); } auto out_sizes = input.sizes().vec(); out_sizes.back() = w.get_dim(1); From 925767073133c0d6c48da01f5e19ee2c6cd251a7 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 12 Sep 2024 21:22:19 +0000 Subject: [PATCH 0778/1018] [export] ignore mark_dynamic() in export (#135536) Previously we were accomodating `torch._dynamo.mark_dynamic()` for export's dynamic shapes. Here we clean things up and ignore it, requiring users to specify an export input for `dynamic_shapes`. Note: there's 4 decorators relevant to export, `mark_dynamic, maybe_mark_dynamic, mark_static, mark_unbacked`. User calls that involve export have only been `mark_dynamic()`, and we use `maybe_mark_dynamic` under the hood for `Dim.AUTO`, but we could start using others. One reason I decided to not warn and just silently ignore is these decorators cause the tensors to carry dynamic info, and it'll be hard to tell whether the markers are from export or user calls when re-exporting with the same inputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135536 Approved by: https://github.com/avikchaudhuri --- benchmarks/dynamo/common.py | 28 +++++++++++++++++++++++++++- test/dynamo/test_export.py | 17 ----------------- test/export/test_export.py | 30 ------------------------------ test/inductor/test_aot_inductor.py | 8 ++++++-- torch/export/dynamic_shapes.py | 20 +++++--------------- 5 files changed, 38 insertions(+), 65 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index b3de936bad1b48..80e7d74cef7701 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1338,6 +1338,16 @@ def try_script(model, example_inputs): return None +def _produce_dynamic_shapes_for_export(path, x): + # mark_dynamic() is ignored for export. + # use this to produce dynamic_shapes spec instead. + from torch.export.dynamic_shapes import Dim + + if not isinstance(x, torch.Tensor): + return None + return {i: Dim.AUTO for i in getattr(x, "_dynamo_dynamic_indices", {})} + + class AOTInductorModelCache: cache = {} @@ -1345,6 +1355,7 @@ class AOTInductorModelCache: def load(cls, model, example_inputs, device): import torch._inductor import torch.export._trace + from torch.export.dynamic_shapes import _tree_map_with_path key = weakref.ref(model) if key not in cls.cache: @@ -1364,10 +1375,16 @@ def load(cls, model, example_inputs, device): else: _register_dataclass_output_as_pytree(example_outputs) + combined_args = tuple(example_args) + tuple(example_kwargs.values()) + dynamic_shapes = _tree_map_with_path( + _produce_dynamic_shapes_for_export, combined_args + ) + gm = torch.export._trace._export( model, example_args, example_kwargs, + dynamic_shapes=dynamic_shapes, pre_dispatch=True, strict=False, ).module() @@ -1382,11 +1399,20 @@ def load(cls, model, example_inputs, device): def export(model, example_inputs): + from torch.export.dynamic_shapes import _tree_map_with_path + example_args, example_kwargs = _normalize_bench_inputs(example_inputs) example_outputs = model(*example_args, **example_kwargs) _register_dataclass_output_as_pytree(example_outputs) - ep = torch.export.export(model, example_args, example_kwargs) + combined_args = tuple(example_args) + tuple(example_kwargs.values()) + dynamic_shapes = _tree_map_with_path( + _produce_dynamic_shapes_for_export, combined_args + ) + + ep = torch.export.export( + model, example_args, example_kwargs, dynamic_shapes=dynamic_shapes + ) def opt_export(_, example_inputs): example_args, example_kwargs = _normalize_bench_inputs(example_inputs) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index e0ee9fab6d1375..03d40e377335b4 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -3048,23 +3048,6 @@ def forward(self, *args): res = gm(input_tensor, input_tensor2) self.assertTrue(torch._dynamo.utils.same(ref, res)) - def test_export_mark_dynamic_conflict_dynamic_dim(self): - y = torch.randn([3, 3, 3]) - - def my_dyn_fn(x): - if x.shape[0] > 3: - return x.sin() - return x.cos() - - torch._dynamo.mark_dynamic(y, 0) - with self.assertRaisesRegex( - RuntimeError, - "Constraints violated", - ): - torch._dynamo.export( - my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},) - )(y) - def test_export_dynamic_dim_cleanup(self): y = torch.randn([3, 3, 3]) diff --git a/test/export/test_export.py b/test/export/test_export.py index 674aa0060dced3..8896ef3983ac39 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2619,37 +2619,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ): em.module()(x) - def test_mark_and_auto_dynamic(self): - # for this use case, mark_dynamic() and AUTO should have same effect. - # check that same symbol gets allocated to both dims without raising constraint violation. - AUTO, STATIC = Dim.AUTO, Dim.STATIC - - class Foo(torch.nn.Module): - def forward(self, x, y): - torch._check(x.shape[0] == y.shape[0]) - torch._check(x.shape[0] <= 64) - return x + 2, y + 2 - - inputs = (torch.randn(4, 4), torch.randn(4, 4)) - ep_auto = torch.export.export( - Foo(), inputs, dynamic_shapes={"x": (AUTO, None), "y": (AUTO, None)} - ) - torch._dynamo.mark_dynamic(inputs[0], 0) - torch._dynamo.mark_dynamic(inputs[1], 0) - ep_dynamic = torch.export.export(Foo(), inputs) - - # test both programs have same effect - for ep in [ep_auto, ep_dynamic]: - gm = ep.module() - gm(torch.randn(32, 4), torch.randn(32, 4)) - gm(torch.randn(1, 4), torch.randn(1, 4)) - with self.assertRaises(RuntimeError): - gm(torch.randn(33, 4), torch.randn(32, 4)) - gm(torch.randn(128, 4), torch.randn(128, 4)) - def test_dont_duck_size_for_auto_dynamic(self): - # for this use case, mark_dynamic() and AUTO should have same effect. - # check that same symbol gets allocated to both dims without raising constraint violation. AUTO, STATIC = Dim.AUTO, Dim.STATIC class Foo(torch.nn.Module): diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 42d771f7bb6fff..2c85c488b2bbaa 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1336,15 +1336,19 @@ def forward(self, x, b): return x + b example_inputs = ( - x := torch.randn((3, 2), device=self.device), + torch.randn((3, 2), device=self.device), torch.randn((1, 2), device=self.device), ) - torch._dynamo.mark_dynamic(x, index=0) # Create dynamic symbol + dynamic_shapes = { + "x": {0: Dim("dx"), 1: Dim.STATIC}, + "b": None, + } # Compile & run model where dynamic dim size > 0. so_path: str = AOTIRunnerUtil.compile( Repro(), example_inputs, + dynamic_shapes=dynamic_shapes, ) aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path) aot_inductor_module(*example_inputs) diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 18fc1e73f0cf05..a84467767ac1eb 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -846,21 +846,15 @@ def _tree_map_helper(tree, val): combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] def transform_shapes(path, tensor, shape): - def _marked_dynamic(tensor, i): - # TODO(pianpwk): deprecate mark_dynamic() usage for export - return i in getattr(tensor, "_dynamo_dynamic_indices", set()) - out: Union[None, List[Any], Dict[int, Any]] = None if isinstance(shape, dict): out = {} for i, val in enumerate(tensor.shape): dim = shape.get(i, _DimHint.STATIC) - if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: + if dim == _DimHint.AUTO: # don't have to specify anything if dynamic # None also works, since assume_static_by_default=False - if dim == _DimHint.AUTO: - torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing - continue + torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing elif isinstance(dim, _Dim): out[i] = dim elif isinstance(dim, int): @@ -878,9 +872,8 @@ def _marked_dynamic(tensor, i): out = [] for i, val in enumerate(tensor.shape): dim = shape[i] - if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: - if dim == _DimHint.AUTO: - torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing + if dim == _DimHint.AUTO: + torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing out.append(None) elif isinstance(dim, _Dim): out.append(dim) @@ -896,10 +889,7 @@ def _marked_dynamic(tensor, i): else: assert shape is None if isinstance(tensor, torch.Tensor): - out = [] - for i, val in enumerate(tensor.shape): - out.append(None if _marked_dynamic(tensor, i) else val) - out = out or None + out = list(tensor.shape) or None else: out = None return out From 396b4440fb2c39a23c945bc69b3be9282bdd500d Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Thu, 12 Sep 2024 21:28:37 +0000 Subject: [PATCH 0779/1018] Validate input types for `torch.nn.Linear` and `torch.nn.Bilinear` (#135596) Adding validation checks to check the input types and display better error messages for the same. Fixes https://github.com/pytorch/pytorch/issues/135463 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135596 Approved by: https://github.com/malfet --- torch/nn/modules/linear.py | 20 ++++++++ torch/testing/_internal/common_modules.py | 62 +++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index dc5185b7eec0bf..aa4e67c827ae4f 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -98,6 +98,13 @@ def __init__( device=None, dtype=None, ) -> None: + # validation checks for input types + if not isinstance(in_features, int): + raise TypeError(f"Expected int for in_features but got {type(in_features)}") + if not isinstance(out_features, int): + raise TypeError( + f"Expected int for out_features but got {type(out_features)}" + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_features = in_features @@ -200,6 +207,19 @@ def __init__( device=None, dtype=None, ) -> None: + # validation checks for input types + if not isinstance(in1_features, int): + raise TypeError( + f"Expected int for in1_features but got {type(in1_features)}" + ) + if not isinstance(in2_features, int): + raise TypeError( + f"Expected int for in2_features but got {type(in2_features)}" + ) + if not isinstance(out_features, int): + raise TypeError( + f"Expected int for out_features but got {type(out_features)}" + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in1_features = in1_features diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 63963bab1b050a..238b67e476713d 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -3180,6 +3180,66 @@ def padding3d_circular_ref(inp, pad): # Start of module error inputs functions. +def module_error_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + samples = [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput("10", 20), + forward_input=FunctionInput(make_input(3, 10)), + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=TypeError, + error_regex=r"Expected int for in_features but got " + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20.7), + forward_input=FunctionInput(make_input(3, 10)), + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=TypeError, + error_regex=r"Expected int for out_features but got " + ), + ] + return samples + + +def module_error_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + samples = [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput("10", 20, 30), + forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)), + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=TypeError, + error_regex=r"Expected int for in1_features but got " + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20.7, 30), + forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)), + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=TypeError, + error_regex=r"Expected int for in2_features but got " + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, 20, "30"), + forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)), + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=TypeError, + error_regex=r"Expected int for out_features but got " + ), + ] + return samples + + + def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) samples = [ @@ -3832,12 +3892,14 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad )), ModuleInfo(torch.nn.Linear, module_inputs_func=module_inputs_torch_nn_Linear, + module_error_inputs_func=module_error_inputs_torch_nn_Linear, skips=( # No channels_last support for Linear currently. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) ), ModuleInfo(torch.nn.Bilinear, module_inputs_func=module_inputs_torch_nn_Bilinear, + module_error_inputs_func=module_error_inputs_torch_nn_Bilinear, decorators=[ DecorateInfo( toleranceOverride({ From e2bcb9dc647c8b7b2314a5c16c4efcdb7f81f91f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 12 Sep 2024 21:47:38 +0000 Subject: [PATCH 0780/1018] Revert "[DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725)" This reverts commit 83c594ebd6dfa517fdd67ae23929cc60d5fa325d. Reverted https://github.com/pytorch/pytorch/pull/135725 on behalf of https://github.com/ZainRizvi due to This is breaking lint. See [GH job link](https://github.com/pytorch/pytorch/actions/runs/10835983999/job/30068709508) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/83c594ebd6dfa517fdd67ae23929cc60d5fa325d) ([comment](https://github.com/pytorch/pytorch/pull/135725#issuecomment-2347303272)) --- .../test_2d_composability.py | 58 ------------------- torch/distributed/_state_dict_utils.py | 11 +--- 2 files changed, 2 insertions(+), 67 deletions(-) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index c4e89fe3421633..754705b567a324 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -17,9 +17,7 @@ from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, - set_model_state_dict, set_optimizer_state_dict, - StateDictOptions, ) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -337,57 +335,6 @@ def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): self.assertEqual(loss_no_cp2, loss_cp2) -class TestFullyShard2DStateDict(DTensorTestBase): - @property - def backend(self): - # need to specify gloo backend for testing cpu offload - return "cpu:gloo,cuda:nccl" - - @with_comms - @skip_if_lt_x_gpu(4) - def test_fully_shard_tp_2d_set_full_state_dict(self): - dummy_model = SimpleModel().cuda() - mesh_2d = init_device_mesh( - "cuda", - (2, self.world_size // 2), - mesh_dim_names=("dp", "tp"), - ) - tp_mesh = mesh_2d["tp"] - dp_mesh = mesh_2d["dp"] - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - "net3": ColwiseParallel(), - } - model = parallelize_module(dummy_model, tp_mesh, parallelize_plan) - fully_shard(model, mesh=dp_mesh) - optim = torch.optim.Adam(model.parameters(), lr=0.01) - model(model.get_input()).sum().backward() - optim.step() - # ref_msd, ref_osd are both the default sharded state dict - ref_msd = copy.deepcopy(get_model_state_dict(model)) - ref_osd = copy.deepcopy(get_optimizer_state_dict(model, optimizers=optim)) - - options = StateDictOptions( - full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True - ) - full_msd = get_model_state_dict(model, options=options) - full_osd = get_optimizer_state_dict(model, optimizers=optim, options=options) - # load full_msd and full_osd into model and optim. - # this loads the slice of full tensor into each rank's local DTensor. - set_model_state_dict(model, full_msd, options=options) - set_optimizer_state_dict( - model, optimizers=optim, optim_state_dict=full_osd, options=options - ) - - # check after setting full state dict, the model and optim default sharded state dict - # are the same as the initial default sharded state dict. - new_msd = get_model_state_dict(model) - new_osd = get_optimizer_state_dict(model, optimizers=optim) - self.assertEqual(ref_msd, new_msd) - self.assertEqual(ref_osd, new_osd) - - class Test2dFSDP1ParallelIntegration(DTensorTestBase): def init_model(self, device_type, model_parallel_size=2): torch.manual_seed(0) @@ -597,11 +544,6 @@ def test_2d_e2e_training_not_use_orig_params(self): # TODO: update all state dict unit tests to use distributed.checkpoint.state_dict, # and consolidate all the state_dict test in test.distributed.checkpoint. class TestNew2dParallelStateDict(DTensorTestBase): - @property - def backend(self): - # need to specify gloo backend for testing cpu offload - return "cpu:gloo,cuda:nccl" - @with_comms @skip_if_lt_x_gpu(4) def test_fsdp_2d_extension(self): diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 9520b195f9ed5c..f22a13cb6ec0f4 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -28,7 +28,6 @@ from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed.tensor import distribute_tensor, DTensor, Replicate - from torch.distributed.tensor._utils import compute_local_shape_and_global_offset def _identity_func( @@ -552,14 +551,8 @@ def _distribute_tensors( local_state = _local_state[0] full_tensor = _local_state[1] - - shape, offset = compute_local_shape_and_global_offset( - full_tensor.shape, local_state.device_mesh, local_state.placements - ) - slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))] - local_tensor = full_tensor[slices] - local_state_dict[key] = DTensor.from_local( - local_tensor, local_state.device_mesh, local_state.placements + local_state_dict[key] = distribute_tensor( + full_tensor, local_state.device_mesh, local_state.placements ) From 9b1cdf5380d77d8021166f2c5f0a76aefd00a318 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 12 Sep 2024 22:04:35 +0000 Subject: [PATCH 0781/1018] Revert "Check function declarations of Core ML code (#135467)" This reverts commit bc1b8f094d24de27432f4c29f0729e85a6b5ba63. Reverted https://github.com/pytorch/pytorch/pull/135467 on behalf of https://github.com/malfet due to This breaks ios periodic jobs, see https://github.com/pytorch/pytorch/actions/runs/10797026668/job/29947377532 ([comment](https://github.com/pytorch/pytorch/pull/135467#issuecomment-2347322784)) --- caffe2/CMakeLists.txt | 2 +- torch/csrc/jit/backends/coreml/cpp/context.cpp | 13 +++++++++---- torch/csrc/jit/backends/coreml/cpp/context.h | 11 +++++++++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e498f2380fbf2d..2160399a3ea296 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -767,7 +767,7 @@ if(NOT MSVC) set_source_files_properties(${PROJECT_SOURCE_DIR}/torch/csrc/distributed/c10d/socket.cpp PROPERTIES COMPILE_OPTIONS "-Wno-error=deprecated") endif() -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) target_compile_options_if_supported(torch_cpu "-Wmissing-prototypes") target_compile_options_if_supported(torch_cpu "-Werror=missing-prototypes") get_target_property(TORCH_CPU_SOURCES torch_cpu SOURCES) diff --git a/torch/csrc/jit/backends/coreml/cpp/context.cpp b/torch/csrc/jit/backends/coreml/cpp/context.cpp index a8c385eef525a8..3c63acce711343 100644 --- a/torch/csrc/jit/backends/coreml/cpp/context.cpp +++ b/torch/csrc/jit/backends/coreml/cpp/context.cpp @@ -1,8 +1,10 @@ #include #include -#include -namespace torch::jit::mobile::coreml { +namespace torch { +namespace jit { +namespace mobile { +namespace coreml { std::atomic g_coreml_ctx_registry; @@ -13,8 +15,11 @@ BackendRegistrar::BackendRegistrar(ContextInterface* ctx) { void setModelCacheDirectory(std::string path) { auto p = g_coreml_ctx_registry.load(); if (p) { - p->setModelCacheDirectory(std::move(path)); + p->setModelCacheDirectory(path); } } -} // namespace torch::jit::mobile::coreml +} // namespace coreml +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/coreml/cpp/context.h b/torch/csrc/jit/backends/coreml/cpp/context.h index a07a8b81fc7d28..644e3428af3e1d 100644 --- a/torch/csrc/jit/backends/coreml/cpp/context.h +++ b/torch/csrc/jit/backends/coreml/cpp/context.h @@ -1,9 +1,13 @@ #ifndef PTM_COREML_Context_h #define PTM_COREML_Context_h +#include #include -namespace torch::jit::mobile::coreml { +namespace torch { +namespace jit { +namespace mobile { +namespace coreml { struct ContextInterface { virtual ~ContextInterface() = default; @@ -17,6 +21,9 @@ class BackendRegistrar { void setModelCacheDirectory(std::string path); -} // namespace torch::jit::mobile::coreml +} // namespace coreml +} // namespace mobile +} // namespace jit +} // namespace torch #endif From aed43118c18bb4a9add6ef900f911fe108e8a1ea Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 12 Sep 2024 11:23:05 -0700 Subject: [PATCH 0782/1018] [dynamo] support torch.nn.attention.sdpa_kernel context manager (#135404) Fixes https://github.com/pytorch/pytorch/issues/134608 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135404 Approved by: https://github.com/jansel, https://github.com/drisspg --- test/dynamo/test_ctx_manager.py | 111 +++++++++++++++++++++++++ torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/builder.py | 3 + torch/_dynamo/variables/ctx_manager.py | 74 +++++++++++++++++ torch/_dynamo/variables/torch.py | 11 +++ torch/nn/attention/__init__.py | 64 ++++++++------ 6 files changed, 238 insertions(+), 26 deletions(-) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 1b654e7228aa89..239174df715874 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1533,6 +1533,117 @@ def fn(x): self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) self.assertEqual(cnts.frame_count, 2) + def test_sdpa_kernel_ctx_manager1(self): + modified_backend_state = [torch.nn.attention.SDPBackend.MATH] + + @torch._dynamo.allow_in_graph + def check_backend_state_is_modified(): + self.assertEqual( + torch.nn.attention._cur_sdpa_kernel_backends(), modified_backend_state + ) + + def f(x): + with torch.nn.attention.sdpa_kernel( + # pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. + [torch.nn.attention.SDPBackend.MATH] + ): + output = torch.nn.functional.scaled_dot_product_attention(x, x, x).to( + torch.float32 + ) + check_backend_state_is_modified() + + return output + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + opt_f(torch.randn(2, 2, 2, 2).to(dtype=torch.float16)) + + def test_sdpa_kernel_ctx_manager2(self): + original_backend_state = set(torch.nn.attention._cur_sdpa_kernel_backends()) + modified_backend_state = [torch.nn.attention.SDPBackend.MATH] + + @torch._dynamo.allow_in_graph + def check_backend_state_is_original(): + self.assertEqual( + set(torch.nn.attention._cur_sdpa_kernel_backends()), + original_backend_state, + ) + + @torch._dynamo.allow_in_graph + def check_backend_state_is_modified(): + self.assertEqual( + torch.nn.attention._cur_sdpa_kernel_backends(), modified_backend_state + ) + + def g(x): + torch._dynamo.graph_break() + output = torch.nn.functional.scaled_dot_product_attention(x, x, x).to( + torch.float32 + ) + check_backend_state_is_modified() + return output + + def f(x): + check_backend_state_is_original() + with torch.nn.attention.sdpa_kernel( + # pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. + [torch.nn.attention.SDPBackend.MATH] + ): + output1 = torch.nn.functional.scaled_dot_product_attention(x, x, x).to( + torch.float32 + ) + check_backend_state_is_modified() + + # graph break + output2 = g(x) + + output3 = torch.nn.functional.scaled_dot_product_attention(x, x, x).to( + torch.float32 + ) + check_backend_state_is_modified() + + check_backend_state_is_original() + + return output1 + output2 + output3 + + cnts = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=cnts) + opt_f(torch.randn(2, 2, 2, 2).to(dtype=torch.float16)) + self.assertEqual(cnts.frame_count, 3) + + # test sdpa_kernel graph break with 2 arguments + def test_sdpa_kernel_ctx_manager3(self): + modified_backend_state = { + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + } + + @torch._dynamo.allow_in_graph + def check_backend_state_is_modified(): + self.assertEqual( + set(torch.nn.attention._cur_sdpa_kernel_backends()), + modified_backend_state, + ) + + def f(x): + with torch.nn.attention.sdpa_kernel( + # pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. + [ + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + ] + ): + # FLASH_ATTENTION may not be supported, but we're not actually + # doing any sdpa + x = x + 1 + torch._dynamo.graph_break() + check_backend_state_is_modified() + x = x + 1 + + return x + + opt_f = torch.compile(f, backend="eager") + opt_f(torch.randn(2, 2)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index d522d773e6e865..5a8522e68c4c06 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -14,6 +14,7 @@ GradModeVariable, InferenceModeVariable, JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, SetFwdGradEnabledContextManager, StreamContextVariable, StreamVariable, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index bf1cbed10b8e5c..5a8a9b613b899d 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -843,6 +843,9 @@ def build_key_value(i, k, v): elif isinstance(value, (torch._C._SDPAParams)): self.install_guards(GuardBuilder.TYPE_MATCH) return SDPAParamsVariable.create(self.tx, value, self.source) + elif isinstance(value, torch._C._SDPBackend): + self.install_guards(GuardBuilder.ID_MATCH) + return ConstantVariable(value) elif isinstance(value, _EventBase): self.install_guards(GuardBuilder.ID_MATCH) torch._dynamo.utils.store_user_object_weakref(value) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 301b7f3e819345..eebe90929b379a 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -977,6 +977,80 @@ def fn_name(self): return "use_training_state" +class SDPAKernelVariable(ContextWrappingVariable): + """represents torch.nn.attention.sdpa_kernel""" + + @staticmethod + def create(tx: "InstructionTranslator", backends, **kwargs): + if isinstance(backends, torch.nn.attention.SDPBackend): + backends = [backends] + var = SDPAKernelVariable( + target_values=backends, + initial_values=None, + **kwargs, + ) + return var + + def __init__( + self, + target_values: List[torch.nn.attention.SDPBackend], + initial_values=None, + **kwargs, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + @staticmethod + def _backends_to_nodes(tx, backends): + nodes = [] + for backend in backends: + # convert to/from string in order to bake the backend into FX graph + nodes.append( + tx.output.create_node( + "call_function", + torch.nn.attention._backend_from_string, + (backend.name,), + {}, + ) + ) + return nodes + + def enter(self, tx): + self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends() + self.set_cleanup_hook( + tx, lambda: torch.nn.attention._sdpa_kernel(self.prev_backends) + ) + torch.nn.attention._sdpa_kernel(self.target_values) + arg = self._backends_to_nodes(tx, self.target_values) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self.state.cleanup_assert() + arg = self._backends_to_nodes(tx, self.prev_backends) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg,), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self): + return "torch.nn.attention" + + # use a private version of sdpa_kernel that accepts variadic arguments + # since dynamo reconstructs the contents of target_values one-by-one + def fn_name(self): + return "_sdpa_kernel_variadic" + + class StreamVariable(VariableTracker): def __init__(self, proxy, value, device, **kwargs) -> None: if proxy is not None and "example_value" in proxy.node.meta: diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index eca9d25adfe4ac..68508f001dae35 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -89,6 +89,8 @@ torch.autograd.graph.disable_saved_tensors_hooks, torch.cpu.amp.autocast_mode.autocast, torch.cuda.amp.autocast_mode.autocast, + torch.nn.attention.sdpa_kernel, + torch.nn.attention._sdpa_kernel_variadic, ] ) @@ -229,6 +231,7 @@ def call_function( GradModeVariable, InferenceModeVariable, JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, SetFwdGradEnabledContextManager, StreamVariable, VmapIncrementNestingCtxManagerVariable, @@ -329,6 +332,14 @@ def call_function( return FSDPParamGroupUseTrainingStateVariable.create( tx, args[0], args[1].as_python_constant() ) + elif self.value is torch.nn.attention.sdpa_kernel: + assert len(args) == 1 or (len(kwargs) == 1 and "backends" in kwargs) + backends = args[0] if len(args) == 1 else kwargs["backends"] + return SDPAKernelVariable.create(tx, backends.as_python_constant()) + elif self.value is torch.nn.attention._sdpa_kernel_variadic: + return SDPAKernelVariable.create( + tx, [arg.as_python_constant() for arg in args] + ) return super().call_function(tx, args, kwargs) diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 5567db5919c18c..8cb0a9ac2dea8e 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -1,21 +1,14 @@ # mypy: allow-untyped-defs """ This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """ import contextlib -from typing import List, Union +from typing import Iterable, List, Union from warnings import warn +import torch.backends.cuda from torch._C import _SDPBackend as SDPBackend from torch.backends.cuda import ( can_use_efficient_attention, can_use_flash_attention, - cudnn_sdp_enabled, - enable_cudnn_sdp, - enable_flash_sdp, - enable_math_sdp, - enable_mem_efficient_sdp, - flash_sdp_enabled, - math_sdp_enabled, - mem_efficient_sdp_enabled, SDPAParams, ) @@ -67,6 +60,32 @@ def _raise_kernel_warnings(params: SDPAParams) -> None: can_use_flash_attention(params, True) +_backend_names = { + "cudnn": "CUDNN_ATTENTION", + "flash": "FLASH_ATTENTION", + "mem_efficient": "EFFICIENT_ATTENTION", + "math": "MATH", +} + + +def _backend_from_string(name: str): + return getattr(SDPBackend, name) + + +def _cur_sdpa_kernel_backends(): + backends: List[SDPBackend] = [] + for name, val in _backend_names.items(): + if getattr(torch.backends.cuda, f"{name}_sdp_enabled")(): + backends.append(getattr(SDPBackend, val)) + return backends + + +def _sdpa_kernel(backends: Iterable[SDPBackend]): + for name, val in _backend_names.items(): + enabled = getattr(SDPBackend, val) in backends + getattr(torch.backends.cuda, f"enable_{name}_sdp")(enabled) + + @contextlib.contextmanager def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]): r""" @@ -102,26 +121,19 @@ def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]): backends = [backends] backends = set(backends) - previous_cudnn: bool = cudnn_sdp_enabled() - previous_flash: bool = flash_sdp_enabled() - previous_mem_efficient: bool = mem_efficient_sdp_enabled() - previous_math: bool = math_sdp_enabled() + previous_backends = _cur_sdpa_kernel_backends() try: - enable_cudnn = SDPBackend.CUDNN_ATTENTION in backends - enable_flash = SDPBackend.FLASH_ATTENTION in backends - enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends - enable_math = SDPBackend.MATH in backends - - enable_cudnn_sdp(enable_cudnn) - enable_flash_sdp(enable_flash) - enable_mem_efficient_sdp(enable_mem_efficient) - enable_math_sdp(enable_math) + _sdpa_kernel(backends) yield {} finally: - enable_cudnn_sdp(previous_cudnn) - enable_flash_sdp(previous_flash) - enable_mem_efficient_sdp(previous_mem_efficient) - enable_math_sdp(previous_math) + _sdpa_kernel(previous_backends) + + +# variadic version of sdpa_kernel for dynamo to use while reconstructing +@contextlib.contextmanager +def _sdpa_kernel_variadic(*backends: SDPBackend): + with sdpa_kernel(list(backends)): + yield def _get_flash_version() -> str: From 254d7d91d2fbd387bc22b9b1d149ab0b1ec06025 Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 12 Sep 2024 11:46:32 -0700 Subject: [PATCH 0783/1018] [Easy] Check if quant registered in constant folding (#135875) Belated fix for https://github.com/pytorch/pytorch/issues/110904 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135875 Approved by: https://github.com/shunting314 --- torch/_inductor/constant_folding.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 72f34d32475f39..25d4b872d7737f 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -94,7 +94,12 @@ def is_impure(self, node: torch.fx.node.Node) -> bool: ): # For int8_weight -> dq -> bf16_weight return True - if node.target in [ + + quant_registered = ( + getattr(torch.ops.quantized_decomposed, "dequantize_per_channel", None) + is not None + ) + if quant_registered and node.target in [ torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, From 2e123c81aec9969b62e7131e65bb8bae69ef2c53 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Thu, 12 Sep 2024 23:00:48 +0000 Subject: [PATCH 0784/1018] [ROCm] Use ieee precision for fp32 in flex attention (#135702) https://github.com/pytorch/pytorch/commit/3bebc09be9845c0779f190489e8d4caa9e2653c8 Brought in a change to flex_attention to allow TF32 precision, this largely lacks support on ROCm side and we should use ieee. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135702 Approved by: https://github.com/jeffdaily, https://github.com/drisspg --- torch/_inductor/kernel/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 33113434e732e7..0e3d6ae49186df 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -56,7 +56,7 @@ def maybe_realize(args: List[Optional[IRNode]]): def get_float32_precision(): - if torch.get_float32_matmul_precision() == "highest": + if torch.get_float32_matmul_precision() == "highest" or torch.version.hip: return "'ieee'" else: return "'tf32'" From 1828571b5331b3718b5bb0534a5e6a289c9d84c2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 12 Sep 2024 10:28:56 -0700 Subject: [PATCH 0785/1018] [FlexAttention] Ensure q/k/v and block_mask on excact the same device (#135823) Fixes #134739 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135823 Approved by: https://github.com/BoyuanFeng --- test/inductor/test_flex_attention.py | 23 ++++++++++++++++++++++- torch/nn/attention/flex_attention.py | 6 ++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index cf37f89a5175b2..a6a8bfe13d7958 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -3,6 +3,7 @@ import functools import string +import unittest from collections import namedtuple from contextlib import contextmanager, nullcontext from typing import Callable, Optional, Tuple @@ -28,7 +29,7 @@ ) from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM from torch.utils._triton import has_triton @@ -1952,6 +1953,26 @@ def mask_mod(b, h, q, kv): self.run_test_with_call(attention, Q_S=Q_S, KV_S=KV_S) + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + def test_qkv_and_block_mask_on_the_same_device(self): + make_tensor = functools.partial( + torch.ones, + (2, 2, 256, 32), + device="cuda:0", + dtype=torch.float32, + requires_grad=True, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + + def mask_mod(b, h, q, kv): + return q >= kv + + block_mask = create_block_mask(mask_mod, 1, 1, 256, 256, device="cuda:1") + with self.assertRaisesRegex( + RuntimeError, "Expect q/k/v and block_mask to be on the same device" + ): + torch.compile(flex_attention)(query, key, value, block_mask=block_mask) + @supported_platform def test_fw_bw_graph_correctness(self): cnt = CompileCounterWithBackend("aot_eager") diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 7b0a3354642a7f..bb04ce53d0fcab 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -1003,6 +1003,12 @@ def score_mod( if scale is None: scale = 1.0 / math.sqrt(query.size(-1)) + if query.device != block_mask.kv_num_blocks.device: + raise RuntimeError( + f"Expect q/k/v and block_mask to be on the same device " + f"but got {query.device} and {block_mask.kv_num_blocks.device}." + ) + kernel_options = _apply_kernel_options( query, key, From b1f308d9620da32c635ab4c80a52ee868c728d17 Mon Sep 17 00:00:00 2001 From: Ma Jian Date: Thu, 12 Sep 2024 23:24:55 +0000 Subject: [PATCH 0786/1018] fix compiled_autograd deadlock throw (#135795) Fixes #135298 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135795 Approved by: https://github.com/xmfan --- test/inductor/test_compiled_autograd.py | 5 +---- torch/_dynamo/compiled_autograd.py | 4 ++++ torch/autograd/graph.py | 7 ++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index ea6675276e212c..672ea6bb17b74b 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2677,10 +2677,7 @@ def fn(x): inp = torch.rand(10, 10, requires_grad=True) out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True) - with self.assertRaisesRegex( - RuntimeError, - r"\(e.g. reentrant checkpointing\), this is not supported yet\.", - ), torch._dynamo.compiled_autograd.enable(torch.compile): + with torch._dynamo.compiled_autograd.enable(torch.compile): out.backward() diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index e7c5d2414f6e2b..675f99e46fb78d 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -525,6 +525,10 @@ def disable(): torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) +def maybe_disable_compiled_autograd(): + return disable() if in_compiled_autograd_region else contextlib.nullcontext() + + # return to starting state of a new process def reset() -> None: compiled_autograd_enable = False diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 4792ced32b2703..991b00e8b50cfb 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -822,9 +822,10 @@ def _engine_run_backward( if attach_logging_hooks: unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) try: - return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass + with torch._dynamo.compiled_autograd.maybe_disable_compiled_autograd(): + return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass finally: if attach_logging_hooks: unregister_hooks() # type: ignore[possibly-undefined] From 6e646e0148c9b1a1308243703afbdc803def81a9 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 12 Sep 2024 23:53:07 +0000 Subject: [PATCH 0787/1018] [aoti] Fix workspace generation for triton (#135552) Fixes #131337 - add `arg_type` for workspace_arg, the type is consistent with the type in `generate_workspace_allocation()`. - do not generate example tensors for `workspace`, and use `generate_workspace_allocation()` instead. - add workspace allocation generation code to `kernel_autotune_calls`. e.g. ```python workspace = empty_strided_cuda((1280, ), (1, ), torch.uint8) workspace.zero_() ..... triton_spl_fused_add_cumprod_0.run(buf2, arg0_1, arg1_1, workspace, 1, 10000, grid=split_scan_grid(1, 10000), stream=stream0) del buf2, arg0_1, arg1_1, workspace ``` - add `empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda` to the header of triton autotune code. The generated cpp has lines like below, so we also implement a `zero_()` for ` AtenTensorHandle `. ```cpp static constexpr int64_t int_array_0[] = {1280L, }; static constexpr int64_t int_array_1[] = {1L, }; AtenTensorHandle workspace_handle; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_uint8, cached_torch_device_type_cuda, 0, &workspace_handle)); RAIIAtenTensorHandle workspace(workspace_handle); workspace.zero_(); ``` - Fix handle grid_fn for grid computation. Pass in "RBLOCK" to `split_scan_grid` - Fix dynamic shapes: Without the fix we generate code that looks like this `workspace = empty_strided_cuda((32*((255 + s0) // 256), ), (1, ), torch.uint8)` when doing triton autotune and `s0` is not defined. The solution approach is to use `V.graph.sizevars.size_hint(nbytes)` to realize the workspace size for triton autotune. Note that we only realize it for triton autotune code, but not for the cpp cuda code. - We also generate slightly different cpp code depending on if `abi_compatible` is turned on. ```cpp RAIIAtenTensorHandle workspace(workspace_handle); AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get())); ``` vs ```cpp at::Tensor workspace = at::detail::empty_strided_cuda({8L*(c10::div_floor_integer(static_cast((255L + s0)), static_cast(256L))), }, {1L, }, at::kByte, c10::DeviceType::CUDA); workspace.zero_(); ``` Test Plan: ``` TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper python test/inductor/test_cuda_cpp_wrapper.py DynamicShapesCudaWrapperCudaTests.test_consecutive_split_cumprod_cuda_dynamic_shapes_cuda_wrapper TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135552 Approved by: https://github.com/desertfire --- test/inductor/test_cuda_cpp_wrapper.py | 1 + torch/_inductor/codegen/common.py | 1 + torch/_inductor/codegen/cpp_wrapper_gpu.py | 32 +++++++++++++++---- torch/_inductor/codegen/triton.py | 12 +++++-- torch/_inductor/codegen/triton_split_scan.py | 6 +++- torch/_inductor/codegen/wrapper.py | 14 +++++++- torch/_inductor/runtime/triton_heuristics.py | 1 + torch/csrc/inductor/aoti_torch/c/shim.h | 2 ++ .../csrc/inductor/aoti_torch/shim_common.cpp | 7 ++++ 9 files changed, 65 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 9b8b802065590d..241bad352ccb5c 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -188,6 +188,7 @@ class BaseTest(NamedTuple): BaseTest("test_sum_int"), # bool, int64, int8, uint8 BaseTest("test_transpose"), # multiple outputs, buffer clear BaseTest("test_unspec_inputs"), + BaseTest("test_consecutive_split_cumprod"), BaseTest("test_pointwise_hermite_polynomial_he"), BaseTest("test_pointwise_hermite_polynomial_h"), BaseTest( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 95183c1484fe00..02f13667b70c79 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1420,6 +1420,7 @@ def python_argdefs(self): arg_defs.append("ws_ptr") call_args.append("workspace") precompile_args.append(self.workspace_arg) + arg_types.append(torch.uint8) return arg_defs, call_args, precompile_args, arg_types def aliases(self): diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 5719d3eba589f8..13f13f44be9eca 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -6,9 +6,9 @@ import sympy -from torch import dtype as torch_dtype +from torch import dtype as torch_dtype, uint8 from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name -from torch._inductor.runtime.triton_heuristics import grid as default_grid +from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn from .. import config from ..codecache import CudaKernelParamCache @@ -88,11 +88,11 @@ def __call__(self): grid = self.grid assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" grid = self._process_grid(grid) - grid_callable = self.grid_callable or default_grid + assert self.grid_callable is not None, "grid_callable can't be None" if not self.grid_extra_kwargs: - grid_fn = grid_callable(*grid) + grid_fn = self.grid_callable(*grid) else: - grid_fn = grid_callable(*grid, **self.grid_extra_kwargs) + grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) params = CudaKernelParamCache.get(self.kernel_name) assert ( @@ -102,6 +102,7 @@ def __call__(self): "XBLOCK": params["x_block"], "YBLOCK": params["y_block"], "ZBLOCK": params["z_block"], + "RBLOCK": params["r_block"], } return grid_fn(block_cfg) @@ -338,7 +339,7 @@ def generate_default_grid( kernel_name: str, grid: List[Any], gpu: bool = True, - grid_callable: Optional[Callable[..., Any]] = None, + grid_callable: Optional[Callable[..., Any]] = default_grid_fn, **grid_extra_kwargs, ): """ @@ -440,3 +441,22 @@ def generate_kernel_call( ), ) self.writeline("}") + + def generate_workspace_allocation(self, nbytes, device, zero_fill): + line = self.make_allocation( + "workspace", device, uint8, shape=(nbytes,), stride=(1,) + ) + self.writeline(line) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(line) + if zero_fill: + if config.abi_compatible: + # TODO: remove this function to use the default WrapperCodegen behavior after service platform has zero_() symbol + # default behavior is f"workspace.zero_(){self.ending}" + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get())){self.ending}" + ) + else: + self.writeline(f"workspace.zero_(){self.ending}") + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 629b6286b2242c..25b267dbc3ee56 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -27,6 +27,7 @@ import torch._logging from torch._dynamo.utils import preserve_rng_state from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties +from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing @@ -2797,9 +2798,12 @@ def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexp if tree.prefix == "x" and self.no_x_dim: code.writeline("XBLOCK: tl.constexpr = 1") - def _get_grid_fn(self): + def _get_grid_fn_str(self): return "grid" + def _get_grid_fn(self): + return default_grid_fn + def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): # TODO(jansel): if there are constants, we shouldn't bother passing them as args for tree in self.range_trees: @@ -2828,7 +2832,9 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): ws.nbytes, current_device, ws.zero_fill ) - grid = wrapper.generate_default_grid(name, grid) + grid = wrapper.generate_default_grid( + name, grid, grid_callable=self._get_grid_fn() + ) wrapper.generate_kernel_call( name, call_args, @@ -2837,7 +2843,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): gpu=True, triton=True, arg_types=arg_types, - grid_fn=self._get_grid_fn(), + grid_fn=self._get_grid_fn_str(), triton_meta=self.triton_meta, ) diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 9eea00fbb8d6b6..31dab399236492 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -6,6 +6,7 @@ from torch._inductor import config from torch._inductor.codegen.simd import IterationRangesRoot from torch._inductor.codegen.triton import triton_compute_type, TritonKernel +from torch._inductor.runtime.triton_heuristics import split_scan_grid from torch._prims_common import prod from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv @@ -170,5 +171,8 @@ def scan(self, dtypes, combine_fn, values): def _get_heuristic(self): return "split_scan" - def _get_grid_fn(self): + def _get_grid_fn_str(self): return "split_scan_grid" + + def _get_grid_fn(self): + return split_scan_grid diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 19e4bb293c7eff..5422198194bfae 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -595,6 +595,7 @@ def write_kernel_autotune_defs_header(self) -> None: async_compile = AsyncCompile() generate_example_value = AlgorithmSelectorCache.generate_example_value + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda """ ) @@ -1496,12 +1497,18 @@ def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = No return SymbolicCallArg(expr, tree.numel) def generate_workspace_allocation(self, nbytes, device, zero_fill): + if isinstance(nbytes, sympy.Expr): + nbytes = V.graph.sizevars.size_hint(nbytes) line = self.make_allocation( "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,) ) self.writeline(line) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(line) if zero_fill: self.writeline(f"workspace.zero_(){self.ending}") + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}") def wrap_kernel_call(self, name, call_args): return f"{name}({', '.join(call_args)}){self.ending}" @@ -1720,7 +1727,12 @@ def generate_kernel_call( key, arg = arg.split("=") if isinstance(arg_type, torch_dtype): - if arg not in tensor_args: + # workspace allocation is already generated by `generate_workspace_allocation()` + # in `TritonKernel.call_kernel()`. + if arg == "workspace": + arg_str = "workspace" + tensor_args[arg] = arg_str + elif arg not in tensor_args: arg_str = self.generate_example_arg_value( arg, arg_type, raw_arg, i ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index bdc6a1e10bcaeb..ce5ac0b920e556 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -753,6 +753,7 @@ def save_gpu_kernel(self, grid, stream, launcher): "x_block": launcher.config.kwargs.get("XBLOCK", 1), "y_block": launcher.config.kwargs.get("YBLOCK", None), "z_block": launcher.config.kwargs.get("ZBLOCK", None), + "r_block": launcher.config.kwargs.get("RBLOCK", None), "num_warps": launcher.bin.num_warps if hasattr(launcher.bin, "num_warps") else launcher.bin.metadata.num_warps, diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index b470b5f10061b8..c204fc27460bdc 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -528,6 +528,8 @@ aoti_torch_cpu__wrapped_quantized_linear_prepacked( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( AtenTensorHandle repeats, int64_t* output_size, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index f49bf23b9ce422..2860972f88dac4 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1166,3 +1166,10 @@ AOTITorchError aoti_torch__alloc_from_pool( strides)); }); } + +AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); + t->zero_(); + }); +} From 57bda416e873dc1e32326f8c279960303d0dbbca Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 13 Sep 2024 00:08:52 +0000 Subject: [PATCH 0788/1018] Update nightly PyTorch version to 2.6.0 (#135916) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/135916 Approved by: https://github.com/kit1980 --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index b8feefb940f931..3d87ca93f8a9bf 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.5.0a0 +2.6.0a0 From 1ba62e40fced26569de600dff8dc6e0c74d734dd Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Thu, 12 Sep 2024 13:48:51 +0000 Subject: [PATCH 0789/1018] do not expand in replace/simplify if no changes (#135863) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135863 Approved by: https://github.com/ezyang --- torch/fx/experimental/symbolic_shapes.py | 25 ++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index c892fe846662cb..b85dd0f6a5e6b6 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4592,7 +4592,10 @@ def replace(self, expr: "sympy.Expr") -> "sympy.Expr": # assumption queries if expr has a relational node. if not r.is_Symbol or r != s: replacements[s] = r - return safe_expand(expr.xreplace(replacements)) + if replacements: + return safe_expand(expr.xreplace(replacements)) + else: + return expr @_lru_cache def _update_divisible(self): @@ -4626,8 +4629,9 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": if self.replace(Mod(base, divisor)) in self.divisible and \ base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible: div_replacements[atom] = divisor1 - expr = expr.xreplace(div_replacements) - expr = safe_expand(expr) + if div_replacements: + expr = expr.xreplace(div_replacements) + expr = safe_expand(expr) if expr.has(FloorDiv): div_replacements = {} pows = expr.atoms(sympy.Pow) @@ -4636,13 +4640,14 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: div_replacements[fd] = CleanDiv(base, divisor) - new_expr = expr.xreplace(div_replacements) - new_expr = safe_expand(new_expr) - new_pows = new_expr.atoms(sympy.Pow) - new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer)) - # divisions simplified away - if new_pows.issubset(pows) and new_rationals.issubset(rationals): - expr = new_expr + if div_replacements: + new_expr = expr.xreplace(div_replacements) + new_expr = safe_expand(new_expr) + new_pows = new_expr.atoms(sympy.Pow) + new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer)) + # divisions simplified away + if new_pows.issubset(pows) and new_rationals.issubset(rationals): + expr = new_expr return expr @lru_cache(256) From eb5570ae902daba579aa2d4f86a498358f0140e7 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 12 Sep 2024 10:06:06 -0700 Subject: [PATCH 0790/1018] [inductor] Remove unused check (#135787) I think this is unreachable code because mode is always None on reads. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135787 Approved by: https://github.com/oulgen --- torch/_inductor/scheduler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 1fb91a3e938936..60ed27a1065c8c 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3080,8 +3080,6 @@ def fusable_weak_dep( # if there's indirect indexing, don't match it def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: if isinstance(read, MemoryDep): - if read.mode == write.mode and write.mode is not None: - return True read_name = self.mutation_renames.get(read.name, read.name) if ( From cc08e0e659f666f22d5550c2d813b4735722670a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 12 Sep 2024 10:06:07 -0700 Subject: [PATCH 0791/1018] [inductor] Optimize can_fuse_vertical() (#135788) An O(n^2) to O(n) improvement by not comparing all pairs of deps. Before: ![image](https://github.com/user-attachments/assets/797cd1bd-5d53-4374-8e76-ffce4232d7f9) After: ![image](https://github.com/user-attachments/assets/1e61bf29-adba-41a4-839e-f028130fa979) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135788 Approved by: https://github.com/oulgen ghstack dependencies: #135787 --- torch/_inductor/scheduler.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 60ed27a1065c8c..5f8d35cbcab5eb 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -13,6 +13,7 @@ import textwrap import traceback import typing +from collections import defaultdict from typing import ( Any, Callable, @@ -3008,24 +3009,35 @@ def can_fuse_vertical( be scheduled before the fusion of node1 and node2. """ node1_buf_names = node1.get_buffer_names() - node1_op_names = node1.get_operation_names() - computed_deps: OrderedSet[Dep] = OrderedSet() why = WhyNoFuse(node1, node2) + remaining_deps_by_name: Dict[str, List[Dep]] = defaultdict(list) + + for dep in node2.unmet_dependencies: + name = self.mutation_renames.get(dep.name, dep.name) + if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): + continue + remaining_deps_by_name[name].append(dep) for cd in node1.read_writes.writes: if not isinstance(cd, MemoryDep): continue - for rd in node2.unmet_dependencies: - if self.fusable_read_and_write(rd, cd): - computed_deps.add(rd) - - for dep in node2.unmet_dependencies: - if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): - computed_deps.add(dep) + remaining = remaining_deps_by_name.get( + self.mutation_renames.get(cd.name, cd.name) + ) + if remaining: + for rd in remaining: + if self.fusable_read_and_write(rd, cd): + remaining.remove(rd) remaining_deps = OrderedSet( - dep.name for dep in node2.unmet_dependencies - computed_deps + [ + dep.name + for dep in itertools.chain.from_iterable( + remaining_deps_by_name.values() + ) + ] ) + if remaining_deps & node1_buf_names: # MemoryDeps didn't match and read different locations of the same buffer. # Examples here include: @@ -3033,6 +3045,8 @@ def can_fuse_vertical( # - MemoryDep("foo", x) != StarDep("foo") why("memory deps did not match") return False + + node1_op_names = node1.get_operation_names() for name in remaining_deps: op_name = self.name_to_buf[name].defining_op.get_name() if node1_op_names & self.name_to_fused_node[op_name].ancestors: From 50980461442f19e4dbffb02876f6354b9e27c5de Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 12 Sep 2024 10:06:07 -0700 Subject: [PATCH 0792/1018] [inductor] Use TracerBase directly in LoopBody (#135820) This skips some unneeded work in the subclass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135820 Approved by: https://github.com/oulgen ghstack dependencies: #135787, #135788 --- torch/_inductor/loop_body.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 2f0b00b724e71f..70a7d2fdd7967b 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -11,6 +11,7 @@ import torch.fx from torch._dynamo.utils import identity +from torch.fx.proxy import Scope, TracerBase from torch.utils._sympy.symbol import SymT from . import config, dependencies @@ -45,6 +46,16 @@ def run(self, *args, **kwargs): return super().run(*args, **kwargs) +# We don't need the nn.Module and constant handling in Tracer +class LightTracer(TracerBase): + def __init__(self): + super().__init__() + self.graph = torch.fx.Graph(tracer_cls=self.__class__) # type: ignore[arg-type] + self.scope = Scope("", None) + self.module_stack = {} # type: ignore[assignment] + self.node_name_to_scope = {} + + class MemoryEntry(NamedTuple): index_name: str # LoopBody.indexing_exprs[index_name] buffer_name: Optional[str] @@ -545,8 +556,7 @@ def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): def output(result): tracer.create_proxy("output", "output", (result,), {}) - tracer = torch.fx.Tracer() - tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + tracer = LightTracer() proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) from .index_propagation import IndexPropagation From 76997812f81921c6c92b3985d17dfade50903a3c Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 12 Sep 2024 10:06:08 -0700 Subject: [PATCH 0793/1018] [fx] Minor optimization in create_arg (#135821) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135821 Approved by: https://github.com/oulgen ghstack dependencies: #135787, #135788, #135820 --- torch/fx/graph.py | 4 +++- torch/fx/proxy.py | 40 +++++++++++++++++++++------------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 1601eee07f2717..84de9906d0d71f 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -37,6 +37,8 @@ tuple: Tuple, } +_legal_ops = dict.fromkeys(['call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output']) + # Signature for functions thattransforms the body (`list[str]`) of the # generated code @@ -1010,7 +1012,7 @@ def create_node(self, op: str, target: 'Target', The newly-created and inserted node. """ - assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') + assert op in _legal_ops args = () if args is None else args kwargs = {} if kwargs is None else kwargs assert isinstance(args, tuple), "args must be a tuple" diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 2b86a1c609f918..dc8ab7aa77e75e 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -152,6 +152,7 @@ def create_node(self, kind : str, target : Target, self.scope.module_path, self.scope.module_type, ) + # Optionally set stack trace on the created Node for debugging purposes if fx_traceback.has_preserved_node_meta(): current_meta: Dict[str, Any] = fx_traceback.get_current_meta() @@ -260,29 +261,33 @@ def create_arg(self, a: Any) -> Argument: Can be override to support more trace-specific types. """ - if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): + if isinstance(a, Proxy): + return a.node # most common arg type goes first + elif hasattr(a, '__fx_create_arg__'): return a.__fx_create_arg__(self) # aggregates - elif isinstance(a, tuple) and hasattr(a, '_fields'): - # NamedTuple constructors don't seem to like getting a generator - # expression as an argument to their constructor, so build this - # intermediate tuple and unpack it into the NamedTuple constructor - args = tuple(self.create_arg(elem) for elem in a) - return type(a)(*args) # type: ignore[arg-type] - elif isinstance(a, (tuple, list)): - return type(a)(self.create_arg(elem) for elem in a) + elif isinstance(a, tuple): + if hasattr(a, '_fields'): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = [self.create_arg(elem) for elem in a] + return type(a)(*args) # type: ignore[arg-type] + return type(a)([self.create_arg(elem) for elem in a]) + elif isinstance(a, list): + return [self.create_arg(elem) for elem in a] elif isinstance(a, dict): + def no_node(arg): + if isinstance(arg, Node): + raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {k}") + r = {} for k, v in a.items(): # Check for invalid dict keys. We do not want a Proxy to appear # anywhere within the key. Since keys can be collection types, # we iterate through the key with map_aggregate k = self.create_arg(k) - - def no_node(arg): - if isinstance(arg, Node): - raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " - f"Node. Got key: {k}") map_aggregate(k, no_node) r[k] = self.create_arg(v) @@ -296,16 +301,13 @@ def no_node(arg): elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): return a - if isinstance(a, Proxy): - # base case: we unwrap the Proxy object - return a.node - - if is_dataclass(a): + elif is_dataclass(a): kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} return self.create_node("call_function", a.__class__, (), kwargs) elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: return a + raise NotImplementedError(f"argument of type: {type(a)}") @compatibility(is_backward_compatible=True) From 4a78f2aae5d9b89fa27fa5ee97531e599a798e6b Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 12 Sep 2024 10:06:08 -0700 Subject: [PATCH 0794/1018] [fx] Replace _snake_case with a regexp (#135822) ~2x speedup on this function, though saves <0.5s overall Pull Request resolved: https://github.com/pytorch/pytorch/pull/135822 Approved by: https://github.com/oulgen ghstack dependencies: #135787, #135788, #135820, #135821 --- torch/fx/graph.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 84de9906d0d71f..b0df9f02fcb8e6 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -20,6 +20,7 @@ import math import warnings import inspect +import functools __all__ = ["PythonCode", "CodeGen", "Graph"] @@ -86,14 +87,11 @@ def _snake_case(s: str) -> str: ``mod.pascalCase``-> ``mod.pascal_case`` ``mod.ALL_CAPS`` -> ``mod.all_caps`` """ - chars = [] - prev_lower = False - for c in s: - if prev_lower and c.isupper(): - chars.append('_') - chars.append(c.lower()) - prev_lower = c.islower() - return ''.join(chars) + return _snake_case_sub(s).lower() + + +# Replace occurrences where a lowercase letter is followed by an uppercase letter +_snake_case_sub = functools.partial(re.compile(r'(?<=[a-z])([A-Z])').sub, r'_\1') def _is_from_torch(obj: Any) -> bool: From cebd19b280f96b733fad0d6176916d764e0808e0 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Fri, 13 Sep 2024 00:19:40 +0000 Subject: [PATCH 0795/1018] [Docs] fix inconsistent docs in conv1d, conv2d, and conv3d (#135894) Addresses https://github.com/pytorch/pytorch/issues/135880 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135894 Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet --- torch/nn/modules/conv.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 8501d3f79d2df6..170c806e47fefc 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -270,14 +270,14 @@ class Conv1d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to both sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` """.format( **reproducibility_notes, **convolution_notes @@ -443,13 +443,13 @@ class Conv2d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` """.format( **reproducibility_notes, **convolution_notes ) @@ -615,10 +615,10 @@ class Conv3d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all six sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` """.format( **reproducibility_notes, **convolution_notes ) @@ -1467,14 +1467,14 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` """ @@ -1536,14 +1536,14 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` """ @@ -1606,14 +1606,14 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` """ From e36c7f81b67f5a11e917af8416d6a347a3d9d11c Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 13 Sep 2024 00:30:43 +0000 Subject: [PATCH 0796/1018] [Inductor] Disable TF32 in `test_slice_scatter_reinplace` (#135709) TF32 linear/matmul numerics seem unrelated to test functionality so disabling it here to abate noisy failures Pull Request resolved: https://github.com/pytorch/pytorch/pull/135709 Approved by: https://github.com/eellison --- test/inductor/test_torchinductor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 881c2cf6db7d65..f2240ff64922d4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -7404,6 +7404,7 @@ def fn(a, b): b = torch.empty(0) self.common(fn, [a, b]) + @with_tf32_off def test_slice_scatter_reinplace(self): class M(nn.Module): def __init__(self, device): From 6d1285ff7d094366697a0cb06ee2f220dc5c8eea Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 12 Sep 2024 21:55:23 +0000 Subject: [PATCH 0797/1018] [dynamo] Fix support for classmethod(property(...)) (#134968) Fixes #134451 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134968 Approved by: https://github.com/yanboliang --- test/dynamo/test_repros.py | 77 ++++++++++++++++++++++++- torch/_dynamo/variables/constant.py | 2 + torch/_dynamo/variables/user_defined.py | 4 ++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e1e1dd1837bef4..43b2278cc5c085 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -20,7 +20,7 @@ from abc import ABC from collections import namedtuple from copy import deepcopy -from enum import Enum +from enum import Enum, IntEnum from functools import wraps from typing import Any, Dict, Iterator, List, Tuple from unittest import mock @@ -4546,6 +4546,81 @@ def f(*args): f(*args) self.assertEqual(num_compiles, 1) + def test_issue134451(self): + class BoundingBox2DIndex(IntEnum): + _X = 0 + _Y = 1 + _HEADING = 2 + _LENGTH = 3 + _WIDTH = 4 + + @classmethod + def size(cls): + return 5 + + @classmethod + @property + def X(cls): + return cls._X + + @classmethod + @property + def Y(cls): + return cls._Y + + @classmethod + @property + def HEADING(cls): + return cls._HEADING + + @classmethod + @property + def LENGTH(cls): + return cls._LENGTH + + @classmethod + @property + def WIDTH(cls): + return cls._WIDTH + + @classmethod + @property + def POINT(cls): + # assumes X, Y have subsequent indices + return slice(cls._X, cls._Y + 1) + + @classmethod + @property + def STATE_SE2(cls): + # assumes X, Y, HEADING have subsequent indices + return slice(cls._X, cls._HEADING + 1) + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self._mlp_states = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, BoundingBox2DIndex.size()), + ) + + def forward(self, x): + agent_states = self._mlp_states(x) + agent_states[..., BoundingBox2DIndex.POINT] = ( + agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 + ) + agent_states[..., BoundingBox2DIndex.HEADING] = ( + agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi + ) + return agent_states + + model = SimpleModel().eval() + input_tensor = torch.randn(1, 10, dtype=torch.float32) + opt = torch.compile(model.eval(), backend="eager", fullgraph=True) + actual = opt(input_tensor) + expected = model(input_tensor) + self.assertEqual(actual, expected) + def test_invalid_seq_unpack(self): def myfn(arg): (a, b) = arg diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 65de0aab6ef3c7..de357cf8094f3a 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -222,6 +222,8 @@ def create(cls, cls_type, value_vt, options): unimplemented("Enum variable is constructed with non constant values") def as_proxy(self): + if isinstance(self.value, int): + return int(self.value) # convert IntEnum to a normal int return self.value def __str__(self) -> str: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 34e1d0d10c9f72..9aa64a5a6a8d43 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -197,6 +197,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke else: return SourcelessBuilder.create(tx, func) elif isinstance(obj, classmethod): + if isinstance(obj.__func__, property): + return variables.UserFunctionVariable(obj.__func__.fget).call_function( + tx, [self], {} + ) return variables.UserMethodVariable(obj.__func__, self, source=source) elif isinstance(obj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static(dict, "fromkeys") From 077dc250a9c5e90130b30e18789fbd143d43820b Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Fri, 13 Sep 2024 02:04:34 +0000 Subject: [PATCH 0798/1018] Fix lint errors in fbcode (#135614) Summary: Fixed a bunch of fbcode imports that happened to work but confused autodeps. After this autodeps still suggests "improvements" to TARGETS (which breaks our builds) but at least it can find all the imports. Test Plan: ``` fbpython fbcode/tools/build/buck/linters/lint_autoformat.py --linter=autodeps --default-exec-timeout=1800 -- fbcode/caffe2/TARGETS fbcode/caffe2/test/TARGETS ``` Before: ``` ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/testing.py:229) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fbur$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export.py:87) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_serdes.py:9) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fb$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_serdes.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_retraceability.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https:$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_retraceability.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See ht$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_nonstrict.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See http$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_nonstrict.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See $ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:8) when processing rule "test_export". Please make sure it's listed in the srcs parameter of an$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of anoth$ ERROR while processing caffe2/test/TARGETS: Found "//python/typeshed_internal:typeshed_internal_library" owner for "cv2" but it is protected by visibility rules: [] (from caffe2/test/test_bundled_images.py:7) when processing rule "test_bundled_$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "caffe2.test.profiler_test_cpp_thread_lib" (from caffe2/test/profiler/test_cpp_thread.py:29) when processing rule "profiler_test_cpp_thread". Please make sure it's listed in t$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_custom_ops.py:23) when processing rule "custom_ops". Please make sure it's listed in the srcs parameter of anoth$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_public_bindings.py:13) when processing rule "public_bindings". Please make sure it's listed in the srcs paramete$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.symbolize_tracebacks" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another $ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.gather_traceback" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another rule$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for include (from caffe2/test/profiler/test_cpp_thread.cpp:2) when processing profiler_test_cpp_thread_lib. Some things to try: ``` Differential Revision: D62049222 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135614 Approved by: https://github.com/oulgen, https://github.com/laithsakka --- test/export/test_export.py | 2 +- test/export/test_export_nonstrict.py | 4 ++-- .../test_export_training_ir_to_run_decomp.py | 4 ++-- test/export/test_retraceability.py | 4 ++-- test/export/test_serdes.py | 4 ++-- test/export/testing.py | 2 +- test/inductor/custom_ops.cpp | 2 +- test/inductor/test_aot_inductor.py | 14 ++++++++++---- test/inductor/test_aot_inductor_utils.py | 2 +- test/inductor/test_benchmark_fusion.py | 6 +++++- test/inductor/test_binary_folding.py | 10 ++++++++-- test/inductor/test_ck_backend.py | 2 +- test/inductor/test_combo_kernels.py | 5 ++++- test/inductor/test_compiled_optimizers.py | 7 +++++-- test/inductor/test_coordinate_descent_tuner.py | 2 +- test/inductor/test_cpu_cpp_wrapper.py | 10 +++++----- test/inductor/test_cpu_repro.py | 2 +- test/inductor/test_cpu_select_algorithm.py | 4 ++-- test/inductor/test_cuda_cpp_wrapper.py | 10 +++++----- test/inductor/test_cuda_repro.py | 6 +++--- .../test_cudagraph_trees_expandable_segments.py | 6 ++++-- test/inductor/test_debug_trace.py | 2 +- test/inductor/test_efficient_conv_bn_eval.py | 4 +++- test/inductor/test_extension_backend.py | 4 ++-- test/inductor/test_foreach.py | 5 ++++- test/inductor/test_halide.py | 4 ++-- test/inductor/test_inductor_freezing.py | 6 +++++- test/inductor/test_inplacing_pass.py | 4 ++-- test/inductor/test_memory_planning.py | 4 +++- test/inductor/test_perf.py | 4 ++-- test/inductor/test_profiler.py | 2 +- .../test_torchinductor_codegen_dynamic_shapes.py | 4 ++-- test/inductor/test_torchinductor_dynamic_shapes.py | 2 +- test/inductor/test_torchinductor_opinfo.py | 5 ++++- test/inductor/test_triton_extension_backend.py | 8 +++++--- test/inductor/test_triton_heuristics.py | 6 +++--- test/inductor/test_triton_kernels.py | 8 ++++---- test/inductor/test_xpu_basic.py | 5 ++++- test/profiler/test_cpp_thread.cpp | 2 +- test/profiler/test_cpp_thread.py | 2 +- test/profiler/test_cpp_thread_lib.pyi | 0 test/test_bundled_images.py | 2 +- test/test_cuda.py | 2 +- test/test_custom_ops.py | 2 +- test/test_public_bindings.py | 2 +- 45 files changed, 120 insertions(+), 77 deletions(-) create mode 100644 test/profiler/test_cpp_thread_lib.pyi diff --git a/test/export/test_export.py b/test/export/test_export.py index 8896ef3983ac39..db67266632ce58 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -84,7 +84,7 @@ try: from . import testing except ImportError: - import testing + import testing # @manual=fbcode//caffe2/test:test_export-library # The following import pattern matters as `test_export.export` is patched # in other files (like test_export_nonstrict.py). `torch.export.export` # will invalidate the patch. diff --git a/test/export/test_export_nonstrict.py b/test/export/test_export_nonstrict.py index c368eb069a2bd1..99944e4841b72e 100644 --- a/test/export/test_export_nonstrict.py +++ b/test/export/test_export_nonstrict.py @@ -3,8 +3,8 @@ try: from . import test_export, testing except ImportError: - import test_export - import testing + import test_export # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library from torch.export import export diff --git a/test/export/test_export_training_ir_to_run_decomp.py b/test/export/test_export_training_ir_to_run_decomp.py index 6f780c1fec85bf..3cab41c6d095c5 100644 --- a/test/export/test_export_training_ir_to_run_decomp.py +++ b/test/export/test_export_training_ir_to_run_decomp.py @@ -5,9 +5,9 @@ try: from . import test_export, testing except ImportError: - import test_export + import test_export # @manual=fbcode//caffe2/test:test_export-library - import testing + import testing # @manual=fbcode//caffe2/test:test_export-library test_classes = {} diff --git a/test/export/test_retraceability.py b/test/export/test_retraceability.py index d3f914188ccfce..e7f243fd9fb7e4 100644 --- a/test/export/test_retraceability.py +++ b/test/export/test_retraceability.py @@ -3,8 +3,8 @@ try: from . import test_export, testing except ImportError: - import test_export - import testing + import test_export # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library from torch.export import export diff --git a/test/export/test_serdes.py b/test/export/test_serdes.py index 59b83f22c3d23f..a1ced9dd4e5e6c 100644 --- a/test/export/test_serdes.py +++ b/test/export/test_serdes.py @@ -6,8 +6,8 @@ try: from . import test_export, testing except ImportError: - import test_export - import testing + import test_export # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library from torch.export import export, load, save diff --git a/test/export/testing.py b/test/export/testing.py index 054e3a611dff86..3647d4c9edd86a 100644 --- a/test/export/testing.py +++ b/test/export/testing.py @@ -226,7 +226,7 @@ def _fn(*args, **kwargs): try: from . import test_export except ImportError: - import test_export + import test_export # @manual=fbcode//caffe2/test:test_export-library with patch(f"{test_export.__name__}.export", mocked_export_fn): return fn(*args, **kwargs) diff --git a/test/inductor/custom_ops.cpp b/test/inductor/custom_ops.cpp index 360a2d0b862384..39c1098d95b817 100644 --- a/test/inductor/custom_ops.cpp +++ b/test/inductor/custom_ops.cpp @@ -1,4 +1,4 @@ -#include +#include // @manual=fbcode//caffe2:libtorch #include #include diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 2c85c488b2bbaa..4dee53c4dc41ae 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -45,7 +45,7 @@ if HAS_CUDA: - import triton + import triton # @manual from torch.testing._internal.triton_utils import ( add_kernel, @@ -76,14 +76,20 @@ ) from .test_torchinductor import copy_tests, requires_multigpu, TestFailure except ImportError: - from test_aot_inductor_utils import AOTIRunnerUtil - from test_control_flow import ( + from test_aot_inductor_utils import ( + AOTIRunnerUtil, # @manual=fbcode//caffe2/test/inductor:aot_inductor_utils-library + ) + from test_control_flow import ( # @manual=fbcode//caffe2/test/inductor:control_flow-library CondModels, prepend_counters, prepend_predicates, WhileLoopModels, ) - from test_torchinductor import copy_tests, requires_multigpu, TestFailure + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + copy_tests, + requires_multigpu, + TestFailure, + ) except (unittest.SkipTest, ImportError) as e: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 27433876c9ca45..425e87bcfbb07d 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -67,7 +67,7 @@ def compile( @staticmethod def load_runner(device, so_path): if IS_FBCODE: - from .fb import test_aot_inductor_model_runner_pybind + from .fb import test_aot_inductor_model_runner_pybind # @manual with tempfile.TemporaryDirectory() as temp_dir: # copy *.so file to a unique path just before loading diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 706c2d9ae76546..9eb25aa305a1a5 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -20,7 +20,11 @@ import contextlib import unittest -from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_cuda, + copy_tests, +) from torch._inductor import config from torch._inductor.scheduler import Scheduler diff --git a/test/inductor/test_binary_folding.py b/test/inductor/test_binary_folding.py index a8b39c1bba055e..20f613fc746f37 100644 --- a/test/inductor/test_binary_folding.py +++ b/test/inductor/test_binary_folding.py @@ -15,8 +15,14 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from inductor.test_inductor_freezing import TestCase -from inductor.test_torchinductor import check_model, check_model_gpu, copy_tests +from inductor.test_inductor_freezing import ( + TestCase, # @manual=fbcode//caffe2/test/inductor:inductor_freezing-library +) +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_gpu, + copy_tests, +) from torch.testing._internal.common_utils import TEST_WITH_ASAN from torch.testing._internal.inductor_utils import skipCUDAIf diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index 94e53ad1b3563a..dc507e42aeb315 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -43,7 +43,7 @@ def setUp(self): torch.random.manual_seed(1234) try: - import ck4inductor + import ck4inductor # @manual self.ck_dir = os.path.dirname(ck4inductor.__file__) os.environ["TORCHINDUCTOR_CK_DIR"] = self.ck_dir diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index 0f997d01535721..bccdacab2a679a 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -20,7 +20,10 @@ try: from .test_torchinductor import check_model, check_model_cuda except ImportError: - from test_torchinductor import check_model, check_model_cuda + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_cuda, + ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index abd62fbd584b03..b7fde0bb9fa9fc 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -270,7 +270,10 @@ def build_opt_kwarg_db(): try: from .test_torchinductor import check_model, check_model_gpu except ImportError: - from test_torchinductor import check_model, check_model_gpu + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_gpu, + ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": @@ -835,7 +838,7 @@ def test_S429861(self): try: from . import s429861_repro except ImportError: - import s429861_repro + import s429861_repro # @manual forward = s429861_repro.forward diff --git a/test/inductor/test_coordinate_descent_tuner.py b/test/inductor/test_coordinate_descent_tuner.py index dbe92859a08d66..bedb9a64727c7c 100644 --- a/test/inductor/test_coordinate_descent_tuner.py +++ b/test/inductor/test_coordinate_descent_tuner.py @@ -12,7 +12,7 @@ try: - import triton + import triton # @manual except ImportError: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 8063d0545ac032..5dfe5487799a54 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -28,11 +28,11 @@ test_torchinductor_dynamic_shapes, ) except ImportError: - import test_cpu_repro - import test_cpu_select_algorithm - import test_mkldnn_pattern_matcher - import test_torchinductor - import test_torchinductor_dynamic_shapes + import test_cpu_repro # @manual=fbcode//caffe2/test/inductor:test_cpu_repro-library + import test_cpu_select_algorithm # @manual=fbcode//caffe2/test/inductor:cpu_select_algorithm_cpu-library + import test_mkldnn_pattern_matcher # @manual + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library + import test_torchinductor_dynamic_shapes # @manual=fbcode//caffe2/test/inductor:test_inductor-library_dynamic_shapes except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index f90b9da1f11a34..bc65658a1e2ec1 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -42,7 +42,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 2034fa05653a87..320c51087b25d5 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -39,8 +39,8 @@ try: from . import test_cpu_repro, test_torchinductor except ImportError: - import test_cpu_repro - import test_torchinductor + import test_cpu_repro # @manual=fbcode//caffe2/test/inductor:test_cpu_repro-library + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 241bad352ccb5c..2a01509b26cf54 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -26,11 +26,11 @@ except ImportError: import test_combo_kernels - import test_foreach - import test_pattern_matcher - import test_select_algorithm - import test_torchinductor - import test_torchinductor_dynamic_shapes + import test_foreach # @manual=fbcode//caffe2/test/inductor:foreach-library + import test_pattern_matcher # @manual=fbcode//caffe2/test/inductor:pattern_matcher-library + import test_select_algorithm # @manual=fbcode//caffe2/test/inductor:select_algorithm-library + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library + import test_torchinductor_dynamic_shapes # @manual=fbcode//caffe2/test/inductor:test_inductor-library_dynamic_shapes except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 29653ffcda3976..852b56e6326a75 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -38,15 +38,15 @@ try: try: - import triton - from triton import language as tl + import triton # @manual + from triton import language as tl # @manual except ImportError: raise unittest.SkipTest("requires triton") # noqa: B904 try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_cudagraph_trees_expandable_segments.py b/test/inductor/test_cudagraph_trees_expandable_segments.py index 9f57aa0d30a494..aa1e85fd82d149 100644 --- a/test/inductor/test_cudagraph_trees_expandable_segments.py +++ b/test/inductor/test_cudagraph_trees_expandable_segments.py @@ -18,12 +18,14 @@ try: from .test_cudagraph_trees import CudaGraphTreeTests except ImportError: - from test_cudagraph_trees import CudaGraphTreeTests # noqa: F401 + from test_cudagraph_trees import ( # noqa: F401 # @manual=fbcode//caffe2/test/inductor:cudagraph_trees-library + CudaGraphTreeTests, + ) REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(REPO_ROOT)) -from tools.stats.import_test_stats import get_disabled_tests +from tools.stats.import_test_stats import get_disabled_tests # @manual # Make sure to remove REPO_ROOT after import is done diff --git a/test/inductor/test_debug_trace.py b/test/inductor/test_debug_trace.py index 304cba001106fb..701d4e6cd9f5a5 100644 --- a/test/inductor/test_debug_trace.py +++ b/test/inductor/test_debug_trace.py @@ -18,7 +18,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_efficient_conv_bn_eval.py b/test/inductor/test_efficient_conv_bn_eval.py index c186416827276b..90628a4c6a135f 100644 --- a/test/inductor/test_efficient_conv_bn_eval.py +++ b/test/inductor/test_efficient_conv_bn_eval.py @@ -23,7 +23,9 @@ importlib.import_module("functorch") importlib.import_module("filelock") -from inductor.test_torchinductor import copy_tests +from inductor.test_torchinductor import ( + copy_tests, # @manual=fbcode//caffe2/test/inductor:test_inductor-library +) class ConvOp(nn.Module): diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index 22fea818b56c46..6f972e46a1d987 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -11,7 +11,7 @@ try: - from extension_backends.cpp.extension_codegen_backend import ( + from extension_backends.cpp.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend # noqa: B950 ExtensionCppWrapperCodegen, ExtensionScheduling, ExtensionWrapperCodegen, @@ -38,7 +38,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index 5d30af0f79da69..e1ba38af845dac 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -21,7 +21,10 @@ try: from .test_torchinductor import check_model, check_model_cuda except ImportError: - from test_torchinductor import check_model, check_model_cuda + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_cuda, + ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 806d71b6605bd8..a54a9d71ba8f98 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -27,7 +27,7 @@ raise unittest.SkipTest("requires sympy/functorch/filelock") try: - import halide + import halide # @manual HAS_HALIDE = halide is not None except ImportError: @@ -37,7 +37,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library make_halide = config.patch( diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 9ea952ca02bd2a..88f5530b57870c 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -23,7 +23,11 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_cuda, + copy_tests, +) from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM diff --git a/test/inductor/test_inplacing_pass.py b/test/inductor/test_inplacing_pass.py index 309e793667161f..280bcb25c37d9b 100644 --- a/test/inductor/test_inplacing_pass.py +++ b/test/inductor/test_inplacing_pass.py @@ -46,8 +46,8 @@ def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> No if HAS_GPU: - import triton - import triton.language as tl + import triton # @manual + import triton.language as tl # @manual @triton.jit def sin_kernel( diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index df125324e89723..d3e07670492123 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -84,7 +84,9 @@ def test_abi_compatible(self): try: from .test_aot_inductor import AOTIRunnerUtil except ImportError: - from test_aot_inductor import AOTIRunnerUtil + from test_aot_inductor import ( + AOTIRunnerUtil, # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library + ) f, args = self._generate(device="cuda") dim0_x = Dim("dim0_x", min=1, max=2048) diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 302e246de62b19..7de94642f31dde 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -32,8 +32,8 @@ if HAS_CUDA: - import triton - import triton.language as tl + import triton # @manual + import triton.language as tl # @manual from torch.testing._internal.triton_utils import add_kernel diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index e4af44c761d486..016ee768f890c6 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -172,7 +172,7 @@ def fn(x, y): @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") def test_inductor_profiling_triton_hooks(self): - from triton.compiler import CompiledKernel + from triton.compiler import CompiledKernel # @manual hooks_called = {"enter": False, "exit": False} diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index bfadb09344f890..729d368a1e5220 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -20,14 +20,14 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from inductor.test_torchinductor import ( +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library CommonTemplate, copy_tests, run_and_get_cpp_code, run_and_get_triton_code, TestFailure, ) -from inductor.test_torchinductor_dynamic_shapes import ( +from inductor.test_torchinductor_dynamic_shapes import ( # @manual make_dynamic_cls, test_failures as dynamic_shapes_test_failures, ) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 2e631445919ac2..5dee3e2956bae9 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -39,7 +39,7 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from inductor.test_torchinductor import ( +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library check_model, check_model_gpu, CommonTemplate, diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index fdac0cc9cbea59..2281acee96794a 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -49,7 +49,10 @@ try: from .test_torchinductor import check_model, check_model_gpu except ImportError: - from test_torchinductor import check_model, check_model_gpu + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_gpu, + ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": diff --git a/test/inductor/test_triton_extension_backend.py b/test/inductor/test_triton_extension_backend.py index b6e04bf99220bf..3d3fc29f3b398d 100644 --- a/test/inductor/test_triton_extension_backend.py +++ b/test/inductor/test_triton_extension_backend.py @@ -10,8 +10,10 @@ try: - from extension_backends.triton.device_interface import DeviceInterface - from extension_backends.triton.extension_codegen_backend import ( + from extension_backends.triton.device_interface import ( + DeviceInterface, # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend + ) + from extension_backends.triton.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend # noqa: B950 CPUDeviceOpOverrides, ExtensionScheduling, ExtensionWrapperCodegen, @@ -41,7 +43,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index c4add081aa4e26..24f322dfebb845 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -9,8 +9,8 @@ try: - import triton # noqa: F401 - import triton.language as tl + import triton # noqa: F401 # @manual + import triton.language as tl # @manual except ImportError: if __name__ == "__main__": sys.exit(0) @@ -88,7 +88,7 @@ def test_artificial_grid_cpp_wrapper(self): self._test_artificial_zgrid() def _get_cos_kernel_caching_autotuner_args(self): - from triton.compiler.compiler import AttrsDescriptor + from triton.compiler.compiler import AttrsDescriptor # @manual @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 9d426926d7c019..f0aea3f74e49eb 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -36,12 +36,12 @@ if not TEST_WITH_ROCM: if HAS_CUDA: - from triton.language.extra.cuda.libdevice import ( + from triton.language.extra.cuda.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, ) elif HAS_XPU: - from triton.language.extra.intel.libdevice import ( + from triton.language.extra.intel.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, ) @@ -2543,8 +2543,8 @@ def f(x, y): @requires_gpu def test_capture_triton_disabled_in_triton_op(self): - import triton - import triton.language as tl + import triton # @manual + import triton.language as tl # @manual @triton.jit def add_kernel( diff --git a/test/inductor/test_xpu_basic.py b/test/inductor/test_xpu_basic.py index acc197c35f2104..f4bf30e4f2d7d8 100644 --- a/test/inductor/test_xpu_basic.py +++ b/test/inductor/test_xpu_basic.py @@ -20,7 +20,10 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from inductor.test_torchinductor import check_model_gpu, TestCase +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model_gpu, + TestCase, +) # TODO: Remove this file. diff --git a/test/profiler/test_cpp_thread.cpp b/test/profiler/test_cpp_thread.cpp index 09d2fb42a2449b..ce60d9c816c697 100644 --- a/test/profiler/test_cpp_thread.cpp +++ b/test/profiler/test_cpp_thread.cpp @@ -1,5 +1,5 @@ -#include +#include // @manual #include #include diff --git a/test/profiler/test_cpp_thread.py b/test/profiler/test_cpp_thread.py index 109831b6634c2e..5dd12277e181be 100644 --- a/test/profiler/test_cpp_thread.py +++ b/test/profiler/test_cpp_thread.py @@ -26,7 +26,7 @@ def is_fbcode(): if is_fbcode(): - import caffe2.test.profiler_test_cpp_thread_lib as cpp + import caffe2.test.profiler_test_cpp_thread_lib as cpp # @manual=//caffe2/test:profiler_test_cpp_thread_lib else: # cpp extensions use relative paths. Those paths are relative to # this file, so we'll change the working directory temporarily diff --git a/test/profiler/test_cpp_thread_lib.pyi b/test/profiler/test_cpp_thread_lib.pyi new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/test_bundled_images.py b/test/test_bundled_images.py index c91814af31a09b..1919e1cd4fe34a 100644 --- a/test/test_bundled_images.py +++ b/test/test_bundled_images.py @@ -4,7 +4,7 @@ import io -import cv2 +import cv2 # @manual import torch import torch.utils.bundled_inputs diff --git a/test/test_cuda.py b/test/test_cuda.py index 2195e71d6efc1a..a8e35c1c9a35a9 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3345,7 +3345,7 @@ def test_memory_snapshot(self): @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") def test_direct_traceback(self): - from torch._C._profiler import gather_traceback, symbolize_tracebacks + from torch._C._profiler import gather_traceback, symbolize_tracebacks # @manual c = gather_traceback(True, True, True) (r,) = symbolize_tracebacks([c]) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 524095f741a70c..816b640eec861e 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -20,7 +20,7 @@ from torch import Tensor from torch._custom_op.impl import CustomOp, infer_schema from torch._library.infer_schema import tuple_to_list -from torch._utils_internal import get_file_path_2 +from torch._utils_internal import get_file_path_2 # @manual from torch.testing._internal import custom_op_db from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_device_type import ( diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 94d5f7804b6b87..e2d7d8236c6159 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -10,7 +10,7 @@ from typing import Callable import torch -from torch._utils_internal import get_file_path_2 +from torch._utils_internal import get_file_path_2 # @manual from torch.testing._internal.common_utils import ( IS_JETSON, IS_MACOS, From 13f56c10c13cf5e921939dc198261c7f97e29c35 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Fri, 13 Sep 2024 02:22:33 +0000 Subject: [PATCH 0799/1018] [Profiler] Torch Profiler distributed info is not JSON serializable (#135548) Summary: To fix https://github.com/pytorch/pytorch/issues/133308 we must create an encoder for numpy values so we can serialize the distributed metadata to JSON. Test Plan: Added unit test to check that numpy values can be serialized Differential Revision: D62411619 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135548 Approved by: https://github.com/aaronenyeshi, https://github.com/albanD --- torch/profiler/profiler.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 939ae73a99afb4..4b0708c4a78f57 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -36,6 +36,28 @@ PROFILER_STEP_NAME = "ProfilerStep" +class _NumpyEncoder(json.JSONEncoder): + """ + Json encoder for numpy types (np.int, np.float, np.array etc.) + Returns default encoder if numpy is not available + """ + + def default(self, obj): + """Encode NumPy types to JSON""" + try: + import numpy as np + except ImportError: + return json.JSONEncoder.default(self, obj) + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return json.JSONEncoder.default(self, obj) + + def supported_activities(): """ Returns a set of supported profiler tracing activities. @@ -187,7 +209,9 @@ def start_trace(self): if kineto_available(): dist_info = self._get_distributed_info() if dist_info: - self.add_metadata_json("distributedInfo", json.dumps(dist_info)) + self.add_metadata_json( + "distributedInfo", json.dumps(dist_info, cls=_NumpyEncoder) + ) if hasattr(torch, "_inductor"): import torch._inductor.config as inductor_config @@ -931,5 +955,6 @@ def _record_pg_config(self) -> None: ): pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info torch.autograd._record_function_with_args_enter( - "## process_group:init ##", json.dumps(pg_config_info) + "## process_group:init ##", + json.dumps(pg_config_info, cls=_NumpyEncoder), ) From 21abb348258b930b59b9630d6843246143a120cc Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 13 Sep 2024 02:27:05 +0000 Subject: [PATCH 0800/1018] [BE] Use `C10_UNUSED` (#135914) Instead of `(void)foo; // Suppress unused variable` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135914 Approved by: https://github.com/huydhn, https://github.com/eqy --- aten/src/ATen/native/cpu/FlashAttentionKernel.cpp | 9 +++------ aten/src/ATen/native/cuda/Sorting.cu | 9 +++------ aten/src/ATen/native/mps/operations/Normalization.mm | 3 +-- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp index 7956d0c7f5f51f..90e34fc70b81f4 100644 --- a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp +++ b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp @@ -473,8 +473,7 @@ void cpu_flash_attention( scalar_t* transpose_buffer_ptr = transpose_buffer.get(); std::unique_ptr v_copy_buffer = std::make_unique(ekvSplitSize * packb_size); scalar_t* v_copy_buffer_ptr = v_copy_buffer.get(); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable + for (C10_UNUSED auto z : c10::irange(begin, end)) { n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); int64_t ekvBlockSize = kvBlockSize % 2 == 0 ? kvBlockSize : kvBlockSize + 1; @@ -567,8 +566,7 @@ void cpu_flash_attention( ? query_padding_ptr + ompIdx * qSplitSize * eheadSize : nullptr; - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable + for (C10_UNUSED auto z : c10::irange(begin, end)) { int64_t m = k * qSplitSize; int64_t qBlockSize = std::min(qSplitSize, qSize - m); // Initialize max and sum @@ -933,8 +931,7 @@ void cpu_flash_attention_backward( at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype)); accum_t* dsum_data = dsum.data_ptr(); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable + for (C10_UNUSED auto z : c10::irange(begin, end)) { // rowsum of grad_out * out for (int64_t m = 0; m < qSize; m += qSplitSize) { int64_t qBlockSize = std::min(qSplitSize, qSize - m); diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index 6272bbb9b75df5..290be3926c6ffc 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -177,12 +177,11 @@ struct KthValueLauncher { cuda::detail::TensorInfo values_info, int collapse_values_dim, cuda::detail::TensorInfo indices_info, - int collapse_indices_dim, + C10_UNUSED int collapse_indices_dim, cuda::detail::TensorInfo self_info, int collapse_self_dim, int64_t num_slices, int64_t slice_size) { - (void)collapse_indices_dim; // Suppress unused variable warning dim3 grid; if (!getGridFromTiles(num_slices, grid)) { AT_ERROR("slices are too many"); @@ -213,15 +212,13 @@ struct MedianLauncher { template inline void launch( cuda::detail::TensorInfo values_info, - int collapse_values_dim, + C10_UNUSED int collapse_values_dim, cuda::detail::TensorInfo indices_info, - int collapse_indices_dim, + C10_UNUSED int collapse_indices_dim, cuda::detail::TensorInfo self_info, int collapse_self_dim, int64_t num_slices, int64_t slice_size) { - (void)collapse_values_dim; // Suppress unused variable warning - (void)collapse_indices_dim; // Suppress unused variable warning dim3 grid; if (!getGridFromTiles(num_slices, grid)) { AT_ERROR("slices are too many"); diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index 1d060f78ee2337..ee32bc3d2274a1 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -904,8 +904,7 @@ static string get_mem_string(c10::MemoryFormat memory_format) { for (const auto idx : c10::irange(axis)) { stat_shape.push_back(input_shape[idx]); } - for (const auto idx : c10::irange(axis, input.dim())) { - (void)idx; // Suppress unused variable + for (C10_UNUSED auto idx : c10::irange(axis, input.dim())) { stat_shape.push_back(1); } mean = mean.view(stat_shape); From 8ed71a98f1c52278f658641bcc2d8b3a24f2a01b Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Thu, 12 Sep 2024 16:48:28 +0000 Subject: [PATCH 0801/1018] Fix xpu memory stats error (#135818) # Motivation fix https://github.com/pytorch/pytorch/issues/135726 After merging two free blocks, I made a stupid mistake of ignoring the correct size to decrease the active memory size, which should be the original block size instead of the merged block size. # Additional Context Add a UT to guard this scenario. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135818 Approved by: https://github.com/EikanWang --- c10/xpu/XPUCachingAllocator.cpp | 6 ++++-- test/test_xpu.py | 9 +++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 42db9c89c98ab6..067ddf5f82a4d7 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -168,6 +168,8 @@ class DeviceCachingAllocator { !block->allocated && block->event_count == 0 && block->stream_uses.empty()); + size_t original_block_size = block->size; + size_t requested_size = block->requested_size; auto& pool = *block->pool; const std::array merge_candidates = {block->prev, block->next}; for (Block* merge_candidate : merge_candidates) { @@ -180,8 +182,8 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - stats.active_bytes[stat_type].decrease(block->size); - stats.requested_bytes[stat_type].decrease(block->requested_size); + stats.active_bytes[stat_type].decrease(original_block_size); + stats.requested_bytes[stat_type].decrease(requested_size); }); } diff --git a/test/test_xpu.py b/test/test_xpu.py index 8fa2d5da21b620..471a422ab0b0c7 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -390,6 +390,15 @@ def test_memory_allocation(self): self.assertEqual(torch.xpu.memory_allocated(), prev) torch.xpu.empty_cache() self.assertEqual(torch.xpu.memory_reserved(), 0) + torch.xpu.reset_accumulated_memory_stats() + # Activate 1kB memory + a = torch.randn(256, device="xpu") + # Detect if the current active memory is 1kB + self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.current"], 1024) + self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.freed"], 0) + del a + self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.current"], 0) + self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.freed"], 1024) @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected") def test_device_memory_allocated(self): From 38cb0c739bef26e3b8e56ec29e6fc999d8d413f5 Mon Sep 17 00:00:00 2001 From: Prachi Gupta Date: Fri, 13 Sep 2024 02:46:46 +0000 Subject: [PATCH 0802/1018] [ROCm] skip test_fp8_cast_and_t on non-MI300 machines (#135917) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/135917 Approved by: https://github.com/malfet --- test/inductor/test_loop_ordering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 35f0cc55c8d6d1..f0d931ed41994b 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -17,7 +17,7 @@ from torch._inductor.test_operators import realize from torch._inductor.utils import sympy_index_symbol from torch._inductor.virtualized import ops, V -from torch.testing._internal.common_cuda import SM90OrLater +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 from torch.testing._internal.inductor_utils import HAS_CUDA from torch.utils._pytree import tree_map from torch.utils._sympy.functions import ModularIndexing @@ -371,7 +371,7 @@ def f(x): self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) - @unittest.skipIf(not SM90OrLater, "FP8 requires H100+") + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+") def test_fp8_cast_and_t(self): """ This test repros the not able to fuses issue in From 34920e4dc66be333b8cc76051bc8a74b1b55a7c7 Mon Sep 17 00:00:00 2001 From: wz337 Date: Thu, 12 Sep 2024 14:53:32 -0700 Subject: [PATCH 0803/1018] [DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725) Fix https://github.com/pytorch/pytorch/issues/134095 This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135725 Approved by: https://github.com/fegin --- .../test_2d_composability.py | 58 +++++++++++++++++++ torch/distributed/_state_dict_utils.py | 11 +++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 754705b567a324..5080c25e21a155 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -17,7 +17,9 @@ from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, + set_model_state_dict, set_optimizer_state_dict, + StateDictOptions, ) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -335,6 +337,57 @@ def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): self.assertEqual(loss_no_cp2, loss_cp2) +class TestFullyShard2DStateDict(DTensorTestBase): + @property + def backend(self): + # need to specify gloo backend for testing cpu offload + return "cpu:gloo,cuda:nccl" + + @with_comms + @skip_if_lt_x_gpu(4) + def test_fully_shard_tp_2d_set_full_state_dict(self): + dummy_model = SimpleModel().cuda() + mesh_2d = init_device_mesh( + "cuda", + (2, self.world_size // 2), + mesh_dim_names=("dp", "tp"), + ) + tp_mesh = mesh_2d["tp"] + dp_mesh = mesh_2d["dp"] + parallelize_plan = { + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + "net3": ColwiseParallel(), + } + model = parallelize_module(dummy_model, tp_mesh, parallelize_plan) + fully_shard(model, mesh=dp_mesh) + optim = torch.optim.Adam(model.parameters(), lr=0.01) + model(model.get_input()).sum().backward() + optim.step() + # ref_msd, ref_osd are both the default sharded state dict + ref_msd = copy.deepcopy(get_model_state_dict(model)) + ref_osd = copy.deepcopy(get_optimizer_state_dict(model, optimizers=optim)) + + options = StateDictOptions( + full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True + ) + full_msd = get_model_state_dict(model, options=options) + full_osd = get_optimizer_state_dict(model, optimizers=optim, options=options) + # load full_msd and full_osd into model and optim. + # this loads the slice of full tensor into each rank's local DTensor. + set_model_state_dict(model, full_msd, options=options) + set_optimizer_state_dict( + model, optimizers=optim, optim_state_dict=full_osd, options=options + ) + + # check after setting full state dict, the model and optim default sharded state dict + # are the same as the initial default sharded state dict. + new_msd = get_model_state_dict(model) + new_osd = get_optimizer_state_dict(model, optimizers=optim) + self.assertEqual(ref_msd, new_msd) + self.assertEqual(ref_osd, new_osd) + + class Test2dFSDP1ParallelIntegration(DTensorTestBase): def init_model(self, device_type, model_parallel_size=2): torch.manual_seed(0) @@ -544,6 +597,11 @@ def test_2d_e2e_training_not_use_orig_params(self): # TODO: update all state dict unit tests to use distributed.checkpoint.state_dict, # and consolidate all the state_dict test in test.distributed.checkpoint. class TestNew2dParallelStateDict(DTensorTestBase): + @property + def backend(self): + # need to specify gloo backend for testing cpu offload + return "cpu:gloo,cuda:nccl" + @with_comms @skip_if_lt_x_gpu(4) def test_fsdp_2d_extension(self): diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index f22a13cb6ec0f4..9520b195f9ed5c 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -28,6 +28,7 @@ from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed.tensor import distribute_tensor, DTensor, Replicate + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset def _identity_func( @@ -551,8 +552,14 @@ def _distribute_tensors( local_state = _local_state[0] full_tensor = _local_state[1] - local_state_dict[key] = distribute_tensor( - full_tensor, local_state.device_mesh, local_state.placements + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))] + local_tensor = full_tensor[slices] + local_state_dict[key] = DTensor.from_local( + local_tensor, local_state.device_mesh, local_state.placements ) From 07673d75d740ed5a8ad438b1c3e78425bfc282c6 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 13 Sep 2024 03:35:16 +0000 Subject: [PATCH 0804/1018] real tensor prop for composite ops (#135717) Fixes #135632 Adds real tensor propagation for decompositions, checking any symbols on their outputs Pull Request resolved: https://github.com/pytorch/pytorch/pull/135717 Approved by: https://github.com/ezyang --- test/export/test_export.py | 19 +++++++++++++ torch/_subclasses/fake_tensor.py | 34 +++++++++++++++++++++--- torch/fx/experimental/symbolic_shapes.py | 1 + 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index db67266632ce58..cfd9ef0d46f18e 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -890,6 +890,25 @@ def forward(self, x): torch.allclose(ep.module()(torch.zeros(2, 3)), torch.ones(2, 3) * 21) ) + @testing.expectedFailureTrainingIRToRunDecompNonStrict # TODO(pianpwk): user_output signature + def test_real_tensor_for_max_op(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + x = x[x > 0] + y = y[y > 0] + return max(x.shape[0], y.shape[0]) + + model = Foo() + inputs = (torch.randn(64), torch.randn(64)) + with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True): + ep = export(model, inputs) + + self.assertEqual(ep.module()(*inputs), model(*inputs)) + x = torch.zeros(64) + y = torch.ones(64) + self.assertEqual(ep.module()(x, x), model(x, x)) + self.assertEqual(ep.module()(x, y), model(x, y)) + def test_export_script_module(self): class Foo(torch.nn.Module): def forward(self, rv: torch.Tensor, t: torch.Tensor): diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 4867b43e84c2ca..17b31a8e19ba3f 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1877,6 +1877,7 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]: for a in flat_args ) ): + log.debug("propagate_real_tensors %s", func) real_flat_args = [maybe_to_real_tensor(a) for a in flat_args] real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec) real_out = func(*real_args, **real_kwargs) @@ -1888,7 +1889,7 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]: # However, if there's a bug in the condition above, this condition # will also trigger. log.debug( - "propagate_real_tensors skipped %s(%s, %s) %s", + "SKIPPED propagate_real_tensors %s(%s, %s) %s", func, flat_arg_fake_tensors, flat_args, @@ -1898,17 +1899,40 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]: def maybe_propagate_real_tensors(fake_out: T) -> T: import sympy + log.debug("maybe_propagate_real_tensors %s", func) + def go(t: object, real_t: Tensor) -> None: if isinstance(t, FakeTensor): # NB: unconditionally overwrite + log.debug( + "maybe_propagate_real_tensors %s -> %s", id(t), id(real_t) + ) t.real_tensor = real_t + for s, real_s in zip(t.size(), real_t.size()): + go(s, real_s) # type: ignore[arg-type] + for s, real_s in zip(t.stride(), real_t.stride()): + go(s, real_s) # type: ignore[arg-type] + go(t.storage_offset(), real_t.storage_offset()) # type: ignore[arg-type] elif isinstance(t, py_sym_types) and free_unbacked_symbols(t): if isinstance(t.node.expr, sympy.Symbol): assert self.shape_env is not None self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t) if real_out is not nil: - tree_map_(go, fake_out, real_out) + if ( + not isinstance(fake_out, Tensor) + and not isinstance(real_out, Tensor) + and type(fake_out) != type(real_out) + ): + # This can happen when decompositions have different return types, + # e.g. namedtuple vs. tuple vs. list. + tree_map_( + go, + tuple(pytree.tree_flatten(fake_out)), + tuple(pytree.tree_flatten(real_out)), + ) + else: + tree_map_(go, fake_out, real_out) # If a data-dependent op is used in a decomposition, we # may need to get the unbacked settings "early" @@ -1940,13 +1964,15 @@ def go(t: object, real_t: Tensor) -> None: ) ): with self: - return decomposition_table[func](*args, **kwargs) + return maybe_propagate_real_tensors( + decomposition_table[func](*args, **kwargs) + ) with self: # Decomposes CompositeImplicitAutograd ops r = func.decompose(*args, **kwargs) if r is not NotImplemented: - return r + return maybe_propagate_real_tensors(r) # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index b85dd0f6a5e6b6..68701a2080e5df 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2726,6 +2726,7 @@ def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr): def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None: """Used only when propagate_real_tensors; registers a value for an unbacked symbol, which can be used last resort to resolve hints.""" + log.info("set_unbacked_var_to_val %s = %s", k, v) self.unbacked_var_to_val[k] = sympy.sympify(v) # Unlike set_replacement, this records a shapeenv event From 62b815c1ce225ac52ef9ec3e183070738e6f9d1c Mon Sep 17 00:00:00 2001 From: wz337 Date: Thu, 12 Sep 2024 14:53:33 -0700 Subject: [PATCH 0805/1018] [DCP][DSD] Add a test case to demonstrate the workaround to load full state dict into a 2D model (#135763) Fix https://github.com/pytorch/pytorch/issues/134095 This is a workaround for loading full state dict into a FSDP1+TP 2D model. Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do: - load the full state dict into a 1D FSDP model - dcp.save the full/shard state dict into storage - initialize a 2D FSDP1+TP model - get the default sharded state dict for the 2D model (full_state_dict=False) - dcp.load the state dict from storage - load the state dict into the 2D model Pull Request resolved: https://github.com/pytorch/pytorch/pull/135763 Approved by: https://github.com/fegin ghstack dependencies: #135725 --- .../test_2d_composability.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 5080c25e21a155..c8d2ac4d4dc257 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -828,6 +828,81 @@ def test_2d_optim_state_dict(self, is_even_sharded_model): else: self.assertEqual(new_state, state) + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(4) + def test_fsdp1_tp_2d_set_full_state_dict(self): + """ + This is a workaround for loading full state dict into a FSDP1+TP 2D model. + Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict + and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do: + 1) load the full state dict into a 1D FSDP model + 2) dcp.save the full/shard state dict into storage + 3) initialize a 2D FSDP1+TP model + 4) get the default sharded state dict for the 2D model (full_state_dict=False) + 5) dcp.load the state dict from storage + 6) load the state dict into the 2D model + """ + dummy_model = SimpleModel().cuda() + mesh_1d = init_device_mesh("cuda", (self.world_size,)) + model = FSDP(dummy_model, device_mesh=mesh_1d) + optim = torch.optim.Adam(model.parameters(), lr=0.01) + model(model.get_input()).sum().backward() + optim.step() + ref_full_msd = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + ref_full_osd = get_optimizer_state_dict( + model, + optimizers=optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + state_dict = {"model": ref_full_msd, "optim": ref_full_osd} + # save the full state dict into storage first + dcp.save(state_dict, checkpoint_id=self.temp_dir) + + # initialize 2d model + dummy_model = SimpleModel().cuda() + mesh_2d = init_device_mesh( + "cuda", + (2, self.world_size // 2), + mesh_dim_names=("dp", "tp"), + ) + tp_mesh = mesh_2d["tp"] + dp_mesh = mesh_2d["dp"] + parallelize_plan = { + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + "net3": ColwiseParallel(), + } + model_2d = parallelize_module(dummy_model, tp_mesh, parallelize_plan) + model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) + optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) + # get the default sharded state dict for model_2d + # note this is because we can not set full_state_dict back to 2D directly + msd = get_model_state_dict(model_2d) + osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d) + state_dict = {"model": msd, "optim": osd} + dcp.load(state_dict=state_dict, checkpoint_id=self.temp_dir) + + set_model_state_dict(model_2d, state_dict["model"]) + set_optimizer_state_dict( + model_2d, optimizers=optim_2d, optim_state_dict=state_dict["optim"] + ) + + # check after setting sharded state dict, the model and optim full state dict + # are the same as the initial full state dict. + new_full_msd = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + new_full_osd = get_optimizer_state_dict( + model, + optimizers=optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + self.assertEqual(ref_full_msd, new_full_msd) + self.assertEqual(ref_full_osd, new_full_osd) + instantiate_parametrized_tests(TestNew2dParallelStateDict) From b4a37b457ecc0854fb70c8ff88d609e3d368f40e Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 13 Sep 2024 04:10:39 +0000 Subject: [PATCH 0806/1018] [inductor] More fixes on the keys of `constants` and `signature` dictionaries (#135406) Previous PR forgets to change two other places that also create `constants` and `signature`. https://github.com/pytorch/pytorch/pull/135170 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135406 Approved by: https://github.com/jansel --- docs/source/torch.compiler_get_started.rst | 2 +- test/inductor/test_cuda_repro.py | 6 +++++- test/inductor/test_metrics.py | 2 +- test/inductor/test_triton_heuristics.py | 2 +- torch/_inductor/codegen/triton.py | 6 +++--- torch/_inductor/codegen/triton_combo_kernel.py | 10 ++++++---- torch/_inductor/codegen/triton_utils.py | 5 +++-- torch/_inductor/codegen/wrapper.py | 11 ++++++----- torch/_inductor/runtime/triton_heuristics.py | 2 +- torch/_inductor/select_algorithm.py | 6 ++++-- 10 files changed, 31 insertions(+), 21 deletions(-) diff --git a/docs/source/torch.compiler_get_started.rst b/docs/source/torch.compiler_get_started.rst index 8e18ad411544df..2b5bec254958f5 100644 --- a/docs/source/torch.compiler_get_started.rst +++ b/docs/source/torch.compiler_get_started.rst @@ -57,7 +57,7 @@ the following: .. code-block:: python - @pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) + @pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 10000 diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 852b56e6326a75..5172801572b3fc 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -433,7 +433,11 @@ def decorator(fn): triton.Config({"XBLOCK": 2}), ], meta={ - "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, + "signature": { + "in_out_ptr0": "*fp32", + "in_ptr0": "*fp32", + "xnumel": "i32", + }, "device": DeviceProperties.create(torch.device("cuda")), "configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())], "constants": {}, diff --git a/test/inductor/test_metrics.py b/test/inductor/test_metrics.py index f8c4815cc8946d..bec9943eafd7eb 100644 --- a/test/inductor/test_metrics.py +++ b/test/inductor/test_metrics.py @@ -14,7 +14,7 @@ reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={ - 'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, + 'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': 0, 'device_type': 'GPU_TYPE', 'constants': {}, diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index 24f322dfebb845..65daa638520021 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -102,7 +102,7 @@ def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): tl.store(out_ptr0 + (x0), tmp1, xmask) triton_meta = { - "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, + "signature": {"in_ptr0": "*fp32", "out_ptr0": "*fp32", "xnumel": "i32"}, "device": DeviceProperties.create(torch.device("cuda")), "constants": {}, "configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 25b267dbc3ee56..c9b4c38c160bad 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2648,7 +2648,7 @@ def codegen_kernel(self, name=None): mutated_args = sorted(mutated_args) triton_meta_signature = signature_to_meta( - signature, size_dtype=self.index_dtype + signature, size_dtype=self.index_dtype, argdefs=argdefs ) triton_meta = { "signature": triton_meta_signature, @@ -2676,7 +2676,7 @@ def codegen_kernel(self, name=None): for tree in self.active_range_trees(): sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) signature.append(sizearg) - triton_meta_signature[len(argdefs)] = signature_of( + triton_meta_signature[sizearg.name] = signature_of( sizearg, size_dtype=self.index_dtype ) argdefs.append(f"{tree.prefix}numel") @@ -2694,7 +2694,7 @@ def codegen_kernel(self, name=None): # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] - triton_meta["constants"][arg_num] = 1 # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] self.triton_meta = triton_meta diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 946d318fac11c5..b58d61eb326588 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -660,18 +660,19 @@ def jit_line( heuristics: str, size_hints: List[int], selected_kernel: TritonKernel, + signature: List[Any], + argdefs: List[str], pointwise_with_reduce: bool = False, - signature: Optional[List[Any]] = None, ) -> str: can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) size_dtype = "tl.int32" if can_use_32bit else "tl.int64" - if signature is None: - _, _, signature, _ = self.args.python_argdefs() for i, sub in enumerate(self.sub_kernels): self.min_x_blocks_sub_kernel(sub, i) self.select_dispatch_strategy() triton_meta = { - "signature": signature_to_meta(signature, size_dtype=size_dtype), + "signature": signature_to_meta( + signature, size_dtype=size_dtype, argdefs=argdefs + ), "device": DeviceProperties.create( V.graph.scheduler.get_current_device_or_throw() ), @@ -850,6 +851,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: selected_kernel, pointwise_with_reduce=pointwise_with_reduction, signature=signature, + argdefs=argdefs, ) ) code.writeline( diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index c8dbd516f13595..77ef3abe025d8a 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -68,12 +68,13 @@ def signature_to_meta( signature: List[KernelArgType], *, size_dtype: str, + argdefs: List[str], indices: Optional[List[int]] = None, -) -> Dict[int, str]: +) -> Dict[str, str]: if indices is None: indices = list(range(len(signature))) return { - i: signature_of(arg, size_dtype=size_dtype) + argdefs[i]: signature_of(arg, size_dtype=size_dtype) for i, arg in zip(indices, signature) } diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 5422198194bfae..c3257f4d638c7a 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1275,15 +1275,15 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): from .common import KernelArgType, SizeArg, TensorArg signature: List[KernelArgType] = [] - constants: Dict[int, Any] = {} + constants: Dict[str, Any] = {} non_constant_indices = [] - equal_to_1_arg_idx: List[int] = [] + equal_to_1_args: List[str] = [] for idx, key in enumerate(kernel.arg_names): if key not in kwargs: continue arg = kwargs[key] if idx in kernel.constexprs: - constants[idx] = arg + constants[key] = arg else: non_constant_indices.append(idx) if isinstance(arg, ir.Buffer): @@ -1313,13 +1313,14 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): ) and V.graph.sizevars.statically_known_equals( arg, 1 # type: ignore[arg-type] ): - equal_to_1_arg_idx.append(idx) + equal_to_1_args.append(key) index_dtype = "tl.int32" triton_meta = { "signature": signature_to_meta( signature, size_dtype=index_dtype, indices=non_constant_indices, + argdefs=kernel.arg_names, ), "device": DeviceProperties.create( V.graph.scheduler.get_current_device_or_throw() @@ -1333,7 +1334,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 "constants": { **constants, - **dict.fromkeys(equal_to_1_arg_idx, 1), + **dict.fromkeys(equal_to_1_args, 1), }, "configs": [ config_of( diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ce5ac0b920e556..008572d65bba32 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -359,7 +359,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): if k == "waves_per_eu": compile_meta["waves_per_eu"] = v continue - compile_meta["constants"][self.fn.arg_names.index(k)] = v + compile_meta["constants"][k] = v compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages compile_meta["debug"] = self.inductor_meta.get( diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 7b90690cf50b19..9514f04476611f 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -214,13 +214,15 @@ def jit_lines(self): argdefs, _, signature, _ = self.args.python_argdefs() triton_meta = { - "signature": signature_to_meta(signature, size_dtype=self.index_dtype), + "signature": signature_to_meta( + signature, size_dtype=self.index_dtype, argdefs=argdefs + ), "device": DeviceProperties.create(self.output_node.get_device()), "constants": {}, } triton_meta["configs"] = [config_of(signature)] for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] - triton_meta["constants"][arg_num] = 1 # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0) if matrix_instr_nonkdim != 0: triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim From e86ba1340f39d2a9ef266be1f75f6ddbdd56e7f6 Mon Sep 17 00:00:00 2001 From: xingyuan li <108672484+hoshibara@users.noreply.github.com> Date: Fri, 13 Sep 2024 05:16:26 +0000 Subject: [PATCH 0807/1018] [Inductor UT] Generalize inductor UT for intel GPU (Part 2) (#134556) [Inductor UT] Reuse Inductor test case for Intel GPU. Reuse `test/inductor/test_torchinductor_opinfo.py` Reuse `test/inductor/test_minifier_isolate.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134556 Approved by: https://github.com/etaf, https://github.com/eellison --- test/inductor/test_minifier_isolate.py | 14 +- test/inductor/test_torchinductor_opinfo.py | 727 ++++++++++++++++----- 2 files changed, 569 insertions(+), 172 deletions(-) diff --git a/test/inductor/test_minifier_isolate.py b/test/inductor/test_minifier_isolate.py index 5681498ef9cc6f..61cf6e39611336 100644 --- a/test/inductor/test_minifier_isolate.py +++ b/test/inductor/test_minifier_isolate.py @@ -8,12 +8,11 @@ IS_MACOS, skipIfRocm, skipIfWindows, + skipIfXpu, TEST_WITH_ASAN, ) -from torch.testing._internal.inductor_utils import HAS_CUDA - - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.triton_utils import requires_gpu # These minifier tests are slow, because they must be run in separate @@ -41,10 +40,11 @@ def test_after_aot_cpu_runtime_error(self): self._test_after_aot_runtime_error("cpu", "") @skipIfRocm - @requires_cuda + @skipIfXpu + @requires_gpu @inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error") - def test_after_aot_cuda_runtime_error(self): - self._test_after_aot_runtime_error("cuda", "device-side assert") + def test_after_aot_gpu_runtime_error(self): + self._test_after_aot_runtime_error(GPU_TYPE, "device-side assert") if __name__ == "__main__": diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 2281acee96794a..24c4392b4e533e 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -26,6 +26,7 @@ ops, skipCPUIf, skipCUDAIf, + skipXPUIf, ) from torch.testing._internal.common_methods_invocations import op_db, skipOps from torch.testing._internal.common_utils import ( @@ -40,7 +41,7 @@ TEST_WITH_ASAN, TEST_WITH_ROCM, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map @@ -207,6 +208,8 @@ def format_op(op): inductor_skips["cuda"]["logcumsumexp"] = {f32} inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64} +inductor_skips["xpu"] = {} + inductor_expected_failures_single_sample = defaultdict(dict) inductor_expected_failures_single_sample["cpu"] = { @@ -256,6 +259,109 @@ def format_op(op): }, # NYI: could not find kernel for aten.view.default at dispatch key DispatchKey.SparseCUDA } +inductor_expected_failures_single_sample["xpu"] = { + "_upsample_bilinear2d_aa": {f16, f32, f64}, + "cholesky": {f32, f64}, + "multinomial": {f16, f32, f64}, + ("normal", "in_place"): {f16, f32, f64}, + ("normal", "number_mean"): {f16, f32, f64}, + "normal": {f16, f32, f64}, + "sparse.sampled_addmm": {f32, f64}, + "tan": {f16}, + "torch.ops.aten._flash_attention_forward": {f16}, + "torch.ops.aten._efficient_attention_forward": {f16, f32}, + "to_sparse": {f16, f32, f64, b8, i32, i64}, + "linalg.eig": {f32, f64}, + "linalg.eigvals": {f32, f64}, + # Double and complex datatype matmul is not supported in oneDNN + "__rmatmul__": {f64}, + ("addmm", "decomposed"): {f64}, + "addr": {f64}, + "baddbmm": {f64}, + "bmm": {f64}, + "byte": {f16, f32}, + "cdist": {f64}, + "corrcoef": {f64}, + "cov": {f64}, + "einsum": {f64}, + "inner": {f64}, + "linalg.cholesky_ex": {f64}, + "linalg.cholesky": {f64}, + ("linalg.det", "singular"): {f64}, + "linalg.ldl_factor_ex": {f64}, + "linalg.ldl_factor": {f64}, + "linalg.ldl_solve": {f64}, + "linalg.matrix_power": {f64}, + "linalg.multi_dot": {f64}, + "matmul": {f64}, + "mm": {f64}, + "mv": {f64}, + "nn.functional.bilinear": {f64}, + "nn.functional.linear": {f64}, + "pca_lowrank": {f64}, + "svd_lowrank": {f64}, + "tensordot": {f64}, + "triangular_solve": {f64}, + "svd": {f64}, + "qr": {f64}, + "pinverse": {f64}, + "ormqr": {f64}, + ("norm", "nuc"): {f64}, + "lu": {f64}, + "lu_solve": {f64}, + "logdet": {f64}, + "linalg.tensorsolve": {f64}, + "linalg.tensorinv": {f64}, + "linalg.svdvals": {f64}, + "linalg.svd": {f64}, + "linalg.solve": {f64}, + "linalg.solve_triangular": {f64}, + "linalg.solve_ex": {f64}, + "linalg.slogdet": {f64}, + "linalg.qr": {f64}, + "linalg.pinv": {f64}, + ("linalg.pinv", "hermitian"): {f64}, + "linalg.norm": {f64}, + ("linalg.norm", "subgradients_at_zero"): {f64}, + "linalg.matrix_rank": {f64}, + ("linalg.matrix_rank", "hermitian"): {f64}, + "linalg.matrix_norm": {f64}, + "linalg.lu": {f64}, + "linalg.lu_solve": {f64}, + "linalg.lu_factor": {f64}, + "linalg.lu_factor_ex": {f64}, + "linalg.lstsq": {f64}, + ("linalg.lstsq", "grad_oriented"): {f64}, + "linalg.inv": {f64}, + "linalg.inv_ex": {f64}, + "linalg.householder_product": {f64}, + "linalg.eigvalsh": {f64}, + "linalg.eigh": {f64}, + "linalg.det": {f64}, + "linalg.cond": {f64}, + "geqrf": {f64}, + "cholesky_solve": {f64}, + "cholesky_inverse": {f64}, + # could not create a primitive + "addbmm": {f16, f32, f64}, + "addmm": {f16, f32, f64}, + "addmv": {f32, f64}, + # could not create a primitive descriptor for + # a deconvolution forward propagation primitive + "nn.functional.conv_transpose2d": {f32, f64}, + "nn.functional.conv_transpose3d": {f32, f64}, + # rrelu not supported on XPU now + "nn.functional.rrelu": {f16, f32, f64}, + "histc": {i32, i64}, + # not implemented for 'Half' + "nn.functional.multilabel_margin_loss": {f16}, + "nn.functional.multi_margin_loss": {f16}, + "nn.functional.avg_pool3d": {f16}, + "nn.functional.adaptive_max_pool3d": {f16}, + # not implemented for 'Bool' + "nn.functional.unfold": {b8}, +} + # intentionally not handled intentionally_not_handled = { @@ -278,11 +384,13 @@ def format_op(op): } inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled) +inductor_expected_failures_single_sample["xpu"].update(intentionally_not_handled) inductor_gradient_expected_failures_single_sample = defaultdict(dict) inductor_gradient_expected_failures_single_sample["cuda"] = {} +inductor_gradient_expected_failures_single_sample["xpu"] = {} if not TEST_MKL: inductor_expected_failures_single_sample["cpu"].update({}) @@ -290,6 +398,7 @@ def format_op(op): inductor_should_fail_with_exception = defaultdict(dict) inductor_should_fail_with_exception["cpu"] = {} inductor_should_fail_with_exception["cuda"] = {} +inductor_should_fail_with_exception["xpu"] = {} def get_skips_and_xfails(from_dict, xfails=True): @@ -324,9 +433,10 @@ def wrapper_noop_set_seed(op, *args, **kwargs): wrapper_noop_set_seed ) +# key can be either op_name, or (op_name, dtype) +inductor_override_kwargs = defaultdict(dict) -# key can be either op_name, or (op_name, deivce_type), or (op_name, device_type, dtype) -inductor_override_kwargs = { +inductor_override_kwargs["cpu"] = { # the return value of empty is undefined "empty": {"assert_equal": False}, "empty_permuted": {"assert_equal": False}, @@ -335,92 +445,222 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "empty_strided": {"assert_equal": False}, "new_empty_strided": {"assert_equal": False}, "randn": {"assert_equal": False}, - ("cross", "cuda", f16): {"reference_in_float": True}, - ("linalg.cross", "cuda", f16): {"reference_in_float": True}, - ("addr", "cuda", f16): {"reference_in_float": True}, - ("baddbmm", "cuda", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy - ("angle", "cuda", f64): {"reference_in_float": True}, - ("asin", "cuda", f16): {"reference_in_float": True}, - ("atanh", "cuda", f16): {"reference_in_float": True}, - ("cauchy", "cuda"): {"reference_in_float": True}, - ("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002}, - ("cumsum", "cuda", f16): {"reference_in_float": True}, - ("cumprod", "cuda"): {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002}, - ("logcumsumexp", "cuda"): {"grad_atol": 8e-4, "grad_rtol": 0.001}, - ("exponential", "cuda"): {"reference_in_float": True}, - ("geometric", "cuda"): {"reference_in_float": True}, - ("kron", "cuda", f16): {"reference_in_float": True}, - ("log_normal", "cuda"): {"reference_in_float": True}, - ("masked.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, - ("nn.functional.batch_norm", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.batch_norm.without_cudnn", "cuda", f16): { - "reference_in_float": True - }, - ("nn.functional.cosine_similarity", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.instance_norm", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.local_response_norm", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.normalize", "cuda", f16): {"atol": 1e-3, "rtol": 0.05}, - ("nn.functional.rms_norm", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.soft_margin_loss", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, - ("nn.functional.softsign", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001}, - ("nn.functional.multilabel_soft_margin_loss", "cpu", f16): { + ("nn.functional.multilabel_soft_margin_loss", f16): { "atol": 3e-4, "rtol": 0.002, }, - ("outer", "cuda", f16): {"reference_in_float": True}, - ("round.decimals_3", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.triplet_margin_loss", "cuda", f16): {"atol": 1e-4, "rtol": 0.02}, - ("nn.functional.triplet_margin_with_distance_loss", "cuda", f16): { - "atol": 1e-4, - "rtol": 0.02, - }, - ("sinc", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, - ("torch.ops.aten._safe_softmax.default", "cuda", f16): {"atol": 5e-4, "rtol": 0.02}, - ("softmax", "cpu", f16): {"atol": 1e-4, "rtol": 0.02}, - ("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02}, - ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, - ("special.log_ndtr", "cuda", f64): {"atol": 1e-6, "rtol": 1e-5}, - ("polygamma.polygamma_n_0", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("polygamma.polygamma_n_1", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("polygamma.polygamma_n_2", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("polygamma.polygamma_n_3", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("polygamma.polygamma_n_4", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("special.polygamma.special_polygamma_n_0", "cpu", f32): { + ("softmax", f16): {"atol": 1e-4, "rtol": 0.02}, + ("polygamma.polygamma_n_0", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("polygamma.polygamma_n_1", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("polygamma.polygamma_n_2", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("polygamma.polygamma_n_3", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("polygamma.polygamma_n_4", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("special.polygamma.special_polygamma_n_0", f32): { "atol": 1e-3, "rtol": 1e-4, }, - ("std_mean.unbiased", "cuda", f16): {"reference_in_float": True}, - ("uniform", "cuda"): {"reference_in_float": True}, - ("_unsafe_masked_index_put_accumulate", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, - ("_unsafe_masked_index_put_accumulate", "cpu", f16): {"atol": 1e-4, "rtol": 0.01}, + ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01}, # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors - ("nn.functional.interpolate.bilinear", "cpu", u8): {"atol": 1, "rtol": 0}, - ("nn.functional.upsample_bilinear", "cpu", u8): {"atol": 1, "rtol": 0}, - ("nn.functional.interpolate.bicubic", "cpu", u8): {"atol": 1, "rtol": 0}, + ("nn.functional.interpolate.bilinear", u8): {"atol": 1, "rtol": 0}, + ("nn.functional.upsample_bilinear", u8): {"atol": 1, "rtol": 0}, + ("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0}, + # High atol due to precision loss + ("nn.functional.interpolate.bicubic", f32): {"atol": 5e-3, "rtol": 0}, +} + +inductor_override_kwargs["cuda"] = { + # the return value of empty is undefined + "empty": {"assert_equal": False}, + "empty_permuted": {"assert_equal": False}, + "empty_like": {"assert_equal": False}, + "new_empty": {"assert_equal": False}, + "empty_strided": {"assert_equal": False}, + "new_empty_strided": {"assert_equal": False}, + "randn": {"assert_equal": False}, + ("cross", f16): {"reference_in_float": True}, + ("linalg.cross", f16): {"reference_in_float": True}, + ("addr", f16): {"reference_in_float": True}, + ("baddbmm", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy + ("angle", f64): {"reference_in_float": True}, + ("asin", f16): {"reference_in_float": True}, + ("atanh", f16): {"reference_in_float": True}, + "cauchy": {"reference_in_float": True}, + ("cummax", f16): {"atol": 5e-4, "rtol": 0.002}, + ("cumsum", f16): {"reference_in_float": True}, + "cumprod": {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002}, + "logcumsumexp": {"grad_atol": 8e-4, "grad_rtol": 0.001}, + "exponential": {"reference_in_float": True}, + "geometric": {"reference_in_float": True}, + ("kron", f16): {"reference_in_float": True}, + "log_normal": {"reference_in_float": True}, + ("masked.softmin", f16): {"atol": 1e-4, "rtol": 0.01}, + ("nn.functional.batch_norm", f16): {"reference_in_float": True}, + ("nn.functional.batch_norm.without_cudnn", f16): {"reference_in_float": True}, + ("nn.functional.cosine_similarity", f16): {"reference_in_float": True}, + ("nn.functional.instance_norm", f16): {"reference_in_float": True}, + ("nn.functional.local_response_norm", f16): {"reference_in_float": True}, + ("nn.functional.normalize", f16): {"atol": 1e-3, "rtol": 0.05}, + ("nn.functional.rms_norm", f16): {"reference_in_float": True}, + ("nn.functional.soft_margin_loss", f16): {"reference_in_float": True}, + ("nn.functional.softmin", f16): {"atol": 1e-4, "rtol": 0.01}, + ("nn.functional.softsign", f16): {"reference_in_float": True}, + ("nn.functional.tanhshrink", f16): {"atol": 3e-4, "rtol": 0.001}, + ("outer", f16): {"reference_in_float": True}, + ("round.decimals_3", f16): {"reference_in_float": True}, + ("nn.functional.triplet_margin_loss", f16): {"atol": 1e-4, "rtol": 0.02}, + ("nn.functional.triplet_margin_with_distance_loss", f16): { + "atol": 1e-4, + "rtol": 0.02, + }, + ("sinc", f16): {"atol": 0.008, "rtol": 0.002}, + ("torch.ops.aten._safe_softmax.default", f16): {"atol": 5e-4, "rtol": 0.02}, + ("softmax", f16): {"atol": 1e-4, "rtol": 0.02}, + ("_softmax_backward_data", f16): {"atol": 0.008, "rtol": 0.002}, + ("special.log_ndtr", f64): {"atol": 1e-6, "rtol": 1e-5}, + ("std_mean.unbiased", f16): {"reference_in_float": True}, + "uniform": {"reference_in_float": True}, + ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01}, # High atol due to precision loss - ("nn.functional.interpolate.bilinear", "cuda", f64): {"atol": 5e-4, "rtol": 0}, - ("nn.functional.upsample_bilinear", "cuda", f64): {"atol": 5e-4, "rtol": 0}, - ("nn.functional.interpolate.bicubic", "cpu", f32): {"atol": 5e-3, "rtol": 0}, - ("nn.functional.interpolate.bicubic", "cuda", f64): {"atol": 1e-3, "rtol": 0}, + ("nn.functional.interpolate.bilinear", f64): {"atol": 5e-4, "rtol": 0}, + ("nn.functional.upsample_bilinear", f64): {"atol": 5e-4, "rtol": 0}, + ("nn.functional.interpolate.bicubic", f64): {"atol": 1e-3, "rtol": 0}, # Unreasonably high atol requirement: - ("index_reduce.mean", "cuda", f16): {"check_gradient": False}, - ("index_reduce.mean", "cuda", f32): {"check_gradient": False}, - ("index_reduce.mean", "cuda", f64): {"check_gradient": False}, + ("index_reduce.mean", f16): {"check_gradient": False}, + ("index_reduce.mean", f32): {"check_gradient": False}, + ("index_reduce.mean", f64): {"check_gradient": False}, # Gradient contains non-finite entries: - ("index_reduce.amin", "cuda", f64): {"check_gradient": False}, - ("index_reduce.amin", "cuda", f32): {"check_gradient": False}, - ("index_reduce.amin", "cuda", f16): {"check_gradient": False}, - ("index_reduce.amax", "cuda", f64): {"check_gradient": False}, - ("index_reduce.amax", "cuda", f32): {"check_gradient": False}, - ("index_reduce.amax", "cuda", f16): {"check_gradient": False}, - ("tanh", "cuda", f16): {"atol": 1e-4, "rtol": 1e-2}, + ("index_reduce.amin", f64): {"check_gradient": False}, + ("index_reduce.amin", f32): {"check_gradient": False}, + ("index_reduce.amin", f16): {"check_gradient": False}, + ("index_reduce.amax", f64): {"check_gradient": False}, + ("index_reduce.amax", f32): {"check_gradient": False}, + ("index_reduce.amax", f16): {"check_gradient": False}, + ("tanh", f16): {"atol": 1e-4, "rtol": 1e-2}, } +inductor_override_kwargs["xpu"] = { + # the return value of empty is undefined + "empty": {"assert_equal": False}, + "empty_permuted": {"assert_equal": False}, + "empty_like": {"assert_equal": False}, + "new_empty": {"assert_equal": False}, + "empty_strided": {"assert_equal": False}, + "new_empty_strided": {"assert_equal": False}, + "randn": {"assert_equal": False}, + # XPU + ("cross", f16): {"reference_in_float": True}, + ("addr", f16): {"reference_in_float": True}, + ("baddbmm", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy + ("angle", f64): {"reference_in_float": True}, + ("asin", f16): {"reference_in_float": True}, + ("atanh", f16): {"reference_in_float": True}, + "cauchy": {"reference_in_float": True}, + ("cummax", f16): {"atol": 5e-4, "rtol": 0.002}, + ("cumsum", f16): {"reference_in_float": True}, + "cumprod": {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002}, + ("dot", f16): {"atol": 1e-5, "rtol": 0.002}, + "logcumsumexp": {"grad_atol": 8e-4, "grad_rtol": 0.001}, + "exponential": {"reference_in_float": True}, + "geometric": {"reference_in_float": True}, + ("kron", f16): {"reference_in_float": True}, + ("linalg.cross", f16): {"reference_in_float": True}, + ("linalg.vecdot", f16): {"atol": 1e-5, "rtol": 2e-2}, + "log_normal": {"reference_in_float": True}, + ("logsumexp", f16): {"atol": 1e-5, "rtol": 1e-2}, + ("masked.cumprod", f16): {"atol": 1e-5, "rtol": 5e-2}, + ("masked.cumsum", f16): {"atol": 1e-5, "rtol": 5e-3}, + ("masked.softmin", f16): {"atol": 1e-4, "rtol": 0.01}, + ("masked.softmax", f16): {"atol": 2e-4, "rtol": 0.01}, + ("masked.var", f16): {"atol": 2e-5, "rtol": 5e-3}, + ("native_batch_norm", f64): {"atol": 1e-7, "rtol": 1e-5}, + ("_native_batch_norm_legit", f64): {"atol": 1e-7, "rtol": 5e-6}, + ("_batch_norm_with_update", f64): {"atol": 1e-7, "rtol": 1e-6}, + ("native_layer_norm", f16): {"atol": 5e-3, "rtol": 5e-3}, + ("native_layer_norm", f32): {"atol": 5e-3, "rtol": 5e-3}, + ("nn.functional.batch_norm", f16): {"reference_in_float": True}, + ("nn.functional.batch_norm", f64): {"atol": 1e-6, "rtol": 1e-6}, + ("nn.functional.batch_norm.without_cudnn", f16): {"reference_in_float": True}, + ("nn.functional.conv1d", f16): {"atol": 1e-5, "rtol": 6e-3}, + ("nn.functional.conv3d", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("nn.functional.conv_transpose2d", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("nn.functional.conv_transpose3d", f16): {"atol": 1e-5, "rtol": 5e-3}, + ("nn.functional.cosine_embedding_loss", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("nn.functional.cosine_similarity", f16): { + "reference_in_float": True, + "atol": 1e-5, + "rtol": 5e-3, + }, + ("nn.functional.instance_norm", f16): {"reference_in_float": True}, + ("nn.functional.instance_norm", f64): {"atol": 1e-6, "rtol": 1e-6}, + ("nn.functional.layer_norm", f16): {"atol": 5e-3, "rtol": 2e-3}, + ("nn.functional.layer_norm", f32): {"atol": 5e-5, "rtol": 2e-3}, + ("nn.functional.local_response_norm", f16): {"reference_in_float": True}, + ("nn.functional.multilabel_soft_margin_loss", f16): { + "atol": 3e-4, + "rtol": 2e-3, + }, + ("nn.functional.normalize", f16): {"atol": 1e-3, "rtol": 0.05}, + ("nn.functional.rms_norm", f16): {"reference_in_float": True}, + ("nn.functional.soft_margin_loss", f16): {"reference_in_float": True}, + ("nn.functional.softmin", f16): {"atol": 1e-4, "rtol": 0.01}, + ("nn.functional.softsign", f16): { + "reference_in_float": True, + "atol": 1e-5, + "rtol": 0.005, + }, + ("nn.functional.tanhshrink", f16): {"atol": 3e-4, "rtol": 0.001}, + ("outer", f16): {"reference_in_float": True}, + ("round.decimals_3", f16): {"reference_in_float": True}, + ("nn.functional.triplet_margin_loss", f16): {"atol": 1e-4, "rtol": 0.02}, + ("nn.functional.triplet_margin_with_distance_loss", f16): { + "atol": 1e-4, + "rtol": 0.02, + }, + ("remainder", f16): {"atol": 1e-4, "rtol": 0.005}, + ("nn.functional.upsample_bilinear", f16): {"atol": 1e-5, "rtol": 0.002}, + ("sinc", f16): {"atol": 0.008, "rtol": 0.002}, + ("softmax", f16): {"atol": 1e-4, "rtol": 0.02}, + ("_softmax_backward_data", f16): {"atol": 0.008, "rtol": 0.002}, + ("special.log_ndtr", f64): {"atol": 1e-6, "rtol": 1e-5}, + ("std_mean.unbiased", f16): { + "reference_in_float": True, + "atol": 5e-5, + "rtol": 5e-3, + }, + ("trapezoid", f16): {"atol": 1e-5, "rtol": 5e-3}, + ("trapz", f16): {"atol": 1e-5, "rtol": 5e-3}, + "uniform": {"reference_in_float": True}, + ("var_mean", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("var_mean.unbiased", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("vdot", f16): {"atol": 1e-5, "rtol": 2e-3}, + # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors + # High atol due to precision loss + ("nn.functional.interpolate.bilinear", f64): {"atol": 5e-4, "rtol": 0}, + ("nn.functional.upsample_bilinear", f64): {"atol": 5e-4, "rtol": 0}, + ("nn.functional.interpolate.bicubic", f64): {"atol": 1e-3, "rtol": 0}, + # Unreasonably high atol requirement: + ("index_reduce.mean", f16): {"check_gradient": False}, + ("index_reduce.mean", f32): {"check_gradient": False}, + ("index_reduce.mean", f64): {"check_gradient": False}, + # Gradient contains non-finite entries: + ("index_reduce.amin", f64): {"check_gradient": False}, + ("index_reduce.amin", f32): {"check_gradient": False}, + ("index_reduce.amin", f16): {"check_gradient": False}, + ("index_reduce.amax", f64): {"check_gradient": False}, + ("index_reduce.amax", f32): {"check_gradient": False}, + ("index_reduce.amax", f16): {"check_gradient": False}, + ("tanh", f16): {"atol": 1e-4, "rtol": 1e-2}, + ("nn.functional.embedding_bag", f16): {"check_gradient": False}, + ("nn.functional.embedding_bag", f32): {"check_gradient": False}, + ("nn.functional.embedding_bag", f64): {"check_gradient": False}, + ("_unsafe_masked_index", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3}, +} # Test with one sample only for following ops -inductor_one_sample = { +inductor_one_sample = defaultdict(dict) + +inductor_one_sample["cpu"] = { "_segment_reduce.lengths": {f16}, "_segment_reduce.offsets": {f16}, "addmv": {f16}, @@ -456,89 +696,244 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "normal": {f16, f32, f64}, "put": {f16, f32, f64}, "take": {b8, f16, f32, f64, i32, i64}, - ("__rdiv__", "cuda"): {f16}, - ("__rmod__", "cuda"): {f16, i64}, - ("__rmul__", "cuda"): {f16}, - ("__rpow__", "cuda"): {f16}, - ("_unsafe_masked_index", "cuda"): {f16}, - ("_unsafe_masked_index_put_accumulate", "cuda"): {f16}, - ("addcdiv", "cuda"): {f16}, - ("addcmul", "cuda"): {f16}, - ("atan2", "cuda"): {f16}, - ("cumsum", "cuda"): {f16}, - ("cumulative_trapezoid", "cuda"): {f16}, - ("dist", "cuda"): {f16}, - ("div.no_rounding_mode", "cuda"): {f16}, - ("fmod", "cuda"): {f16}, - ("grid_sampler_2d", "cuda"): {f16}, - ("index_fill", "cuda"): {f16, f32, f64}, - ("ldexp", "cuda"): {f16}, - ("lerp", "cuda"): {f16}, - ("linalg.householder_product", "cuda"): {f32}, - ("linalg.matrix_norm", "cuda"): {f16}, - ("linalg.vector_norm", "cuda"): {f16}, - ("logspace", "cuda"): {i32, i64}, - ("masked.cumsum", "cuda"): {f16}, - ("masked.logsumexp", "cuda"): {f16}, - ("masked.mean", "cuda"): {b8}, - ("masked.normalize", "cuda"): {f16}, - ("masked.prod", "cuda"): {f16}, - ("masked.std", "cuda"): {f16}, - ("masked.var", "cuda"): {f16}, - ("mul", "cuda"): {f16}, - ("nn.functional.alpha_dropout", "cuda"): {f16, f32, f64}, - ("nn.functional.avg_pool1d", "cuda"): {f16, f32, f64}, - ("nn.functional.avg_pool2d", "cuda"): {f16, f32, f64}, - ("nn.functional.avg_pool3d", "cuda"): {f16, f32, f64}, - ("nn.functional.binary_cross_entropy", "cuda"): {f16}, - ("nn.functional.binary_cross_entropy_with_logits", "cuda"): {f16}, - ("nn.functional.conv2d", "cuda"): {f16}, - ("nn.functional.cosine_embedding_loss", "cuda"): {f16}, - ("nn.functional.dropout2d", "cuda"): {f16, f32, f64}, - ("nn.functional.dropout3d", "cuda"): {f16, f32, f64}, - ("nn.functional.dropout", "cuda"): {f16, f32, f64}, - ("nn.functional.feature_alpha_dropout.with_train", "cuda"): {f16, f32, f64}, - ("nn.functional.fractional_max_pool2d", "cuda"): {f16, f32, f64}, - ("nn.functional.fractional_max_pool3d", "cuda"): {f16, f32, f64}, - ("nn.functional.grid_sample", "cuda"): {f16}, - ("nn.functional.group_norm", "cuda"): {f16}, - ("nn.functional.hinge_embedding_loss", "cuda"): {f16}, +} + +inductor_one_sample["cuda"] = { + "_segment_reduce.lengths": {f16}, + "_segment_reduce.offsets": {f16}, + "addmv": {f16}, + "as_strided.partial_views": {f16}, + "corrcoef": {f16}, + "diff": {f16}, + "einsum": {f16, i32}, + "gradient": {f16}, + "histogram": {f32, f64}, + "histogramdd": {f32, f64}, + "index_put": {f16, f32, f64}, + "linalg.eig": {f32, f64}, + "linspace": {f16, i32, i64}, + "linspace.tensor_overload": {f16, f32, f64, i32, i64}, + "logspace": {f16, i32, i64}, + "logspace.tensor_overload": {f16, f32, f64, i32, i64}, + "masked_logsumexp": {i64}, + "max_pool2d_with_indices_backward": {f16, f32, f64}, + "new_empty_strided": {f16}, + "nn.functional.adaptive_avg_pool3d": {f16}, + "nn.functional.adaptive_max_pool1d": {f16, f32}, + "nn.functional.adaptive_max_pool2d": {f16, f32}, + "nn.functional.bilinear": {f16}, + "nn.functional.conv_transpose1d": {f16}, + "nn.functional.conv_transpose2d": {f16}, + "nn.functional.conv_transpose3d": {f16}, + "nn.functional.cosine_similarity": {f16}, + "nn.functional.cross_entropy": {f16, f32, f64}, + "nn.functional.gaussian_nll_loss": {f16}, + "nn.functional.grid_sample": {f16, f32, f64}, + "nn.functional.interpolate.area": {f16}, + "nn.functional.nll_loss": {f16, f32, f64}, + "normal": {f16, f32, f64}, + "put": {f16, f32, f64}, + "take": {b8, f16, f32, f64, i32, i64}, + "__rdiv__": {f16}, + "__rmod__": {f16, i64}, + "__rmul__": {f16}, + "__rpow__": {f16}, + "_unsafe_masked_index": {f16}, + "_unsafe_masked_index_put_accumulate": {f16}, + "addcdiv": {f16}, + "addcmul": {f16}, + "atan2": {f16}, + "cumsum": {f16}, + "cumulative_trapezoid": {f16}, + "dist": {f16}, + "div.no_rounding_mode": {f16}, + "fmod": {f16}, + "grid_sampler_2d": {f16}, + "index_fill": {f16, f32, f64}, + "ldexp": {f16}, + "lerp": {f16}, + "linalg.householder_product": {f32}, + "linalg.matrix_norm": {f16}, + "linalg.vector_norm": {f16}, + "masked.cumsum": {f16}, + "masked.logsumexp": {f16}, + "masked.mean": {b8}, + "masked.normalize": {f16}, + "masked.prod": {f16}, + "masked.std": {f16}, + "masked.var": {f16}, + "mul": {f16}, + "nn.functional.alpha_dropout": {f16, f32, f64}, + "nn.functional.avg_pool1d": {f16, f32, f64}, + "nn.functional.avg_pool2d": {f16, f32, f64}, + "nn.functional.avg_pool3d": {f16, f32, f64}, + "nn.functional.binary_cross_entropy": {f16}, + "nn.functional.binary_cross_entropy_with_logits": {f16}, + "nn.functional.conv2d": {f16}, + "nn.functional.cosine_embedding_loss": {f16}, + "nn.functional.dropout2d": {f16, f32, f64}, + "nn.functional.dropout3d": {f16, f32, f64}, + "nn.functional.dropout": {f16, f32, f64}, + "nn.functional.feature_alpha_dropout.with_train": {f16, f32, f64}, + "nn.functional.fractional_max_pool2d": {f16, f32, f64}, + "nn.functional.fractional_max_pool3d": {f16, f32, f64}, + "nn.functional.group_norm": {f16}, + "nn.functional.hinge_embedding_loss": {f16}, # Enabling all tests for this test fails randomly # See https://github.com/pytorch/pytorch/issues/129238 - ("nn.functional.huber_loss", "cuda"): {f16}, - ("nn.functional.interpolate.bicubic", "cuda"): {f16}, - ("nn.functional.interpolate.bilinear", "cuda"): {f16}, - ("nn.functional.interpolate.trilinear", "cuda"): {f16}, - ("nn.functional.kl_div", "cuda"): {f16}, - ("nn.functional.margin_ranking_loss", "cuda"): {f16}, - ("nn.functional.max_pool1d", "cuda"): {f16, f32, f64}, - ("nn.functional.max_pool3d", "cuda"): {f16}, - ("nn.functional.mse_loss", "cuda"): {f16}, - ("nn.functional.multi_margin_loss", "cuda"): {f16}, - ("nn.functional.multilabel_margin_loss", "cuda"): {f16}, - ("nn.functional.multilabel_soft_margin_loss", "cuda"): {f16}, - ("nn.functional.normalize", "cuda"): {f16}, - ("nn.functional.pad.replicate", "cuda"): {f16, f32, f64}, - ("nn.functional.pad.reflect", "cuda"): {f16}, - ("nn.functional.pairwise_distance", "cuda"): {f16}, - ("nn.functional.poisson_nll_loss", "cuda"): {f16}, - ("nn.functional.rms_norm", "cuda"): {f16}, - ("norm", "cuda"): {f16}, - ("pow", "cuda"): {f16}, - ("prod", "cuda"): {f16}, - ("scatter_reduce.amax", "cuda"): {f16, f32, f64}, - ("scatter_reduce.amin", "cuda"): {f16, f32, f64}, - ("scatter_reduce.mean", "cuda"): {f16, f32, f64}, - ("special.xlog1py", "cuda"): {f16}, - ("std", "cuda"): {f16}, - ("std_mean", "cuda"): {f16}, - ("svd_lowrank", "cuda"): {f32, f64}, - ("trapezoid", "cuda"): {f16}, - ("trapz", "cuda"): {f16}, - ("true_divide", "cuda"): {f16}, - ("var", "cuda"): {f16}, - ("var_mean", "cuda"): {f16}, - ("xlogy", "cuda"): {f16}, + "nn.functional.huber_loss": {f16}, + "nn.functional.interpolate.bicubic": {f16}, + "nn.functional.interpolate.bilinear": {f16}, + "nn.functional.interpolate.trilinear": {f16}, + "nn.functional.kl_div": {f16}, + "nn.functional.margin_ranking_loss": {f16}, + "nn.functional.max_pool1d": {f16, f32, f64}, + "nn.functional.max_pool3d": {f16}, + "nn.functional.mse_loss": {f16}, + "nn.functional.multi_margin_loss": {f16}, + "nn.functional.multilabel_margin_loss": {f16}, + "nn.functional.multilabel_soft_margin_loss": {f16}, + "nn.functional.normalize": {f16}, + "nn.functional.pad.replicate": {f16, f32, f64}, + "nn.functional.pad.reflect": {f16}, + "nn.functional.pairwise_distance": {f16}, + "nn.functional.poisson_nll_loss": {f16}, + "nn.functional.rms_norm": {f16}, + "norm": {f16}, + "pow": {f16}, + "prod": {f16}, + "scatter_reduce.amax": {f16, f32, f64}, + "scatter_reduce.amin": {f16, f32, f64}, + "scatter_reduce.mean": {f16, f32, f64}, + "special.xlog1py": {f16}, + "std": {f16}, + "std_mean": {f16}, + "svd_lowrank": {f32, f64}, + "trapezoid": {f16}, + "trapz": {f16}, + "true_divide": {f16}, + "var": {f16}, + "var_mean": {f16}, + "xlogy": {f16}, +} + +inductor_one_sample["xpu"] = { + "_segment_reduce.lengths": {f16}, + "_segment_reduce.offsets": {f16}, + "addmv": {f16}, + "as_strided.partial_views": {f16}, + "corrcoef": {f16}, + "diff": {f16}, + "einsum": {f16, i32}, + "gradient": {f16}, + "histogram": {f32, f64}, + "histogramdd": {f32, f64}, + "index_put": {f16, f32, f64}, + "linalg.eig": {f32, f64}, + "linspace": {f16, i32, i64}, + "linspace.tensor_overload": {f16, f32, f64, i32, i64}, + "logspace": {f16, i32, i64}, + "logspace.tensor_overload": {f16, f32, f64, i32, i64}, + "masked_logsumexp": {i64}, + "max_pool2d_with_indices_backward": {f16, f32, f64}, + "new_empty_strided": {f16}, + "nn.functional.adaptive_avg_pool3d": {f16}, + "nn.functional.adaptive_max_pool1d": {f16, f32}, + "nn.functional.adaptive_max_pool2d": {f16, f32}, + "nn.functional.bilinear": {f16}, + "nn.functional.conv_transpose1d": {f16}, + "nn.functional.conv_transpose2d": {f16}, + "nn.functional.conv_transpose3d": {f16}, + "nn.functional.cosine_similarity": {f16}, + "nn.functional.cross_entropy": {f16, f32, f64}, + "nn.functional.gaussian_nll_loss": {f16}, + "nn.functional.grid_sample": {f16, f32, f64}, + "nn.functional.interpolate.area": {f16}, + "nn.functional.nll_loss": {f16, f32, f64}, + "normal": {f16, f32, f64}, + "put": {f16, f32, f64}, + "take": {b8, f16, f32, f64, i32, i64}, + "__rdiv__": {f16}, + "__rmod__": {f16, i64}, + "__rmul__": {f16}, + "__rpow__": {f16}, + "_unsafe_masked_index": {f16}, + "_unsafe_masked_index_put_accumulate": {f16}, + "addcdiv": {f16}, + "addcmul": {f16}, + "atan2": {f16}, + "cumsum": {f16}, + "cumulative_trapezoid": {f16}, + "dist": {f16}, + "div.no_rounding_mode": {f16}, + "fmod": {f16}, + "grid_sampler_2d": {f16}, + "index_fill": {f16, f32, f64}, + "ldexp": {f16}, + "lerp": {f16}, + "linalg.householder_product": {f32}, + "linalg.matrix_norm": {f16}, + "linalg.vector_norm": {f16}, + "masked.cumsum": {f16}, + "masked.logsumexp": {f16}, + "masked.mean": {b8}, + "masked.normalize": {f16}, + "masked.prod": {f16}, + "masked.std": {f16}, + "masked.var": {f16}, + "mul": {f16}, + "nn.functional.alpha_dropout": {f16, f32, f64}, + "nn.functional.avg_pool1d": {f16, f32, f64}, + "nn.functional.avg_pool2d": {f16, f32, f64}, + "nn.functional.avg_pool3d": {f16, f32, f64}, + "nn.functional.binary_cross_entropy": {f16}, + "nn.functional.binary_cross_entropy_with_logits": {f16}, + "nn.functional.conv2d": {f16}, + "nn.functional.cosine_embedding_loss": {f16}, + "nn.functional.dropout2d": {f16, f32, f64}, + "nn.functional.dropout3d": {f16, f32, f64}, + "nn.functional.dropout": {f16, f32, f64}, + "nn.functional.feature_alpha_dropout.with_train": {f16, f32, f64}, + "nn.functional.fractional_max_pool2d": {f16, f32, f64}, + "nn.functional.fractional_max_pool3d": {f16, f32, f64}, + "nn.functional.group_norm": {f16}, + "nn.functional.hinge_embedding_loss": {f16}, + # Enabling all tests for this test fails randomly + # See https://github.com/pytorch/pytorch/issues/129238 + "nn.functional.huber_loss": {f16}, + "nn.functional.interpolate.bicubic": {f16}, + "nn.functional.interpolate.bilinear": {f16}, + "nn.functional.interpolate.trilinear": {f16}, + "nn.functional.kl_div": {f16}, + "nn.functional.margin_ranking_loss": {f16}, + "nn.functional.max_pool1d": {f16, f32, f64}, + "nn.functional.max_pool3d": {f16}, + "nn.functional.mse_loss": {f16}, + "nn.functional.multi_margin_loss": {f16}, + "nn.functional.multilabel_margin_loss": {f16}, + "nn.functional.multilabel_soft_margin_loss": {f16}, + "nn.functional.normalize": {f16}, + "nn.functional.pad.replicate": {f16, f32, f64}, + "nn.functional.pad.reflect": {f16}, + "nn.functional.pairwise_distance": {f16}, + "nn.functional.poisson_nll_loss": {f16}, + "nn.functional.rms_norm": {f16}, + "norm": {f16}, + "pow": {f16}, + "prod": {f16}, + "scatter_reduce.amax": {f16, f32, f64}, + "scatter_reduce.amin": {f16, f32, f64}, + "scatter_reduce.mean": {f16, f32, f64}, + "special.xlog1py": {f16}, + "std": {f16}, + "std_mean": {f16}, + "svd_lowrank": {f32, f64}, + "trapezoid": {f16}, + "trapz": {f16}, + "true_divide": {f16}, + "var": {f16}, + "var_mean": {f16}, + "xlogy": {f16}, } @@ -572,6 +967,7 @@ def tearDown(self): True ) # inductor kernels failing this test intermittently @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") + @skipXPUIf(not HAS_XPU, "Skipped! Supported XPU compiler not found") @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @skipIfTorchDynamo("Test uses dynamo already") @@ -593,6 +989,8 @@ def test_comprehensive(self, device, dtype, op): # TODO: should we move empty_cache to the common device interface if device_type == "cuda": torch.cuda.empty_cache() + elif device == "xpu": + torch.xpu.empty_cache() op_name = op.name if op.variant_test_name: op_name += f".{op.variant_test_name}" @@ -630,12 +1028,12 @@ def test_comprehensive(self, device, dtype, op): test_expect = ExpectedTestResult.SUCCESS overridden_kwargs = {} - if op_name in inductor_override_kwargs: - overridden_kwargs = inductor_override_kwargs[op_name] - elif (op_name, device_type) in inductor_override_kwargs: - overridden_kwargs = inductor_override_kwargs[(op_name, device_type)] - elif (op_name, device_type, dtype) in inductor_override_kwargs: - overridden_kwargs = inductor_override_kwargs[(op_name, device_type, dtype)] + overridden_kwargs.update( + inductor_override_kwargs.get(device_type, {}).get(op_name, {}) + ) + overridden_kwargs.update( + inductor_override_kwargs.get(device_type, {}).get((op_name, dtype), {}) + ) func = op.get_op() def fn(*args, **kwargs): @@ -653,8 +1051,7 @@ def fn(*args, **kwargs): samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) if ( - dtype in inductor_one_sample.get(op_name, {}) - or dtype in inductor_one_sample.get((op_name, device_type), {}) + dtype in inductor_one_sample.get(device_type, {}).get(op_name, {}) ) and not ALL_SAMPLES: if isinstance(samples, (list, tuple)): samples = [samples[0]] @@ -796,7 +1193,7 @@ def _get_tolerances(dtype): # print(f"SUCCEEDED OP {op_name} on {device_type} with {dtype}", flush=True, file=f) -instantiate_device_type_tests(TestInductorOpInfo, globals()) +instantiate_device_type_tests(TestInductorOpInfo, globals(), allow_xpu=True) if __name__ == "__main__": run_tests() From 6ee0026032aff514887d02a69a7967a28e37d921 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Thu, 12 Sep 2024 15:40:08 +0000 Subject: [PATCH 0808/1018] [Inductor UT] Generalize device-bias code in test_triton_kernels.py introduced in #135530 (#135656) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135656 Approved by: https://github.com/EikanWang, https://github.com/zou3519 --- test/inductor/test_triton_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index f0aea3f74e49eb..c11bfbcb790c70 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -1035,7 +1035,7 @@ def _(x): @register_lowering(torch.ops.mylib.weird_op_with_lowering) def _(x): return empty_strided( - x.shape, (2, 1), dtype=x.dtype, device=torch.device("cuda:0") + x.shape, (2, 1), dtype=x.dtype, device=torch.device(GPU_TYPE, 0) ) # Triton kernel that has different behavior depending on the input strides. @@ -1068,7 +1068,7 @@ def f(x): arange_out(x, y) return x + y - x = torch.randn(2, 2, device="cuda") + x = torch.randn(2, 2, device=GPU_TYPE) eager_out = f(x) compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True) From f3613cab1db0ae6ff1cd9b282ca2639764de05d1 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 12 Sep 2024 17:50:48 -0700 Subject: [PATCH 0809/1018] [inductor] Remove the batch fusion passes from being a default (#135922) Ads team do a search internally to figure out which fusion passes to use. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135922 Approved by: https://github.com/eellison, https://github.com/yanboliang ghstack dependencies: #135819 --- torch/_inductor/config.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d3d58c551ec374..1d247bcc5865c9 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -198,14 +198,7 @@ def autotune_remote_cache_default() -> Optional[bool]: # merge_splits_pass # mutate_cat_pass # split_cat_pass -pre_grad_fusion_options: Dict[str, Dict[str, Any]] = { - "batch_linear": {}, - "batch_linear_lhs": {}, - "batch_layernorm": {}, - "batch_tanh": {}, - "batch_relu": {}, - "batch_sigmoid": {}, -} +pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {} # Post grad fusion and options, set to empty dict to disable fusion. # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. From a203fc2e5f0d60e931c575b7c41b19be9fb49587 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 12 Sep 2024 18:52:20 -0700 Subject: [PATCH 0810/1018] remove fast_flush arguments (#135387) I've removed them from upstream Triton in https://github.com/triton-lang/triton/pull/4485. It looks like most places in the code use the default value of `fast_flush=True` anyway, though there are two PRs from @pearu that use `False`. To my knowledge, there's no reason to use the `False` value. Differential Revision: [D62325778](https://our.internmc.facebook.com/intern/diff/D62325778) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135387 Approved by: https://github.com/nmacchioni, https://github.com/jansel --- benchmarks/sparse/triton_ops.py | 4 +--- torch/_inductor/codegen/multi_kernel.py | 2 +- torch/_inductor/codegen/triton.py | 2 +- torch/_inductor/codegen/triton_combo_kernel.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 2 +- torch/_inductor/wrapper_benchmark.py | 4 +--- torch/sparse/_triton_ops_meta.py | 10 +++------- 7 files changed, 9 insertions(+), 17 deletions(-) diff --git a/benchmarks/sparse/triton_ops.py b/benchmarks/sparse/triton_ops.py index fad89c280f8975..2493e1e0f74019 100644 --- a/benchmarks/sparse/triton_ops.py +++ b/benchmarks/sparse/triton_ops.py @@ -28,9 +28,7 @@ def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): def _test_worker(test_func): - ms, ms_min, ms_max = benchmarker.benchmark_gpu( - test_func, warmup=500, rep=100, fast_flush=False - ) + ms, ms_min, ms_max = benchmarker.benchmark_gpu(test_func, warmup=500, rep=100) tflops = 2 * m * k * n * 1e-12 / (ms * 1e-3) return ms, tflops diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 020c38dcc77e4e..6081530fd98e54 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -309,7 +309,7 @@ def inner(): return inner return [ - benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40, fast_flush=True) + benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40) for kernel in self.kernels ] diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c9b4c38c160bad..fa64fc54ffb179 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2503,7 +2503,7 @@ def codegen_kernel_benchmark(self, num_gb, grid=None): result.writeline("args = get_args()") result.writeline( - "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)" + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index b58d61eb326588..c5ab3a922d5706 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -994,7 +994,7 @@ def codegen_kernel_benchmark( result.writeline("args = get_args()") result.writeline( - "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)" + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 008572d65bba32..c0b0e20e906d5e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -672,7 +672,7 @@ def kernel_call(): return do_bench_using_profiling(kernel_call, warmup=10, rep=40) - return benchmarker.benchmark_gpu(kernel_call, rep=40, fast_flush=True) + return benchmarker.benchmark_gpu(kernel_call, rep=40) def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: from ..compile_fx import clone_preserve_strides diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index bdd1f0fc95b7f9..ddbfbcf19609c8 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -120,9 +120,7 @@ def get_info_str(ms, n_regs, n_spills, shared, prefix=""): f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" ) else: - ms = benchmarker.benchmark_gpu( - lambda: kernel_mod.call(args), rep=40, fast_flush=True - ) + ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40) assert ( len(triton_kernel.launchers) == 1 ), "Autotuner should have selected the best config" diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index a97d9c502dc349..c7234641886234 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -507,9 +507,7 @@ def bench(meta, bsr=bsr, dense=dense): def test_func(): return bsr_scatter_mm(bsr, dense, indices_data=indices_data) - ms_min = triton.testing.do_bench( - test_func, warmup=500, rep=100, fast_flush=False - ) + ms_min = triton.testing.do_bench(test_func, warmup=500, rep=100) return ms_min @@ -663,7 +661,7 @@ def test_func(): input, bsr, dense, beta=beta, alpha=alpha, meta=meta, out=out ) - return triton.testing.do_bench(test_func, warmup=500, rep=100, fast_flush=False) + return triton.testing.do_bench(test_func, warmup=500, rep=100) # The step function that increments a specified meta parameter: def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=BK): @@ -866,9 +864,7 @@ def test_func(): else: raise NotImplementedError(op) - ms_min = triton.testing.do_bench( - test_func, warmup=500, rep=100, fast_flush=False - ) + ms_min = triton.testing.do_bench(test_func, warmup=500, rep=100) return ms_min From 3b411e379a11fadf4d4b55892921c218bda1b684 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 11 Sep 2024 14:43:46 -0700 Subject: [PATCH 0811/1018] [Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732) For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732 Approved by: https://github.com/ydwu4 --- torch/_dynamo/backends/debugging.py | 12 +++++ torch/_higher_order_ops/cond.py | 20 +++++--- torch/_higher_order_ops/scan.py | 16 +++++-- torch/_higher_order_ops/while_loop.py | 21 +++++++-- torch/fx/experimental/proxy_tensor.py | 67 ++++++++++++++++----------- torch/nn/attention/flex_attention.py | 42 +++++++++++------ 6 files changed, 124 insertions(+), 54 deletions(-) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 87214f7d59774e..18784498862e39 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs): return gm.forward +def make_eager_backend_with_torch_function_mode(mode): + """Used to trace HOPs (cond and while) for eager exectution, the metadata + TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks + in the HOP, so we need to externally run this mode and not trace it.""" + + def fn(gm, fake_tensor_inputs, **kwargs): + with mode: + return gm.forward + + return fn + + @register_backend def eager_noexcept(gm, fake_tensor_inputs, **kwargs): if kwargs: diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index dee400d76f5964..0467e2899adc28 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -28,6 +28,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, @@ -129,6 +130,10 @@ def false_fn(x: torch.Tensor): if torch.compiler.is_dynamo_compiling(): return cond_op(pred, true_fn, false_fn, operands) + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + if isinstance(pred, (bool, int, float)): log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." @@ -169,12 +174,15 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(): - with torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_pre_dispatch_torch_function_mode(): - return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( - pred, true_fn, false_fn, operands - ) + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( + pred, true_fn, false_fn, operands + ) def create_fw_bw_graph_branches(true_fn, false_fn, *operands): diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index d6eecb0a9f29f9..d66cff067f6682 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -20,6 +20,7 @@ from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, track_tensor_tree, @@ -118,10 +119,19 @@ def _scan_op_wrapper(*args, **kwargs): return scan(*args, **kwargs) if not torch._dynamo.is_compiling(): + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)( - combine_fn, init, xs, dim=dim, reverse=reverse - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_scan_op_wrapper, backend=backend, fullgraph=True)( + combine_fn, init, xs, dim=dim, reverse=reverse + ) leaves_init, spec_init = pytree.tree_flatten(init) leaves_xs, spec_xs = pytree.tree_flatten(xs) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index d64f512399d9d5..f14321842f40b9 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -15,7 +15,11 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + ProxyTorchDispatchMode, + track_tensor_tree, +) class WhileLoopOp(HigherOrderOperator): @@ -113,6 +117,9 @@ def body_fn(iter, x): - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. """ + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. @@ -140,9 +147,15 @@ def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( - cond_fn, body_fn, carried_inputs, additional_inputs - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile( + _while_loop_op_wrapper, backend=backend, fullgraph=True + )(cond_fn, body_fn, carried_inputs, additional_inputs) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 5a5771d2ae8f7a..213672e216e6ce 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from contextlib import contextmanager, ExitStack, nullcontext +from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -1084,38 +1084,43 @@ def unwrap_proxy(self, e: T) -> object: return e -@contextmanager -def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: - from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode +def _make_temp_remove_mode_context_manager( + mode_ty: Type[TorchFunctionMode], +) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: + @contextmanager + def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - temp_elements = [] - pre_dispatch_mode = None + temp_elements = [] + removed_mode = None - while _len_torch_function_stack() > 0: - mode = _pop_mode() - if isinstance(mode, PreDispatchTorchFunctionMode): - pre_dispatch_mode = mode - break - else: - temp_elements.append(mode) + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, mode_ty): + removed_mode = mode + break + else: + temp_elements.append(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + for mode in reversed(temp_elements): + _push_mode(mode) - try: - yield + try: + yield removed_mode - finally: - if pre_dispatch_mode is not None: - count = len(temp_elements) - while count > 0: - mode = _pop_mode() - count -= 1 + finally: + if removed_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(removed_mode) - temp_elements.append(pre_dispatch_mode) + for mode in reversed(temp_elements): + _push_mode(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + return context_manager_fn @torch._disable_dynamo @@ -1230,6 +1235,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( + TorchFunctionMetadataMode +) + + # This mode is **only** used for pre_dispatch tracing. # In particular, we need to make sure that autograd/autocast API's # that do not desugar into dispatcher operators stay in the graph. @@ -1258,6 +1268,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( + PreDispatchTorchFunctionMode +) + + class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index bb04ce53d0fcab..2317dbac83bdda 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,6 +19,7 @@ ) from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input @@ -1033,6 +1034,10 @@ def score_mod( if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + # Dynamo is expecting a callable with "__code__" attribute. # We cannot directly pass hop to it. So we wrap it in a dummy function. def _flex_attention_hop_wrapper(*args, **kwargs): @@ -1041,18 +1046,25 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _set_compilation_env(): with torch._dynamo.utils.disable_cache_limit(): with _temp_remove_pre_dispatch_torch_function_mode(): - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend="eager", fullgraph=True - )( - query, - key, - value, - score_mod, - block_mask.as_tuple(), - scale, - kernel_options, - ) - if return_lse: - return out, lse * math.log(2) - else: - return out + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode( + metadata_mode + ) + else: + backend = "eager" + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out From 18665e500a8df283dc9c1b893b43baa3e1c2bdb2 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 12 Sep 2024 01:28:22 -0700 Subject: [PATCH 0812/1018] [Dynamo] Trace torch function modes entered outside of torch.compile (#133137) This PR adds initial tracing for torch function modes. Details: In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers. Previously landed: https://github.com/pytorch/pytorch/pull/133135 https://github.com/pytorch/pytorch/pull/133136 https://github.com/pytorch/pytorch/pull/133134 https://github.com/pytorch/pytorch/pull/133133 https://github.com/pytorch/pytorch/pull/133132 https://github.com/pytorch/pytorch/pull/133131 https://github.com/pytorch/pytorch/pull/133729 https://github.com/pytorch/pytorch/pull/133130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137 Approved by: https://github.com/jansel, https://github.com/zou3519 ghstack dependencies: #134732 --- test/dynamo/test_modes.py | 135 +++++++++ test/functorch/test_control_flow.py | 4 + .../pt2e/test_metadata_porting.py | 4 +- torch/_dynamo/convert_frame.py | 5 + torch/_dynamo/guards.py | 14 + torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 78 ++---- torch/_dynamo/trace_rules.py | 5 +- torch/_dynamo/utils.py | 10 +- torch/_dynamo/variables/ctx_manager.py | 2 + torch/_dynamo/variables/torch.py | 260 ++++++++++-------- torch/_dynamo/variables/torch_function.py | 129 +++++++-- torch/_dynamo/variables/user_defined.py | 7 - torch/_export/non_strict_utils.py | 6 +- 14 files changed, 457 insertions(+), 204 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index ee8a9b579ff124..bad5a788903440 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -14,6 +14,17 @@ from torch.utils._python_dispatch import TorchDispatchMode +class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -324,6 +335,130 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 2) + def test_nested_torch_function_mode(self): + mode_1_called = False + mode_2_called = False + + def reset_state(): + nonlocal mode_1_called + nonlocal mode_2_called + mode_1_called = False + mode_2_called = False + + ones = torch.ones(2, 2) + zeros = torch.zeros(2, 2) + + class TestMode1(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_1_called + + mode_1_called = True + + if func == torch.add: + return zeros + + return super().__torch_function__(func, types, args, kwargs) + + class TestMode2(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_2_called + + mode_2_called = True + + if func == torch.mul: + return ones + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + def fn_2(x): + return torch.mul(x, 3) + torch.add(x, 3) + + inp = torch.ones(2, 2) + 1 + + for fn_i in [fn, fn_2]: + fn_opt = torch.compile(fn_i, fullgraph=True) + with TestMode1(), TestMode2(): + expected = fn_i(inp), mode_1_called, mode_2_called + reset_state() + actual = fn_opt(inp), mode_1_called, mode_2_called + reset_state() + + self.assertEqual(expected, actual) + + def test_torch_function_mode_disable(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + with torch._C.DisableTorchFunctionSubclass(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + with torch._C.DisableTorchFunction(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_highest_priority(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 16740fdeddf3a6..c1fb51af8e2f90 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -26,6 +26,7 @@ parametrize, requires_cuda, run_tests, + skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, @@ -2882,6 +2883,7 @@ def f(fct, init, xs): gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x) @skipIfNoDynamoSupport + @skipIfCrossRef # Arg order changes with crossref def test_scan_simple_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -2988,6 +2990,7 @@ def f(x, y): self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -3250,6 +3253,7 @@ def test_while_loop_compile(self, backend, while_loop_test): self._check_compile(fn, inp, backend=backend) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] from torch._dynamo.testing import EagerAndRecordGraphs diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 27df58e159f3b6..21488e8cacdd42 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef class TestHelperModules: @@ -139,6 +139,8 @@ def _test_metadata_porting( self.assertEqual(v, node_tags[k]) return m + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_simple_metadata_porting(self): """ Model under test diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 3ab2937a9496d4..4a166e1ff80251 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -605,6 +605,10 @@ def _compile( output: Optional[OutputGraph] = None tracer: Optional[InstructionTranslator] = None + tf_mode_stack: List[ + torch.overrides.TorchFunctionMode + ] = torch.overrides._get_current_function_mode_stack() + @preserve_global_state def transform( instructions: List[Instruction], code_options: Dict[str, object] @@ -618,6 +622,7 @@ def transform( locals, globals, builtins, + tf_mode_stack, code_options, compiler_fn, one_graph, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index b0bd273e96bbb5..4437ef4038190d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -98,6 +98,7 @@ ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, + TorchFunctionModeStackSource, TupleIteratorGetItemSource, TypeSource, UnspecializedBuiltinNNModuleSource, @@ -111,6 +112,7 @@ dict_keys_repr, get_custom_getattr, get_torch_function_mode_stack, + get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, @@ -314,6 +316,7 @@ def uninteresting_files(): "___dict_contains": lambda a, b: a in b, "___tuple_iterator_len": tuple_iterator_len, "___tuple_iterator_getitem": tuple_iterator_getitem, + "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -901,6 +904,15 @@ def get_guard_manager_from_source(self, source): ): assert base_guard_manager # to make mypy happy out = base_guard_manager + elif istype(source, TorchFunctionModeStackSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: get_torch_function_mode_stack_at( + source._get_index() + ), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -2214,6 +2226,8 @@ def __init__( self.output_graph = output_graph w_builder = None + # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing + # in case a set default device call was made in the graph. self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 6ed9d97f6a1210..c44f8d9e69abf5 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source): ind: int def name(self): - return "" + return f"___get_torch_function_mode_stack_at({self._get_index()})" def _get_index(self): from .variables.torch_function import TorchFunctionModeStackVariable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index ab92f82aa0f630..415179a5ed32a6 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -19,20 +19,7 @@ import types import typing import weakref -from typing import ( - Any, - Callable, - cast, - Deque, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union from unittest.mock import patch import torch @@ -72,14 +59,12 @@ GlobalWeakRefSource, LocalSource, Source, - TorchFunctionModeStackSource, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( counters, get_fake_value, get_instruction_source_311, - get_torch_function_mode_stack, graph_break_dup_warning_checker, istype, LazyString, @@ -120,11 +105,10 @@ ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable - - -if TYPE_CHECKING: - from .variables.torch_function import TorchFunctionModeVariable - +from .variables.torch_function import ( + SymbolicTorchFunctionState, + TorchFunctionModeVariable, +) from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -284,6 +268,10 @@ def resume_fn(self): return ReenterWith(self.stack_index) def exit(self, tx): + if hasattr(self, "graph_break") and isinstance( + self.with_context, TorchFunctionModeVariable + ): + return assert self.with_context is not None return self.with_context.exit(tx) @@ -652,7 +640,9 @@ def handle_graph_break( # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: assert b.with_context is not None - assert isinstance(b.with_context, ContextWrappingVariable) + assert isinstance( + b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) + ) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -728,7 +718,7 @@ class InstructionTranslatorBase( output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] + symbolic_torch_function_state: SymbolicTorchFunctionState stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -2548,7 +2538,7 @@ def __init__( code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, f_code: types.CodeType, export: bool, inline_depth: int, @@ -2563,7 +2553,7 @@ def __init__( self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals - self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack + self.symbolic_torch_function_state = symbolic_torch_function_state self.stack = [] # stack of variable names for tracking 3.13 closures self.name_stack: list[Any] = [] @@ -2652,6 +2642,7 @@ def __init__( f_locals, f_globals, f_builtins, + torch_function_mode_stack, code_options, compiler_fn, one_graph, @@ -2686,7 +2677,7 @@ def __init__( symbolic_locals={}, # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, - symbolic_torch_function_mode_stack=collections.deque(), + symbolic_torch_function_state=None, # type: ignore[arg-type] # set below f_code=f_code, export=export, inline_depth=0, @@ -2721,7 +2712,9 @@ def __init__( if k in f_locals } - self._init_torch_function_mode_stack() + self.symbolic_torch_function_state = SymbolicTorchFunctionState( + torch_function_mode_stack + ) self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] if export: @@ -2762,29 +2755,6 @@ def _throw_if_in_functorch(self): ) unimplemented(msg) - def _init_torch_function_mode_stack(self): - from .variables.torch_function import TorchFunctionModeStackVariable - - TorchFunctionModeStackVariable.reset() - - self.symbolic_torch_function_mode_stack: Deque[ - TorchFunctionModeVariable - ] = collections.deque() - # We want to retrieve all modes to properly reconstruct the stack if needed - py_stack = get_torch_function_mode_stack(filter_ignored=False) - - if py_stack: - has_device_context = isinstance( - py_stack[0], torch.utils._device.DeviceContext - ) - - for i, val in enumerate(py_stack): - self.symbolic_torch_function_mode_stack.append( - variables.LazyVariableTracker.create( - val, source=TorchFunctionModeStackSource(i) - ) - ) - def get_example_value(self, source: Source): if isinstance(source, LocalSource): return self.f_locals[source.local_name] @@ -3116,7 +3086,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3126,7 +3096,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3179,7 +3149,7 @@ def __init__( code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ) -> None: @@ -3196,7 +3166,7 @@ def __init__( f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, - symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, + symbolic_torch_function_state=symbolic_torch_function_state, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 391beb07cd600d..1a206ae339c2db 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3258,6 +3258,7 @@ def _module_dir(m: types.ModuleType): "torch.testing", "torch.utils._content_store", "torch.utils._contextlib", + "torch.utils._device", "torch.utils._foreach_utils", "torch.utils._python_dispatch", "torch.utils._pytree", @@ -3592,7 +3593,9 @@ def lookup_inner( if reasons is not None: reasons.add("func name is patched_init") return SkipFunctionVariable - elif name == "__torch_function__": + elif name == "__torch_function__" or ( + obj and obj.__name__ == "__torch_function__" + ): if reasons is not None: reasons.add("func name is __torch_function__") return UserFunctionVariable diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 0ec548d788f868..ef63ddbea85619 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -63,7 +63,6 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( - _get_function_stack_at, _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, @@ -3065,7 +3064,9 @@ def is_parameter_freezing(): def get_torch_function_mode_stack(filter_ignored=True): from .variables.torch_function import IGNORED_MODES - stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] + stack = [ + get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) + ] if filter_ignored: stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] @@ -3085,6 +3086,11 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) +def clear_torch_function_mode_stack(): + for i in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index eebe90929b379a..8c4eb3dc4e715b 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -637,6 +637,8 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 + tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] + tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] tx.output.set_torch_function_state(values[0]) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 68508f001dae35..f6cb565ef021c0 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -156,6 +156,15 @@ bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) +@functools.lru_cache(None) +def get_overridable_functions(): + from itertools import chain + + from torch.overrides import get_overridable_functions as get_overridable_functions_ + + return set(chain(*get_overridable_functions_().values())) + + class BaseTorchVariable(VariableTracker): """common base for all torch.* functions, classes, modules and other things""" @@ -806,10 +815,10 @@ def handle_pop_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - if not tx.symbolic_torch_function_mode_stack: + if not tx.symbolic_torch_function_state.mode_stack: raise unimplemented("Popping from an empty torch function mode stack") TorchFunctionModeStackVariable.register_mutation(tx) - return tx.symbolic_torch_function_mode_stack.pop() + return tx.symbolic_torch_function_state.pop_torch_function_mode() @register(torch._C._push_on_torch_function_stack) def handle_push_torch_function( @@ -817,7 +826,7 @@ def handle_push_torch_function( ): assert len(args) == 1 and not kwargs TorchFunctionModeStackVariable.register_mutation(tx) - tx.symbolic_torch_function_mode_stack.append(args[0]) + tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) @@ -825,7 +834,9 @@ def handle_len_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) + return ConstantVariable.create( + len(tx.symbolic_torch_function_state.mode_stack) + ) @register(torch.set_default_device) def handle_set_default_device( @@ -857,6 +868,9 @@ def call_function( from . import ConstantVariable, SymNodeVariable, TensorVariable from .builder import wrap_fx_proxy + if self.torch_function_override_enabled(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -878,147 +892,144 @@ def call_function( if result: return result - if can_dispatch_torch_function(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - else: - any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) - all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args - ) - if ( - getattr(self.value, "__module__", "") == "torch" - and self.value.__name__ in bin_ops - and any_symints_or_symfloats - and all_ints_or_floats - ): - msg = f"""\ + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ - log.warning(msg) - unimplemented(msg) - - # TODO(voz): Replace w/ dynamic shape rewrite table. - # Ideally, we would be able to do this at ctor time, but alas we need a combination - # of value + args to determine this. - fn_ = self.value - if any_symints_or_symfloats: - torch_sym_op = f"_sym_{self.value.__name__}" - if getattr(self.value, "__module__", None) == "math" and hasattr( - torch, torch_sym_op - ): - fn_ = getattr(torch, torch_sym_op) - - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - *proxy_args_kwargs(args, kwargs), - ), - ) - - if ( - isinstance(tensor_variable, TensorVariable) - and "requires_grad" in kwargs - and kwargs["requires_grad"].as_python_constant() + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op ): - unimplemented( - """factory functions that return tensors that require grad are not supported. + fn_ = getattr(torch, torch_sym_op) + + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" - ) + ) - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items - ): - if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor.size != result_tensor.size - ): - # It's hard to get out variants with resizing on graph inputs work - # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape + out_tensor.source + and out_tensor in tx.output.graphargs + and isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor.size != result_tensor.size ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if ( + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out_shape != fake_tensor.shape + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( - "out= op was called where output tensor was non-contiguous" + "out= op was called where some of the output tensors were non-contiguous" ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where some of the output tensors were non-contiguous" - ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") + else: + unimplemented(f"out variant of {type(kwargs['out'])}") - return tensor_variable + return tensor_variable def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" @@ -1146,3 +1157,12 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad source ) return result + + def torch_function_override_enabled(self, tx, args, kwargs): + return ( + self.get_function() in get_overridable_functions() + or isinstance( + self.get_function(), + (torch._ops.OpOverload, torch._ops.OpOverloadPacket), + ) + ) and can_dispatch_torch_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4b3188507fe4b5..6e52cc688e0309 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,8 +1,11 @@ # mypy: ignore-errors +import collections +import contextlib import inspect -from typing import Dict, List, TYPE_CHECKING +from typing import Deque, Dict, List, TYPE_CHECKING +import torch._C import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import _get_overloaded_args, get_default_nowrap_functions @@ -15,6 +18,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import ContextWrappingVariable +from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable @@ -59,6 +63,67 @@ IGNORED_MODES = {DeviceContext} +class SymbolicTorchFunctionState: + def __init__(self, py_stack): + # This is annoyingly complicated because of how the torch function subclass + mode C API was designed + # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass + # These are their definitions: + # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered + # (if either are entered, this will be False) + # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR + # torch._C.DisableTorchFunction has been entered + # To disambiguate these and keep myself sane I added a C API to check whether all torch function + # concepts (modes and subclasses) are enabled. + # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate + # the stack length from the enablement state of torch function modes. + # This is important because now if a mode is pushed while dynamo is tracing, we know whether + # or not torch function modes are enabled and whether we should trace it. + self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() + + # This differs from the C API of the same name + # this will only be false iff we have entered torch._C.DisableTorchFunction + # and does not take into account the mode stack length, while the C API bundles these + # two concepts + self.torch_function_mode_enabled = ( + not torch._C._is_torch_function_all_disabled() + ) + + self.cur_mode = None + + TorchFunctionModeStackVariable.reset() + + self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() + + for i, val in enumerate(py_stack): + self.mode_stack.append( + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) + ) + + def in_torch_function_mode(self): + return len(self.mode_stack) > 0 + + def pop_torch_function_mode(self): + return self.mode_stack.pop() + + def push_torch_function_mode(self, mode_var): + self.mode_stack.append(mode_var) + + def call_torch_function_mode(self, tx, fn, types, args, kwargs): + with self._pop_mode_for_inlining() as cur_mode: + return cur_mode.call_torch_function(tx, fn, types, args, kwargs) + + @contextlib.contextmanager + def _pop_mode_for_inlining(self): + old_mode = self.cur_mode + self.cur_mode = self.pop_torch_function_mode() + try: + yield self.cur_mode + finally: + mode = self.cur_mode + self.cur_mode = old_mode + self.push_torch_function_mode(mode) + + class TorchFunctionModeStackVariable(VariableTracker): """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" @@ -88,19 +153,20 @@ def reset(cls): def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( - source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack + source=Source(), + symbolic_stack=tx.symbolic_torch_function_state.mode_stack, ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 - tx.symbolic_torch_function_mode_stack.insert( + stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) @@ -109,7 +175,7 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @@ -124,23 +190,39 @@ def get_mode_index(cls, ind): class TorchFunctionModeVariable(ContextWrappingVariable): - def __init__(self, value, **kwargs): + def __init__(self, value, source=None, **kwargs): super().__init__(value, **kwargs) self.value = value - - @staticmethod - def get_global_mangled_name(tx, val): - return get_safe_global_name( - tx, f"__torch_function_mode_{val.__class__.__name__}", val - ) + self.cm_obj = value # needed for BC with calling enter from CM code + self.source = source def reconstruct(self, codegen): - # We don't support locally created torch function modes yet + # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) - def _call_func(self, tx, values): - unimplemented("torch function mode context manager is not supported yet") + def module_name(self): + return self.value.__module__ + + def fn_name(self): + return type(self.value).__name__ + + def python_type(self): + return type(self.value) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + return call_torch_function( + tx, + self, + build_torch_function_fn(tx, self.value, self.source), + fn, + types, + args, + kwargs, + ) + + def _call_func(self, tx: "InstructionTranslator", values): + unimplemented("enter/exit for torch function mode NYI") def _get_all_args(args, kwargs): @@ -231,9 +313,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): - return tx.output.torch_function_enabled and any( + has_overridden_args = any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) + tf_state = tx.symbolic_torch_function_state + return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( + tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() + ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): @@ -245,11 +331,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): _get_subclass_type, ) + types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) + + if tx.symbolic_torch_function_state.in_torch_function_mode(): + res = tx.symbolic_torch_function_state.call_torch_function_mode( + tx, fn, types, args, kwargs + ) + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + return res + for arg in overloaded_args: res = arg.call_torch_function( tx, fn, - TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), + types, args, kwargs, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 9aa64a5a6a8d43..5d259fda55907d 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,11 +82,6 @@ def is_forbidden_context_manager(ctx): from _pytest.python_api import RaisesContext from _pytest.recwarn import WarningsChecker - # TODO mlazos: Temporary to get this stack to pass - # remove in subsequent PR - from torch.overrides import BaseTorchFunctionMode - - f_ctxs.append(BaseTorchFunctionMode) f_ctxs.append(RaisesContext) f_ctxs.append(WarningsChecker) except ImportError: @@ -417,7 +412,6 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): - # import here to avoid an unfortunate circular dependency. from .ctx_manager import GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( @@ -425,7 +419,6 @@ def call_function( ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj - elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 5c7331659d26a0..ef15e5fea9e973 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: + if ( + not torch.compiler.is_dynamo_compiling() + and log.isEnabledFor(logging.DEBUG) + and config.extended_debug_current_loc + ): frame = _find_user_code_frame() if frame is not None: log.debug( From 2f830e695934be736cb87e03b0fba96f227db2de Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 12 Sep 2024 01:28:23 -0700 Subject: [PATCH 0813/1018] [Dynamo] Support thread local setattr (#135443) In preparation for tracing through DeviceContext (https://github.com/pytorch/pytorch/blob/defb515306fc53ec62e92937a5a76fa5cbc05b84/torch/utils/_device.py#L66) This PR adds support for calling the setattr of thread local objects. These objects have a slots impl, and since this doesn't appear to have any side effects, we call this setattr impl when replaying mutations, since calling `object.__setattr__` on these objects results in a type error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135443 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137 --- test/dynamo/test_misc.py | 15 +++++++++++++++ torch/_dynamo/variables/user_defined.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 9d2bb374621134..96047d3436ebe1 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3380,6 +3380,21 @@ def fn4(x) -> None: self.assertTrue(same(obj41.y, obj42.y)) self.assertEqual(cnts.frame_count, 1) + def test_thread_local_setattr(self): + from threading import local + + loc = local() + + @torch.compile(fullgraph=True) + def fn(x, l): + l.x = x + return x + 1 + + x = torch.ones(2, 2) + fn(x, loc) + + self.assertTrue(loc.x is x) + def test_user_defined_class_name(self): class MyClassFoo: pass diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 5d259fda55907d..cfb50a5d7268a1 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -9,6 +9,7 @@ import itertools import random import sys +import threading import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING @@ -708,7 +709,7 @@ def call_method( if method is object.__init__: return ConstantVariable.create(None) - if is_standard_setattr(method): + if is_standard_setattr(method) or isinstance(self.value, threading.local): return self.method_setattr_standard(tx, *args, **kwargs) # [NOTE] OrderedDict, dict subtypes must always have source @@ -806,7 +807,7 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def needs_slow_setattr(self): return not is_standard_setattr( inspect.getattr_static(self.value, "__setattr__", None) - ) + ) and not isinstance(self.value, threading.local) def unpack_var_sequence(self, tx): if ( From 6ec2fa8d4e617eff52d2b13573afa6035c8f3632 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 12 Sep 2024 01:28:23 -0700 Subject: [PATCH 0814/1018] [Dynamo] Simplify torch function mode stack guard (#135444) The semantics of ignored modes previously had edge cases, this eliminates these by in essence filtering any ignored modes out of both the ref stack and the current torch function mode stack. This is purely to fix complexity in #135422. The ignored modes handling will be removed in a future PR after https://github.com/pytorch/pytorch/pull/135422 lands, since we will then trace through DeviceContexts vs inserting them into the graph which needed these extra workarounds for correctness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135444 Approved by: https://github.com/anijain2305, https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443 --- test/dynamo/test_modes.py | 2 ++ torch/_dynamo/guards.py | 10 ++++---- torch/csrc/dynamo/guards.cpp | 44 +++++++++++++++++++++++++++++------- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index bad5a788903440..fa4c23fd320ddb 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -68,9 +68,11 @@ def tearDownClass(cls): def setUp(self): torch.set_default_device(None) + torch._dynamo.reset() def tearDown(self): torch.set_default_device(None) + torch._dynamo.reset() def _run_torch_function_mode_guard_test(self): class TestMode1(BaseTorchFunctionMode): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4437ef4038190d..1544bdd6dc697b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2663,12 +2663,14 @@ def make_torch_function_mode_stack_guard(intial_stack): def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - if len(cur_stack) != len(types): + + types_ = [ty for ty in types if ty not in IGNORED_MODES] + cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] + + if len(cur_stack_) != len(types_): return False - for ty, mode in zip(types, cur_stack): - if ty in IGNORED_MODES: - continue + for ty, mode in zip(types_, cur_stack_): if ty != type(mode): return False diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index afc3182c632640..771de44427b669 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2523,7 +2523,8 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref - this->_ref_stack.push_back(Py_TYPE(mode)); + auto type = Py_TYPE(mode); + this->_ref_stack.push_back(type); } len = PyList_Size(ignored_types.ptr()); @@ -2543,29 +2544,56 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface size_t ref_ind = 0; - int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - for (int64_t idx = 0; idx < len; idx++) { + int64_t idx = 0; + while ((idx < len) && (ref_ind < ref_stack_size)) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + bool act_ignored = this->_ignored_types.count(mode_type) > 0; + bool ref_ignored = + this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; // skip ignored types - if (this->_ignored_types.count(mode_type) > 0) { + if (act_ignored && ref_ignored) { + idx++; + ref_ind++; + continue; + } else if (ref_ignored) { + ref_ind++; + continue; + } else if (act_ignored) { + idx++; continue; } // if we already have more non-ignored modes than the ref stack // or if the mode doesn't match at the current index, return false - else if ( - (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || - mode_type != _ref_stack[ref_ind]) { + else if (mode_type != _ref_stack.at(ref_ind)) { return false; } ref_ind++; + idx++; + } + + for (; ref_ind < ref_stack_size; ref_ind++) { + if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { + return false; + } + } + + for (; idx < len; idx++) { + std::shared_ptr mode = + at::impl::PythonTorchFunctionTLS::get_stack_at(idx); + + PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + if (!(this->_ignored_types.count(mode_type) > 0)) { + return false; + } } - return ref_ind == this->_ref_stack.size(); + return ref_ind == ref_stack_size && idx == len; } private: From 99a4bdbc3743383f3d4ebc4eff1c1a0950989d5f Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 12 Sep 2024 01:28:23 -0700 Subject: [PATCH 0815/1018] [Dynamo] Trace enter/exit of TorchFunctionModes (#135422) This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode) Typically the bytecode for a context manager looks like this during a graph break: 1. graph call 2. enter context 3. unsupported code 4. exit context 5. resume call resume fn structure: 1. enter context 2. jump ... 3. exit context The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack). So for torch function modes the structure of our output code is this: 1. graph call 2. mutate tf mode stack to replay mutations 4. unsupported code 5. on exception restore stack 6. resume function Then our resume fn looks like this: 1. no-op enter torch function mode 2. jump 3. exit tf mode To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context). Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422 Approved by: https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443, #135444 --- test/dynamo/test_modes.py | 88 ++++++++++++++++++ torch/_dynamo/convert_frame.py | 6 +- torch/_dynamo/output_graph.py | 6 +- torch/_dynamo/polyfills/__init__.py | 20 ++++ torch/_dynamo/resume_execution.py | 104 +++++++++++++++++++++ torch/_dynamo/side_effects.py | 11 +++ torch/_dynamo/symbolic_convert.py | 30 ++++-- torch/_dynamo/testing.py | 1 + torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/builder.py | 20 ++-- torch/_dynamo/variables/ctx_manager.py | 12 +++ torch/_dynamo/variables/torch.py | 14 ++- torch/_dynamo/variables/torch_function.py | 108 ++++++++++++++++++++-- torch/_dynamo/variables/user_defined.py | 14 ++- 14 files changed, 402 insertions(+), 34 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index fa4c23fd320ddb..63e06a515ed186 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -461,6 +461,94 @@ def fn(x): self.assertEqual(expected, actual) + def test_torch_function_mode_enter_exit(self): + def fn(x, y): + with TestMode(): + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn, fullgraph=True) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_graph_break(self): + def fn(x, y): + with TestMode(): + torch._dynamo.graph_break() + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_and_pop_graph_break(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_restore_on_exc(self): + @torch._dynamo.disable() + def err(): + raise RuntimeError("test") + + @torch.compile() + def fn(x): + with TestMode(): + x += 1 + err() + x += 2 + return x + + try: + fn(torch.ones(2, 2)) + except RuntimeError: + pass + self.assertEqual(_len_torch_function_stack(), 0) + + def test_torch_function_mode_and_pop_graph_break_mutation(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + z.y = 5 + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + o = torch.mul(o, z.y) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 4a166e1ff80251..5f543545ecdba3 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -112,6 +112,7 @@ troubleshooting_url, write_record_to_file, ) +from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -210,15 +211,18 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() - exit_stack = contextlib.ExitStack() exit_stack.enter_context( torch.fx._symbolic_trace._maybe_revert_all_patches() ) + exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() + assert ( + torch._C._len_torch_function_stack() == 0 + ), "Torch function mode stack state changed while dynamo tracing, please report a bug" exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ba8cfa5a03dfba..0ea19e1cad6586 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,7 +78,6 @@ get_instruction_source_311, get_locals_to_steal, get_static_address_type, - get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -250,6 +249,7 @@ def __init__( local_scope: Scope, global_scope: Scope, f_code, + torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -368,7 +368,7 @@ def __init__( # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = get_torch_function_mode_stack() + self.torch_function_mode_stack = torch_function_mode_stack # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ def append_prefix_insts(): prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx) + block.exit(tx, is_graph_break=reason.graph_break) self.cleanup_graph() tx.prune_dead_locals() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 2b3f38920a0c2a..5b2812bc08c9e5 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -25,6 +25,26 @@ sys as sys, ) +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + def index(iterator, item, start=0, end=None): from itertools import islice diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 132e9e4081bceb..6f7db66514c487 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -6,6 +6,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( + add_push_null, create_call_function, create_call_method, create_dup_top, @@ -48,6 +49,109 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None + def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): + """ + Codegen based off of: + try: + (rest) + except: + (restore previous stack) + + """ + from .variables.torch_function import get_prev_stack_var_name + + except_jump_target = create_instruction( + "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" + ) + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally: List[Instruction] = [] + + if sys.version_info < (3, 11): + setup_finally.append( + create_instruction("SETUP_FINALLY", target=except_jump_target) + ) + else: + exn_tab_begin = create_instruction("NOP") + exn_tab_end = create_instruction("NOP") + exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_begin, + exn_tab_end, + except_jump_target, + self.stack_index + 1, + False, + ) + setup_finally.append(exn_tab_begin) + + def create_reset(): + insts = [ + create_instruction( + "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" + ), + create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), + ] + add_push_null(insts) + return [ + *insts, + create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()), + *create_call_function(1, False), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *create_reset(), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + create_instruction("END_FINALLY"), + cleanup_complete_jump_target, + ] + elif sys.version_info < (3, 11): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + cleanup_complete_jump_target, + ] + else: + finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) + finally_exn_tab_target = create_instruction("COPY", arg=3) + except_jump_target.exn_tab_entry = InstructionExnTabEntry( + except_jump_target, + finally_exn_tab_end, + finally_exn_tab_target, + self.stack_index + 2, + True, + ) + epilogue = [ + exn_tab_end, + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, # PUSH_EXC_INFO + create_instruction("POP_TOP"), + *create_reset(), + finally_exn_tab_end, + finally_exn_tab_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), # RERAISE 1 + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fc1bd976ff57df..fd5b54e4f7f7af 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -593,11 +593,22 @@ def codegen_update_mutated(self, cg: PyCodegen): elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): + # Needed in the finally block for stack restoration + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "get_torch_function_mode_stack" + ) + ) + cg.call_function(0, False) + name = variables.torch_function.get_prev_stack_var_name() + cg.code_options["co_varnames"] += (name,) + cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) + cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 415179a5ed32a6..61acaec995757a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -267,13 +267,12 @@ def resume_fn(self): else: return ReenterWith(self.stack_index) - def exit(self, tx): - if hasattr(self, "graph_break") and isinstance( - self.with_context, TorchFunctionModeVariable - ): - return + def exit(self, tx, is_graph_break): assert self.with_context is not None - return self.with_context.exit(tx) + if ( + is_graph_break and self.with_context.exit_on_graph_break() + ) or not is_graph_break: + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -639,10 +638,17 @@ def handle_graph_break( cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: + # Don't exit any modes we have entered, + # output bytecode will mutate the tf mode stack accordingly + if isinstance(b.with_context, TorchFunctionModeVariable): + cg.extend_output( + b.resume_fn().try_except_torch_function_mode( + cg.code_options, cleanup + ) + ) + continue assert b.with_context is not None - assert isinstance( - b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) - ) + assert isinstance(b.with_context, (ContextWrappingVariable)) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -2295,7 +2301,10 @@ def setup_or_before_with(self, inst): ): unimplemented(f"{inst.opname} {ctx}") - if isinstance(ctx, GenericContextWrappingVariable): + if ( + isinstance(ctx, GenericContextWrappingVariable) + and not ctx.supports_graph_breaks() + ): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2668,6 +2677,7 @@ def __init__( local_scope=f_locals, global_scope=f_globals, f_code=f_code, + torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 4922b521bada34..704a3889707233 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -163,6 +163,7 @@ def insert_nops(instructions, code_options): local_scope=locals(), global_scope=globals(), f_code=frame.f_code, + torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1a206ae339c2db..ed81b26191b4de 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -304,6 +304,7 @@ "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -2797,7 +2798,6 @@ "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", - "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5a8a9b613b899d..193e227267ec3b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -204,6 +204,7 @@ from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( @@ -1668,15 +1669,16 @@ def wrap_numpy_ndarray(self, value): # but warning is not the end of the world assert isinstance(value.base, np.nditer) - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides - - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 8c4eb3dc4e715b..e19c4e254c647e 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,6 +125,12 @@ def call_function( if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return True + class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -183,6 +189,12 @@ def exit(self, tx: "InstructionTranslator", *args): tx.generic_context_manager_depth -= 1 return x + def supports_graph_breaks(self): + return False + + def exit_on_graph_break(self): + return True + class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f6cb565ef021c0..48f4f4101ba96a 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -162,7 +162,10 @@ def get_overridable_functions(): from torch.overrides import get_overridable_functions as get_overridable_functions_ - return set(chain(*get_overridable_functions_().values())) + funcs = set(chain(*get_overridable_functions_().values())) + more = {torch.ones, torch.ones_like, torch.zeros, torch.zeros_like, torch.empty} + funcs.update(more) + return funcs class BaseTorchVariable(VariableTracker): @@ -838,6 +841,13 @@ def handle_len_torch_function( len(tx.symbolic_torch_function_state.mode_stack) ) + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + assert len(args) == 1 and not kwargs + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] + @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -855,7 +865,7 @@ def handle_set_default_device( else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return None + return ConstantVariable.create(None) return handlers diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 6e52cc688e0309..b7fd83a3d24f15 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -2,22 +2,35 @@ import collections import contextlib +import functools import inspect from typing import Deque, Dict, List, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import _get_overloaded_args, get_default_nowrap_functions +from torch.overrides import ( + _get_overloaded_args, + get_default_nowrap_functions, + TorchFunctionMode, +) from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import ContextWrappingVariable +from .ctx_manager import GenericContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -63,6 +76,39 @@ IGNORED_MODES = {DeviceContext} +@functools.lru_cache(None) +def get_prev_stack_var_name(): + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self): + self.stack = [] + + def __enter__(self): + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__(self, exc_type, exc_value, traceback): + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self): + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + class SymbolicTorchFunctionState: def __init__(self, py_stack): # This is annoyingly complicated because of how the torch function subclass + mode C API was designed @@ -189,9 +235,26 @@ def get_mode_index(cls, ind): return ind + cls.offset -class TorchFunctionModeVariable(ContextWrappingVariable): +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty): + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the funtion across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ + ) + def __init__(self, value, source=None, **kwargs): - super().__init__(value, **kwargs) + if value is not None: + super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source @@ -221,8 +284,39 @@ def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwar kwargs, ) - def _call_func(self, tx: "InstructionTranslator", values): - unimplemented("enter/exit for torch function mode NYI") + def enter(self, tx): + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen): + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return False def _get_all_args(args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index cfb50a5d7268a1..32057c838d74c8 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -413,10 +413,22 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): + from torch.overrides import TorchFunctionMode + from .ctx_manager import GenericContextWrappingVariable + from .torch_function import TorchFunctionModeVariable + + if issubclass( + self.value, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode( + self.value + ): + var_cls = TorchFunctionModeVariable + else: + var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, GenericContextWrappingVariable, {} + self.source, self.value, var_cls, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj From 1d173838985775f744cff20989c4d97775869748 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 12 Sep 2024 01:28:24 -0700 Subject: [PATCH 0816/1018] [Dynamo] Remove ignored modes workaround (#135502) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135502 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137, #135443, #135444, #135422 --- test/dynamo/test_modes.py | 65 ----------------------- torch/_dynamo/guards.py | 12 ++--- torch/_dynamo/utils.py | 10 +--- torch/_dynamo/variables/torch_function.py | 6 --- 4 files changed, 5 insertions(+), 88 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 63e06a515ed186..4d1f2bbea389ef 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -from unittest.mock import patch import torch import torch._dynamo.test_case @@ -107,70 +106,6 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 4) - def _run_ignored_mode_types_test(self): - class IgnoredMode(BaseTorchFunctionMode): - pass - - cnt = torch._dynamo.testing.CompileCounter() - - @torch.compile(backend=cnt.__call__, fullgraph=True) - def fn(x): - return x + 1 - - inp = torch.ones(2, 2) - - with patch( - "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} - ): - # initial compile - fn(inp) - - # no recompile, mode ignored - # note: the ref stack is length 0, and the stack we are checking against has length 2 - # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack - with IgnoredMode(), IgnoredMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 1) - - # recompile due to new mode on the stack - with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 2) - - # recompile - # tests both ref stack len > runtime stack len for the above guard check - # and ref stack len < runtime stack len for the initial zero mode case - with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 3) - - # no recompile - with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 3) - - # This is tricky, basically the ignored modes are baked into the guard - # IgnoredMode will be ignored forever by that guard. - # This is okay since we don't expect to be modifying IGNORED_MODES - # in the middle of execution except for the purposes of testing. - torch._dynamo.reset() - - with IgnoredMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 4) - - @torch._dynamo.config.patch("enable_cpp_guard_manager", False) - def test_torch_function_mode_guards_ignored_types_py(self): - self._run_ignored_mode_types_test() - - def test_torch_function_mode_guards_ignored_types_cpp(self): - self._run_ignored_mode_types_test() - @torch._dynamo.config.patch("enable_cpp_guard_manager", False) def test_torch_function_mode_guards_py(self): self._run_torch_function_mode_guard_test() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1544bdd6dc697b..66deeb3cbc0e79 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2344,15 +2344,13 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): ) if config.enable_cpp_guard_manager: - from .variables.torch_function import IGNORED_MODES - # Insert the global_state guard assert self.guard_manager # to make mypy happy self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(IGNORED_MODES), + list(), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list @@ -2659,18 +2657,14 @@ def is_recompiles_verbose_enabled(): # this will only be used if cpp guards are disabled def make_torch_function_mode_stack_guard(intial_stack): types = [type(x) for x in intial_stack] - from .variables.torch_function import IGNORED_MODES def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - types_ = [ty for ty in types if ty not in IGNORED_MODES] - cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] - - if len(cur_stack_) != len(types_): + if len(cur_stack) != len(types): return False - for ty, mode in zip(types_, cur_stack_): + for ty, mode in zip(types, cur_stack): if ty != type(mode): return False diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ef63ddbea85619..6f3a74eb2625d2 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3061,16 +3061,10 @@ def is_parameter_freezing(): return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(filter_ignored=True): - from .variables.torch_function import IGNORED_MODES - - stack = [ +def get_torch_function_mode_stack(): + return [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] - if filter_ignored: - stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] - - return stack def get_torch_function_mode_stack_at(ind): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index b7fd83a3d24f15..ffb3d27d4d703b 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -69,12 +69,6 @@ if is_tensor_base_attr_getter(fn) ] -# Today set default device is placed in the graph and guarded on separately -# so we should not trace through it. In the future we can trace it once -# mode tracing is implemented and not put in the graph, but this is more -# of a BE project and can be evaluated later -IGNORED_MODES = {DeviceContext} - @functools.lru_cache(None) def get_prev_stack_var_name(): From 3f7be503cdd03a5666a294043340713ff8ac965e Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 12 Sep 2024 01:28:24 -0700 Subject: [PATCH 0817/1018] [Dynamo] Remove ignored modes from torch function mode stack guard (#135503) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135503 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137, #135443, #135444, #135422, #135502 --- torch/_C/_dynamo/guards.pyi | 2 +- torch/_dynamo/guards.py | 1 - torch/csrc/dynamo/guards.cpp | 69 +++++------------------------------- 3 files changed, 10 insertions(+), 62 deletions(-) diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index d6aab0d547f727..918d913068e6df 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -67,7 +67,7 @@ class GuardManager: ) -> None: ... def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_torch_function_mode_stack_guard( - self, initial_stack, ignored_types, verbose_code_parts: list[str] + self, initial_stack, verbose_code_parts: list[str] ) -> None: ... class RootGuardManager(GuardManager): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 66deeb3cbc0e79..3dcc9a032f208f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2350,7 +2350,6 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 771de44427b669..1d1cf2fb0c60a6 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2515,90 +2515,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { public: TORCH_FUNCTION_MODE_STACK( const py::list& initial_stack, - const py::list& ignored_types, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), - _ref_stack(), - _ignored_types() { + : LeafGuard(std::move(verbose_code_parts)), _ref_stack() { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref auto type = Py_TYPE(mode); this->_ref_stack.push_back(type); } - - len = PyList_Size(ignored_types.ptr()); - for (Py_ssize_t idx = 0; idx < len; idx++) { - PyObject* type_obj = - PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref - if (PyType_Check(type_obj) == 0) { - PyErr_SetString( - PyExc_TypeError, "ignored_types should contain a list of types"); - return; - } - PyTypeObject* type = (PyTypeObject*)type_obj; - this->_ignored_types.insert(type); - } } bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface - size_t ref_ind = 0; - const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - int64_t idx = 0; - while ((idx < len) && (ref_ind < ref_stack_size)) { - std::shared_ptr mode = - at::impl::PythonTorchFunctionTLS::get_stack_at(idx); - - PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - bool act_ignored = this->_ignored_types.count(mode_type) > 0; - bool ref_ignored = - this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; - // skip ignored types - if (act_ignored && ref_ignored) { - idx++; - ref_ind++; - continue; - } else if (ref_ignored) { - ref_ind++; - continue; - } else if (act_ignored) { - idx++; - continue; - } - // if we already have more non-ignored modes than the ref stack - // or if the mode doesn't match at the current index, return false - else if (mode_type != _ref_stack.at(ref_ind)) { - return false; - } - ref_ind++; - idx++; - } - - for (; ref_ind < ref_stack_size; ref_ind++) { - if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { - return false; - } + if (len != ref_stack_size) { + return false; } - for (; idx < len; idx++) { + for (int64_t idx = 0; (size_t)idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (!(this->_ignored_types.count(mode_type) > 0)) { + if (mode_type != _ref_stack.at(idx)) { return false; } } - return ref_ind == ref_stack_size && idx == len; + return true; } private: std::vector _ref_stack; - std::set _ignored_types; }; class TENSOR_MATCH : public LeafGuard { @@ -3763,7 +3713,7 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") - .def(py::init()) + .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( py_m, "DATA_PTR_MATCH") @@ -4000,10 +3950,9 @@ PyObject* torch_c_dynamo_guards_init() { "add_torch_function_mode_stack_guard", [](GuardManager& self, const py::list& initial_stack, - const py::list& ignored_types, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - initial_stack, ignored_types, std::move(verbose_code_parts))); + initial_stack, std::move(verbose_code_parts))); }) .def( "add_data_ptr_guard", From 611cc4220797b74cc7b5297d82cedc61dd720066 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 12 Sep 2024 22:47:25 -0700 Subject: [PATCH 0818/1018] [Distributed] add FP8 support to NaN checker (#135891) Adding support for `torch.float8_e4m3fn` and `torch.float8_e5m2` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135891 Approved by: https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 33 ++++++++++++++++++++++--- torch/csrc/distributed/c10d/NanCheck.cu | 4 ++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5dabf3adc2c552..3b1cd8e56d7e8c 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -237,6 +237,8 @@ def setUp(self): self.test_nan_assert_float32.__wrapped__, self.test_nan_assert_float64.__wrapped__, self.test_nan_assert_bfloat16.__wrapped__, + self.test_nan_assert_float8_e4m3fn.__wrapped__, + self.test_nan_assert_float8_e5m2.__wrapped__, ] # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests @@ -347,7 +349,17 @@ def test_close_pg(self): not (TEST_MULTIGPU and CUDA_12_AND_ABOVE), "NCCL test requires 2+ GPUs and Device side assert could cause unexpected errors in lower versions of CUDA", ) - @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16]) + @parametrize( + "type", + [ + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], + ) @skip_if_rocm def test_nan_assert(self, type): # Expecting a device-side error when NaN is detected @@ -364,12 +376,27 @@ def test_nan_assert(self, type): size = (1024, 1024, 1024) # 1G elements else: size = (1,) # 1 element - nan_tensor = torch.zeros(*size, dtype=type, device=device) + + # Note: currently we cannot fill values into a FP8 tensor, thus we + # create the NaN tensor in float32 type and cast it to FP8 + if type == torch.float8_e4m3fn or type == torch.float8_e5m2: + init_type = torch.float32 + else: + init_type = type + + nan_tensor = torch.zeros(*size, dtype=init_type, device=device) # randomly pick an nan element index = tuple([random.randrange(size[i]) for i in range(len(size))]) nan_tensor[index] = float("nan") + if init_type != type: + # Now cast to the targeted dtype + nan_tensor = nan_tensor.to(type) + + output = torch.empty(self.world_size, *size, dtype=type, device=device) with self.assertRaises(RuntimeError): - pg.allreduce(nan_tensor) + # Note: using all-gather here bc FP8 types do not support reduce ops + # at the moment + pg._allgather_base(output, nan_tensor) dist.destroy_process_group() # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" diff --git a/torch/csrc/distributed/c10d/NanCheck.cu b/torch/csrc/distributed/c10d/NanCheck.cu index 85abd733affc95..0aab4ca809260d 100644 --- a/torch/csrc/distributed/c10d/NanCheck.cu +++ b/torch/csrc/distributed/c10d/NanCheck.cu @@ -165,9 +165,11 @@ void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { maxNumBlocks, (tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock); - AT_DISPATCH_FLOATING_TYPES_AND2( + AT_DISPATCH_FLOATING_TYPES_AND4( at::ScalarType::Half, at::ScalarType::BFloat16, + at::ScalarType::Float8_e4m3fn, + at::ScalarType::Float8_e5m2, tensor.scalar_type(), "checkForNaN", [&] { From f1bee4c9c20d3fd625484a7e3e4ed77d9ddba95e Mon Sep 17 00:00:00 2001 From: CaoE Date: Fri, 13 Sep 2024 09:11:47 +0000 Subject: [PATCH 0819/1018] Update document for autocast on CPU (#135299) Update document for autocast on CPU due to the support of float16 and changes in the operator list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135299 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/svekars --- docs/source/amp.rst | 39 ++++++++++++++++++++++-------- docs/source/notes/amp_examples.rst | 2 +- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/docs/source/amp.rst b/docs/source/amp.rst index 7192a47fb24b60..8698742a9367d0 100644 --- a/docs/source/amp.rst +++ b/docs/source/amp.rst @@ -95,6 +95,11 @@ updates the parameters, so the scale factor does not interfere with the learning .. currentmodule:: torch.cuda.amp +.. autoclass:: GradScaler + :members: + +.. currentmodule:: torch.cpu.amp + .. autoclass:: GradScaler :members: @@ -365,7 +370,7 @@ in which unlisted ops run if they're downstream from autocasted ops. If an op is unlisted, we assume it's numerically stable in ``bfloat16``. If you believe an unlisted op is numerically unstable in ``bfloat16``, -please file an issue. +please file an issue. ``float16`` shares the lists of ``bfloat16``. CPU Ops that can autocast to ``bfloat16`` """"""""""""""""""""""""""""""""""""""""" @@ -375,19 +380,25 @@ CPU Ops that can autocast to ``bfloat16`` ``conv3d``, ``bmm``, ``mm``, +``linalg_vecdot``, ``baddbmm``, ``addmm``, ``addbmm``, ``linear``, ``matmul``, -``_convolution`` +``_convolution``, +``conv_tbc``, +``mkldnn_rnn_layer``, +``conv_transpose1d``, +``conv_transpose2d``, +``conv_transpose3d``, +``prelu``, +``scaled_dot_product_attention``, +``_native_multi_head_attention`` CPU Ops that can autocast to ``float32`` """""""""""""""""""""""""""""""""""""""" -``conv_transpose1d``, -``conv_transpose2d``, -``conv_transpose3d``, ``avg_pool3d``, ``binary_cross_entropy``, ``grid_sampler``, @@ -421,9 +432,22 @@ CPU Ops that can autocast to ``float32`` ``replication_pad2d``, ``replication_pad3d``, ``mse_loss``, +``cosine_embedding_loss``, +``nll_loss``, +``nll_loss2d``, +``hinge_embedding_loss``, +``poisson_nll_loss``, +``cross_entropy_loss``, +``l1_loss``, +``huber_loss``, +``margin_ranking_loss``, +``soft_margin_loss``, +``triplet_margin_loss``, +``multi_margin_loss``, ``ctc_loss``, ``kl_div``, ``multilabel_margin_loss``, +``binary_cross_entropy_with_logits``, ``fft_fft``, ``fft_ifft``, ``fft_fft2``, @@ -438,7 +462,6 @@ CPU Ops that can autocast to ``float32`` ``fft_irfftn``, ``fft_hfft``, ``fft_ihfft``, -``linalg_matrix_norm``, ``linalg_cond``, ``linalg_matrix_rank``, ``linalg_solve``, @@ -451,14 +474,10 @@ CPU Ops that can autocast to ``float32`` ``linalg_tensorinv``, ``linalg_tensorsolve``, ``fake_quantize_per_tensor_affine``, -``eig``, ``geqrf``, -``lstsq``, ``_lu_with_info``, ``qr``, -``solve``, ``svd``, -``symeig``, ``triangular_solve``, ``fractional_max_pool2d``, ``fractional_max_pool3d``, diff --git a/docs/source/notes/amp_examples.rst b/docs/source/notes/amp_examples.rst index f95f99b7ac2fa4..1ae63c4396cc6a 100644 --- a/docs/source/notes/amp_examples.rst +++ b/docs/source/notes/amp_examples.rst @@ -9,7 +9,7 @@ Ordinarily, "automatic mixed precision training" means training with :class:`torch.autocast` and :class:`torch.amp.GradScaler` together. Instances of :class:`torch.autocast` enable autocasting for chosen regions. -Autocasting automatically chooses the precision for GPU operations to improve performance +Autocasting automatically chooses the precision for operations to improve performance while maintaining accuracy. Instances of :class:`torch.amp.GradScaler` help perform the steps of From fc90d7733482f7c948be1a579fcbe3e9a9cbe67c Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 13 Sep 2024 12:21:57 +0000 Subject: [PATCH 0820/1018] [AOTI][reland] Fix assert_function call in cpu autotune template (#135920) Summary: Reland https://github.com/pytorch/pytorch/pull/135086. In the ABI-compatible mode, assert_function should be AOTI_TORCH_CHECK. Test Plan: CI Differential Revision: D62500592 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135920 Approved by: https://github.com/chenyang78 --- test/inductor/test_aot_inductor.py | 4 +++- test/inductor/test_cpu_cpp_wrapper.py | 13 ++++++------- torch/_inductor/codegen/cpp.py | 4 +--- torch/_inductor/codegen/cpp_micro_gemm.py | 10 +++++----- torch/_inductor/codegen/cpp_prefix.h | 1 + torch/_inductor/codegen/cpp_template.py | 9 ++++----- torch/_inductor/graph.py | 3 +-- 7 files changed, 21 insertions(+), 23 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 4dee53c4dc41ae..27885796b86766 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3270,7 +3270,9 @@ def forward(self, values, offsets): model, example_inputs_list, dynamic_shapes=dynamic_shapes ) - @common_utils.parametrize("max_autotune", [False, True]) + # max_autotune is disabled due to https://github.com/pytorch/pytorch/issues/135106 + # @common_utils.parametrize("max_autotune", [False, True]) + @common_utils.parametrize("max_autotune", [False]) def test_misc_1(self, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 5dfe5487799a54..50a21411a28144 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -87,13 +87,7 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): } ) if config.abi_compatible: - xfail_list = [ - *[ - func - for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) - if func.startswith("test_linear_with_pointwise") - ], - ] + xfail_list = [] for test_name in xfail_list: test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( ("cpp_wrapper",), is_skip=False @@ -103,6 +97,11 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False) skip_list = [ "test_multihead_attention_cpu", + *[ + func + for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) + if func.startswith("test_linear_with_pointwise") + ], ] for test_name in skip_list: test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index cef37d18ac20db..8bbe8023a312af 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2146,9 +2146,7 @@ def codegen_loops(self, code, worksharing): @property def assert_function(self) -> str: - if V.graph.aot_mode: - # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models - # compared with JIT Inductor which uses TORCH_CHECK + if config.abi_compatible: return "AOTI_TORCH_CHECK" else: return "TORCH_CHECK" diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index cb26c48fce53c6..2f0c3e78f56791 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -332,8 +332,8 @@ class CppMicroGemmFP32Vec(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { - TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); @@ -364,7 +364,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm): break; {%- endfor %} default: - {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); } } } @@ -509,8 +509,8 @@ class CppMicroGemmAMX(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { - TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 8972e5ef522d42..fd5c380cd7711e 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -26,6 +26,7 @@ #include #include #include +#include #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) #define INDUCTOR_USE_VECTOR_TYPES() 1 diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 2bce16b2ad1538..a237924b9182d4 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -111,11 +111,10 @@ def make_kernel_render( def header(self) -> IndentedBuffer: res = IndentedBuffer() res.writeline(codecache.cpp_prefix()) - res.splice( - """ - #include "c10/util/Unroll.h" - """ - ) + # TODO: add c10::ForcedUnroll test to test_aoti_abi_check + res.splice("""#include """) + if config.abi_compatible: + res.splice("""#include """) enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ "linux", "win32", diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 9fdcf4a5e98da7..989977897f9d92 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -854,7 +854,6 @@ def get_original_value_of_constant(self, name: str) -> torch.Tensor: def allocate_non_dup_const_name( self, name: Optional[str], data: Union[Tensor] ) -> str: - orig_name = name if not config.aot_inductor.use_runtime_constant_folding: for constant_name, value in self.constants.items(): if ( @@ -871,7 +870,7 @@ def allocate_non_dup_const_name( if name is None: name = f"constant{len(self.constants)}" - assert name is not None + orig_name = name if name[0].isdigit(): name = f"constant_{name}" name = self.qualify_name(name) From 46097f0dda7cf736d5fbfcf2f11ce05cbbfef67a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:29:03 +0000 Subject: [PATCH 0821/1018] Revert "[dynamo] Fix support for classmethod(property(...)) (#134968)" This reverts commit bf68e16e94fc05f10d434cdc162a14d02c6ad23c. Reverted https://github.com/pytorch/pytorch/pull/134968 on behalf of https://github.com/jithunnair-amd due to Broke ROCm CI: eg. https://github.com/pytorch/pytorch/actions/runs/10845542664/job/30097956613 ([comment](https://github.com/pytorch/pytorch/pull/134968#issuecomment-2348837553)) --- test/dynamo/test_repros.py | 77 +------------------------ torch/_dynamo/variables/constant.py | 2 - torch/_dynamo/variables/user_defined.py | 4 -- 3 files changed, 1 insertion(+), 82 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 43b2278cc5c085..e1e1dd1837bef4 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -20,7 +20,7 @@ from abc import ABC from collections import namedtuple from copy import deepcopy -from enum import Enum, IntEnum +from enum import Enum from functools import wraps from typing import Any, Dict, Iterator, List, Tuple from unittest import mock @@ -4546,81 +4546,6 @@ def f(*args): f(*args) self.assertEqual(num_compiles, 1) - def test_issue134451(self): - class BoundingBox2DIndex(IntEnum): - _X = 0 - _Y = 1 - _HEADING = 2 - _LENGTH = 3 - _WIDTH = 4 - - @classmethod - def size(cls): - return 5 - - @classmethod - @property - def X(cls): - return cls._X - - @classmethod - @property - def Y(cls): - return cls._Y - - @classmethod - @property - def HEADING(cls): - return cls._HEADING - - @classmethod - @property - def LENGTH(cls): - return cls._LENGTH - - @classmethod - @property - def WIDTH(cls): - return cls._WIDTH - - @classmethod - @property - def POINT(cls): - # assumes X, Y have subsequent indices - return slice(cls._X, cls._Y + 1) - - @classmethod - @property - def STATE_SE2(cls): - # assumes X, Y, HEADING have subsequent indices - return slice(cls._X, cls._HEADING + 1) - - class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self._mlp_states = nn.Sequential( - nn.Linear(10, 20), - nn.ReLU(), - nn.Linear(20, BoundingBox2DIndex.size()), - ) - - def forward(self, x): - agent_states = self._mlp_states(x) - agent_states[..., BoundingBox2DIndex.POINT] = ( - agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 - ) - agent_states[..., BoundingBox2DIndex.HEADING] = ( - agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi - ) - return agent_states - - model = SimpleModel().eval() - input_tensor = torch.randn(1, 10, dtype=torch.float32) - opt = torch.compile(model.eval(), backend="eager", fullgraph=True) - actual = opt(input_tensor) - expected = model(input_tensor) - self.assertEqual(actual, expected) - def test_invalid_seq_unpack(self): def myfn(arg): (a, b) = arg diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index de357cf8094f3a..65de0aab6ef3c7 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -222,8 +222,6 @@ def create(cls, cls_type, value_vt, options): unimplemented("Enum variable is constructed with non constant values") def as_proxy(self): - if isinstance(self.value, int): - return int(self.value) # convert IntEnum to a normal int return self.value def __str__(self) -> str: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 32057c838d74c8..14499b4d2e4bf7 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -193,10 +193,6 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke else: return SourcelessBuilder.create(tx, func) elif isinstance(obj, classmethod): - if isinstance(obj.__func__, property): - return variables.UserFunctionVariable(obj.__func__.fget).call_function( - tx, [self], {} - ) return variables.UserMethodVariable(obj.__func__, self, source=source) elif isinstance(obj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static(dict, "fromkeys") From e2ad410b5d1a3eb7d1ea2448f5608d882a79f24c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:35:05 +0000 Subject: [PATCH 0822/1018] Revert "[PT2][inductor][Optimus] Add pad_aten_mm_pass pattern to resolve long computation kernel in LCE (#135167)" This reverts commit eb0fe029337b31bcb3d4b2d1e539895393975d68. Reverted https://github.com/pytorch/pytorch/pull/135167 on behalf of https://github.com/jithunnair-amd due to Broke ROCm CI eg. https://github.com/pytorch/pytorch/actions/runs/10845542664/job/30097957154 ([comment](https://github.com/pytorch/pytorch/pull/135167#issuecomment-2348847595)) --- test/inductor/test_pad_mm.py | 34 -------------------------- torch/_inductor/fx_passes/pad_mm.py | 23 ----------------- torch/_inductor/fx_passes/split_cat.py | 1 - 3 files changed, 58 deletions(-) diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index 957ec59916de6c..dbe1170994fdbe 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -9,7 +9,6 @@ get_pad_cache, get_padded_length, should_pad_common, - should_pad_mm_bf16, ) from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code @@ -450,39 +449,6 @@ def mm(inps, b): repr(get_pad_cache().get_local_cache()) ) - @fresh_inductor_cache() - @inductor_config.patch( - post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}} - ) - def test_pad_mm_bf16(self): - m = 2 - n = 13 - k = 15691904 - mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16) - mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16) - expected_alignment = get_alignment_size(mat1) - - assert expected_alignment == 8, "Alignment for bfloat16 should be 8" - assert should_pad_common( - mat1, mat2 - ), "This should pass the common padding criteria" - assert should_pad_mm_bf16( - mat1.dtype, m, n, k - ), "This should pass the should_pad_mm_bf16 padding criteria" - - @torch.compile() - def mm(mat1, mat2): - return torch.mm(mat1, mat2) - - res2, (code,) = run_and_get_code(mm, mat1, mat2) - mm_expected_result = torch.mm(mat1, mat2) - # in call code, expect to see a single pad per input, and then we should see padded allocation for output - FileCheck().check("del async_compile").check_count( - ".run(", 2, exactly=True - ).check("empty_strided_cuda((8, 16)").run(code) - - assert torch.allclose(res2, mm_expected_result), "MM results are not identical" - if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index a00bad39747918..87450b34e7ee27 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -364,23 +364,6 @@ def should_pad(key: str, ori_time, pad_time) -> bool: return should_pad -def should_pad_mm_bf16(dtype, M, N, K): - # always force pad for mm with bf16 when the following are satisfied to avoid perf regression - large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ - "pad_aten_mm_pass" - ].get("k_threshold_to_pad", 8388608) - if ( - dtype is torch.bfloat16 - and K > M - and K > N - and N % 2 == 1 - and K >= large_k_threshold_to_pad - and torch.cuda.get_device_capability() < (9, 0) - ): # doesnt repro on h100s: - return True - return False - - def should_pad_bench( match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None ) -> bool: @@ -427,12 +410,6 @@ def realize_symbols(ds): if torch._inductor.config.force_shape_pad: return True - if ( - "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options - and should_pad_mm_bf16(mat1.dtype, m, n, k) - ): - return True - if not has_triton(): return False diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index b3cef2ea079814..d0098a535ee295 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -65,7 +65,6 @@ "decompose_mm_pass", "unbind_stack_aten_pass", "shape_padding_multiplier", - "pad_aten_mm_pass", ] for pass_name in pre_grad_pass_names: From f4df4fbd64cd139694c41bb9f9cee4a696161ef1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:52:57 +0000 Subject: [PATCH 0823/1018] Revert "[Dynamo] Remove ignored modes from torch function mode stack guard (#135503)" This reverts commit c56728b643e2b7d796abd7ec45803319e1c5967d. Reverted https://github.com/pytorch/pytorch/pull/135503 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378)) --- torch/_C/_dynamo/guards.pyi | 2 +- torch/_dynamo/guards.py | 1 + torch/csrc/dynamo/guards.cpp | 69 +++++++++++++++++++++++++++++++----- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 918d913068e6df..d6aab0d547f727 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -67,7 +67,7 @@ class GuardManager: ) -> None: ... def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_torch_function_mode_stack_guard( - self, initial_stack, verbose_code_parts: list[str] + self, initial_stack, ignored_types, verbose_code_parts: list[str] ) -> None: ... class RootGuardManager(GuardManager): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 3dcc9a032f208f..66deeb3cbc0e79 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2350,6 +2350,7 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, + list(), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 1d1cf2fb0c60a6..771de44427b669 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2515,40 +2515,90 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { public: TORCH_FUNCTION_MODE_STACK( const py::list& initial_stack, + const py::list& ignored_types, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), _ref_stack() { + : LeafGuard(std::move(verbose_code_parts)), + _ref_stack(), + _ignored_types() { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref auto type = Py_TYPE(mode); this->_ref_stack.push_back(type); } + + len = PyList_Size(ignored_types.ptr()); + for (Py_ssize_t idx = 0; idx < len; idx++) { + PyObject* type_obj = + PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref + if (PyType_Check(type_obj) == 0) { + PyErr_SetString( + PyExc_TypeError, "ignored_types should contain a list of types"); + return; + } + PyTypeObject* type = (PyTypeObject*)type_obj; + this->_ignored_types.insert(type); + } } bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface - const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); + size_t ref_ind = 0; + const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - if (len != ref_stack_size) { - return false; + int64_t idx = 0; + while ((idx < len) && (ref_ind < ref_stack_size)) { + std::shared_ptr mode = + at::impl::PythonTorchFunctionTLS::get_stack_at(idx); + + PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + bool act_ignored = this->_ignored_types.count(mode_type) > 0; + bool ref_ignored = + this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; + // skip ignored types + if (act_ignored && ref_ignored) { + idx++; + ref_ind++; + continue; + } else if (ref_ignored) { + ref_ind++; + continue; + } else if (act_ignored) { + idx++; + continue; + } + // if we already have more non-ignored modes than the ref stack + // or if the mode doesn't match at the current index, return false + else if (mode_type != _ref_stack.at(ref_ind)) { + return false; + } + ref_ind++; + idx++; } - for (int64_t idx = 0; (size_t)idx < len; idx++) { + for (; ref_ind < ref_stack_size; ref_ind++) { + if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { + return false; + } + } + + for (; idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (mode_type != _ref_stack.at(idx)) { + if (!(this->_ignored_types.count(mode_type) > 0)) { return false; } } - return true; + return ref_ind == ref_stack_size && idx == len; } private: std::vector _ref_stack; + std::set _ignored_types; }; class TENSOR_MATCH : public LeafGuard { @@ -3713,7 +3763,7 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") - .def(py::init()) + .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( py_m, "DATA_PTR_MATCH") @@ -3950,9 +4000,10 @@ PyObject* torch_c_dynamo_guards_init() { "add_torch_function_mode_stack_guard", [](GuardManager& self, const py::list& initial_stack, + const py::list& ignored_types, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - initial_stack, std::move(verbose_code_parts))); + initial_stack, ignored_types, std::move(verbose_code_parts))); }) .def( "add_data_ptr_guard", From 1d2bdcee0fdcdf8925e5501b8401c96c01b88bb9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:52:57 +0000 Subject: [PATCH 0824/1018] Revert "[Dynamo] Remove ignored modes workaround (#135502)" This reverts commit 7d5e0dd4b1a8d20fc8624b3085a6f5ddedd89a2e. Reverted https://github.com/pytorch/pytorch/pull/135502 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378)) --- test/dynamo/test_modes.py | 65 +++++++++++++++++++++++ torch/_dynamo/guards.py | 12 +++-- torch/_dynamo/utils.py | 10 +++- torch/_dynamo/variables/torch_function.py | 6 +++ 4 files changed, 88 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 4d1f2bbea389ef..63e06a515ed186 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -1,4 +1,5 @@ # Owner(s): ["module: dynamo"] +from unittest.mock import patch import torch import torch._dynamo.test_case @@ -106,6 +107,70 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 4) + def _run_ignored_mode_types_test(self): + class IgnoredMode(BaseTorchFunctionMode): + pass + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt.__call__, fullgraph=True) + def fn(x): + return x + 1 + + inp = torch.ones(2, 2) + + with patch( + "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} + ): + # initial compile + fn(inp) + + # no recompile, mode ignored + # note: the ref stack is length 0, and the stack we are checking against has length 2 + # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack + with IgnoredMode(), IgnoredMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 1) + + # recompile due to new mode on the stack + with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 2) + + # recompile + # tests both ref stack len > runtime stack len for the above guard check + # and ref stack len < runtime stack len for the initial zero mode case + with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 3) + + # no recompile + with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 3) + + # This is tricky, basically the ignored modes are baked into the guard + # IgnoredMode will be ignored forever by that guard. + # This is okay since we don't expect to be modifying IGNORED_MODES + # in the middle of execution except for the purposes of testing. + torch._dynamo.reset() + + with IgnoredMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 4) + + @torch._dynamo.config.patch("enable_cpp_guard_manager", False) + def test_torch_function_mode_guards_ignored_types_py(self): + self._run_ignored_mode_types_test() + + def test_torch_function_mode_guards_ignored_types_cpp(self): + self._run_ignored_mode_types_test() + @torch._dynamo.config.patch("enable_cpp_guard_manager", False) def test_torch_function_mode_guards_py(self): self._run_torch_function_mode_guard_test() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 66deeb3cbc0e79..1544bdd6dc697b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2344,13 +2344,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): ) if config.enable_cpp_guard_manager: + from .variables.torch_function import IGNORED_MODES + # Insert the global_state guard assert self.guard_manager # to make mypy happy self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(), + list(IGNORED_MODES), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list @@ -2657,14 +2659,18 @@ def is_recompiles_verbose_enabled(): # this will only be used if cpp guards are disabled def make_torch_function_mode_stack_guard(intial_stack): types = [type(x) for x in intial_stack] + from .variables.torch_function import IGNORED_MODES def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - if len(cur_stack) != len(types): + types_ = [ty for ty in types if ty not in IGNORED_MODES] + cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] + + if len(cur_stack_) != len(types_): return False - for ty, mode in zip(types, cur_stack): + for ty, mode in zip(types_, cur_stack_): if ty != type(mode): return False diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 6f3a74eb2625d2..ef63ddbea85619 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3061,10 +3061,16 @@ def is_parameter_freezing(): return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(): - return [ +def get_torch_function_mode_stack(filter_ignored=True): + from .variables.torch_function import IGNORED_MODES + + stack = [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] + if filter_ignored: + stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] + + return stack def get_torch_function_mode_stack_at(ind): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index ffb3d27d4d703b..b7fd83a3d24f15 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -69,6 +69,12 @@ if is_tensor_base_attr_getter(fn) ] +# Today set default device is placed in the graph and guarded on separately +# so we should not trace through it. In the future we can trace it once +# mode tracing is implemented and not put in the graph, but this is more +# of a BE project and can be evaluated later +IGNORED_MODES = {DeviceContext} + @functools.lru_cache(None) def get_prev_stack_var_name(): From e6f44acc9c38b8e8a0f0373004469f8dba97c234 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:52:57 +0000 Subject: [PATCH 0825/1018] Revert "[Dynamo] Trace enter/exit of TorchFunctionModes (#135422)" This reverts commit 2af3b8ffd84e36b91279174e9106f84b2d2a11f2. Reverted https://github.com/pytorch/pytorch/pull/135422 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378)) --- test/dynamo/test_modes.py | 88 ------------------ torch/_dynamo/convert_frame.py | 6 +- torch/_dynamo/output_graph.py | 6 +- torch/_dynamo/polyfills/__init__.py | 20 ---- torch/_dynamo/resume_execution.py | 104 --------------------- torch/_dynamo/side_effects.py | 11 --- torch/_dynamo/symbolic_convert.py | 30 ++---- torch/_dynamo/testing.py | 1 - torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/builder.py | 20 ++-- torch/_dynamo/variables/ctx_manager.py | 12 --- torch/_dynamo/variables/torch.py | 14 +-- torch/_dynamo/variables/torch_function.py | 108 ++-------------------- torch/_dynamo/variables/user_defined.py | 14 +-- 14 files changed, 34 insertions(+), 402 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 63e06a515ed186..fa4c23fd320ddb 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -461,94 +461,6 @@ def fn(x): self.assertEqual(expected, actual) - def test_torch_function_mode_enter_exit(self): - def fn(x, y): - with TestMode(): - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn, fullgraph=True) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_graph_break(self): - def fn(x, y): - with TestMode(): - torch._dynamo.graph_break() - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_and_pop_graph_break(self): - def fn(x, y): - with TestMode(): - z = _pop_torch_function_stack() - torch._dynamo.graph_break() - _push_on_torch_function_stack(z) - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_restore_on_exc(self): - @torch._dynamo.disable() - def err(): - raise RuntimeError("test") - - @torch.compile() - def fn(x): - with TestMode(): - x += 1 - err() - x += 2 - return x - - try: - fn(torch.ones(2, 2)) - except RuntimeError: - pass - self.assertEqual(_len_torch_function_stack(), 0) - - def test_torch_function_mode_and_pop_graph_break_mutation(self): - def fn(x, y): - with TestMode(): - z = _pop_torch_function_stack() - z.y = 5 - torch._dynamo.graph_break() - _push_on_torch_function_stack(z) - o = torch.add(x, 3) - o = torch.mul(o, z.y) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 5f543545ecdba3..4a166e1ff80251 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -112,7 +112,6 @@ troubleshooting_url, write_record_to_file, ) -from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -211,18 +210,15 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() + exit_stack = contextlib.ExitStack() exit_stack.enter_context( torch.fx._symbolic_trace._maybe_revert_all_patches() ) - exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() - assert ( - torch._C._len_torch_function_stack() == 0 - ), "Torch function mode stack state changed while dynamo tracing, please report a bug" exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 0ea19e1cad6586..ba8cfa5a03dfba 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,6 +78,7 @@ get_instruction_source_311, get_locals_to_steal, get_static_address_type, + get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -249,7 +250,6 @@ def __init__( local_scope: Scope, global_scope: Scope, f_code, - torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -368,7 +368,7 @@ def __init__( # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = torch_function_mode_stack + self.torch_function_mode_stack = get_torch_function_mode_stack() # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ def append_prefix_insts(): prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx, is_graph_break=reason.graph_break) + block.exit(tx) self.cleanup_graph() tx.prune_dead_locals() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 5b2812bc08c9e5..2b3f38920a0c2a 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -25,26 +25,6 @@ sys as sys, ) -from torch.overrides import BaseTorchFunctionMode - - -# These classes handle support for TorchFunctionModes across -# graph breaks -# Today the TorchFunctionMode enter (for the classes we support) -# simply pushes the mode onto the stack. Since after this occurs -# the stack is mutated, and we replay these mutations, we don't need -# any cleanup logic to be run once the graph break occurs, we simply replay -# these mutations to ensure at the graph break the torch function mode stack is correct -# and reconstruct the torch function mode stack normally -# when we compile the resume function on the other side of the break. -# However, to ensure we exit properly -# in the resume function, we need to re-enter the contexts as we do other contexts. -# These contexts do nothing on enter, but provide the correct exit logic to ensure -# the stack state is correct. -class NoEnterTorchFunctionMode(BaseTorchFunctionMode): - def __enter__(self): - pass - def index(iterator, item, start=0, end=None): from itertools import islice diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 6f7db66514c487..132e9e4081bceb 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -6,7 +6,6 @@ from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( - add_push_null, create_call_function, create_call_method, create_dup_top, @@ -49,109 +48,6 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None - def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): - """ - Codegen based off of: - try: - (rest) - except: - (restore previous stack) - - """ - from .variables.torch_function import get_prev_stack_var_name - - except_jump_target = create_instruction( - "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" - ) - cleanup_complete_jump_target = create_instruction("NOP") - - setup_finally: List[Instruction] = [] - - if sys.version_info < (3, 11): - setup_finally.append( - create_instruction("SETUP_FINALLY", target=except_jump_target) - ) - else: - exn_tab_begin = create_instruction("NOP") - exn_tab_end = create_instruction("NOP") - exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( - exn_tab_begin, - exn_tab_end, - except_jump_target, - self.stack_index + 1, - False, - ) - setup_finally.append(exn_tab_begin) - - def create_reset(): - insts = [ - create_instruction( - "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" - ), - create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), - ] - add_push_null(insts) - return [ - *insts, - create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()), - *create_call_function(1, False), - create_instruction("POP_TOP"), - ] - - if sys.version_info < (3, 9): - epilogue = [ - create_instruction("POP_BLOCK"), - create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), - except_jump_target, - *create_reset(), - create_instruction("POP_TOP"), - create_instruction("POP_TOP"), - create_instruction("POP_TOP"), - *create_reset(), - create_instruction("RAISE_VARARGS", argval=0), - create_instruction("POP_EXCEPT", argval=0), - create_instruction("END_FINALLY"), - cleanup_complete_jump_target, - ] - elif sys.version_info < (3, 11): - epilogue = [ - create_instruction("POP_BLOCK"), - create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), - except_jump_target, - create_instruction("POP_TOP"), - create_instruction("POP_TOP"), - create_instruction("POP_TOP"), - *create_reset(), - create_instruction("RAISE_VARARGS", argval=0), - create_instruction("POP_EXCEPT", argval=0), - cleanup_complete_jump_target, - ] - else: - finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) - finally_exn_tab_target = create_instruction("COPY", arg=3) - except_jump_target.exn_tab_entry = InstructionExnTabEntry( - except_jump_target, - finally_exn_tab_end, - finally_exn_tab_target, - self.stack_index + 2, - True, - ) - epilogue = [ - exn_tab_end, - create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), - except_jump_target, # PUSH_EXC_INFO - create_instruction("POP_TOP"), - *create_reset(), - finally_exn_tab_end, - finally_exn_tab_target, # COPY 3 - create_instruction("POP_EXCEPT"), - create_instruction("RERAISE", arg=1), # RERAISE 1 - cleanup_complete_jump_target, - ] - - cleanup[:] = epilogue + cleanup - return setup_finally - # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fd5b54e4f7f7af..fc1bd976ff57df 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -593,22 +593,11 @@ def codegen_update_mutated(self, cg: PyCodegen): elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): - # Needed in the finally block for stack restoration - cg.add_push_null( - lambda: cg.load_import_from( - utils.__name__, "get_torch_function_mode_stack" - ) - ) - cg.call_function(0, False) - name = variables.torch_function.get_prev_stack_var_name() - cg.code_options["co_varnames"] += (name,) - cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) - cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 61acaec995757a..415179a5ed32a6 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -267,12 +267,13 @@ def resume_fn(self): else: return ReenterWith(self.stack_index) - def exit(self, tx, is_graph_break): + def exit(self, tx): + if hasattr(self, "graph_break") and isinstance( + self.with_context, TorchFunctionModeVariable + ): + return assert self.with_context is not None - if ( - is_graph_break and self.with_context.exit_on_graph_break() - ) or not is_graph_break: - return self.with_context.exit(tx) + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -638,17 +639,10 @@ def handle_graph_break( cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: - # Don't exit any modes we have entered, - # output bytecode will mutate the tf mode stack accordingly - if isinstance(b.with_context, TorchFunctionModeVariable): - cg.extend_output( - b.resume_fn().try_except_torch_function_mode( - cg.code_options, cleanup - ) - ) - continue assert b.with_context is not None - assert isinstance(b.with_context, (ContextWrappingVariable)) + assert isinstance( + b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) + ) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -2301,10 +2295,7 @@ def setup_or_before_with(self, inst): ): unimplemented(f"{inst.opname} {ctx}") - if ( - isinstance(ctx, GenericContextWrappingVariable) - and not ctx.supports_graph_breaks() - ): + if isinstance(ctx, GenericContextWrappingVariable): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2677,7 +2668,6 @@ def __init__( local_scope=f_locals, global_scope=f_globals, f_code=f_code, - torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 704a3889707233..4922b521bada34 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -163,7 +163,6 @@ def insert_nops(instructions, code_options): local_scope=locals(), global_scope=globals(), f_code=frame.f_code, - torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index ed81b26191b4de..1a206ae339c2db 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -304,7 +304,6 @@ "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, - "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -2798,6 +2797,7 @@ "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", + "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 193e227267ec3b..5a8a9b613b899d 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -204,7 +204,6 @@ from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, - torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( @@ -1669,16 +1668,15 @@ def wrap_numpy_ndarray(self, value): # but warning is not the end of the world assert isinstance(value.base, np.nditer) - with torch_function_mode_stack_state_mgr.temp_restore_stack(): - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides - - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index e19c4e254c647e..8c4eb3dc4e715b 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,12 +125,6 @@ def call_function( if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) - def supports_graph_breaks(self): - return True - - def exit_on_graph_break(self): - return True - class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -189,12 +183,6 @@ def exit(self, tx: "InstructionTranslator", *args): tx.generic_context_manager_depth -= 1 return x - def supports_graph_breaks(self): - return False - - def exit_on_graph_break(self): - return True - class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 48f4f4101ba96a..f6cb565ef021c0 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -162,10 +162,7 @@ def get_overridable_functions(): from torch.overrides import get_overridable_functions as get_overridable_functions_ - funcs = set(chain(*get_overridable_functions_().values())) - more = {torch.ones, torch.ones_like, torch.zeros, torch.zeros_like, torch.empty} - funcs.update(more) - return funcs + return set(chain(*get_overridable_functions_().values())) class BaseTorchVariable(VariableTracker): @@ -841,13 +838,6 @@ def handle_len_torch_function( len(tx.symbolic_torch_function_state.mode_stack) ) - @register(torch._C._get_function_stack_at) - def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): - assert len(args) == 1 and not kwargs - ind = args[0].as_python_constant() - assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) - return tx.symbolic_torch_function_state.mode_stack[ind] - @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -865,7 +855,7 @@ def handle_set_default_device( else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return ConstantVariable.create(None) + return None return handlers diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index b7fd83a3d24f15..6e52cc688e0309 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -2,35 +2,22 @@ import collections import contextlib -import functools import inspect from typing import Deque, Dict, List, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import ( - _get_overloaded_args, - get_default_nowrap_functions, - TorchFunctionMode, -) +from torch.overrides import _get_overloaded_args, get_default_nowrap_functions from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard -from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import ( - class_has_getattribute, - clear_torch_function_mode_stack, - get_safe_global_name, - has_torch_function, - is_tensor_base_attr_getter, - set_torch_function_mode_stack, -) +from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import GenericContextWrappingVariable +from .ctx_manager import ContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -76,39 +63,6 @@ IGNORED_MODES = {DeviceContext} -@functools.lru_cache(None) -def get_prev_stack_var_name(): - from ..bytecode_transformation import unique_id - - return unique_id("___prev_torch_function_mode_stack") - - -# Used to clear/restore the python torch function mode stack and temporarily restore it as needed -class TorchFunctionModeStackStateManager: - def __init__(self): - self.stack = [] - - def __enter__(self): - self.stack = torch.overrides._get_current_function_mode_stack() - clear_torch_function_mode_stack() - - def __exit__(self, exc_type, exc_value, traceback): - set_torch_function_mode_stack(self.stack) - self.stack = [] - - @contextlib.contextmanager - def temp_restore_stack(self): - prev = torch.overrides._get_current_function_mode_stack() - set_torch_function_mode_stack(self.stack) - try: - yield - finally: - set_torch_function_mode_stack(prev) - - -torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() - - class SymbolicTorchFunctionState: def __init__(self, py_stack): # This is annoyingly complicated because of how the torch function subclass + mode C API was designed @@ -235,26 +189,9 @@ def get_mode_index(cls, ind): return ind + cls.offset -class TorchFunctionModeVariable(GenericContextWrappingVariable): - @staticmethod - def is_supported_torch_function_mode(ty): - # Supported in this sense means we can support graph breaks under the - # context. - # We are able to trace custom modes but if there are graph breaks under them - # and they have a custom __enter__/__exit__ we don't handle this for the - # same reason we don't handle generic context managers: there may be side effects - # that are now affected by executing the funtion across two frames instead of one - # Today we support the enter/exit of the default TorchFunctionMode as well as - # DeviceContext (which is used for set_default_device) - return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( - not class_has_getattribute(ty) - and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ - and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ - ) - +class TorchFunctionModeVariable(ContextWrappingVariable): def __init__(self, value, source=None, **kwargs): - if value is not None: - super().__init__(value, **kwargs) + super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source @@ -284,39 +221,8 @@ def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwar kwargs, ) - def enter(self, tx): - from .torch import TorchInGraphFunctionVariable - - if isinstance(self.value, NoEnterTorchFunctionMode): - return ConstantVariable.create(None) - - TorchInGraphFunctionVariable( - torch._C._push_on_torch_function_stack - ).call_function(tx, [self], {}) - return ConstantVariable.create(None) - - def exit(self, tx: "InstructionTranslator", *args): - from .torch import TorchInGraphFunctionVariable - - TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( - tx, [], {} - ) - return ConstantVariable.create(None) - - def reconstruct_type(self, codegen): - ty = NoEnterTorchFunctionMode - codegen( - AttrSource( - codegen.tx.import_source(ty.__module__), - ty.__name__, - ) - ) - - def supports_graph_breaks(self): - return True - - def exit_on_graph_break(self): - return False + def _call_func(self, tx: "InstructionTranslator", values): + unimplemented("enter/exit for torch function mode NYI") def _get_all_args(args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 14499b4d2e4bf7..ad92277cf70ff4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -409,22 +409,10 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): - from torch.overrides import TorchFunctionMode - from .ctx_manager import GenericContextWrappingVariable - from .torch_function import TorchFunctionModeVariable - - if issubclass( - self.value, TorchFunctionMode - ) and TorchFunctionModeVariable.is_supported_torch_function_mode( - self.value - ): - var_cls = TorchFunctionModeVariable - else: - var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, var_cls, {} + self.source, self.value, GenericContextWrappingVariable, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj From 317e0d187eae9f6ffbd6b1f66b43a2dbde1e141b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:52:57 +0000 Subject: [PATCH 0826/1018] Revert "[Dynamo] Simplify torch function mode stack guard (#135444)" This reverts commit 0c080cb2c78a85a5320fbeadbbb9a2cc640fd89d. Reverted https://github.com/pytorch/pytorch/pull/135444 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378)) --- test/dynamo/test_modes.py | 2 -- torch/_dynamo/guards.py | 10 ++++---- torch/csrc/dynamo/guards.cpp | 44 +++++++----------------------------- 3 files changed, 12 insertions(+), 44 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index fa4c23fd320ddb..bad5a788903440 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -68,11 +68,9 @@ def tearDownClass(cls): def setUp(self): torch.set_default_device(None) - torch._dynamo.reset() def tearDown(self): torch.set_default_device(None) - torch._dynamo.reset() def _run_torch_function_mode_guard_test(self): class TestMode1(BaseTorchFunctionMode): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1544bdd6dc697b..4437ef4038190d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2663,14 +2663,12 @@ def make_torch_function_mode_stack_guard(intial_stack): def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - - types_ = [ty for ty in types if ty not in IGNORED_MODES] - cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] - - if len(cur_stack_) != len(types_): + if len(cur_stack) != len(types): return False - for ty, mode in zip(types_, cur_stack_): + for ty, mode in zip(types, cur_stack): + if ty in IGNORED_MODES: + continue if ty != type(mode): return False diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 771de44427b669..afc3182c632640 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2523,8 +2523,7 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref - auto type = Py_TYPE(mode); - this->_ref_stack.push_back(type); + this->_ref_stack.push_back(Py_TYPE(mode)); } len = PyList_Size(ignored_types.ptr()); @@ -2544,56 +2543,29 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface size_t ref_ind = 0; - const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - int64_t idx = 0; - while ((idx < len) && (ref_ind < ref_stack_size)) { + for (int64_t idx = 0; idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - bool act_ignored = this->_ignored_types.count(mode_type) > 0; - bool ref_ignored = - this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; // skip ignored types - if (act_ignored && ref_ignored) { - idx++; - ref_ind++; - continue; - } else if (ref_ignored) { - ref_ind++; - continue; - } else if (act_ignored) { - idx++; + if (this->_ignored_types.count(mode_type) > 0) { continue; } // if we already have more non-ignored modes than the ref stack // or if the mode doesn't match at the current index, return false - else if (mode_type != _ref_stack.at(ref_ind)) { + else if ( + (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || + mode_type != _ref_stack[ref_ind]) { return false; } ref_ind++; - idx++; - } - - for (; ref_ind < ref_stack_size; ref_ind++) { - if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { - return false; - } - } - - for (; idx < len; idx++) { - std::shared_ptr mode = - at::impl::PythonTorchFunctionTLS::get_stack_at(idx); - - PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (!(this->_ignored_types.count(mode_type) > 0)) { - return false; - } } - return ref_ind == ref_stack_size && idx == len; + return ref_ind == this->_ref_stack.size(); } private: From 46c3acc2d13e9db41953fba9369c92be1449418a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:52:57 +0000 Subject: [PATCH 0827/1018] Revert "[Dynamo] Support thread local setattr (#135443)" This reverts commit 30b007bea329f512af3dc4fd4e6c7d145e807b71. Reverted https://github.com/pytorch/pytorch/pull/135443 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378)) --- test/dynamo/test_misc.py | 15 --------------- torch/_dynamo/variables/user_defined.py | 5 ++--- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 96047d3436ebe1..9d2bb374621134 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3380,21 +3380,6 @@ def fn4(x) -> None: self.assertTrue(same(obj41.y, obj42.y)) self.assertEqual(cnts.frame_count, 1) - def test_thread_local_setattr(self): - from threading import local - - loc = local() - - @torch.compile(fullgraph=True) - def fn(x, l): - l.x = x - return x + 1 - - x = torch.ones(2, 2) - fn(x, loc) - - self.assertTrue(loc.x is x) - def test_user_defined_class_name(self): class MyClassFoo: pass diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ad92277cf70ff4..d069d2e060a108 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -9,7 +9,6 @@ import itertools import random import sys -import threading import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING @@ -705,7 +704,7 @@ def call_method( if method is object.__init__: return ConstantVariable.create(None) - if is_standard_setattr(method) or isinstance(self.value, threading.local): + if is_standard_setattr(method): return self.method_setattr_standard(tx, *args, **kwargs) # [NOTE] OrderedDict, dict subtypes must always have source @@ -803,7 +802,7 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def needs_slow_setattr(self): return not is_standard_setattr( inspect.getattr_static(self.value, "__setattr__", None) - ) and not isinstance(self.value, threading.local) + ) def unpack_var_sequence(self, tx): if ( From 56d4ea8abc10832098ff3f62ba4a808c72a83c5b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:52:58 +0000 Subject: [PATCH 0828/1018] Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)" This reverts commit fafdd588f27e1d56090c6d260d0382c255eaf9eb. Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378)) --- test/dynamo/test_modes.py | 135 --------- test/functorch/test_control_flow.py | 4 - .../pt2e/test_metadata_porting.py | 4 +- torch/_dynamo/convert_frame.py | 5 - torch/_dynamo/guards.py | 14 - torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 78 ++++-- torch/_dynamo/trace_rules.py | 5 +- torch/_dynamo/utils.py | 10 +- torch/_dynamo/variables/ctx_manager.py | 2 - torch/_dynamo/variables/torch.py | 260 ++++++++---------- torch/_dynamo/variables/torch_function.py | 129 ++------- torch/_dynamo/variables/user_defined.py | 7 + torch/_export/non_strict_utils.py | 6 +- 14 files changed, 204 insertions(+), 457 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index bad5a788903440..ee8a9b579ff124 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -14,17 +14,6 @@ from torch.utils._python_dispatch import TorchDispatchMode -class TestMode(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - if func == torch.add: - return torch.zeros(2, 2) - - return super().__torch_function__(func, types, args, kwargs) - - class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -335,130 +324,6 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 2) - def test_nested_torch_function_mode(self): - mode_1_called = False - mode_2_called = False - - def reset_state(): - nonlocal mode_1_called - nonlocal mode_2_called - mode_1_called = False - mode_2_called = False - - ones = torch.ones(2, 2) - zeros = torch.zeros(2, 2) - - class TestMode1(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - nonlocal mode_1_called - - mode_1_called = True - - if func == torch.add: - return zeros - - return super().__torch_function__(func, types, args, kwargs) - - class TestMode2(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - nonlocal mode_2_called - - mode_2_called = True - - if func == torch.mul: - return ones - - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - def fn_2(x): - return torch.mul(x, 3) + torch.add(x, 3) - - inp = torch.ones(2, 2) + 1 - - for fn_i in [fn, fn_2]: - fn_opt = torch.compile(fn_i, fullgraph=True) - with TestMode1(), TestMode2(): - expected = fn_i(inp), mode_1_called, mode_2_called - reset_state() - actual = fn_opt(inp), mode_1_called, mode_2_called - reset_state() - - self.assertEqual(expected, actual) - - def test_torch_function_mode_disable(self): - class TestSubclass(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - if func == torch.add: - return torch.ones(2, 2) - return super().__torch_function__(func, types, args, kwargs) - - class TestMode(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - if func == torch.add: - return torch.zeros(2, 2) - - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) - - fn_opt = torch.compile(fn, fullgraph=True) - with TestMode(), torch._dynamo.config.patch( - "traceable_tensor_subclasses", {TestSubclass} - ): - with torch._C.DisableTorchFunctionSubclass(): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - - with torch._C.DisableTorchFunction(): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_highest_priority(self): - class TestSubclass(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - if func == torch.add: - return torch.ones(2, 2) - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) - - fn_opt = torch.compile(fn, fullgraph=True) - with TestMode(), torch._dynamo.config.patch( - "traceable_tensor_subclasses", {TestSubclass} - ): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index c1fb51af8e2f90..16740fdeddf3a6 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -26,7 +26,6 @@ parametrize, requires_cuda, run_tests, - skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, @@ -2883,7 +2882,6 @@ def f(fct, init, xs): gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x) @skipIfNoDynamoSupport - @skipIfCrossRef # Arg order changes with crossref def test_scan_simple_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -2990,7 +2988,6 @@ def f(x, y): self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") - @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -3253,7 +3250,6 @@ def test_while_loop_compile(self, backend, while_loop_test): self._check_compile(fn, inp, backend=backend) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") - @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] from torch._dynamo.testing import EagerAndRecordGraphs diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 21488e8cacdd42..27df58e159f3b6 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef +from torch.testing._internal.common_utils import IS_WINDOWS class TestHelperModules: @@ -139,8 +139,6 @@ def _test_metadata_porting( self.assertEqual(v, node_tags[k]) return m - @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack - # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_simple_metadata_porting(self): """ Model under test diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 4a166e1ff80251..3ab2937a9496d4 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -605,10 +605,6 @@ def _compile( output: Optional[OutputGraph] = None tracer: Optional[InstructionTranslator] = None - tf_mode_stack: List[ - torch.overrides.TorchFunctionMode - ] = torch.overrides._get_current_function_mode_stack() - @preserve_global_state def transform( instructions: List[Instruction], code_options: Dict[str, object] @@ -622,7 +618,6 @@ def transform( locals, globals, builtins, - tf_mode_stack, code_options, compiler_fn, one_graph, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4437ef4038190d..b0bd273e96bbb5 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -98,7 +98,6 @@ ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, - TorchFunctionModeStackSource, TupleIteratorGetItemSource, TypeSource, UnspecializedBuiltinNNModuleSource, @@ -112,7 +111,6 @@ dict_keys_repr, get_custom_getattr, get_torch_function_mode_stack, - get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, @@ -316,7 +314,6 @@ def uninteresting_files(): "___dict_contains": lambda a, b: a in b, "___tuple_iterator_len": tuple_iterator_len, "___tuple_iterator_getitem": tuple_iterator_getitem, - "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -904,15 +901,6 @@ def get_guard_manager_from_source(self, source): ): assert base_guard_manager # to make mypy happy out = base_guard_manager - elif istype(source, TorchFunctionModeStackSource): - out = root_guard_manager.lambda_manager( - python_lambda=lambda _: get_torch_function_mode_stack_at( - source._get_index() - ), - source=source_name, - example_value=example_value, - guard_manager_enum=guard_manager_enum, - ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -2226,8 +2214,6 @@ def __init__( self.output_graph = output_graph w_builder = None - # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing - # in case a set default device call was made in the graph. self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index c44f8d9e69abf5..6ed9d97f6a1210 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source): ind: int def name(self): - return f"___get_torch_function_mode_stack_at({self._get_index()})" + return "" def _get_index(self): from .variables.torch_function import TorchFunctionModeStackVariable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 415179a5ed32a6..ab92f82aa0f630 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -19,7 +19,20 @@ import types import typing import weakref -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union +from typing import ( + Any, + Callable, + cast, + Deque, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) from unittest.mock import patch import torch @@ -59,12 +72,14 @@ GlobalWeakRefSource, LocalSource, Source, + TorchFunctionModeStackSource, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( counters, get_fake_value, get_instruction_source_311, + get_torch_function_mode_stack, graph_break_dup_warning_checker, istype, LazyString, @@ -105,10 +120,11 @@ ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable -from .variables.torch_function import ( - SymbolicTorchFunctionState, - TorchFunctionModeVariable, -) + + +if TYPE_CHECKING: + from .variables.torch_function import TorchFunctionModeVariable + from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -268,10 +284,6 @@ def resume_fn(self): return ReenterWith(self.stack_index) def exit(self, tx): - if hasattr(self, "graph_break") and isinstance( - self.with_context, TorchFunctionModeVariable - ): - return assert self.with_context is not None return self.with_context.exit(tx) @@ -640,9 +652,7 @@ def handle_graph_break( # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: assert b.with_context is not None - assert isinstance( - b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) - ) + assert isinstance(b.with_context, ContextWrappingVariable) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -718,7 +728,7 @@ class InstructionTranslatorBase( output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] - symbolic_torch_function_state: SymbolicTorchFunctionState + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -2538,7 +2548,7 @@ def __init__( code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_state: SymbolicTorchFunctionState, + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], f_code: types.CodeType, export: bool, inline_depth: int, @@ -2553,7 +2563,7 @@ def __init__( self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals - self.symbolic_torch_function_state = symbolic_torch_function_state + self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack self.stack = [] # stack of variable names for tracking 3.13 closures self.name_stack: list[Any] = [] @@ -2642,7 +2652,6 @@ def __init__( f_locals, f_globals, f_builtins, - torch_function_mode_stack, code_options, compiler_fn, one_graph, @@ -2677,7 +2686,7 @@ def __init__( symbolic_locals={}, # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, - symbolic_torch_function_state=None, # type: ignore[arg-type] # set below + symbolic_torch_function_mode_stack=collections.deque(), f_code=f_code, export=export, inline_depth=0, @@ -2712,9 +2721,7 @@ def __init__( if k in f_locals } - self.symbolic_torch_function_state = SymbolicTorchFunctionState( - torch_function_mode_stack - ) + self._init_torch_function_mode_stack() self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] if export: @@ -2755,6 +2762,29 @@ def _throw_if_in_functorch(self): ) unimplemented(msg) + def _init_torch_function_mode_stack(self): + from .variables.torch_function import TorchFunctionModeStackVariable + + TorchFunctionModeStackVariable.reset() + + self.symbolic_torch_function_mode_stack: Deque[ + TorchFunctionModeVariable + ] = collections.deque() + # We want to retrieve all modes to properly reconstruct the stack if needed + py_stack = get_torch_function_mode_stack(filter_ignored=False) + + if py_stack: + has_device_context = isinstance( + py_stack[0], torch.utils._device.DeviceContext + ) + + for i, val in enumerate(py_stack): + self.symbolic_torch_function_mode_stack.append( + variables.LazyVariableTracker.create( + val, source=TorchFunctionModeStackSource(i) + ) + ) + def get_example_value(self, source: Source): if isinstance(source, LocalSource): return self.f_locals[source.local_name] @@ -3086,7 +3116,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_state, + parent.symbolic_torch_function_mode_stack, closure_cells, func, ) @@ -3096,7 +3126,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_state, + parent.symbolic_torch_function_mode_stack, closure_cells, func, ) @@ -3149,7 +3179,7 @@ def __init__( code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_state: SymbolicTorchFunctionState, + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ) -> None: @@ -3166,7 +3196,7 @@ def __init__( f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, - symbolic_torch_function_state=symbolic_torch_function_state, + symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1a206ae339c2db..391beb07cd600d 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3258,7 +3258,6 @@ def _module_dir(m: types.ModuleType): "torch.testing", "torch.utils._content_store", "torch.utils._contextlib", - "torch.utils._device", "torch.utils._foreach_utils", "torch.utils._python_dispatch", "torch.utils._pytree", @@ -3593,9 +3592,7 @@ def lookup_inner( if reasons is not None: reasons.add("func name is patched_init") return SkipFunctionVariable - elif name == "__torch_function__" or ( - obj and obj.__name__ == "__torch_function__" - ): + elif name == "__torch_function__": if reasons is not None: reasons.add("func name is __torch_function__") return UserFunctionVariable diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ef63ddbea85619..0ec548d788f868 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -63,6 +63,7 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( + _get_function_stack_at, _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, @@ -3064,9 +3065,7 @@ def is_parameter_freezing(): def get_torch_function_mode_stack(filter_ignored=True): from .variables.torch_function import IGNORED_MODES - stack = [ - get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) - ] + stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] if filter_ignored: stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] @@ -3086,11 +3085,6 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) -def clear_torch_function_mode_stack(): - for i in range(_len_torch_function_stack()): - _pop_torch_function_stack() - - def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 8c4eb3dc4e715b..eebe90929b379a 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -637,8 +637,6 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 - tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] - tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] tx.output.set_torch_function_state(values[0]) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f6cb565ef021c0..68508f001dae35 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -156,15 +156,6 @@ bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) -@functools.lru_cache(None) -def get_overridable_functions(): - from itertools import chain - - from torch.overrides import get_overridable_functions as get_overridable_functions_ - - return set(chain(*get_overridable_functions_().values())) - - class BaseTorchVariable(VariableTracker): """common base for all torch.* functions, classes, modules and other things""" @@ -815,10 +806,10 @@ def handle_pop_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - if not tx.symbolic_torch_function_state.mode_stack: + if not tx.symbolic_torch_function_mode_stack: raise unimplemented("Popping from an empty torch function mode stack") TorchFunctionModeStackVariable.register_mutation(tx) - return tx.symbolic_torch_function_state.pop_torch_function_mode() + return tx.symbolic_torch_function_mode_stack.pop() @register(torch._C._push_on_torch_function_stack) def handle_push_torch_function( @@ -826,7 +817,7 @@ def handle_push_torch_function( ): assert len(args) == 1 and not kwargs TorchFunctionModeStackVariable.register_mutation(tx) - tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) + tx.symbolic_torch_function_mode_stack.append(args[0]) return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) @@ -834,9 +825,7 @@ def handle_len_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - return ConstantVariable.create( - len(tx.symbolic_torch_function_state.mode_stack) - ) + return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) @register(torch.set_default_device) def handle_set_default_device( @@ -868,9 +857,6 @@ def call_function( from . import ConstantVariable, SymNodeVariable, TensorVariable from .builder import wrap_fx_proxy - if self.torch_function_override_enabled(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -892,144 +878,147 @@ def call_function( if result: return result - any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + if can_dispatch_torch_function(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + else: + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) - all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args - ) - if ( - getattr(self.value, "__module__", "") == "torch" - and self.value.__name__ in bin_ops - and any_symints_or_symfloats - and all_ints_or_floats - ): - msg = f"""\ + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ - log.warning(msg) - unimplemented(msg) - - # TODO(voz): Replace w/ dynamic shape rewrite table. - # Ideally, we would be able to do this at ctor time, but alas we need a combination - # of value + args to determine this. - fn_ = self.value - if any_symints_or_symfloats: - torch_sym_op = f"_sym_{self.value.__name__}" - if getattr(self.value, "__module__", None) == "math" and hasattr( - torch, torch_sym_op - ): - fn_ = getattr(torch, torch_sym_op) - - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - *proxy_args_kwargs(args, kwargs), - ), - ) + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op + ): + fn_ = getattr(torch, torch_sym_op) + + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) - if ( - isinstance(tensor_variable, TensorVariable) - and "requires_grad" in kwargs - and kwargs["requires_grad"].as_python_constant() - ): - unimplemented( - """factory functions that return tensors that require grad are not supported. + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" - ) + ) - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items - ): + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): + if ( + out_tensor.source + and out_tensor in tx.output.graphargs + and isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor.size != result_tensor.size + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor.size != result_tensor.size + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out_shape != fake_tensor.shape ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] - if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape - ): - # It's hard to get out variants with resizing on graph inputs work - # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where output tensor was non-contiguous" - ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( - "out= op was called where some of the output tensors were non-contiguous" + "out= op was called where output tensor was non-contiguous" ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where some of the output tensors were non-contiguous" + ) + else: + unimplemented(f"out variant of {type(kwargs['out'])}") - return tensor_variable + return tensor_variable def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" @@ -1157,12 +1146,3 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad source ) return result - - def torch_function_override_enabled(self, tx, args, kwargs): - return ( - self.get_function() in get_overridable_functions() - or isinstance( - self.get_function(), - (torch._ops.OpOverload, torch._ops.OpOverloadPacket), - ) - ) and can_dispatch_torch_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 6e52cc688e0309..4b3188507fe4b5 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,11 +1,8 @@ # mypy: ignore-errors -import collections -import contextlib import inspect -from typing import Deque, Dict, List, TYPE_CHECKING +from typing import Dict, List, TYPE_CHECKING -import torch._C import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import _get_overloaded_args, get_default_nowrap_functions @@ -18,7 +15,6 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import ContextWrappingVariable -from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable @@ -63,67 +59,6 @@ IGNORED_MODES = {DeviceContext} -class SymbolicTorchFunctionState: - def __init__(self, py_stack): - # This is annoyingly complicated because of how the torch function subclass + mode C API was designed - # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass - # These are their definitions: - # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered - # (if either are entered, this will be False) - # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR - # torch._C.DisableTorchFunction has been entered - # To disambiguate these and keep myself sane I added a C API to check whether all torch function - # concepts (modes and subclasses) are enabled. - # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate - # the stack length from the enablement state of torch function modes. - # This is important because now if a mode is pushed while dynamo is tracing, we know whether - # or not torch function modes are enabled and whether we should trace it. - self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() - - # This differs from the C API of the same name - # this will only be false iff we have entered torch._C.DisableTorchFunction - # and does not take into account the mode stack length, while the C API bundles these - # two concepts - self.torch_function_mode_enabled = ( - not torch._C._is_torch_function_all_disabled() - ) - - self.cur_mode = None - - TorchFunctionModeStackVariable.reset() - - self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() - - for i, val in enumerate(py_stack): - self.mode_stack.append( - LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) - ) - - def in_torch_function_mode(self): - return len(self.mode_stack) > 0 - - def pop_torch_function_mode(self): - return self.mode_stack.pop() - - def push_torch_function_mode(self, mode_var): - self.mode_stack.append(mode_var) - - def call_torch_function_mode(self, tx, fn, types, args, kwargs): - with self._pop_mode_for_inlining() as cur_mode: - return cur_mode.call_torch_function(tx, fn, types, args, kwargs) - - @contextlib.contextmanager - def _pop_mode_for_inlining(self): - old_mode = self.cur_mode - self.cur_mode = self.pop_torch_function_mode() - try: - yield self.cur_mode - finally: - mode = self.cur_mode - self.cur_mode = old_mode - self.push_torch_function_mode(mode) - - class TorchFunctionModeStackVariable(VariableTracker): """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" @@ -153,20 +88,19 @@ def reset(cls): def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( - source=Source(), - symbolic_stack=tx.symbolic_torch_function_state.mode_stack, + source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_state.mode_stack + stack = tx.symbolic_torch_function_mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 - stack.insert( + tx.symbolic_torch_function_mode_stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) @@ -175,7 +109,7 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_state.mode_stack + stack = tx.symbolic_torch_function_mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @@ -190,39 +124,23 @@ def get_mode_index(cls, ind): class TorchFunctionModeVariable(ContextWrappingVariable): - def __init__(self, value, source=None, **kwargs): + def __init__(self, value, **kwargs): super().__init__(value, **kwargs) self.value = value - self.cm_obj = value # needed for BC with calling enter from CM code - self.source = source + + @staticmethod + def get_global_mangled_name(tx, val): + return get_safe_global_name( + tx, f"__torch_function_mode_{val.__class__.__name__}", val + ) def reconstruct(self, codegen): - # This shouldn't be called unless we have a source + # We don't support locally created torch function modes yet assert self.source self.source.reconstruct(codegen) - def module_name(self): - return self.value.__module__ - - def fn_name(self): - return type(self.value).__name__ - - def python_type(self): - return type(self.value) - - def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): - return call_torch_function( - tx, - self, - build_torch_function_fn(tx, self.value, self.source), - fn, - types, - args, - kwargs, - ) - - def _call_func(self, tx: "InstructionTranslator", values): - unimplemented("enter/exit for torch function mode NYI") + def _call_func(self, tx, values): + unimplemented("torch function mode context manager is not supported yet") def _get_all_args(args, kwargs): @@ -313,13 +231,9 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): - has_overridden_args = any( + return tx.output.torch_function_enabled and any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) - tf_state = tx.symbolic_torch_function_state - return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( - tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() - ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): @@ -331,20 +245,11 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): _get_subclass_type, ) - types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) - - if tx.symbolic_torch_function_state.in_torch_function_mode(): - res = tx.symbolic_torch_function_state.call_torch_function_mode( - tx, fn, types, args, kwargs - ) - if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): - return res - for arg in overloaded_args: res = arg.call_torch_function( tx, fn, - types, + TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), args, kwargs, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index d069d2e060a108..34e1d0d10c9f72 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,6 +82,11 @@ def is_forbidden_context_manager(ctx): from _pytest.python_api import RaisesContext from _pytest.recwarn import WarningsChecker + # TODO mlazos: Temporary to get this stack to pass + # remove in subsequent PR + from torch.overrides import BaseTorchFunctionMode + + f_ctxs.append(BaseTorchFunctionMode) f_ctxs.append(RaisesContext) f_ctxs.append(WarningsChecker) except ImportError: @@ -408,6 +413,7 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): + # import here to avoid an unfortunate circular dependency. from .ctx_manager import GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( @@ -415,6 +421,7 @@ def call_function( ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj + elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index ef15e5fea9e973..5c7331659d26a0 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -506,11 +506,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if ( - not torch.compiler.is_dynamo_compiling() - and log.isEnabledFor(logging.DEBUG) - and config.extended_debug_current_loc - ): + if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: frame = _find_user_code_frame() if frame is not None: log.debug( From 12fc5c484c4493ac6c1707d6ac37839c5f1570cb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 12:52:58 +0000 Subject: [PATCH 0829/1018] Revert "[Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732)" This reverts commit e504fb70693d4a3741c3380b6a989d441e84f737. Reverted https://github.com/pytorch/pytorch/pull/134732 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378)) --- torch/_dynamo/backends/debugging.py | 12 ----- torch/_higher_order_ops/cond.py | 20 +++----- torch/_higher_order_ops/scan.py | 16 ++----- torch/_higher_order_ops/while_loop.py | 21 ++------- torch/fx/experimental/proxy_tensor.py | 67 +++++++++++---------------- torch/nn/attention/flex_attention.py | 42 ++++++----------- 6 files changed, 54 insertions(+), 124 deletions(-) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 18784498862e39..87214f7d59774e 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -31,18 +31,6 @@ def eager(gm, fake_tensor_inputs, **kwargs): return gm.forward -def make_eager_backend_with_torch_function_mode(mode): - """Used to trace HOPs (cond and while) for eager exectution, the metadata - TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks - in the HOP, so we need to externally run this mode and not trace it.""" - - def fn(gm, fake_tensor_inputs, **kwargs): - with mode: - return gm.forward - - return fn - - @register_backend def eager_noexcept(gm, fake_tensor_inputs, **kwargs): if kwargs: diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 0467e2899adc28..dee400d76f5964 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -28,7 +28,6 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, @@ -130,10 +129,6 @@ def false_fn(x: torch.Tensor): if torch.compiler.is_dynamo_compiling(): return cond_op(pred, true_fn, false_fn, operands) - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) - if isinstance(pred, (bool, int, float)): log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." @@ -174,15 +169,12 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode(metadata_mode) - else: - backend = "eager" - return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( - pred, true_fn, false_fn, operands - ) + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_pre_dispatch_torch_function_mode(): + return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( + pred, true_fn, false_fn, operands + ) def create_fw_bw_graph_branches(true_fn, false_fn, *operands): diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index d66cff067f6682..d6eecb0a9f29f9 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -20,7 +20,6 @@ from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, track_tensor_tree, @@ -119,19 +118,10 @@ def _scan_op_wrapper(*args, **kwargs): return scan(*args, **kwargs) if not torch._dynamo.is_compiling(): - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) - with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode(metadata_mode) - else: - backend = "eager" - return torch.compile(_scan_op_wrapper, backend=backend, fullgraph=True)( - combine_fn, init, xs, dim=dim, reverse=reverse - ) + return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)( + combine_fn, init, xs, dim=dim, reverse=reverse + ) leaves_init, spec_init = pytree.tree_flatten(init) leaves_xs, spec_xs = pytree.tree_flatten(xs) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index f14321842f40b9..d64f512399d9d5 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -15,11 +15,7 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, - ProxyTorchDispatchMode, - track_tensor_tree, -) +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree class WhileLoopOp(HigherOrderOperator): @@ -117,9 +113,6 @@ def body_fn(iter, x): - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. """ - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. @@ -147,15 +140,9 @@ def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode(metadata_mode) - else: - backend = "eager" - return torch.compile( - _while_loop_op_wrapper, backend=backend, fullgraph=True - )(cond_fn, body_fn, carried_inputs, additional_inputs) + return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( + cond_fn, body_fn, carried_inputs, additional_inputs + ) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 213672e216e6ce..5a5771d2ae8f7a 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext +from contextlib import contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -1084,43 +1084,38 @@ def unwrap_proxy(self, e: T) -> object: return e -def _make_temp_remove_mode_context_manager( - mode_ty: Type[TorchFunctionMode], -) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: - @contextmanager - def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: - from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - - temp_elements = [] - removed_mode = None +@contextmanager +def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - while _len_torch_function_stack() > 0: - mode = _pop_mode() - if isinstance(mode, mode_ty): - removed_mode = mode - break - else: - temp_elements.append(mode) + temp_elements = [] + pre_dispatch_mode = None - for mode in reversed(temp_elements): - _push_mode(mode) + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, PreDispatchTorchFunctionMode): + pre_dispatch_mode = mode + break + else: + temp_elements.append(mode) - try: - yield removed_mode + for mode in reversed(temp_elements): + _push_mode(mode) - finally: - if removed_mode is not None: - count = len(temp_elements) - while count > 0: - mode = _pop_mode() - count -= 1 + try: + yield - temp_elements.append(removed_mode) + finally: + if pre_dispatch_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 - for mode in reversed(temp_elements): - _push_mode(mode) + temp_elements.append(pre_dispatch_mode) - return context_manager_fn + for mode in reversed(temp_elements): + _push_mode(mode) @torch._disable_dynamo @@ -1235,11 +1230,6 @@ def __torch_function__( return func(*args, **kwargs) -_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( - TorchFunctionMetadataMode -) - - # This mode is **only** used for pre_dispatch tracing. # In particular, we need to make sure that autograd/autocast API's # that do not desugar into dispatcher operators stay in the graph. @@ -1268,11 +1258,6 @@ def __torch_function__( return func(*args, **kwargs) -_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( - PreDispatchTorchFunctionMode -) - - class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 2317dbac83bdda..bb04ce53d0fcab 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,7 +19,6 @@ ) from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input @@ -1034,10 +1033,6 @@ def score_mod( if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) - # Dynamo is expecting a callable with "__code__" attribute. # We cannot directly pass hop to it. So we wrap it in a dummy function. def _flex_attention_hop_wrapper(*args, **kwargs): @@ -1046,25 +1041,18 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _set_compilation_env(): with torch._dynamo.utils.disable_cache_limit(): with _temp_remove_pre_dispatch_torch_function_mode(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode( - metadata_mode - ) - else: - backend = "eager" - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend=backend, fullgraph=True - )( - query, - key, - value, - score_mod, - block_mask.as_tuple(), - scale, - kernel_options, - ) - if return_lse: - return out, lse * math.log(2) - else: - return out + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend="eager", fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out From 844bdd35613205c1fb4684aef9c09965683d2f92 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 13 Sep 2024 15:05:41 +0000 Subject: [PATCH 0830/1018] Remove cycle dependency by localizing the import. (#135926) Summary: Since https://www.internalfb.com/diff/D62215095 landed there has been many silence errors due to the dependency between functional_tensor and config. ``` File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/export/__init__.py", line 64, in File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/export/dynamic_shapes.py", line 23, in File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/export/exported_program.py", line 26, in File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/_higher_order_ops/__init__.py", line 1, in File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/_higher_order_ops/cond.py", line 6, in File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/_subclasses/functional_tensor.py", line 9, in File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/_inductor/config.py", line 44, in ``` https://fburl.com/logarithm/ol5kx0ee complaining about a cycle dependency this fix it. Test Plan: buck test multipy/runtime:test_deploy_embedded_cuda_interp_without_cuda_available -- --run-disabled TorchpyTest.AcquireMultipleSessionsInDifferentPackages Reviewed By: aorenste Differential Revision: D62616765 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135926 Approved by: https://github.com/aorenste, https://github.com/oulgen, https://github.com/Skylion007 --- torch/_subclasses/functional_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 7bfa16a295a5d2..bec1ebc559c328 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -6,7 +6,6 @@ from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union import torch -import torch._inductor.config as inductor_config import torch.utils._pytree as pytree from torch._C import _functionalization_reapply_views_tls as _reapply_views from torch._ops import _get_dispatch_mode_pre_dispatch @@ -471,6 +470,8 @@ def unwrap(x): # it doesn't matter what mode we use here because # the implementation of do_auto_functionalize doesn't # interact with FunctionalTensorMode at all + import torch._inductor.config as inductor_config + if self.export or not inductor_config.enable_auto_functionalized_v2: return do_auto_functionalize(func, args, kwargs) else: From 566040efd7ac0e32d8b60c986b96a80d32c4852a Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 13 Sep 2024 15:07:18 +0000 Subject: [PATCH 0831/1018] [CUDA][FP8] Skip rowwise scaling test on sm89 (#135718) Same reason as #https://github.com/pytorch/pytorch/pull/133612, rowwise scaling implementation is sm90+ specific (e.g., uses TMA) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135718 Approved by: https://github.com/Skylion007 --- test/test_matmul_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 70235665cb6fef..dc82e517ca30bd 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -560,6 +560,7 @@ def test_float8_scale_fast_accum(self, device) -> None: self.assertEqual(out_fp8, out_fp8_s) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific") @skipIfRocm() @parametrize("use_fast_accum", [True, False]) def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: From 1a96117b5454fa5f0cf111a23540e6c08a1da056 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 12 Sep 2024 20:03:29 -0700 Subject: [PATCH 0832/1018] Fix "expand: SymIntArrayRef expected to contain only concrete integers" in AOTInductor (#135933) Internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1501860707118802/ Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135933 Approved by: https://github.com/angelayi --- test/inductor/test_aot_inductor.py | 35 ++++++++++++++++++++++++++++++ torch/_inductor/compile_fx.py | 29 ++++++++++++++----------- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 27885796b86766..fe49c04e8469f7 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3301,6 +3301,40 @@ def forward(self, x, y): Model(), example_inputs, options=dict(max_autotune=max_autotune) ) + @skip_if_no_torchvision + def test_torchvision_transforms_functional_tensor_resize(self): + import torchvision + + # https://fb.workplace.com/groups/1075192433118967/permalink/1501860707118802/ + class A(torch.nn.Module): + def forward(self, image: torch.Tensor, target_size: torch.Tensor): + target_h, target_w = target_size.tolist() + torch._check(target_h > 0) + torch._check(target_w > 0) + torch._check(target_h <= 4000) + torch._check(target_w <= 4000) + + return torchvision.transforms._functional_tensor.resize( + image, + size=[target_h, target_w], + interpolation="bilinear", + antialias=False, + ) + + model = A() + example_inputs = ( + torch.ones([3, 800, 600], device=self.device), + torch.tensor([448, 336], device=self.device), + ) + dynamic_shapes = { + "image": { + 1: torch.export.Dim("height", min=1, max=4000), + 2: torch.export.Dim("width", min=1, max=4000), + }, + "target_size": None, + } + self.check_model(model, example_inputs, dynamic_shapes=dynamic_shapes) + def test_aoti_debug_printer_codegen(self): # basic addmm model to test codegen for aoti intermediate debug printer class Model(torch.nn.Module): @@ -3627,6 +3661,7 @@ def fail_non_abi_compatible_cuda(is_skip=False): is_skip=True ), "test_size_from_multi_output": fail_stack_allocation(is_skip=True), + "test_torchvision_transforms_functional_tensor_resize": fail_minimal_arrayref_interface(), } # test_failures, xfail by default, set is_skip=True to skip diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 4c583a4b2a90a9..f79f247db35fc9 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -17,6 +17,7 @@ import torch.fx import torch.utils._pytree as pytree from functorch.compile import min_cut_rematerialization_partition +from torch._dispatch.python import enable_python_dispatcher from torch._dynamo import ( compiled_autograd, config as dynamo_config, @@ -400,20 +401,22 @@ def fake_tensor_prop( The created fake mode will be returned. """ - fake_mode = detect_fake_mode(example_inputs) - if not fake_mode: - fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) - FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) - else: - ctx = ( - contextlib.nullcontext() - if not force_allow_non_fake_inputs - else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) - ) - with ctx: # type: ignore[attr-defined] - FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( - *example_inputs + # Ensure that decomps that support symbolic shapes are used + with enable_python_dispatcher(): + fake_mode = detect_fake_mode(example_inputs) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) + else: + ctx = ( + contextlib.nullcontext() + if not force_allow_non_fake_inputs + else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) ) + with ctx: # type: ignore[attr-defined] + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( + *example_inputs + ) return fake_mode From 7e1d468885378b0a668ba88592bb1234869ba0e3 Mon Sep 17 00:00:00 2001 From: James Wu Date: Fri, 13 Sep 2024 16:35:51 +0000 Subject: [PATCH 0833/1018] Add remote cache time saved to compilation metrics (#135490) Summary: Record remote cache time saved via frame_phase_timing We add to the "phase" when remote cache hits and saves us time, so that we have a 1:1 correspondence between a frame and time saved. Test Plan: Internally run benchmark, see that it's populated in sandbox table after previous diff lands and logger config is actualized. Show that column exists in table: https://fburl.com/scuba/logger_staging_jjwu_30582a48f1ff9cf5f4ac50a4c40af/fp2te0ff Note that an earlier version of D62105258 had the column as a string so the staging table is a bit messed up. But you can see the most recent samples have the column populates as a float. Reviewed By: aorenste Differential Revision: D62106921 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135490 Approved by: https://github.com/aorenste --- torch/_dynamo/convert_frame.py | 5 +++++ torch/_dynamo/utils.py | 22 ++++++++++++++++++++++ torch/_inductor/codecache.py | 8 +++++++- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 3ab2937a9496d4..4d1557e15a7d23 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -998,6 +998,9 @@ def format_guard_failures() -> str: ] - start_possibly_missed_reinplacing_opportunities ) + remote_cache_time_saved = frame_phase_timing[frame_key].get( + "remote_cache_time_saved", 0 + ) else: guard_count = None shape_env_guard_count = None @@ -1014,6 +1017,7 @@ def format_guard_failures() -> str: # If compilation failed, the entire time is wasted dynamo_time_before_restart = time.time() - start_time possibly_missed_reinplacing_opportunities = None + remote_cache_time_saved = None metrics = CompilationMetrics( str(compile_id), @@ -1043,6 +1047,7 @@ def format_guard_failures() -> str: dynamo_time_before_restart, guarded_code is not None, possibly_missed_reinplacing_opportunities, + remote_cache_time_saved, ) record_compilation_metrics(metrics) torch._dynamo.callback_handler.run_end_callbacks() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 0ec548d788f868..6ff258ff5657c0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -220,6 +220,21 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: frame_phase_timing[key][phase_name] += time_spent +# Use frame_phase_timing to record remote_cache_time_saved +# This follows the same principles of key as the other frame phase timings, +# but is incremented by FxGraphCache (and later AOTAutogradCache) directly +def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None: + key = None + if is_backward: + # Use compile id as the frame key for backwards compilation + key = str(torch._guards.CompileContext.current_compile_id()) + else: + key = str(curr_frame) + # Convert to seconds (as a float) + time_saved = time_saved_ns / 1e9 + _add_time_spent(key, "remote_cache_time_saved", time_saved) + + def get_cache_stats() -> Dict[str, Any]: """Get a bunch of metadata about cache hits and misses to use in chromium events""" cache_stats = { @@ -332,15 +347,20 @@ def dynamo_timed( code_gen_time = frame_phase_timing[compile_id].get( "code_gen", None ) + remote_cache_time_saved = frame_phase_timing[ + compile_id + ].get("remote_cache_time_saved", None) else: inductor_compile_time = None code_gen_time = None + remote_cache_time_saved = None metrics = BwdCompilationMetrics( compile_id, inductor_compile_time, code_gen_time, fail_type, fail_reason, + remote_cache_time_saved, ) record_compilation_metrics(metrics) @@ -779,6 +799,7 @@ class CompilationMetrics: # a compiled frame has_guarded_code: bool possibly_missed_reinplacing_opportunities: Optional[int] + remote_cache_time_saved_s: Optional[float] @dataclasses.dataclass @@ -788,6 +809,7 @@ class BwdCompilationMetrics: code_gen_time_s: Optional[float] fail_type: Optional[str] fail_reason: Optional[str] + remote_cache_time_saved_s: Optional[float] DEFAULT_COMPILATION_METRICS_LIMIT = 64 diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index b7f3bbf1e9dc99..26a35a2d801c93 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -53,7 +53,12 @@ import torch import torch.distributed as dist from torch import SymInt, Tensor -from torch._dynamo.utils import counters, dynamo_timed, get_chromium_event_logger +from torch._dynamo.utils import ( + add_remote_cache_time_saved, + counters, + dynamo_timed, + get_chromium_event_logger, +) from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env from torch._inductor.codegen.rocm.compile_command import ( @@ -1350,6 +1355,7 @@ def load( # type: ignore[no-untyped-def] cache_event_time = time_ns() if (time_saved_ns := compiled_graph._time_taken_ns) is not None: cache_info["time_saved_ns"] = time_saved_ns + add_remote_cache_time_saved(time_saved_ns, fx_kwargs["is_backward"]) if ( ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( time_saved_ns From 507e9af95442b32689f2bd67eed2b8f9bc1ef3c0 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 12 Sep 2024 19:05:42 -0700 Subject: [PATCH 0834/1018] [FlexAttention] Fix output layout (#135882) We previously only supported the same v_head dim and + qk_head dim. When allowed for different head-dims I accidently kept the same query strides for the output. This PR fixes this bug as well it ensures that we always produce output in the same stride order as the input query. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135882 Approved by: https://github.com/yanboliang, https://github.com/Chillee --- test/inductor/test_flex_attention.py | 46 +++++++++++++++++++ torch/_higher_order_ops/flex_attention.py | 55 +++++++++++++++++++++-- torch/_inductor/kernel/flex_attention.py | 36 ++++++++++++++- 3 files changed, 132 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a6a8bfe13d7958..5556d3ee4edb60 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1666,6 +1666,52 @@ def mask_mod(b, h, q, kv): out = func(query, key, value, block_mask=block_mask) out.sum().backward() + @supported_platform + @common_utils.parametrize("mode", ["eager", "inductor"]) + @common_utils.parametrize( + "permute_order", + [ + (0, 1, 2, 3), # Default order + (1, 0, 2, 3), # Reverse order + (0, 2, 1, 3), # Mixed order + (2, 0, 1, 3), # Another mixed order + ], + ) + @common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)]) + def test_flex_attention_stride_ordering(self, mode, permute_order, shape): + from torch._inductor.ir import get_stride_order + + # Setup + make_tensor = functools.partial( + torch.randn, + shape, + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + + # Create and permute tensors + query, key, value = make_tensor(), make_tensor(), make_tensor() + query = query.permute(permute_order) + key = key.permute(permute_order) + value = value.permute(permute_order) + + if mode == "inductor": + func = torch.compile(flex_attention, backend=mode, fullgraph=True) + else: + func = flex_attention + + out = func(query, key, value) + + out_stride_order = get_stride_order(out.stride()) + query_stride_order = get_stride_order(query.stride()) + + self.assertEqual( + out_stride_order, + query_stride_order, + f"Stride order mismatch: out {out_stride_order}, query {query_stride_order}", + ) + @supported_platform @common_utils.parametrize("compile", [True, False]) def test_fully_masked_out_rows_0_check(self, compile: bool): diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 3783b29f117e1e..b28f657b9d5d8f 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import math -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Sequence, Tuple, Union import torch import torch.utils._pytree as pytree @@ -23,6 +23,53 @@ from torch.overrides import TorchFunctionMode +# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import +def _construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor: + """ + Create a new tensor with the same data and shape as the input, + but with strides permuted based on the input tensor's stride order. + + Args: + out (torch.Tensor): The output tensor of attention. + query_strides (List[int]): The stride order of the input query tensor + + Returns: + torch.Tensor: A new tensor with same shape and data as the input, + but with strides permuted based on the query tensor's stride order. + """ + from torch._inductor.ir import get_stride_order, stride_order2fill_order + + stride_order = get_stride_order(query_strides) + fill_order = stride_order2fill_order(stride_order) + assert out.storage_offset() == 0, "Only support storage_offset == 0" + out_strides = _construct_strides(out.shape, fill_order) + new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides) + new_out.copy_(out) + return new_out + + class TransformGetItemToIndex(TorchFunctionMode): # This is needed since we want to support calling # A[q_idx], where q_idx is a scalar tensor in score_mod. @@ -244,7 +291,7 @@ def sdpa_dense( score_mod_other_buffers, mask_mod_other_buffers, ) - out = out.contiguous() + out = _permute_strides(out, query.stride()) return out, lse @@ -432,7 +479,9 @@ def flex_attention_fake_tensor_mode( batch_size, num_heads, seq_len_q, dtype=torch.float32 ) out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) - return query.new_empty(out_shape), logsumexp + out = query.new_empty(out_shape) + out = _permute_strides(out, query.stride()) + return out, logsumexp # ---------------------------- Autograd Implementation ---------------------------- diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 0e3d6ae49186df..d6dfca28662b3f 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -3,7 +3,7 @@ import logging import math -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Sequence, Tuple import sympy @@ -17,9 +17,11 @@ ExternKernel, FixedLayout, FlexibleLayout, + get_stride_order, InputBuffer, IRNode, StorageBox, + stride_order2fill_order, Subgraph, TensorBox, ) @@ -29,6 +31,29 @@ log = logging.getLogger(__name__) aten = torch.ops.aten +Expr = sympy.Expr + + +def construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta): @@ -761,11 +786,18 @@ def flex_attention( # This works because only the last dim differs and we check it is contiguous. q_strides = query.get_stride() assert q_strides[-1] == 1, "Query must be contiguous in the last dimension" + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + stride_order = get_stride_order(query.get_stride()) + fill_order = stride_order2fill_order(stride_order) + out_strides = construct_strides(out_size, fill_order) + layout = FixedLayout( query.get_device(), query.get_dtype(), [B, Hq, seq_len_q, v_head_dim], - query.get_stride(), + stride=out_strides, ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = [B, Hq, seq_len_q] From c16141ecb99f0b52915492c5b80456195f9cc0a9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 16:42:37 +0000 Subject: [PATCH 0835/1018] Revert "Fix clang-tidy warnings in Caffe2 code (#134935)" This reverts commit 7cfd23636c8fa6fcbb8bf3ea34e15b847ec9ad9d. Reverted https://github.com/pytorch/pytorch/pull/134935 on behalf of https://github.com/izaitsevfb due to breaks internal builds, caffe2 is still used internally ([comment](https://github.com/pytorch/pytorch/pull/134935#issuecomment-2349368152)) --- .lintrunner.toml | 2 - BUCK.oss | 1 + BUILD.bazel | 1 + build.bzl | 1 + caffe2/core/timer.h | 2 +- caffe2/perfkernels/common_avx.cc | 2 +- caffe2/perfkernels/common_avx2.cc | 2 +- caffe2/serialize/CMakeLists.txt | 3 +- caffe2/serialize/crc_alt.h | 2 +- caffe2/serialize/file_adapter.cc | 11 ++- caffe2/serialize/file_adapter.h | 10 ++- caffe2/serialize/in_memory_adapter.h | 21 ++--- caffe2/serialize/inline_container.cc | 26 +++--- caffe2/serialize/inline_container.h | 14 ++-- caffe2/serialize/inline_container_test.cc | 24 +++--- caffe2/serialize/istream_adapter.cc | 9 ++- caffe2/serialize/istream_adapter.h | 6 +- caffe2/serialize/read_adapter_interface.cc | 10 +++ caffe2/serialize/read_adapter_interface.h | 12 +-- caffe2/serialize/versions.h | 6 +- caffe2/utils/string_utils.cc | 2 +- caffe2/utils/string_utils.h | 2 +- caffe2/utils/threadpool/ThreadPool.h | 6 +- caffe2/utils/threadpool/pthreadpool.cc | 80 +++++++++---------- caffe2/utils/threadpool/pthreadpool_impl.cc | 4 +- test/cpp/jit/test_custom_class.cpp | 1 - .../mobile/compatibility/backport_manager.cpp | 1 - .../compatibility/model_compatibility.cpp | 1 - 28 files changed, 150 insertions(+), 112 deletions(-) create mode 100644 caffe2/serialize/read_adapter_interface.cc diff --git a/.lintrunner.toml b/.lintrunner.toml index 5fa37ef3b6b019..4789991736be17 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -210,8 +210,6 @@ include_patterns = [ 'aten/src/ATen/native/nested/*.h', 'c10/**/*.cpp', 'c10/**/*.h', - 'caffe2/**/*.cc', - 'caffe2/**/*.h', 'torch/*.h', 'torch/csrc/*.h', 'torch/csrc/*.cpp', diff --git a/BUCK.oss b/BUCK.oss index 448cce08cd854a..13d5518801a249 100644 --- a/BUCK.oss +++ b/BUCK.oss @@ -65,6 +65,7 @@ cxx_library( "caffe2/serialize/file_adapter.cc", "caffe2/serialize/inline_container.cc", "caffe2/serialize/istream_adapter.cc", + "caffe2/serialize/read_adapter_interface.cc", ], visibility = ["PUBLIC"], deps = [ diff --git a/BUILD.bazel b/BUILD.bazel index e5ab0d0e29b856..1018f7907adecd 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -473,6 +473,7 @@ filegroup( "caffe2/serialize/file_adapter.cc", "caffe2/serialize/inline_container.cc", "caffe2/serialize/istream_adapter.cc", + "caffe2/serialize/read_adapter_interface.cc", ], ) diff --git a/build.bzl b/build.bzl index 3b72864f047ae1..dbb1866ac54823 100644 --- a/build.bzl +++ b/build.bzl @@ -34,6 +34,7 @@ def define_targets(rules): "caffe2/serialize/file_adapter.cc", "caffe2/serialize/inline_container.cc", "caffe2/serialize/istream_adapter.cc", + "caffe2/serialize/read_adapter_interface.cc", ], copts = ["-fexceptions"], tags = [ diff --git a/caffe2/core/timer.h b/caffe2/core/timer.h index 28d373f468d0df..a0384b0dbdbd02 100644 --- a/caffe2/core/timer.h +++ b/caffe2/core/timer.h @@ -3,7 +3,7 @@ #include -#include +#include "caffe2/core/common.h" namespace caffe2 { diff --git a/caffe2/perfkernels/common_avx.cc b/caffe2/perfkernels/common_avx.cc index cdb8b760a0d9e2..08c70d8966bc0f 100644 --- a/caffe2/perfkernels/common_avx.cc +++ b/caffe2/perfkernels/common_avx.cc @@ -2,7 +2,7 @@ // example, if your compiler did not specify -mavx, you should not provide // the CAFFE2_PERF_WITH_AVX macro. -#include "caffe2/core/macros.h" +#include "caffe2/core/common.h" #ifdef CAFFE2_PERF_WITH_AVX #ifndef __AVX__ diff --git a/caffe2/perfkernels/common_avx2.cc b/caffe2/perfkernels/common_avx2.cc index a45d887b2b49ad..3db34aa549d883 100644 --- a/caffe2/perfkernels/common_avx2.cc +++ b/caffe2/perfkernels/common_avx2.cc @@ -2,7 +2,7 @@ // example, if your compiler did not specify -mavx2, you should not provide // the CAFFE2_PERF_WITH_AVX2 macro. -#include "caffe2/core/macros.h" +#include "caffe2/core/common.h" #ifdef CAFFE2_PERF_WITH_AVX2 #ifndef __AVX2__ diff --git a/caffe2/serialize/CMakeLists.txt b/caffe2/serialize/CMakeLists.txt index 30303e8b744cb5..1552b59d0d4419 100644 --- a/caffe2/serialize/CMakeLists.txt +++ b/caffe2/serialize/CMakeLists.txt @@ -6,7 +6,8 @@ list(APPEND Caffe2_CPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc ${CMAKE_CURRENT_SOURCE_DIR}/istream_adapter.cc ${CMAKE_CURRENT_SOURCE_DIR}/file_adapter.cc - ${CMAKE_CURRENT_SOURCE_DIR}/crc.cc) + ${CMAKE_CURRENT_SOURCE_DIR}/crc.cc + ${CMAKE_CURRENT_SOURCE_DIR}/read_adapter_interface.cc) list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.1.0) set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/caffe2/serialize/crc_alt.h b/caffe2/serialize/crc_alt.h index 10f9c6bdb380ef..9d1c4f1dc7ddc8 100644 --- a/caffe2/serialize/crc_alt.h +++ b/caffe2/serialize/crc_alt.h @@ -25,7 +25,7 @@ // using the aforementioned #defines the table is automatically fitted to your needs // uint8_t, uint32_t, int32_t -#include +#include // size_t #include diff --git a/caffe2/serialize/file_adapter.cc b/caffe2/serialize/file_adapter.cc index f112dd61e5dbf4..67634d7f7fd278 100644 --- a/caffe2/serialize/file_adapter.cc +++ b/caffe2/serialize/file_adapter.cc @@ -3,11 +3,13 @@ #include #include #include +#include "caffe2/core/common.h" -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { -FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) - : fp_(fopen(file_name.c_str(), "rb")) { +FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) { + fp_ = fopen(file_name.c_str(), "rb"); if (fp_ == nullptr) { auto old_errno = errno; #if defined(_WIN32) && (defined(__MINGW32__) || defined(_MSC_VER)) @@ -75,4 +77,5 @@ size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what) FileAdapter::~FileAdapter() = default; -} // namespace caffe2::serialize +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/file_adapter.h b/caffe2/serialize/file_adapter.h index 36663e3ea8dee2..36bc5de2710034 100644 --- a/caffe2/serialize/file_adapter.h +++ b/caffe2/serialize/file_adapter.h @@ -1,11 +1,14 @@ #pragma once +#include +#include #include -#include +#include "caffe2/serialize/istream_adapter.h" #include "caffe2/serialize/read_adapter_interface.h" -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { class TORCH_API FileAdapter final : public ReadAdapterInterface { public: @@ -29,4 +32,5 @@ class TORCH_API FileAdapter final : public ReadAdapterInterface { uint64_t size_; }; -} // namespace caffe2::serialize +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/in_memory_adapter.h b/caffe2/serialize/in_memory_adapter.h index 869ead3b24563a..817f3f5d677493 100644 --- a/caffe2/serialize/in_memory_adapter.h +++ b/caffe2/serialize/in_memory_adapter.h @@ -1,9 +1,10 @@ #pragma once -#include -#include #include +#include -namespace caffe2::serialize { + +namespace caffe2 { +namespace serialize { class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { public: @@ -14,18 +15,18 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { return size_; } - size_t read( - uint64_t pos, - void* buf, - size_t n, - const char* what [[maybe_unused]] = "") const override { + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override { + (void) what; memcpy(buf, (int8_t*)(data_) + pos, n); return n; } private: const void* data_; - off_t size_{}; + off_t size_; }; -} // namespace caffe2::serialize + +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index c232838dec0c42..2761147cf333da 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -10,11 +10,15 @@ #include #include +#include +#include #include +#include #include #include #include +#include "caffe2/core/common.h" #include "caffe2/serialize/file_adapter.h" #include "caffe2/serialize/inline_container.h" #include "caffe2/serialize/istream_adapter.h" @@ -23,8 +27,8 @@ #include "caffe2/serialize/versions.h" #include "miniz.h" - -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { constexpr c10::string_view kDebugPklSuffix(".debug_pkl"); struct MzZipReaderIterWrapper { @@ -190,7 +194,8 @@ void PyTorchStreamReader::init() { // version check at::DataPtr version_ptr; - size_t version_size = 0; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t version_size; if (hasRecord(".data/version")) { std::tie(version_ptr, version_size) = getRecord(".data/version"); } else { @@ -199,7 +204,7 @@ void PyTorchStreamReader::init() { } std::string version(static_cast(version_ptr.get()), version_size); try { - version_ = std::stoll(version); + version_ = std::stoull(version); } catch (const std::invalid_argument& e) { CAFFE_THROW("Couldn't parse the version ", version, @@ -622,7 +627,7 @@ PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name) } PyTorchStreamWriter::PyTorchStreamWriter( - const std::function& writer_func) + const std::function writer_func) : archive_name_("archive"), writer_func_(writer_func) { setup(archive_name_); @@ -633,7 +638,7 @@ void PyTorchStreamWriter::setup(const string& file_name) { memset(ar_.get(), 0, sizeof(mz_zip_archive)); archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord(). - if (archive_name_.empty()) { + if (archive_name_.size() == 0) { CAFFE_THROW("invalid file name: ", file_name); } if (!writer_func_) { @@ -644,7 +649,7 @@ void PyTorchStreamWriter::setup(const string& file_name) { const std::string dir_name = parentdir(file_name); if(!dir_name.empty()) { - struct stat st{}; + struct stat st; bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR)); TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist."); } @@ -701,8 +706,8 @@ void PyTorchStreamWriter::writeRecord( /*uncomp_size=*/0, /*uncomp_crc32=*/0, /*last_modified=*/nullptr, - /*user_extra_data_local=*/padding_.c_str(), - /*user_extra_data_local_len=*/padding_size, + /*user_extra_data=*/padding_.c_str(), + /*user_extra_data_len=*/padding_size, /*user_extra_data_central=*/nullptr, /*user_extra_data_central_len=*/0); valid("writing file ", name.c_str()); @@ -815,4 +820,5 @@ PyTorchStreamWriter::~PyTorchStreamWriter() { } } -} // namespace caffe2::serialize +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index fe319e9011bc32..6a13d414feb9e6 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -90,8 +91,8 @@ typedef struct mz_zip_archive mz_zip_archive; // model.json as the last file when writing after we have accumulated all // other information. - -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { static constexpr const char* kSerializationIdRecordName = ".data/serialization_id"; @@ -195,18 +196,18 @@ class TORCH_API PyTorchStreamReader final { std::string archive_name_; std::string archive_name_plus_slash_; std::shared_ptr in_; - int64_t version_{}; + int64_t version_; std::mutex reader_lock_; bool load_debug_symbol_ = true; std::string serialization_id_; - size_t additional_reader_size_threshold_{}; + size_t additional_reader_size_threshold_; }; class TORCH_API PyTorchStreamWriter final { public: explicit PyTorchStreamWriter(const std::string& archive_name); explicit PyTorchStreamWriter( - const std::function& writer_func); + const std::function writer_func); void setMinVersion(const uint64_t version); @@ -273,4 +274,5 @@ size_t getPadding( std::string& padding_buf); } // namespace detail -} // namespace caffe2::serialize +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 4e58c0d609932b..4e027f681961d0 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -5,14 +5,12 @@ #include -#include -#include #include "caffe2/serialize/inline_container.h" -#include "caffe2/serialize/istream_adapter.h" - +#include +#include "c10/util/irange.h" -// NOLINTBEGIN(*-narrowing-conversions) -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { namespace { TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { @@ -21,7 +19,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { std::ostringstream oss; // write records through writers PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { - oss.write(static_cast(b), static_cast(n)); + oss.write(static_cast(b), n); return oss ? n : 0; }); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) @@ -30,14 +28,14 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { std::vector buf(data1.size()); for (auto i : c10::irange(data1.size())) { - data1[i] = static_cast(data1.size() - i); + data1[i] = data1.size() - i; } writer.writeRecord("key1", data1.data(), data1.size()); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) std::array data2; for (auto i : c10::irange(data2.size())) { - data2[i] = static_cast(data2.size() - i); + data2[i] = data2.size() - i; } writer.writeRecord("key2", data2.data(), data2.size()); @@ -151,7 +149,7 @@ TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) { PyTorchStreamReader reader(&iss); reader.setAdditionalReaderSizeThreshold(0); // before testing, sanity check - int64_t size1 = 0, size2 = 0, ret = 0; + int64_t size1, size2, ret; at::DataPtr data_ptr; std::tie(data_ptr, size1) = reader.getRecord("key1"); std::tie(data_ptr, size2) = reader.getRecord("key2"); @@ -298,7 +296,7 @@ TEST(PytorchStreamWriterAndReader, SkipDebugRecords) { reader.setShouldLoadDebugSymbol(false); EXPECT_FALSE(reader.hasRecord("key1.debug_pkl")); at::DataPtr ptr; - size_t size = 0; + size_t size; std::tie(ptr, size) = reader.getRecord("key1.debug_pkl"); EXPECT_EQ(size, 0); std::vector dst(data1.size()); @@ -481,5 +479,5 @@ TEST_P(ChunkRecordIteratorTest, ChunkRead) { } } // namespace -} // namespace caffe2::serialize -// NOLINTEND(*-narrowing-conversions) +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/istream_adapter.cc b/caffe2/serialize/istream_adapter.cc index baad0e9b2ee69b..9509a088736ef8 100644 --- a/caffe2/serialize/istream_adapter.cc +++ b/caffe2/serialize/istream_adapter.cc @@ -1,7 +1,8 @@ #include "caffe2/serialize/istream_adapter.h" #include -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { IStreamAdapter::IStreamAdapter(std::istream* istream) : istream_(istream) {} @@ -32,6 +33,8 @@ void IStreamAdapter::validate(const char* what) const { } } -IStreamAdapter::~IStreamAdapter() = default; +// NOLINTNEXTLINE(modernize-use-equals-default) +IStreamAdapter::~IStreamAdapter() {} -} // namespace caffe2::serialize +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/istream_adapter.h b/caffe2/serialize/istream_adapter.h index 69f91faf3a79f3..680c288a15f2e8 100644 --- a/caffe2/serialize/istream_adapter.h +++ b/caffe2/serialize/istream_adapter.h @@ -5,7 +5,8 @@ #include "c10/macros/Macros.h" #include "caffe2/serialize/read_adapter_interface.h" -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { // this is a reader implemented by std::istream class TORCH_API IStreamAdapter final : public ReadAdapterInterface { @@ -22,4 +23,5 @@ class TORCH_API IStreamAdapter final : public ReadAdapterInterface { void validate(const char* what) const; }; -} // namespace caffe2::serialize +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/read_adapter_interface.cc b/caffe2/serialize/read_adapter_interface.cc new file mode 100644 index 00000000000000..e9b84660efd72f --- /dev/null +++ b/caffe2/serialize/read_adapter_interface.cc @@ -0,0 +1,10 @@ +#include "caffe2/serialize/read_adapter_interface.h" + +namespace caffe2 { +namespace serialize { + +// NOLINTNEXTLINE(modernize-use-equals-default) +ReadAdapterInterface::~ReadAdapterInterface() {} + +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/read_adapter_interface.h b/caffe2/serialize/read_adapter_interface.h index 450db5c9b739ca..0a6b5b74a762e7 100644 --- a/caffe2/serialize/read_adapter_interface.h +++ b/caffe2/serialize/read_adapter_interface.h @@ -3,19 +3,21 @@ #include #include -#include +#include "c10/macros/Macros.h" -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { // this is the interface for the (file/stream/memory) reader in -// PyTorchStreamReader. With this interface, we can extend the support +// PyTorchStreamReader. with this interface, we can extend the support // besides standard istream class TORCH_API ReadAdapterInterface { public: virtual size_t size() const = 0; virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") const = 0; - virtual ~ReadAdapterInterface() = default; + virtual ~ReadAdapterInterface(); }; -} // namespace caffe2::serialize +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index aa3a6b1753aa67..6e2c27adc8fae8 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -1,7 +1,8 @@ #pragma once #include -namespace caffe2::serialize { +namespace caffe2 { +namespace serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; @@ -128,4 +129,5 @@ constexpr uint64_t kProducedBytecodeVersion = 0x8L; constexpr uint64_t kMinSupportedBytecodeVersion = 0x4L; constexpr uint64_t kMaxSupportedBytecodeVersion = 0x9L; -} // namespace caffe2::serialize +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/utils/string_utils.cc b/caffe2/utils/string_utils.cc index 8e31ada6e12ec7..ba763738f7bb04 100644 --- a/caffe2/utils/string_utils.cc +++ b/caffe2/utils/string_utils.cc @@ -51,7 +51,7 @@ size_t editDistance( (c)=(uint8_t)(s)[(i)++]; \ } -size_t editDistanceHelper(const char* s1, +int32_t editDistanceHelper(const char* s1, size_t s1_len, const char* s2, size_t s2_len, diff --git a/caffe2/utils/string_utils.h b/caffe2/utils/string_utils.h index 677e25deed76e2..6bf3b867b5434e 100644 --- a/caffe2/utils/string_utils.h +++ b/caffe2/utils/string_utils.h @@ -39,7 +39,7 @@ TORCH_API inline bool EndsWith( } } -TORCH_API size_t editDistanceHelper( +TORCH_API int32_t editDistanceHelper( const char* s1, size_t s1_len, const char* s2, diff --git a/caffe2/utils/threadpool/ThreadPool.h b/caffe2/utils/threadpool/ThreadPool.h index b5bc9dd44fddaa..dbaefc51389ff9 100644 --- a/caffe2/utils/threadpool/ThreadPool.h +++ b/caffe2/utils/threadpool/ThreadPool.h @@ -1,11 +1,15 @@ #ifndef CAFFE2_UTILS_THREADPOOL_H_ #define CAFFE2_UTILS_THREADPOOL_H_ +#include "ThreadPoolCommon.h" + +#include #include #include #include +#include -#include +#include "caffe2/core/common.h" // // A work-stealing threadpool loosely based off of pthreadpool diff --git a/caffe2/utils/threadpool/pthreadpool.cc b/caffe2/utils/threadpool/pthreadpool.cc index 2f2866a77d4a9d..b8c6c7cebb8e8a 100644 --- a/caffe2/utils/threadpool/pthreadpool.cc +++ b/caffe2/utils/threadpool/pthreadpool.cc @@ -1,9 +1,9 @@ /* Standard C headers */ -#include - -#include -#include -#include +#include +#include +#include +#include +#include #include #ifdef _MSC_VER @@ -71,8 +71,8 @@ void legacy_pthreadpool_compute_1d_tiled( } struct compute_2d_context { - legacy_pthreadpool_function_2d_t function{}; - void* argument{}; + legacy_pthreadpool_function_2d_t function; + void* argument; caffe2::FixedDivisor range_j; }; @@ -80,8 +80,8 @@ static void compute_2d(void* context_, size_t linear_index) { TORCH_DCHECK_LE(linear_index, std::numeric_limits::max()); const struct compute_2d_context* context = static_cast(context_); - int32_t q = 0; - int32_t r = 0; + int32_t q; + int32_t r; context->range_j.DivMod(static_cast(linear_index), &q, &r); context->function(context->argument, q, r); } @@ -112,18 +112,18 @@ void legacy_pthreadpool_compute_2d( } struct compute_2d_tiled_context { - legacy_pthreadpool_function_2d_tiled_t function{}; - void* argument{}; + legacy_pthreadpool_function_2d_tiled_t function; + void* argument; caffe2::FixedDivisor tile_range_j; - size_t range_i{}; - size_t range_j{}; - size_t tile_i{}; - size_t tile_j{}; + size_t range_i; + size_t range_j; + size_t tile_i; + size_t tile_j; }; static void compute_2d_tiled(void* context_, size_t linear_index) { - int32_t q = 0; - int32_t r = 0; + int32_t q; + int32_t r; const struct compute_2d_tiled_context* context = static_cast(context_); context->tile_range_j.DivMod(linear_index, &q, &r); @@ -172,26 +172,26 @@ void legacy_pthreadpool_compute_2d_tiled( } struct compute_3d_tiled_context { - legacy_pthreadpool_function_3d_tiled_t function{}; - void* argument{}; + legacy_pthreadpool_function_3d_tiled_t function; + void* argument; caffe2::FixedDivisor tile_range_j; caffe2::FixedDivisor tile_range_k; - size_t range_i{}; - size_t range_j{}; - size_t range_k{}; - size_t tile_i{}; - size_t tile_j{}; - size_t tile_k{}; + size_t range_i; + size_t range_j; + size_t range_k; + size_t tile_i; + size_t tile_j; + size_t tile_k; }; static void compute_3d_tiled( void* context_, size_t linear_index) { - int32_t tile_index_ij = 0, tile_index_k = 0; + int32_t tile_index_ij, tile_index_k; const struct compute_3d_tiled_context* context = static_cast(context_); context->tile_range_k.DivMod( static_cast(linear_index), &tile_index_ij, &tile_index_k); - int32_t tile_index_i = 0, tile_index_j = 0; + int32_t tile_index_i, tile_index_j; context->tile_range_j.DivMod(tile_index_ij, &tile_index_i, &tile_index_j); const size_t max_tile_i = context->tile_i; const size_t max_tile_j = context->tile_j; @@ -261,31 +261,31 @@ void legacy_pthreadpool_compute_3d_tiled( } struct compute_4d_tiled_context { - legacy_pthreadpool_function_4d_tiled_t function{}; - void* argument{}; + legacy_pthreadpool_function_4d_tiled_t function; + void* argument; caffe2::FixedDivisor tile_range_kl; caffe2::FixedDivisor tile_range_j; caffe2::FixedDivisor tile_range_l; - size_t range_i{}; - size_t range_j{}; - size_t range_k{}; - size_t range_l{}; - size_t tile_i{}; - size_t tile_j{}; - size_t tile_k{}; - size_t tile_l{}; + size_t range_i; + size_t range_j; + size_t range_k; + size_t range_l; + size_t tile_i; + size_t tile_j; + size_t tile_k; + size_t tile_l; }; static void compute_4d_tiled( void* context_, size_t linear_index) { - int32_t tile_index_ij = 0, tile_index_kl = 0; + int32_t tile_index_ij, tile_index_kl; const struct compute_4d_tiled_context* context = static_cast(context_); context->tile_range_kl.DivMod( static_cast(linear_index), &tile_index_ij, &tile_index_kl); - int32_t tile_index_i = 0, tile_index_j = 0; + int32_t tile_index_i, tile_index_j; context->tile_range_j.DivMod(tile_index_ij, &tile_index_i, &tile_index_j); - int32_t tile_index_k = 0, tile_index_l = 0; + int32_t tile_index_k, tile_index_l; context->tile_range_l.DivMod(tile_index_kl, &tile_index_k, &tile_index_l); const size_t max_tile_i = context->tile_i; const size_t max_tile_j = context->tile_j; diff --git a/caffe2/utils/threadpool/pthreadpool_impl.cc b/caffe2/utils/threadpool/pthreadpool_impl.cc index 79597b5a2457e7..ae031ca2ae7e60 100644 --- a/caffe2/utils/threadpool/pthreadpool_impl.cc +++ b/caffe2/utils/threadpool/pthreadpool_impl.cc @@ -7,7 +7,9 @@ namespace caffe2 { namespace { static thread_local bool using_new_threadpool{false}; } -WithCastToNewThreadPool::WithCastToNewThreadPool(bool use_new_threadpool) : use_new_threadpool_(using_new_threadpool) { +WithCastToNewThreadPool::WithCastToNewThreadPool(bool use_new_threadpool) { + use_new_threadpool_ = using_new_threadpool; + using_new_threadpool = use_new_threadpool; } WithCastToNewThreadPool::~WithCastToNewThreadPool() { using_new_threadpool = use_new_threadpool_; diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index 42234cddf4e5d3..f7fe339b561435 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 17145fee78a6b5..3d71329e61240d 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp index d87660238232a0..b8b1ca6adc0dc4 100644 --- a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include // removed after using simple type_resolver/obj_loader #include #include From 2423d452eb9ea8db6c99171f5133fca56ba160a8 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 13 Sep 2024 16:45:39 +0000 Subject: [PATCH 0836/1018] [ROCm] Enable ROCm support for inductor's dynamic_rblock_scaling (#129663) As of ROCm 6.1 [hipDeviceProp_t::regsPerMultiprocessor](https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/structhip_device_prop__t.html#a7390d5b180d63978c81aa971060270b4) is now available allowing us to enable this attribute on ROCm. ``` >>> torch.cuda.get_device_properties(0) _CudaDeviceProperties(name='AMD Instinct MI250X/MI250', major=9, minor=0, gcnArchName='gfx90a:sramecc+:xnack-', total_memory=65520MB, multi_processor_count=104) >>> torch.cuda.get_device_properties(0).regs_per_multiprocessor 65536 ``` With https://github.com/triton-lang/triton/pull/3962we can extract n_regs and n_spells from a triton binary with AMD backend allowing us to enable inductor's dynamic_rblock_scaling on ROCm initially implemented in https://github.com/pytorch/pytorch/pull/115094 Leaving this in draft until following PRs have landed: - https://github.com/pytorch/pytorch/pull/129361 to bump the triton commit pin - https://github.com/pytorch/pytorch/pull/128449 to allow us to grab warp_size from device properties instead of hard coding 64 on ROCm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129663 Approved by: https://github.com/jansel, https://github.com/shunting314 --- torch/_inductor/runtime/hints.py | 14 +++++++++++--- torch/_inductor/runtime/triton_heuristics.py | 9 +++++---- torch/csrc/cuda/Module.cpp | 5 ++--- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index d3c9f560fd2a1e..0f1495e49972c7 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -104,24 +104,32 @@ class DeviceProperties(typing.NamedTuple): regs_per_multiprocessor: Optional[int] = None max_threads_per_multi_processor: Optional[int] = None multi_processor_count: Optional[int] = None + warp_size: Optional[int] = None @classmethod def create(cls, device): import torch from torch._dynamo.device_interface import get_interface_for_device - device_type = device.type if torch.version.hip is None else "hip" + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + device_interface = get_interface_for_device(device) - if device_type == "cuda": + if device_type in ["cuda", "hip"]: props = device_interface.get_device_properties(device) return cls( type=device_type, index=device.index, cc=device_interface.get_compute_capability(device), major=props.major, - regs_per_multiprocessor=props.regs_per_multiprocessor, + regs_per_multiprocessor=props.regs_per_multiprocessor + if hasattr(props, "regs_per_multiprocessor") + else None, max_threads_per_multi_processor=props.max_threads_per_multi_processor, multi_processor_count=props.multi_processor_count, + warp_size=props.warp_size, ) return cls( type=device_type, diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c0b0e20e906d5e..89c47838da43f3 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -265,10 +265,11 @@ def precompile(self, warm_cache_only=False): self.inductor_meta.get("dynamic_scale_rblock", True) and self.heuristic_type == HeuristicType.REDUCTION and self.size_hints is not None - # Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary. - and device_prop.type == "cuda" + # Disable for Intel as Triton is not ready to return n_regs for a compiled_binary. + and device_prop.type in ["cuda", "hip"] and device_prop.major - and device_prop.major >= 8 + and (device_prop.major >= 8 or torch.version.hip) + and device_prop.regs_per_multiprocessor is not None ): assert device_prop.regs_per_multiprocessor assert device_prop.max_threads_per_multi_processor @@ -305,7 +306,7 @@ def precompile(self, warm_cache_only=False): ): continue - nreg_per_warp = nreg * 32 + nreg_per_warp = nreg * device_prop.warp_size nreg_per_block = nreg_per_warp * triton_config.num_warps # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 22cbadc42fc09f..461a23e651924b 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -989,11 +989,10 @@ static void registerCudaDeviceProperties(PyObject* module) { "max_threads_per_multi_processor", &cudaDeviceProp::maxThreadsPerMultiProcessor) .def_readonly("warp_size", &cudaDeviceProp::warpSize) -#if !USE_ROCM - // NVIDA only property +#if (defined(USE_ROCM) && ROCM_VERSION >= 60100) || !USE_ROCM .def_readonly( "regs_per_multiprocessor", &cudaDeviceProp::regsPerMultiprocessor) -#endif // USE_ROCM +#endif // HIP-only property; reuse name attribute for CUDA builds .def_readonly( "gcnArchName", From 4eae0d92efdabe65f0d7a848945ac4da56a28def Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 17:09:45 +0000 Subject: [PATCH 0837/1018] Revert "Pass ideep:lowp_kind to matmul_forward::compute on cache misses (#135058)" This reverts commit 3d2431380999252d5401f83d5010b398a32e7597. Reverted https://github.com/pytorch/pytorch/pull/135058 on behalf of https://github.com/malfet due to It regresses x86 performance ([comment](https://github.com/pytorch/pytorch/pull/135058#issuecomment-2349480861)) --- .../native/quantized/cpu/qlinear_dynamic.cpp | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 8a100d516446b8..b78d3cc9a4f566 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -590,21 +590,10 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl( LinearParams& params = get_cache().get_param(); ideep::matmul_forward::compute(params, x, w, b, y, src_scales, src_zero_point); } else { - ideep::matmul_forward::compute( - x, - w, - b, - y, - src_scales, - weights_scales, - ideep::scale_t(), - src_zero_point, - ideep::zero_point_t(), - 1.0f, - 1.0f, - op_attr, - ideep::tensor::data_type::undef, - std::is_signed_v ? ideep::s8s8 : ideep::u8s8); + ideep::matmul_forward::compute(x, w, b, y, + src_scales, weights_scales, ideep::scale_t(), + src_zero_point, ideep::zero_point_t(), + 1.0f, 1.0f, op_attr); } auto out_sizes = input.sizes().vec(); out_sizes.back() = w.get_dim(1); From f613c03608e20e5ba25ccc747d5854e655195c49 Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Fri, 13 Sep 2024 17:19:25 +0000 Subject: [PATCH 0838/1018] [AOTI][Tooling] Add stats summary (mean/min/max, etc) for jit inductor tensor value printing (#135887) Summary: As title. Follow up to add stats summary (mean/min/max, etc) for jit inductor tensor value printing as well. The inductor python wrapper code level printing would look something like this: {F1859224287} Test Plan: CI Reviewed By: chenyang78 Differential Revision: D62415575 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135887 Approved by: https://github.com/chenyang78 --- torch/_inductor/codegen/debug_utils.py | 24 ++++++++++++++++++++---- torch/_inductor/codegen/wrapper.py | 4 ++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index 56347f77f5cc4e..985678cadb285c 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -17,6 +17,24 @@ log = logging.getLogger(__name__) +def _print_debugging_tensor_value_info(msg, arg): + # helper for printing debugging stats for intermediate tensor values + # at jit inductor level codegen + max_numel_to_print = 64 + print(msg) + numel = arg.float().numel() + # print the debug printing stats + if numel <= max_numel_to_print: + print(arg) + print("Number of elements: ", numel) + print("Size: ", arg.float().size()) + print("Dtype: ", arg.float().mean().item()) + print("Mean: ", arg.float().mean().item()) + print("Min: ", arg.float().min().item()) + print("Max: ", arg.float().max().item()) + print("Std: ", arg.float().std().item()) + + # AOTI debug printing related configs class IntermediateValueDebuggingLevel(Enum): # OFF: No intermediate tensor value debug info will be printed or saved. @@ -198,8 +216,6 @@ def codegen_intermediate_tensor_value_print( # TODO: add non-abi compatible mode debug printing info pass else: - # TODO: add mean/min/max instead of naively print the tensor value - line = ( - f"print('inductor: {launch_prefix} - {kernel_name} - {arg}', {arg})" + V.graph.wrapper_code.writeline( + f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})' ) - V.graph.wrapper_code.writeline(line) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index c3257f4d638c7a..0ac62adb016208 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -545,6 +545,9 @@ def write_header(self) -> None: aot_config_comment = "" if context is not None and context.aot_graph_name is not None: aot_config_comment = f"# AOT ID: {context.aot_graph_name}" + aot_inductor_debug_utils = "" + if int(config.aot_inductor.debug_intermediate_value_printer) > 0: + aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" self.imports.splice( f""" {aot_config_comment} @@ -562,6 +565,7 @@ def write_header(self) -> None: from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall + {aot_inductor_debug_utils} """, strip=True, ) From aab706d21c2bcbc032f2da2da7e21a167b594364 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 13 Sep 2024 17:29:24 +0000 Subject: [PATCH 0839/1018] [trymerge] Manually close merged PR when Github fails (#135890) Manually close merged PR when Github fails to do it. Consequences of current design: Sleeping for 1 min uses up the machine, might result in race conditions, results in merging label to removed a bit later, pr still left open if this api fails too (ie no async clean up job) Tested in https://github.com/malfet/deleteme/pull/92 by removing the part of the commit message that has "resolved #pr num" Pull Request resolved: https://github.com/pytorch/pytorch/pull/135890 Approved by: https://github.com/malfet, https://github.com/huydhn --- .github/scripts/github_utils.py | 8 ++++++ .github/scripts/trymerge.py | 48 ++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index 79f6c4732ba558..a5206bd675fe6d 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -168,6 +168,14 @@ def gh_post_commit_comment( ) +def gh_close_pr(org: str, repo: str, pr_num: int, dry_run: bool = False) -> None: + url = f"{GITHUB_API_URL}/repos/{org}/{repo}/pulls/{pr_num}" + if dry_run: + print(f"Dry run closing PR {pr_num}") + else: + gh_fetch_url(url, method="PATCH", data={"state": "closed"}) + + def gh_delete_comment(org: str, repo: str, comment_id: int) -> None: url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/comments/{comment_id}" gh_fetch_url(url, method="DELETE") diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 066472a8042100..a1a60698af7950 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -36,6 +36,7 @@ import yaml from github_utils import ( + gh_close_pr, gh_fetch_json_list, gh_fetch_merge_base, gh_fetch_url, @@ -1174,11 +1175,11 @@ def merge_into( for pr in additional_merged_prs: pr.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run) - if comment_id and self.pr_num: - # When the merge process reaches this part, we can assume that the commit - # has been successfully pushed to trunk - merge_commit_sha = repo.rev_parse(name=REMOTE_MAIN_BRANCH) + # When the merge process reaches this part, we can assume that the commit + # has been successfully pushed to trunk + merge_commit_sha = repo.rev_parse(name=self.default_branch()) + if comment_id and self.pr_num: # Finally, upload the record to Rockset. The list of pending and failed # checks are at the time of the merge save_merge_record( @@ -1203,6 +1204,17 @@ def merge_into( else: print("Missing comment ID or PR number, couldn't upload to Rockset") + # Usually Github will see that the commit has "resolves " in the + # commit message and close the PR, but sometimes it doesn't, leading to + # confusion. When it doesn't, we close it manually. + time.sleep(60) # Give Github some time to close the PR + manually_close_merged_pr( + pr=self, + additional_merged_prs=additional_merged_prs, + merge_commit_sha=merge_commit_sha, + dry_run=dry_run, + ) + def merge_changes( self, repo: GitRepo, @@ -1503,6 +1515,34 @@ def checks_to_markdown_bullets( ] +def manually_close_merged_pr( + pr: GitHubPR, + additional_merged_prs: List[GitHubPR], + merge_commit_sha: str, + dry_run: bool, +) -> None: + def _comment_and_close(pr: GitHubPR, comment: str) -> None: + pr = GitHubPR(pr.org, pr.project, pr.pr_num) # Refresh the PR + if not pr.is_closed(): + gh_post_pr_comment(pr.org, pr.project, pr.pr_num, comment, dry_run) + gh_close_pr(pr.org, pr.project, pr.pr_num, dry_run) + + message = ( + f"This PR (#{pr.pr_num}) was merged in {merge_commit_sha} but it is still open, likely due to a Github bug, " + "so mergebot is closing it manually. If you think this is a mistake, please feel free to reopen and contact Dev Infra." + ) + _comment_and_close(pr, message) + for additional_pr in additional_merged_prs: + message = ( + f"This PR (#{additional_pr.pr_num}) was merged as part of PR #{pr.pr_num} in the stack under {merge_commit_sha} " + "but it is still open, likely due to a Github bug, so mergebot is closing it manually. " + "If you think this is a mistake, please feel free to reopen and contact Dev Infra." + ) + _comment_and_close(additional_pr, message) + + print(f"PR {pr.pr_num} and all additional PRs in the stack have been closed.") + + @retries_decorator() def save_merge_record( comment_id: int, From 227bfeca37ebf9abb59c8392afe76aac7f0ad66c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 17:47:36 +0000 Subject: [PATCH 0840/1018] Revert "[aoti] Fix workspace generation for triton (#135552)" This reverts commit d3833253928f29ed760b2dccac2b730028a868ca. Reverted https://github.com/pytorch/pytorch/pull/135552 on behalf of https://github.com/izaitsevfb due to blocks revert of #135313, internal failures, see D62511427 ([comment](https://github.com/pytorch/pytorch/pull/135552#issuecomment-2349641372)) --- test/inductor/test_cuda_cpp_wrapper.py | 1 - torch/_inductor/codegen/common.py | 1 - torch/_inductor/codegen/cpp_wrapper_gpu.py | 32 ++++--------------- torch/_inductor/codegen/triton.py | 12 ++----- torch/_inductor/codegen/triton_split_scan.py | 6 +--- torch/_inductor/codegen/wrapper.py | 14 +------- torch/_inductor/runtime/triton_heuristics.py | 1 - torch/csrc/inductor/aoti_torch/c/shim.h | 2 -- .../csrc/inductor/aoti_torch/shim_common.cpp | 7 ---- 9 files changed, 11 insertions(+), 65 deletions(-) diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 2a01509b26cf54..bdc3066c2818d4 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -188,7 +188,6 @@ class BaseTest(NamedTuple): BaseTest("test_sum_int"), # bool, int64, int8, uint8 BaseTest("test_transpose"), # multiple outputs, buffer clear BaseTest("test_unspec_inputs"), - BaseTest("test_consecutive_split_cumprod"), BaseTest("test_pointwise_hermite_polynomial_he"), BaseTest("test_pointwise_hermite_polynomial_h"), BaseTest( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 02f13667b70c79..95183c1484fe00 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1420,7 +1420,6 @@ def python_argdefs(self): arg_defs.append("ws_ptr") call_args.append("workspace") precompile_args.append(self.workspace_arg) - arg_types.append(torch.uint8) return arg_defs, call_args, precompile_args, arg_types def aliases(self): diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 13f13f44be9eca..5719d3eba589f8 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -6,9 +6,9 @@ import sympy -from torch import dtype as torch_dtype, uint8 +from torch import dtype as torch_dtype from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name -from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn +from torch._inductor.runtime.triton_heuristics import grid as default_grid from .. import config from ..codecache import CudaKernelParamCache @@ -88,11 +88,11 @@ def __call__(self): grid = self.grid assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" grid = self._process_grid(grid) - assert self.grid_callable is not None, "grid_callable can't be None" + grid_callable = self.grid_callable or default_grid if not self.grid_extra_kwargs: - grid_fn = self.grid_callable(*grid) + grid_fn = grid_callable(*grid) else: - grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) + grid_fn = grid_callable(*grid, **self.grid_extra_kwargs) params = CudaKernelParamCache.get(self.kernel_name) assert ( @@ -102,7 +102,6 @@ def __call__(self): "XBLOCK": params["x_block"], "YBLOCK": params["y_block"], "ZBLOCK": params["z_block"], - "RBLOCK": params["r_block"], } return grid_fn(block_cfg) @@ -339,7 +338,7 @@ def generate_default_grid( kernel_name: str, grid: List[Any], gpu: bool = True, - grid_callable: Optional[Callable[..., Any]] = default_grid_fn, + grid_callable: Optional[Callable[..., Any]] = None, **grid_extra_kwargs, ): """ @@ -441,22 +440,3 @@ def generate_kernel_call( ), ) self.writeline("}") - - def generate_workspace_allocation(self, nbytes, device, zero_fill): - line = self.make_allocation( - "workspace", device, uint8, shape=(nbytes,), stride=(1,) - ) - self.writeline(line) - if config.triton.autotune_at_compile_time: - self.kernel_autotune_calls.writeline(line) - if zero_fill: - if config.abi_compatible: - # TODO: remove this function to use the default WrapperCodegen behavior after service platform has zero_() symbol - # default behavior is f"workspace.zero_(){self.ending}" - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get())){self.ending}" - ) - else: - self.writeline(f"workspace.zero_(){self.ending}") - if config.triton.autotune_at_compile_time: - self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index fa64fc54ffb179..69daf2cad522a6 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -27,7 +27,6 @@ import torch._logging from torch._dynamo.utils import preserve_rng_state from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties -from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing @@ -2798,11 +2797,8 @@ def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexp if tree.prefix == "x" and self.no_x_dim: code.writeline("XBLOCK: tl.constexpr = 1") - def _get_grid_fn_str(self): - return "grid" - def _get_grid_fn(self): - return default_grid_fn + return "grid" def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): # TODO(jansel): if there are constants, we shouldn't bother passing them as args @@ -2832,9 +2828,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): ws.nbytes, current_device, ws.zero_fill ) - grid = wrapper.generate_default_grid( - name, grid, grid_callable=self._get_grid_fn() - ) + grid = wrapper.generate_default_grid(name, grid) wrapper.generate_kernel_call( name, call_args, @@ -2843,7 +2837,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): gpu=True, triton=True, arg_types=arg_types, - grid_fn=self._get_grid_fn_str(), + grid_fn=self._get_grid_fn(), triton_meta=self.triton_meta, ) diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 31dab399236492..9eea00fbb8d6b6 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -6,7 +6,6 @@ from torch._inductor import config from torch._inductor.codegen.simd import IterationRangesRoot from torch._inductor.codegen.triton import triton_compute_type, TritonKernel -from torch._inductor.runtime.triton_heuristics import split_scan_grid from torch._prims_common import prod from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv @@ -171,8 +170,5 @@ def scan(self, dtypes, combine_fn, values): def _get_heuristic(self): return "split_scan" - def _get_grid_fn_str(self): - return "split_scan_grid" - def _get_grid_fn(self): - return split_scan_grid + return "split_scan_grid" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0ac62adb016208..348f1a2a7bc673 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -599,7 +599,6 @@ def write_kernel_autotune_defs_header(self) -> None: async_compile = AsyncCompile() generate_example_value = AlgorithmSelectorCache.generate_example_value - empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda """ ) @@ -1502,18 +1501,12 @@ def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = No return SymbolicCallArg(expr, tree.numel) def generate_workspace_allocation(self, nbytes, device, zero_fill): - if isinstance(nbytes, sympy.Expr): - nbytes = V.graph.sizevars.size_hint(nbytes) line = self.make_allocation( "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,) ) self.writeline(line) - if config.triton.autotune_at_compile_time: - self.kernel_autotune_calls.writeline(line) if zero_fill: self.writeline(f"workspace.zero_(){self.ending}") - if config.triton.autotune_at_compile_time: - self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}") def wrap_kernel_call(self, name, call_args): return f"{name}({', '.join(call_args)}){self.ending}" @@ -1732,12 +1725,7 @@ def generate_kernel_call( key, arg = arg.split("=") if isinstance(arg_type, torch_dtype): - # workspace allocation is already generated by `generate_workspace_allocation()` - # in `TritonKernel.call_kernel()`. - if arg == "workspace": - arg_str = "workspace" - tensor_args[arg] = arg_str - elif arg not in tensor_args: + if arg not in tensor_args: arg_str = self.generate_example_arg_value( arg, arg_type, raw_arg, i ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 89c47838da43f3..658dcbf33d99e5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -754,7 +754,6 @@ def save_gpu_kernel(self, grid, stream, launcher): "x_block": launcher.config.kwargs.get("XBLOCK", 1), "y_block": launcher.config.kwargs.get("YBLOCK", None), "z_block": launcher.config.kwargs.get("ZBLOCK", None), - "r_block": launcher.config.kwargs.get("RBLOCK", None), "num_warps": launcher.bin.num_warps if hasattr(launcher.bin, "num_warps") else launcher.bin.metadata.num_warps, diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index c204fc27460bdc..b470b5f10061b8 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -528,8 +528,6 @@ aoti_torch_cpu__wrapped_quantized_linear_prepacked( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( AtenTensorHandle repeats, int64_t* output_size, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 2860972f88dac4..f49bf23b9ce422 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1166,10 +1166,3 @@ AOTITorchError aoti_torch__alloc_from_pool( strides)); }); } - -AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) { - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ - at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); - t->zero_(); - }); -} From c5d026a5d08d56178e301d6cfced0aad369aa8e7 Mon Sep 17 00:00:00 2001 From: Daohang Shi Date: Fri, 13 Sep 2024 17:48:25 +0000 Subject: [PATCH 0841/1018] [gpu-profiler] Expose active and repeat in os env var (#135757) Summary: https://fb.workplace.com/groups/ai.efficiency.tools.users/permalink/1855136444971825/ Test Plan: `buck2 test mode/opt caffe2/test:profiler -- -r test_kineto_profiler_api ` eyes Differential Revision: D62529249 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135757 Approved by: https://github.com/Yuzhen11 --- test/profiler/test_profiler.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 295f08fbd10a5c..f4f4e2e99270a4 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -990,19 +990,41 @@ def trace_handler(p): # print(output) test_schedule = torch.profiler.schedule( - skip_first=2, wait=1, warmup=1, active=2, repeat=2 + skip_first=3, wait=2, warmup=1, active=4, repeat=2 ) test_schedule_expected_outputs = [ + # skip first 3 ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, + # ---- + # repeat No. 1 begin + # wait 2 + ProfilerAction.NONE, + ProfilerAction.NONE, + # warmup 1 ProfilerAction.WARMUP, + # active 2 begin + ProfilerAction.RECORD, + ProfilerAction.RECORD, ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE, + # active 2 end + # repeat No. 1 end + # --- + # repeat No. 2 begin + # wait 2 ProfilerAction.NONE, + ProfilerAction.NONE, + # warmup 1 ProfilerAction.WARMUP, + # active 2 begin + ProfilerAction.RECORD, + ProfilerAction.RECORD, ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE, + # active 2 end + # repeat No. 2 end ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, From 41725697fa1141ec5ef2b38f6d5975a57fbd4d59 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 17:53:21 +0000 Subject: [PATCH 0842/1018] Revert "[Inductor] Rename `cpp_wrapper_cuda.py` as `cpp_wrapper_gpu.py` (#135313)" This reverts commit 16b37b309f64ddd4e498c57a99191e1d9b3dfdac. Reverted https://github.com/pytorch/pytorch/pull/135313 on behalf of https://github.com/izaitsevfb due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/135313#issuecomment-2349662091)) --- tools/amd_build/build_amd.py | 2 +- torch/_inductor/codegen/common.py | 2 +- .../codegen/{cpp_wrapper_gpu.py => cpp_wrapper_cuda.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename torch/_inductor/codegen/{cpp_wrapper_gpu.py => cpp_wrapper_cuda.py} (100%) diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 60a1be73fbb2d1..967f63ae2c8f15 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -207,7 +207,7 @@ def remove_hcc(line: str) -> str: ignores=ignores, extra_files=[ "torch/_inductor/codegen/cpp_wrapper_cpu.py", - "torch/_inductor/codegen/cpp_wrapper_gpu.py", + "torch/_inductor/codegen/cpp_wrapper_cuda.py", "torch/_inductor/codegen/wrapper.py", ], out_of_place_only=args.out_of_place_only, diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 95183c1484fe00..efbef7f461c8c2 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -232,7 +232,7 @@ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): def init_backend_registration(): from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu - from .cpp_wrapper_gpu import CppWrapperGpu + from .cpp_wrapper_cuda import CppWrapperGpu from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .triton import TritonScheduling diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py similarity index 100% rename from torch/_inductor/codegen/cpp_wrapper_gpu.py rename to torch/_inductor/codegen/cpp_wrapper_cuda.py From 7d6f92c1aaa52d000451a4b8ffcc2ab638ac3668 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 13 Sep 2024 18:06:56 +0000 Subject: [PATCH 0843/1018] Revert "Validate input types for `torch.nn.Linear` and `torch.nn.Bilinear` (#135596)" This reverts commit e157ce3ebbb3f30d008c15914e82eb74217562f0. Reverted https://github.com/pytorch/pytorch/pull/135596 on behalf of https://github.com/malfet due to It's too restrictive, should allow other int-like types, such as `numpy.int64` ([comment](https://github.com/pytorch/pytorch/pull/135596#issuecomment-2349714104)) --- torch/nn/modules/linear.py | 20 -------- torch/testing/_internal/common_modules.py | 62 ----------------------- 2 files changed, 82 deletions(-) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index aa4e67c827ae4f..dc5185b7eec0bf 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -98,13 +98,6 @@ def __init__( device=None, dtype=None, ) -> None: - # validation checks for input types - if not isinstance(in_features, int): - raise TypeError(f"Expected int for in_features but got {type(in_features)}") - if not isinstance(out_features, int): - raise TypeError( - f"Expected int for out_features but got {type(out_features)}" - ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_features = in_features @@ -207,19 +200,6 @@ def __init__( device=None, dtype=None, ) -> None: - # validation checks for input types - if not isinstance(in1_features, int): - raise TypeError( - f"Expected int for in1_features but got {type(in1_features)}" - ) - if not isinstance(in2_features, int): - raise TypeError( - f"Expected int for in2_features but got {type(in2_features)}" - ) - if not isinstance(out_features, int): - raise TypeError( - f"Expected int for out_features but got {type(out_features)}" - ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in1_features = in1_features diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 238b67e476713d..63963bab1b050a 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -3180,66 +3180,6 @@ def padding3d_circular_ref(inp, pad): # Start of module error inputs functions. -def module_error_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs): - make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - samples = [ - ErrorModuleInput( - ModuleInput( - constructor_input=FunctionInput("10", 20), - forward_input=FunctionInput(make_input(3, 10)), - ), - error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, - error_type=TypeError, - error_regex=r"Expected int for in_features but got " - ), - ErrorModuleInput( - ModuleInput( - constructor_input=FunctionInput(10, 20.7), - forward_input=FunctionInput(make_input(3, 10)), - ), - error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, - error_type=TypeError, - error_regex=r"Expected int for out_features but got " - ), - ] - return samples - - -def module_error_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs): - make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - samples = [ - ErrorModuleInput( - ModuleInput( - constructor_input=FunctionInput("10", 20, 30), - forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)), - ), - error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, - error_type=TypeError, - error_regex=r"Expected int for in1_features but got " - ), - ErrorModuleInput( - ModuleInput( - constructor_input=FunctionInput(10, 20.7, 30), - forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)), - ), - error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, - error_type=TypeError, - error_regex=r"Expected int for in2_features but got " - ), - ErrorModuleInput( - ModuleInput( - constructor_input=FunctionInput(10, 20, "30"), - forward_input=FunctionInput(make_input(3, 10), make_input(3, 20)), - ), - error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, - error_type=TypeError, - error_regex=r"Expected int for out_features but got " - ), - ] - return samples - - - def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) samples = [ @@ -3892,14 +3832,12 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad )), ModuleInfo(torch.nn.Linear, module_inputs_func=module_inputs_torch_nn_Linear, - module_error_inputs_func=module_error_inputs_torch_nn_Linear, skips=( # No channels_last support for Linear currently. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) ), ModuleInfo(torch.nn.Bilinear, module_inputs_func=module_inputs_torch_nn_Bilinear, - module_error_inputs_func=module_error_inputs_torch_nn_Bilinear, decorators=[ DecorateInfo( toleranceOverride({ From bf783dc0b06a8953c831dc27b1c79680535b8684 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Fri, 13 Sep 2024 18:43:55 +0000 Subject: [PATCH 0844/1018] fix requirements.txt installation failure issue on Windows (#134567) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #134564 Root cause: The `lintrunner` wheel released on [pypi.org](https://pypi.org/project/lintrunner/#files) only supports Windows 32bit and Linux 64bit. Since compilation of pytorch requires a 64bit env, on windows, the `lintrunner` has to be compiled from source distribution. `Rust` is its dependency for compilation, as indicated in the error message. Meanwhile, Visual Studio environment is needed for linking libraries.. ![image](https://github.com/user-attachments/assets/180cd899-8886-43b5-b42f-031f41e81683) Issue when performing `pip install lintrunner` without a Visual Studio environment activated is shown below. ```bash >python -m pip install lintrunner Collecting lintrunner Downloading lintrunner-0.12.5.tar.gz (62 kB) Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Building wheels for collected packages: lintrunner Building wheel for lintrunner (pyproject.toml) ... error error: subprocess-exited-with-error × Building wheel for lintrunner (pyproject.toml) did not run successfully. │ exit code: 1 ╰─> [137 lines of output] Running `maturin pep517 build-wheel -i C:\Users\\miniforge3\envs\py310\python.exe --compatibility off` 📡 Using build options bindings from pyproject.toml Compiling proc-macro2 v1.0.79 Compiling unicode-ident v1.0.12 Compiling version_check v0.9.4 Compiling windows_x86_64_msvc v0.52.4 Compiling winapi v0.3.9 Compiling serde v1.0.197 Compiling autocfg v1.2.0 Compiling syn v1.0.109 Compiling lazy_static v1.4.0 Compiling libc v0.2.153 Compiling equivalent v1.0.1 Compiling hashbrown v0.14.3 Compiling memchr v2.7.2 Compiling yansi v1.0.1 Compiling unicode-width v0.1.11 Compiling regex-syntax v0.8.3 Compiling encode_unicode v0.3.6 Compiling cfg-if v1.0.0 Compiling winnow v0.6.5 Compiling cc v1.0.92 error: could not compile `windows_x86_64_msvc` (build script) due to 2 previous errors warning: build failed, waiting for other jobs to finish... error: could not compile `serde` (build script) due to 2 previous errors error: could not compile `proc-macro2` (build script) due to 2 previous errors error: could not compile `syn` (build script) due to 2 previous errors error: could not compile `libc` (build script) due to 2 previous errors error: could not compile `winapi` (build script) due to 2 previous errors 💥 maturin failed Caused by: Failed to build a native library through cargo Caused by: Cargo build finished with "exit code: 101": `cargo rustc --manifest-path Cargo.toml --message-format json --release --bins --` 📦 Including license file "LICENSE" 🔗 Found bin bindings error: linker `link.exe` not found | = note: program not found note: the msvc targets depend on the msvc linker but `link.exe` was not found note: please ensure that Visual Studio 2017 or later, or Build Tools for Visual Studio were installed with the Visual C++ option. note: VS Code is a different product, and is not sufficient. error: aborting due to 1 previous error error: linker `link.exe` not found | = note: program not found note: the msvc targets depend on the msvc linker but `link.exe` was not found note: please ensure that Visual Studio 2017 or later, or Build Tools for Visual Studio were installed with the Visual C++ option. note: VS Code is a different product, and is not sufficient. error: aborting due to 1 previous error error: linker `link.exe` not found | = note: program not found note: the msvc targets depend on the msvc linker but `link.exe` was not found note: please ensure that Visual Studio 2017 or later, or Build Tools for Visual Studio were installed with the Visual C++ option. note: VS Code is a different product, and is not sufficient. error: aborting due to 1 previous error error: linker `link.exe` not found | = note: program not found note: the msvc targets depend on the msvc linker but `link.exe` was not found note: please ensure that Visual Studio 2017 or later, or Build Tools for Visual Studio were installed with the Visual C++ option. note: VS Code is a different product, and is not sufficient. error: aborting due to 1 previous error error: linker `link.exe` not found | = note: program not found note: the msvc targets depend on the msvc linker but `link.exe` was not found note: please ensure that Visual Studio 2017 or later, or Build Tools for Visual Studio were installed with the Visual C++ option. note: VS Code is a different product, and is not sufficient. error: aborting due to 1 previous error error: linker `link.exe` not found | = note: program not found note: the msvc targets depend on the msvc linker but `link.exe` was not found note: please ensure that Visual Studio 2017 or later, or Build Tools for Visual Studio were installed with the Visual C++ option. note: VS Code is a different product, and is not sufficient. error: aborting due to 1 previous error Error: command ['maturin', 'pep517', 'build-wheel', '-i', 'C:\\Users\\\\miniforge3\\envs\\py310\\python.exe', '--compatibility', 'off'] returned non-zero exit status 1 [end of output] note: This error originates from a subprocess, and is likely not a problem with pip. ERROR: Failed building wheel for lintrunner Failed to build lintrunner ERROR: ERROR: Failed to build installable wheels for some pyproject.toml based projects (lintrunner) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134567 Approved by: https://github.com/malfet --- README.md | 59 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index d7567419acb77b..c7dd72ccc77c64 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,8 @@ Our trunk health (Continuous Integration signals) can be found at [hud.pytorch.o - [NVIDIA CUDA Support](#nvidia-cuda-support) - [AMD ROCm Support](#amd-rocm-support) - [Intel GPU Support](#intel-gpu-support) - - [Install Dependencies](#install-dependencies) - [Get the PyTorch Source](#get-the-pytorch-source) + - [Install Dependencies](#install-dependencies) - [Install PyTorch](#install-pytorch) - [Adjust Build Options (Optional)](#adjust-build-options-optional) - [Docker Image](#docker-image) @@ -161,9 +161,34 @@ They require JetPack 4.2 and above, and [@dusty-nv](https://github.com/dusty-nv) #### Prerequisites If you are installing from source, you will need: - Python 3.8 or later (for Linux, Python 3.8.1+ is needed) -- A compiler that fully supports C++17, such as clang or gcc (gcc 9.4.0 or newer is required) +- A compiler that fully supports C++17, such as clang or gcc (gcc 9.4.0 or newer is required, on Linux) +- Visual Studio or Visual Studio Build Tool on Windows + +\* PyTorch CI uses Visual C++ BuildTools, which come with Visual Studio Enterprise, +Professional, or Community Editions. You can also install the build tools from +https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools *do not* +come with Visual Studio Code by default. + +\* We highly recommend installing an [Anaconda](https://www.anaconda.com/download) environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro. + +An example of environment setup is shown below: -We highly recommend installing an [Anaconda](https://www.anaconda.com/download) environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro. +* Linux: + +```bash +$ source /bin/activate +$ conda create -y -n +$ conda activate +``` + +* Windows: + +```bash +$ source \Scripts\activate.bat +$ conda create -y -n +$ conda activate +$ call "C:\Program Files\Microsoft Visual Studio\\Community\VC\Auxiliary\Build\vcvarsall.bat" x64 +``` ##### NVIDIA CUDA Support If you want to compile with CUDA support, [select a supported version of CUDA from our support matrix](https://pytorch.org/get-started/locally/), then install the following: @@ -194,12 +219,23 @@ If you want to compile with Intel GPU support, follow these If you want to disable Intel GPU support, export the environment variable `USE_XPU=0`. Other potentially useful environment variables may be found in `setup.py`. +#### Get the PyTorch Source +```bash +git clone --recursive https://github.com/pytorch/pytorch +cd pytorch +# if you are updating an existing checkout +git submodule sync +git submodule update --init --recursive +``` + #### Install Dependencies **Common** ```bash conda install cmake ninja +# Run this command on native Windows +conda install rust # Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below pip install -r requirements.txt ``` @@ -235,15 +271,6 @@ pip install mkl-static mkl-include conda install -c conda-forge libuv=1.39 ``` -#### Get the PyTorch Source -```bash -git clone --recursive https://github.com/pytorch/pytorch -cd pytorch -# if you are updating an existing checkout -git submodule sync -git submodule update --init --recursive -``` - #### Install PyTorch **On Linux** @@ -284,13 +311,6 @@ python3 setup.py develop **On Windows** -Choose Correct Visual Studio Version. - -PyTorch CI uses Visual C++ BuildTools, which come with Visual Studio Enterprise, -Professional, or Community Editions. You can also install the build tools from -https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools *do not* -come with Visual Studio Code by default. - If you want to build legacy python code, please refer to [Building on legacy code and CUDA](https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md#building-on-legacy-code-and-cuda) **CPU-only builds** @@ -298,7 +318,6 @@ If you want to build legacy python code, please refer to [Building on legacy cod In this mode PyTorch computations will run on your CPU, not your GPU ```cmd -conda activate python setup.py develop ``` From 5c2dbaed02b6f5a0553ceb693d718f201889718a Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Fri, 13 Sep 2024 19:59:47 +0000 Subject: [PATCH 0845/1018] Fix script name in the comments (#135507) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135507 Approved by: https://github.com/atalman --- scripts/release/tag-docker-images.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/release/tag-docker-images.sh b/scripts/release/tag-docker-images.sh index ab366ecc0e3e2b..f2299d6c463ee2 100644 --- a/scripts/release/tag-docker-images.sh +++ b/scripts/release/tag-docker-images.sh @@ -12,7 +12,7 @@ # git submodule update --init --recursive # # Usage (run from root of project): -# DRY_RUN=disabled ./scripts/release/tag_docker_images.sh +# DRY_RUN=disabled ./scripts/release/tag-docker-images.sh # set -eou pipefail From 03f617f6ba9498bda1152b63d17769b88a0b421a Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Fri, 13 Sep 2024 20:15:15 +0000 Subject: [PATCH 0846/1018] [reland][export] fix re-export custom metadata (#135720) Summary: Fixes https://github.com/pytorch/pytorch/issues/134778 The previous D62304294 broke some executorch tests. It has already been reverted. In this diff, `_collect_param_buffer_metadata()` is modified in a way that when a `call_function` node is encountered and its input nodes include `get_attr`. We skip the fields that have been collected previously and only collect rest of the fields. This prevents over-writing. Test Plan: ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//executorch/backends/xnnpack/test:test_xnnpack_ops buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_re_export_preserve_handle buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_run_decompositions_preserve_handle ``` Differential Revision: D62514208 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135720 Approved by: https://github.com/zhxchen17, https://github.com/jerryzh168 --- test/quantization/pt2e/test_numeric_debugger.py | 9 +++------ torch/_export/utils.py | 3 +++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index e6cc4dc3252f44..85b5e9137bdb5c 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -20,6 +20,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import TestHelperModules from torch.testing._internal.common_utils import IS_WINDOWS, TestCase @@ -116,22 +117,18 @@ def test_deepcopy_preserve_handle(self): self.assertEqual(debug_handle_map, debug_handle_map_ref) - @unittest.skip("All nodes' meta are preserved but get_attr nodes' meta are wrong.") def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - m = capture_pre_autograd_graph(m, example_inputs) + m = export_for_training(m, example_inputs).module() generate_numeric_debug_handle(m) debug_handle_map_ref = _extract_debug_handles(m) - m_export = capture_pre_autograd_graph(m, example_inputs) + m_export = export_for_training(m, example_inputs).module() debug_handle_map = _extract_debug_handles(m_export) self.assertEqual(debug_handle_map, debug_handle_map_ref) - @unittest.skip( - "All nodes' meta are preserved but the first arg for the first node seems to be dropped" - ) def test_run_decompositions_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() diff --git a/torch/_export/utils.py b/torch/_export/utils.py index e085e18b68a207..5b6590166bff8b 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -139,6 +139,9 @@ def _getattr(model: torch.fx.GraphModule, attr_name: str): for arg in node._input_nodes: if arg.op == "get_attr": for entry in torch.fx.proxy._COPY_META_FIELDS: + # the custom field should not be copied + if entry == "custom": + continue if entry in meta: params_buffers_to_node_meta[arg.target][entry] = meta[entry] From 8a5f0f0e289d058ae73c505e42cb124fe89f0fa4 Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 13 Sep 2024 20:27:11 +0000 Subject: [PATCH 0847/1018] Use python 3.11 for Large Wheel build (#136042) Use Python 3.11 in nightly Large wheel builds. Required for Colab testing Pull Request resolved: https://github.com/pytorch/pytorch/pull/136042 Approved by: https://github.com/kit1980, https://github.com/malfet Co-authored-by: Sergii Dymchenko --- .../scripts/generate_binary_build_matrix.py | 4 +- ...nerated-linux-binary-manywheel-nightly.yml | 140 +++++++++--------- ...rated-linux-binary-manywheel-split-nightly | 140 +++++++++--------- 3 files changed, 142 insertions(+), 142 deletions(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 365f750e1ad7bb..993459bf320471 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -412,8 +412,8 @@ def generate_wheels_matrix( ), } ) - # Special build building to use on Colab. PyThon 3.10 for 12.1 CUDA - if python_version == "3.10" and arch_version == "12.1": + # Special build building to use on Colab. Python 3.11 for 12.1 CUDA + if python_version == "3.11" and arch_version == "12.1": ret.append( { "python_version": python_version, diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index f817e5a009122d..5a86872b3e288a 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -1010,76 +1010,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_1-full-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: False - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-full-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_1-full-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: False - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-full-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_1-full-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: False - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-full - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -1766,6 +1696,76 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml + manywheel-py3_11-cuda12_1-full-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_1-full + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-full-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_1-full-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1-full + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-full-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda12_1-full-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1-full + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + manywheel-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml diff --git a/.github/workflows/generated-linux-binary-manywheel-split-nightly b/.github/workflows/generated-linux-binary-manywheel-split-nightly index 0d29acf56a0835..c3e0dbdd07c198 100644 --- a/.github/workflows/generated-linux-binary-manywheel-split-nightly +++ b/.github/workflows/generated-linux-binary-manywheel-split-nightly @@ -467,76 +467,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_1-full-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-full-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_1-full-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-full-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_1-full-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-full - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -817,6 +747,76 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml + manywheel-py3_11-cuda12_1-full-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_1-full + build_environment: linux-binary-manywheel-split + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-full-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_1-full-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1-full + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-full-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda12_1-full-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1-full + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + manywheel-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml From b0230d3fdfc792280ad800cf7c4c2243d54169cb Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 12 Sep 2024 18:10:03 -0700 Subject: [PATCH 0848/1018] Add gpu and gpu_dynamic versions of add_loop (#135809) I am thinking maybe 3 iterations are enough for this one? - so I am keeping eager and inductor since inductor is 2X eager time - Eager dynamic is 2X eager so keeping this as well. - inductor have three tests. (dynamic gpu, gpu and cpu) I am unsure if am over profiling here happy to trim if anyone have suggestions. ``` collecting compile time instruction count for add_loop_eager compile time instruction count for iteration 0 is 8213664211 compile time instruction count for iteration 1 is 2798628246 compile time instruction count for iteration 2 is 2796811362 compile time instruction count for iteration 3 is 2794438188 compile time instruction count for iteration 4 is 2794634117 collecting compile time instruction count for add_loop_eager_dynamic compile time instruction count for iteration 0 is 5724108021 compile time instruction count for iteration 1 is 5499908609 compile time instruction count for iteration 2 is 5569101366 compile time instruction count for iteration 3 is 5493806364 compile time instruction count for iteration 4 is 5493169851 collecting compile time instruction count for add_loop_inductor compile time instruction count for iteration 0 is 49789381222 compile time instruction count for iteration 1 is 25769347393 compile time instruction count for iteration 2 is 25772594322 compile time instruction count for iteration 3 is 25768695952 compile time instruction count for iteration 4 is 25768032314 collecting compile time instruction count for add_loop_inductor_gpu compile time instruction count for iteration 0 is 23966942581 compile time instruction count for iteration 1 is 23771950919 compile time instruction count for iteration 2 is 23770784286 compile time instruction count for iteration 3 is 23780160875 compile time instruction count for iteration 4 is 23774634465 collecting compile time instruction count for add_loop_inductor_dynamic_gpu compile time instruction count for iteration 0 is 41505055086 compile time instruction count for iteration 1 is 41293654089 compile time instruction count for iteration 2 is 41301016100 compile time instruction count for iteration 3 is 41306056207 compile time instruction count for iteration 4 is 41308171566 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135809 Approved by: https://github.com/ezyang, https://github.com/anijain2305 --- .../pr_time_benchmarks/benchmarks/add_loop.py | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py index 9959ab452db62c..f56d12a6d1e410 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py @@ -4,21 +4,29 @@ from benchmark_base import BenchmarkBase import torch +from torch._inductor.utils import fresh_inductor_cache class Benchmark(BenchmarkBase): - def __init__(self, backend): - self.backend = backend + def __init__(self, backend, dynamic=False, is_gpu=False): + self._backend = backend + self._dynamic = dynamic + self._device = "cuda" if is_gpu else "cpu" def name(self): - return f"add_loop_{self.backend}" + prefix = f"add_loop_{self._backend}" + if self._dynamic: + prefix += "_dynamic" + if self._device == "cuda": + prefix += "_gpu" + return prefix def description(self): return "a loop over 100 add node" - def _prepare_once(self, dynamic=True): - self.a = torch.ones(1000) - self.b = torch.torch.ones(1000) + def _prepare_once(self): + self.a = torch.ones(1000, device=self._device) + self.b = torch.torch.ones(1000, device=self._device) def _prepare(self): torch._dynamo.reset() @@ -26,7 +34,7 @@ def _prepare(self): gc.disable() def _work(self): - @torch.compile(backend=self.backend, fullgraph=True) + @torch.compile(backend=self._backend, fullgraph=True, dynamic=self._dynamic) def f(a, b): result = a.clone() for i in range(1000): @@ -38,17 +46,24 @@ def f(a, b): result = result.sin() return result - f(self.a, self.b) + with fresh_inductor_cache(): + f(self.a, self.b) def main(): result_path = sys.argv[1] - Benchmark( - "eager" - ).enable_compile_time_instruction_count().collect_all().append_results(result_path) - Benchmark( - "inductor" - ).enable_compile_time_instruction_count().collect_all().append_results(result_path) + all = [ + Benchmark("eager"), + Benchmark("eager", dynamic=True), + Benchmark("inductor"), + Benchmark("inductor", is_gpu=True), + Benchmark("inductor", is_gpu=True, dynamic=True), + ] + + for benchmark in all: + benchmark.enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) if __name__ == "__main__": From 02b75c1d471dbb09f1e9ce894ffc64fe8d6850e1 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 12 Sep 2024 19:02:50 -0700 Subject: [PATCH 0849/1018] Only measure compile time instruction count for sum_floordiv benchmark (#135785) there was a recent strange noise +5%, -5%. using only compile time : 1) avoid gc time . 2) avoid other operations that are not what we try to measure by this. ==> less probable noise. ``` collecting compile time instruction count for sum_floordiv_regression compile time instruction count for iteration 0 is 8899290248 compile time instruction count for iteration 1 is 1188830489 compile time instruction count for iteration 2 is 1180579615 compile time instruction count for iteration 3 is 1176263131 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135785 Approved by: https://github.com/avikchaudhuri, https://github.com/anijain2305 --- .../dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py index 16d2f8ed4eec16..3bd22d18bad00e 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py @@ -32,7 +32,9 @@ def _work(self): def main(): result_path = sys.argv[1] - Benchmark().enable_instruction_count().collect_all().append_results(result_path) + Benchmark().enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) if __name__ == "__main__": From b611999bfa42e8275ba4f6515aea1a5d81ea73f5 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 12 Sep 2024 18:10:36 -0700 Subject: [PATCH 0850/1018] Reduce default iterations to 5 . (#135773) running all benchmarks takes around 15 mins rn, this is the data https://www.internalfb.com/phabricator/paste/view/P1583590240 the data looks mostly stable, and 5 iterations should be good, specially with our 1.5% threshold. that said, the diff also add a way to increase the number of iterations for a specific benchmark. after the change results https://www.internalfb.com/phabricator/paste/view/P1583618969 time is down to half (7 mins) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135773 Approved by: https://github.com/ezyang, https://github.com/anijain2305 --- .../dynamo/pr_time_benchmarks/benchmark_base.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py index f15a87b07c10a9..d279922ea1aa69 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py @@ -60,6 +60,13 @@ class BenchmarkBase(ABC): # TODO is there other parts we need to add ? _enable_compile_time_instruction_count = False + # number of iterations used to run when collecting instruction_count or compile_time_instruction_count. + _num_iterations = 5 + + def with_iterations(self, value): + self._num_iterations = value + return self + def enable_instruction_count(self): self._enable_instruction_count = True return self @@ -88,7 +95,7 @@ def _prepare_once(self): # noqa: B027 def _count_instructions(self): print(f"collecting instruction count for {self.name()}") results = [] - for i in range(10): + for i in range(self._num_iterations): self._prepare() id = i_counter.start() self._work() @@ -102,7 +109,7 @@ def _count_compile_time_instructions(self): config.record_compile_time_instruction_count = True results = [] - for i in range(10): + for i in range(self._num_iterations): self._prepare() # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner # hence this will only count instruction count spent in compile_inner. From fd28c983da1ff1a3fbaadc4eba605b22de195511 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 13 Sep 2024 21:29:06 +0000 Subject: [PATCH 0851/1018] Fix bug in split-build workflows codegen (#136043) By just deleting a few rogue lines left out in https://github.com/pytorch/pytorch/pull/135510 If file in workflows folder does not have a `.yml` extensions it will not be launched at all, will it? Pull Request resolved: https://github.com/pytorch/pytorch/pull/136043 Approved by: https://github.com/kit1980, https://github.com/atalman --- .github/scripts/generate_ci_workflows.py | 5 ----- ...-main => generated-linux-binary-manywheel-split-main.yml} | 0 ...ly => generated-linux-binary-manywheel-split-nightly.yml} | 0 3 files changed, 5 deletions(-) rename .github/workflows/{generated-linux-binary-manywheel-split-main => generated-linux-binary-manywheel-split-main.yml} (100%) rename .github/workflows/{generated-linux-binary-manywheel-split-nightly => generated-linux-binary-manywheel-split-nightly.yml} (100%) diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 302d8263846440..f9c857a3ed9cb9 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -79,11 +79,6 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: GITHUB_DIR / f"workflows/generated-{self.build_environment}-{self.branches}.yml" ) - if self.use_split_build: - output_file_path = ( - GITHUB_DIR - / f"workflows/generated-{self.build_environment}-{self.branches}" - ) with open(output_file_path, "w") as output_file: GENERATED = "generated" # Note that please keep the variable GENERATED otherwise phabricator will hide the whole file output_file.writelines([f"# @{GENERATED} DO NOT EDIT MANUALLY\n"]) diff --git a/.github/workflows/generated-linux-binary-manywheel-split-main b/.github/workflows/generated-linux-binary-manywheel-split-main.yml similarity index 100% rename from .github/workflows/generated-linux-binary-manywheel-split-main rename to .github/workflows/generated-linux-binary-manywheel-split-main.yml diff --git a/.github/workflows/generated-linux-binary-manywheel-split-nightly b/.github/workflows/generated-linux-binary-manywheel-split-nightly.yml similarity index 100% rename from .github/workflows/generated-linux-binary-manywheel-split-nightly rename to .github/workflows/generated-linux-binary-manywheel-split-nightly.yml From de1216bb747ac73f3ed88b3d339d5972f159c467 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 13 Sep 2024 22:05:10 +0000 Subject: [PATCH 0852/1018] Fix inductor-micro-benchmark results upload (take 2) (#136052) I had a brain freeze when I wrote the original fix. The parameters were in the wrong order. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136052 Approved by: https://github.com/clee2000, https://github.com/kit1980, https://github.com/malfet --- .github/workflows/upload-test-stats.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 630e971e77532a..f9e5593bf66ff1 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -96,7 +96,7 @@ jobs: python3 -m tools.stats.check_disabled_tests --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" - name: Upload gpt-fast benchmark results to Rockset - if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' && contains('inductor-micro-benchmark', github.event.workflow_run.name) + if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' && contains(github.event.workflow_run.name, 'inductor-micro-benchmark') env: ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }} WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} From b74119c3f7966956c4dd636c72e66616fba5de3b Mon Sep 17 00:00:00 2001 From: Jessica Vandebon Date: Fri, 13 Sep 2024 22:13:58 +0000 Subject: [PATCH 0853/1018] [MTIA tensor] allow shallow copy between CPU and MTIA tensors (#135871) Reviewed By: egienvalue, hanzlfs Differential Revision: D61662214 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135871 Approved by: https://github.com/egienvalue, https://github.com/nautsimon --- c10/core/TensorImpl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index f1abeb0c33eae4..a8d05dddcfa26c 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2034,7 +2034,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { BackendComponent::MPSBit, BackendComponent::HIPBit, BackendComponent::XPUBit, - BackendComponent::HPUBit}); + BackendComponent::HPUBit, + BackendComponent::MTIABit}); constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense); return ts.has_any(dense_k) && ts.has_any(dense_backends); }; From a20c3a421e38fc2a56b15fd9322c81860ee6ddcc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 13 Sep 2024 22:19:51 +0000 Subject: [PATCH 0854/1018] [ONNX] Remove logging apis from public (#133825) Remove - torch.onnx.enable_log - torch.onnx.disable_log - torch.onnx.set_log_stream - torch.onnx.log Because they are not meant for public consumption and has been marked for deprecation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133825 Approved by: https://github.com/titaiwangms --- docs/source/onnx_torchscript.rst | 2 -- torch/onnx/__init__.py | 36 -------------------------------- torch/onnx/utils.py | 16 +++++++------- 3 files changed, 8 insertions(+), 46 deletions(-) diff --git a/docs/source/onnx_torchscript.rst b/docs/source/onnx_torchscript.rst index 2009aea813d038..8c8032bd26b4da 100644 --- a/docs/source/onnx_torchscript.rst +++ b/docs/source/onnx_torchscript.rst @@ -702,8 +702,6 @@ Functions .. autofunction:: unregister_custom_op_symbolic .. autofunction:: select_model_mode_for_export .. autofunction:: is_in_onnx_export -.. autofunction:: enable_log -.. autofunction:: disable_log .. autofunction:: torch.onnx.verification.find_mismatch Classes diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 27c8e2c6240c18..73979c014bf54b 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -36,8 +36,6 @@ "select_model_mode_for_export", "register_custom_op_symbolic", "unregister_custom_op_symbolic", - "disable_log", - "enable_log", # Base error "OnnxExporterError", # Dynamo Exporter @@ -517,37 +515,3 @@ def _to_dynamic_shape(x): return dynamo_export( model, *model_args, export_options=export_options, **model_kwargs ) - - -# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module. - -# Returns True iff ONNX logging is turned on. -is_onnx_log_enabled = _C._jit_is_onnx_log_enabled - - -def enable_log() -> None: - r"""Enables ONNX logging.""" - _C._jit_set_onnx_log_enabled(True) - - -def disable_log() -> None: - r"""Disables ONNX logging.""" - _C._jit_set_onnx_log_enabled(False) - - -"""Sets output stream for ONNX logging. - -Args: - stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported - as ``stream_name``. -""" -set_log_stream = _C._jit_set_onnx_log_output_stream - - -"""A simple logging facility for ONNX exporter. - -Args: - args: Arguments are converted to string, concatenated together with a newline - character appended to the end, and flushed to output stream. -""" -log = _C._jit_onnx_log diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 2b7848e5805845..37d7ee4b35a739 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -139,14 +139,14 @@ def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFun @contextlib.contextmanager def setup_onnx_logging(verbose: bool): - is_originally_enabled = torch.onnx.is_onnx_log_enabled() - if is_originally_enabled or verbose: - torch.onnx.enable_log() + is_originally_enabled = _C._jit_is_onnx_log_enabled + if is_originally_enabled or verbose: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(True) try: yield finally: - if not is_originally_enabled: - torch.onnx.disable_log() + if not is_originally_enabled: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(False) @contextlib.contextmanager @@ -1125,7 +1125,7 @@ def _model_to_graph( module=module, ) except Exception as e: - torch.onnx.log("Torch IR graph at exception: ", graph) + _C._jit_onnx_log("Torch IR graph at exception: ", graph) raise is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) @@ -1452,7 +1452,7 @@ def _get_module_attributes(module): try: attrs[k] = getattr(module, k) except AttributeError: - torch.onnx.log(f"Skipping module attribute '{k}'") + _C._jit_onnx_log(f"Skipping module attribute '{k}'") continue return attrs @@ -1642,7 +1642,7 @@ def _export( custom_opsets, ) if verbose: - torch.onnx.log("Exported graph: ", graph) + _C._jit_onnx_log("Exported graph: ", graph) onnx_proto_utils._export_file(proto, f, export_type, export_map) finally: assert GLOBALS.in_onnx_export From ba9a8f799d762cfaef3d6255bfbe8e549cdf1092 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 13 Sep 2024 22:50:07 +0000 Subject: [PATCH 0855/1018] Cleanup unused runner variants (#136058) Cleaning up unused runner variants, leaving behind only the few that are actually referenced by workflows For more details see description in the PR that generated these code changes: - https://github.com/pytorch/test-infra/pull/5665 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136058 Approved by: https://github.com/wdvr, https://github.com/malfet --- .github/lf-canary-scale-config.yml | 120 ++--------------------------- .github/lf-scale-config.yml | 120 ++--------------------------- 2 files changed, 14 insertions(+), 226 deletions(-) diff --git a/.github/lf-canary-scale-config.yml b/.github/lf-canary-scale-config.yml index b7f7c7d9cd3efd..482b55e04423e9 100644 --- a/.github/lf-canary-scale-config.yml +++ b/.github/lf-canary-scale-config.yml @@ -7,10 +7,14 @@ # runners. Runners listed here will be available as self hosted # runners, configuration is directly pulled from the main branch. # -# NOTE (Apr, 5, 2021): Linux runners are currently all an amazonlinux2 # -# NOTE (Jan 5, 2021): Linux runners are all non-ephemeral to reduce the amount of CreateInstaces calls -# to avoid RequestLimitExceeded issues +# NOTES: +# - Linux runners are by default non-ephemeral to reduce the amount of CreateInstaces calls +# to avoid RequestLimitExceeded issues +# - When updating this file, run the following command to validate the YAML and to generate +# corresponding versions of scale-config for the pytorch/pytorch repo and merge the +# pytorch/pytorch changes before merging these changes. +# `python .github/scripts/validate_scale_config.py --test-infra-repo-root [path_to_test-infra_root] --pytorch-repo-root [path_to_pytorch_root]`` # # TODO: Add some documentation on how the auto-scaling works # @@ -35,8 +39,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.10xlarge.avx2: disk_size: 200 instance_type: m4.10xlarge @@ -44,11 +46,6 @@ runner_types: max_available: 450 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.24xl.spr-metal: disk_size: 200 instance_type: c7i.metal-24xl @@ -56,11 +53,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.16xlarge.spr: disk_size: 200 instance_type: c7i.16xlarge @@ -68,11 +60,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.9xlarge.ephemeral: disk_size: 200 instance_type: c5.9xlarge @@ -81,8 +68,6 @@ runner_types: os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 am2: ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.12xlarge.ephemeral: @@ -92,11 +77,6 @@ runner_types: max_available: 300 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.16xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.16xlarge @@ -104,11 +84,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.24xlarge: disk_size: 150 instance_type: c5.24xlarge @@ -116,11 +91,6 @@ runner_types: max_available: 500 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.24xlarge.ephemeral: disk_size: 150 instance_type: c5.24xlarge @@ -128,11 +98,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.2xlarge: disk_size: 150 instance_type: c5.2xlarge @@ -140,11 +105,6 @@ runner_types: max_available: 3120 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.4xlarge: disk_size: 150 instance_type: c5.4xlarge @@ -155,8 +115,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.4xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.4xlarge @@ -164,11 +122,6 @@ runner_types: max_available: 1000 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.8xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.8xlarge @@ -179,8 +132,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g4dn.12xlarge.nvidia.gpu: disk_size: 150 instance_type: g4dn.12xlarge @@ -188,11 +139,6 @@ runner_types: max_available: 250 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g4dn.metal.nvidia.gpu: disk_size: 150 instance_type: g4dn.metal @@ -200,11 +146,6 @@ runner_types: max_available: 300 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g5.48xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.48xlarge @@ -212,11 +153,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g5.12xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.12xlarge @@ -224,11 +160,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g5.4xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.4xlarge @@ -236,11 +167,6 @@ runner_types: max_available: 2400 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g6.4xlarge.experimental.nvidia.gpu: disk_size: 150 instance_type: g6.4xlarge @@ -251,8 +177,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.large: max_available: 1200 disk_size: 15 @@ -260,11 +184,6 @@ runner_types: is_ephemeral: false os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.arm64.2xlarge: disk_size: 256 instance_type: t4g.2xlarge @@ -272,11 +191,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.c.linux.arm64.m7g.4xlarge: disk_size: 256 instance_type: m7g.4xlarge @@ -284,11 +198,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.c.linux.arm64.2xlarge.ephemeral: disk_size: 256 instance_type: t4g.2xlarge @@ -296,11 +205,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.c.linux.arm64.m7g.4xlarge.ephemeral: disk_size: 256 instance_type: m7g.4xlarge @@ -308,11 +212,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.c.linux.arm64.m7g.metal: disk_size: 256 instance_type: m7g.metal @@ -320,11 +219,6 @@ runner_types: max_available: 100 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.c.windows.g4dn.xlarge: disk_size: 256 instance_type: g4dn.xlarge diff --git a/.github/lf-scale-config.yml b/.github/lf-scale-config.yml index a0b6e30375be02..7c352157cb464a 100644 --- a/.github/lf-scale-config.yml +++ b/.github/lf-scale-config.yml @@ -7,10 +7,14 @@ # runners. Runners listed here will be available as self hosted # runners, configuration is directly pulled from the main branch. # -# NOTE (Apr, 5, 2021): Linux runners are currently all an amazonlinux2 # -# NOTE (Jan 5, 2021): Linux runners are all non-ephemeral to reduce the amount of CreateInstaces calls -# to avoid RequestLimitExceeded issues +# NOTES: +# - Linux runners are by default non-ephemeral to reduce the amount of CreateInstaces calls +# to avoid RequestLimitExceeded issues +# - When updating this file, run the following command to validate the YAML and to generate +# corresponding versions of scale-config for the pytorch/pytorch repo and merge the +# pytorch/pytorch changes before merging these changes. +# `python .github/scripts/validate_scale_config.py --test-infra-repo-root [path_to_test-infra_root] --pytorch-repo-root [path_to_pytorch_root]`` # # TODO: Add some documentation on how the auto-scaling works # @@ -35,8 +39,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.10xlarge.avx2: disk_size: 200 instance_type: m4.10xlarge @@ -44,11 +46,6 @@ runner_types: max_available: 450 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.24xl.spr-metal: disk_size: 200 instance_type: c7i.metal-24xl @@ -56,11 +53,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.16xlarge.spr: disk_size: 200 instance_type: c7i.16xlarge @@ -68,11 +60,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.9xlarge.ephemeral: disk_size: 200 instance_type: c5.9xlarge @@ -81,8 +68,6 @@ runner_types: os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 am2: ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.12xlarge.ephemeral: @@ -92,11 +77,6 @@ runner_types: max_available: 300 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.16xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.16xlarge @@ -104,11 +84,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.24xlarge: disk_size: 150 instance_type: c5.24xlarge @@ -116,11 +91,6 @@ runner_types: max_available: 500 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.24xlarge.ephemeral: disk_size: 150 instance_type: c5.24xlarge @@ -128,11 +98,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.2xlarge: disk_size: 150 instance_type: c5.2xlarge @@ -140,11 +105,6 @@ runner_types: max_available: 3120 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.4xlarge: disk_size: 150 instance_type: c5.4xlarge @@ -155,8 +115,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.4xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.4xlarge @@ -164,11 +122,6 @@ runner_types: max_available: 1000 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.8xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.8xlarge @@ -179,8 +132,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g4dn.12xlarge.nvidia.gpu: disk_size: 150 instance_type: g4dn.12xlarge @@ -188,11 +139,6 @@ runner_types: max_available: 250 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g4dn.metal.nvidia.gpu: disk_size: 150 instance_type: g4dn.metal @@ -200,11 +146,6 @@ runner_types: max_available: 300 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g5.48xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.48xlarge @@ -212,11 +153,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g5.12xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.12xlarge @@ -224,11 +160,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g5.4xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.4xlarge @@ -236,11 +167,6 @@ runner_types: max_available: 2400 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g6.4xlarge.experimental.nvidia.gpu: disk_size: 150 instance_type: g6.4xlarge @@ -251,8 +177,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.large: max_available: 1200 disk_size: 15 @@ -260,11 +184,6 @@ runner_types: is_ephemeral: false os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.arm64.2xlarge: disk_size: 256 instance_type: t4g.2xlarge @@ -272,11 +191,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.linux.arm64.m7g.4xlarge: disk_size: 256 instance_type: m7g.4xlarge @@ -284,11 +198,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.linux.arm64.2xlarge.ephemeral: disk_size: 256 instance_type: t4g.2xlarge @@ -296,11 +205,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.linux.arm64.m7g.4xlarge.ephemeral: disk_size: 256 instance_type: m7g.4xlarge @@ -308,11 +212,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.linux.arm64.m7g.metal: disk_size: 256 instance_type: m7g.metal @@ -320,11 +219,6 @@ runner_types: max_available: 100 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.windows.g4dn.xlarge: disk_size: 256 instance_type: g4dn.xlarge From 677a68663d2e424076ff8d66d16f8e01f90a5493 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Fri, 13 Sep 2024 22:53:08 +0000 Subject: [PATCH 0856/1018] [PT2][Inductor][Optimus] Fix a corner case in remove_split_with_size_one (#135962) Summary: see context in https://fb.workplace.com/groups/1075192433118967/permalink/1501768230461383/ Test Plan: # local reproduce ``` CUDA_VISIBLE_DEVICES=3 OC_CAUSE=1 buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode batch-split --model_type "mai" --flow_id 642153776 ``` P1586356950 # e2e before fix f642153776 after fix Differential Revision: D62625318 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135962 Approved by: https://github.com/jackiexu1992 --- torch/_inductor/fx_passes/split_cat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index d0098a535ee295..b098781677a84e 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -238,7 +238,9 @@ def remove_split_with_size_one(match: Match, *args, **kwargs): # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. return # remove the dummy split whose split sections size is one - if len(split_sections) == 1: + # theoretically nodes with no users should be removed, but we have seen the corner case + # thus we add its uers check to walk around the StopIteration error. + if len(split_sections) == 1 and len(split_node.users.keys()) > 0: # find the grand children of the split_node next_users = find_next_users(split_node) user = next(iter(split_node.users.keys())) From 37283a2cfdede375f40a65d81b910abc0229c4b6 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 11 Sep 2024 21:17:13 -0700 Subject: [PATCH 0857/1018] Fix attr check for quantization spec (#135736) Summary: Previously we only checked dtype and is_dynamic to decide if two quantization spec are equivalent this may not work in some cases, e.g. when people use different qscheme or quant_min/quant_max This PR added checks for other fields as well Test Plan: regression tests Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D62530974](https://our.internmc.facebook.com/intern/diff/D62530974) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135736 Approved by: https://github.com/sxu --- test/quantization/pt2e/test_quantize_pt2e.py | 23 +++++-------- torch/ao/quantization/pt2e/prepare.py | 36 +++++++++++--------- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 62fbbf7162747e..06a3c56f16dbec 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -600,6 +600,14 @@ def validate(self, model: torch.fx.GraphModule) -> None: def test_fixed_qparams_qspec_observer_dedup(self): class BackendAQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = FixedQParamsQuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + scale=1.0 / 256.0, + zero_point=0, + ) for node in model.graph.nodes: if ( node.op == "call_function" @@ -607,14 +615,6 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: ): input_act = node.args[0] assert isinstance(input_act, Node) - act_qspec = FixedQParamsQuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=255, - qscheme=torch.per_tensor_affine, - scale=1.0 / 256.0, - zero_point=0, - ) node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ input_act: act_qspec, @@ -630,13 +630,6 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: assert isinstance(input_act, Node) input_act1 = node.args[1] assert isinstance(input_act, Node) - act_qspec = QuantizationSpec( - observer_or_fake_quant_ctr=observer.default_observer, - dtype=torch.uint8, - quant_min=0, - quant_max=255, - qscheme=torch.per_tensor_affine, - ) node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ input_act0: act_qspec, diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 88ab2be9ff5a83..290e3d771c3230 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -98,20 +98,14 @@ def _unwrap_shared_qspec( return qspec -def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): - return ( - hasattr(qspec_a, "dtype") - and hasattr(qspec_b, "dtype") - and qspec_a.dtype == qspec_b.dtype - ) - - -def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): +def _has_same_attr( + qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str +): return ( - hasattr(qspec_a, "is_dynamic") - and hasattr(qspec_b, "is_dynamic") - and qspec_a.is_dynamic == qspec_b.is_dynamic - ) + hasattr(qspec_a, attr_name) + and hasattr(qspec_b, attr_name) + and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name) + ) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name)) def _get_edge_or_node_to_qspec( @@ -148,10 +142,18 @@ def _union_input_edge_with( qspec = edge_or_node_to_qspec[edge_or_node] root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) # TODO: add assertions for types of root qspecs - if ( - root_qspec is not None - and _has_same_dtype(root_qspec, input_edge_root_qspec) - and _has_same_is_dynamic(root_qspec, input_edge_root_qspec) + if root_qspec is not None and all( + _has_same_attr(root_qspec, input_edge_root_qspec, attr) + for attr in [ + "dtype", + "is_dynamic", + "quant_min", + "quant_max", + "qscheme", + "ch_axis", + "scale", + "zero_point", + ] ): # the input arg to the node should reuse the existing output observer for arg # since dtype is the same (we may want to extend this to be a more strict check From df40c424f11cf46c654ae40d2d922d686465afb6 Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 13 Sep 2024 09:41:53 -0700 Subject: [PATCH 0858/1018] [Easy] Dont match to mm_plus_mm if not in max autotune (#135929) It's only an optimization when we tune the triton template. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135929 Approved by: https://github.com/FindHao --- test/inductor/test_pattern_matcher.py | 1 + torch/_inductor/fx_passes/post_grad.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 5f34b9ceb69a25..e63e8fcfb4e96d 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -66,6 +66,7 @@ def common( additional_check(codes) counters.clear() + @inductor_config.patch(max_autotune_gemm=True) def test_mm_plus_mm(self): def fn(a, b, c, d): return torch.add(torch.mm(a, b), torch.mm(c, d)) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 640c0a8ffa18bd..064c6de94aed61 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -211,6 +211,9 @@ def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1): def is_valid_mm_plus_mm(match: Match): + if not torch._inductor.utils.use_max_autotune(): + return False + *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape if k1 != k2: From 3212730ef7f3bfdbb5291a27865cb16182864e7d Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 13 Sep 2024 13:02:33 -0700 Subject: [PATCH 0859/1018] [3.13] fix 3.13 pickle error in serialization.py (#136034) Error encountered when adding dynamo 3.13 support. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136034 Approved by: https://github.com/albanD --- torch/serialization.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torch/serialization.py b/torch/serialization.py index d936d31d6f5209..d937680c031c78 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1005,8 +1005,12 @@ def persistent_id(obj: Any) -> Optional[Tuple]: pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) pickle_module.dump(sys_info, f, protocol=pickle_protocol) - pickler = pickle_module.Pickler(f, protocol=pickle_protocol) - pickler.persistent_id = persistent_id + + class PyTorchLegacyPickler(pickle_module.Pickler): + def persistent_id(self, obj): + return persistent_id(obj) + + pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol) pickler.dump(obj) serialized_storage_keys = sorted(serialized_storages.keys()) @@ -1083,8 +1087,12 @@ def persistent_id(obj): # Write the pickle data for `obj` data_buf = io.BytesIO() - pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) - pickler.persistent_id = persistent_id + + class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined] + def persistent_id(self, obj): + return persistent_id(obj) + + pickler = PyTorchPickler(data_buf, protocol=pickle_protocol) pickler.dump(obj) data_value = data_buf.getvalue() zip_file.write_record("data.pkl", data_value, len(data_value)) From 546a0ebb222063328b045ac5292076d808ce740c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 13 Sep 2024 11:38:48 -0700 Subject: [PATCH 0860/1018] [Distributed] add pack-check method for float8_e4m3fn (#135961) We check 8 x FP8 simultaneously, at size of 8 bytes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135961 Approved by: https://github.com/yifuwang, https://github.com/Skylion007 ghstack dependencies: #135891 --- torch/csrc/distributed/c10d/NanCheck.cu | 38 ++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/NanCheck.cu b/torch/csrc/distributed/c10d/NanCheck.cu index 0aab4ca809260d..02a4b6e963a4b7 100644 --- a/torch/csrc/distributed/c10d/NanCheck.cu +++ b/torch/csrc/distributed/c10d/NanCheck.cu @@ -6,6 +6,7 @@ #include #include #include +#include namespace c10d { @@ -13,7 +14,12 @@ namespace c10d { // is raised if NAN is found // Using ulong2 as a "byte pack", with 16 bytes, for efficient data load -typedef ulong2 BytePack; +union BytePack16 { + ulong2 ul2; + uint64_t ul[2]; +}; + +typedef union BytePack16 BytePack; // AMD HIP doesn't define `__trap()`, using `assert` instead #ifdef USE_ROCM @@ -73,6 +79,36 @@ struct CheckBytePack { } }; +// (v) Template specialization for Float8_e4m3fn. +// EltPerPack = 16 / 1 = 16 + +// isnan condition for Float8_e4m3fn: +// (x & 0b01111111) == 0b01111111 +// i.e. +// (x & 0x7f) == 0x7f + +// We want to check 8 x FP8 simultaneously. The algorithm is as follows: +// (1) Mask out the most significant bit with mask 0x7f. +// (2) If the result is 0x7f (is nan), the following arithmetic would cause the +// 8th bit to be 1: x[i] = x[i] + 0x01 +// (3) Only leave the 8th bit by masking with 0x80. +// (4) If any x[i] is nan, then the whole x != 0. + +template<> +struct CheckBytePack { + static __device__ __forceinline__ bool hasNanFP8x8(uint64_t fp8x8) { + auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; + auto incremented = t + 0x0101010101010101ULL; + auto overflow = incremented & 0x8080808080808080ULL; + return overflow != 0; + } + + static __device__ __forceinline__ void check(BytePack* tmp) { + if (hasNanFP8x8(tmp->ul[0]) || hasNanFP8x8(tmp->ul[1])) + __trap(); + } +}; + //// End of templated functions for checking NaNs inside a BytePack From 3f9ead0abe5c4e5a164995a0e69f0121c04a2c17 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sat, 14 Sep 2024 00:35:37 +0000 Subject: [PATCH 0861/1018] [BE] Use squeeze/unsqueeze in im2col (#136006) And move unsqeeze out of the dispatch, as it's dtype agnostic Pull Request resolved: https://github.com/pytorch/pytorch/pull/136006 Approved by: https://github.com/Skylion007, https://github.com/eqy --- aten/src/ATen/native/cuda/Col2Im.cu | 8 ++++---- aten/src/ATen/native/cuda/Im2Col.cu | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/cuda/Col2Im.cu b/aten/src/ATen/native/cuda/Col2Im.cu index e390666b307aa1..4d8f4935039165 100644 --- a/aten/src/ATen/native/cuda/Col2Im.cu +++ b/aten/src/ATen/native/cuda/Col2Im.cu @@ -91,7 +91,7 @@ void col2im_out_cuda_template( if (input.dim() == 2) { // Force batch batched_input = false; - input = input.view({1, input.size(0), input.size(1)}); + input = input.unsqueeze(0); } int64_t batch_size = input.size(0); @@ -134,10 +134,10 @@ void col2im_out_cuda_template( output.mutable_data_ptr(), output_batch_stride); - if (!batched_input) { - output.resize_({n_output_plane, output_height, output_width}); - } }); + if (!batched_input) { + output = output.squeeze(0); + } } } // namespace diff --git a/aten/src/ATen/native/cuda/Im2Col.cu b/aten/src/ATen/native/cuda/Im2Col.cu index 796eda97b37331..d74a5f5f641a06 100644 --- a/aten/src/ATen/native/cuda/Im2Col.cu +++ b/aten/src/ATen/native/cuda/Im2Col.cu @@ -81,7 +81,7 @@ static void im2col_out_cuda_template( if (input.dim() == 3) { batched_input = false; - input = input.view({1, input.size(0), input.size(1), input.size(2)}); + input = input.unsqueeze(0); } int64_t batch_size = input.size(0); @@ -131,10 +131,10 @@ static void im2col_out_cuda_template( output_n.mutable_data_ptr()); } - if (!batched_input) { - output.resize_({n_output_plane, output_length}); - } }); + if (!batched_input) { + output = output.squeeze(0); + } } } // namespace From 9d3f7ead8bf97a04dc53a3ce71b35f91b93d8b86 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Fri, 13 Sep 2024 16:29:23 -0400 Subject: [PATCH 0862/1018] Fix sum() forward for NJT (#131945) This PR solves two problems with `sum()` support in NJT: * `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519. * Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails). Pull Request resolved: https://github.com/pytorch/pytorch/pull/131945 Approved by: https://github.com/davidberard98, https://github.com/jananisriram --- test/test_nestedtensor.py | 15 +++++++++++---- torch/nested/_internal/ops.py | 18 +++++++++++++++--- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 82db5d038bee71..408c9b4559dd2c 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4337,12 +4337,14 @@ def test_op_dim_reduce_ragged_idx_1_different_output_shape( out_expected = torch.cat( [func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()] ) + if keepdim: + out_expected = out_expected.unsqueeze(reduce_dim[0]) self.assertFalse( out_actual.is_nested, f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", ) # output is a dense tensor - self.assertTrue(torch.allclose(out_actual, out_expected)) + self.assertEqual(out_actual, out_expected) @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @@ -4602,12 +4604,14 @@ def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape( for t in nt_transposed.unbind() ] ) + if keepdim: + out_expected = out_expected.unsqueeze(reduce_dim[0]) self.assertFalse( out_actual.is_nested, f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", ) # output is a dense tensor - self.assertTrue(torch.allclose(out_actual, out_expected, rtol=1e-4)) + self.assertEqual(out_actual, out_expected) @dtypes(torch.float32) @parametrize( @@ -7407,8 +7411,6 @@ def g(nt): "jiterator_binary", "jiterator_binary_return_by_ref", "jiterator_unary", - # Bug found: sum() with keepdim=True returns invalid shape - "sum", # RuntimeError: prod(): keepdim=True must be set for NestedTensor "prod", # RuntimeError: "jagged_to_padded_dense" not implemented for 'Bool' @@ -7442,6 +7444,8 @@ def g(nt): "clone", # Calling into torch.ops.aten.size directly "masked_select", + # NotImplementedError: aten._nested_sum_backward.default. Need to fix the backward pass. + "sum", } COMPILE_FORWARD_FAILURES = { @@ -7449,6 +7453,9 @@ def g(nt): # clone() on non-contiguous with holes NJTs currently use unbind(), leading to # data-dependent error in torch.compile "clone", + # torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default + # for inputs where min / max seqlen are not cached + "sum", } COMPARE_TENSOR_COMPONENT_EQUALITY = { diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index f8053d3bcf7baf..9b7544ec2fc06d 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1017,6 +1017,17 @@ def is_same_size_default(func, *args, **kwargs): return args[0]._size == args[1]._size +@register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?") +def sum_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return func(inp._values, **new_kwargs) + + @register_jagged_func( torch.ops.aten.sum.dim_IntList, "self: jt_all, dim: any?, keepdim: any?, dtype: any?", @@ -1097,8 +1108,7 @@ def sum_dim_IntList(func, *args, **kwargs): ) # need to read offsets --> pad jagged dimension and apply sum if new_kwargs["keepdim"]: - # TODO: Fix this; it's a bug. should be unsqueezing on ragged_idx - out = out.unsqueeze(0) + out = out.unsqueeze(inp._ragged_idx) return out else: # raggedness preserved --> return nested tensor if ( @@ -1463,8 +1473,10 @@ def mean_dim(func, *args, **kwargs): # for every non-batch dimension, # unsqueeze lengths into the same shape as the PyTorch sum, # as the extra dimensions must all be divided by the same length + # Note: keepdim=True is on at this point so lengths has to be unsqueezed for + # that 1-size dim as well. lengths = inp._offsets.diff() - for _ in range(inp.dim() - 2): + for _ in range(inp.dim() - 1): lengths = lengths.unsqueeze(-1) return torch_sum / lengths.broadcast_to(torch_sum.shape) From 154834d39b79ad19ef374ee2d04244c9fc26a62a Mon Sep 17 00:00:00 2001 From: "Li, Xingyuan" Date: Sat, 14 Sep 2024 01:43:05 +0000 Subject: [PATCH 0863/1018] [Inductor UT] Generalize inductor UT for intel GPU (Part 3) (#135827) [Inductor UT] Reuse Inductor test case for Intel GPU. Reuse `test/inductor/test_compiled_autograd.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135827 Approved by: https://github.com/etaf, https://github.com/desertfire --- test/inductor/test_compiled_autograd.py | 49 +++++++++++++++---------- torch/_dynamo/device_interface.py | 6 +++ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 672ea6bb17b74b..3e4fb849494936 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -20,11 +20,12 @@ from torch import _inductor as inductor from torch._dynamo import compiled_autograd, config from torch._dynamo.backends.debugging import aot_eager +from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import skipIfWindows -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU from torch.testing._internal.logging_utils import logs_to_string @@ -810,9 +811,9 @@ def fn(): self.check_output_and_recompiles(fn, count=1) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_issue106555(self): - DEVICE = torch.device("cuda:0") + DEVICE = torch.device(GPU_TYPE, 0) NUM_FEATURES = 256 def bias_sigmoid_mul(x1, x2, bias): @@ -855,7 +856,8 @@ def _forward(self, x): x = x + self.module_with_jit_2(x.transpose(-2, -3)).transpose(-2, -3) return x - torch.cuda.set_device(device=DEVICE) + device_interface = get_interface_for_device(GPU_TYPE) + device_interface.set_device(device=DEVICE) torch.manual_seed(1234567890) model = Model() model.train() @@ -1075,7 +1077,7 @@ def backward(ctx, gO_1, gO_2, gO_3): self.check_output_and_recompiles(fn, count=2) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_logging_tensor_flaky(self) -> None: # when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore # resulting in: @@ -1088,7 +1090,7 @@ def _fn(x): return x x = torch.arange( - 1, 10, requires_grad=True, dtype=torch.float16, device="cuda" + 1, 10, requires_grad=True, dtype=torch.float16, device=GPU_TYPE ) out = _fn(x) loss = out.sum() @@ -1121,7 +1123,7 @@ def forward(inputs): compiled_fn(inputs) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_custom_fn_output_metadata(self): def my_compiler_fn(gm): for node in gm.graph.nodes: @@ -1149,7 +1151,7 @@ def backward(ctx, gO): return gO x = torch.arange( - 1, 10, requires_grad=True, dtype=torch.float16, device="cuda" + 1, 10, requires_grad=True, dtype=torch.float16, device=GPU_TYPE ) x_view = x.view(3, 3) out = MyFn.apply(x_view) @@ -2091,17 +2093,20 @@ def fn(): self.check_output_and_recompiles(fn, 3) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_free_activation_memory(self): script = """ import torch +from torch._dynamo.device_interface import get_interface_for_device +from torch.testing._internal.inductor_utils import GPU_TYPE def main(): - assert(torch.cuda.memory_allocated() == 0) + device_interface = get_interface_for_device(GPU_TYPE) + assert(device_interface.memory_allocated() == 0) # Use an op to check that the memory is freed by the time the op is executed def assertion_impl(to_clone): - mem_allocated = torch.cuda.memory_allocated() + mem_allocated = device_interface.memory_allocated() assert mem_allocated < 4000000 # some activations should be freed return to_clone.clone() @@ -2124,8 +2129,8 @@ def forward(activations): compiled_fn = torch.compile(gm) # allocate at least 4,000,000 bytes (1,000,000 * 4 bytes) - activations = [torch.ones(1000000, dtype=torch.float32, device="cuda")] - assert torch.cuda.memory_allocated() > 4000000 + activations = [torch.ones(1000000, dtype=torch.float32, device=GPU_TYPE)] + assert device_interface.memory_allocated() > 4000000 out = compiled_fn(activations) assert len(activations) == 0 @@ -2134,19 +2139,22 @@ def forward(activations): """ self.run_as_subprocess(script) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_free_activation_memory_subclass(self): # cover the case when aot inputs have subclasses, resulting in a different runtime wrapper script = """ import torch +from torch._dynamo.device_interface import get_interface_for_device +from torch.testing._internal.inductor_utils import GPU_TYPE def main(): - assert torch.cuda.memory_allocated() == 0 + device_interface = get_interface_for_device(GPU_TYPE) + assert device_interface.memory_allocated() == 0 # Use an op to check that the memory is freed by the time the op is executed def assertion_impl(to_clone): - mem_allocated = torch.cuda.memory_allocated() + mem_allocated = device_interface.memory_allocated() assert mem_allocated < 1200000 # some activations should be freed assert mem_allocated > 800000 # currently subclasses don't seem to be freed in inductor return to_clone.clone() @@ -2174,23 +2182,24 @@ def fn(inputs): activations = [ jagged_from_list( [ - torch.ones((1, 100000), device="cuda"), # 400,000 bytes - torch.ones((1, 100000), device="cuda"), # 400,000 bytes + torch.ones((1, 100000), device=GPU_TYPE), # 400,000 bytes + torch.ones((1, 100000), device=GPU_TYPE), # 400,000 bytes ], None, )[ 0 ], # NestedTensor - torch.ones((1, 100000), device="cuda"), # 400,000 bytes + torch.ones((1, 100000), device=GPU_TYPE), # 400,000 bytes ] # 1,200,000 bytes (3 * 4 * 100,000 bytes) - assert torch.cuda.memory_allocated() > 1200000 + assert device_interface.memory_allocated() > 1200000 out = compiled_fn(activations) assert len(activations) == 0 main() """ + self.run_as_subprocess(script) def test_callback_graph_break_throws_error(self): called = [0] diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 5670172c49c52c..00975cda550818 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -123,6 +123,10 @@ def get_compute_capability(device: _device_t = None): def is_bf16_supported(including_emulation: bool = False): raise NotImplementedError + @staticmethod + def memory_allocated(device: _device_t = None) -> int: + raise NotImplementedError + class DeviceGuard: """ @@ -202,6 +206,7 @@ def get_device_properties(device: _device_t = None): get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type] exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type] maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type] + memory_allocated = staticmethod(torch.cuda.memory_allocated) is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type] # Can be mock patched by @patch decorator. @@ -273,6 +278,7 @@ def get_device_properties(device: _device_t = None): get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type] exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type] maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type] + memory_allocated = staticmethod(torch.xpu.memory_allocated) # Can be mock patched by @patch decorator. @staticmethod From fe70664dc0f4eee996d39c8f436aac3e869cfbc2 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Sat, 14 Sep 2024 01:51:36 +0000 Subject: [PATCH 0864/1018] [CI] Check that PyTorch is built with OpenMP (#136060) Restriction for x86 only builds should have been removed long time ago Pull Request resolved: https://github.com/pytorch/pytorch/pull/136060 Approved by: https://github.com/clee2000, https://github.com/kit1980, https://github.com/ZainRizvi --- .ci/pytorch/macos-test.sh | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index a54b8c360eba56..5fa57485a0a7a1 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -9,15 +9,13 @@ if [[ -n "$CONDA_ENV" ]]; then export PATH="$CONDA_ENV/bin":$PATH fi -# Test that OpenMP is enabled for non-arm64 build -if [[ ${BUILD_ENVIRONMENT} != *arm64* ]]; then - pushd test - if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available()))") == "1" ]]; then - echo "Build should have OpenMP enabled, but torch.backends.openmp.is_available() is False" - exit 1 - fi - popd +# Test that OpenMP is enabled +pushd test +if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available()))") == "1" ]]; then + echo "Build should have OpenMP enabled, but torch.backends.openmp.is_available() is False" + exit 1 fi +popd setup_test_python() { # The CircleCI worker hostname doesn't resolve to an address. From 73e5c700d5fa58329afe5cf46d198a8ccbf3b053 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Tue, 10 Sep 2024 14:13:47 +0000 Subject: [PATCH 0865/1018] Fix tensor.data_ptr() representation overflow (#135567) # Motivation fix https://github.com/pytorch/pytorch/issues/135550 In PyTorch, [`tensor.data_ptr()`](https://github.com/pytorch/pytorch/blob/e889252493558a56263618faae9a9ef421c2a47d/tools/autograd/templates/python_variable_methods.cpp#L204) is reinterpreted by a [signed int64](https://github.com/pytorch/pytorch/blob/e889252493558a56263618faae9a9ef421c2a47d/torch/csrc/autograd/utils/wrap_outputs.h#L50) data type, which could result in an **overflow issue**, like below: ```python import torch a = torch.randn(2).to('xpu') a.data_ptr() # one possible output is -23453392437248 # this is inconsistent with storage.data_ptr() a.untyped_storage().data_ptr() # one possible output is 18446720620317114368 ``` This PR aims to fix this representation overflow issue to make `tensor.data_ptr()` consistent with [`tensor.untyped_storage().data_ptr()`](https://github.com/pytorch/pytorch/blob/c0d2f991b14d50f8081d788d4a3dc6584ee15502/torch/csrc/StorageMethods.cpp#L62). With this PR, the output will become: ```python import torch a = torch.randn(2).to('xpu') a.data_ptr() # one possible output is 18446720620317114368 # this is consistent with storage.data_ptr() a.untyped_storage().data_ptr() # one possible output is 18446720620317114368 ``` # Solution Use `PyLong_FromVoidPtr` to prevent the overflow issue and fit the semantic of `wrap`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135567 Approved by: https://github.com/dvrogozh, https://github.com/EikanWang, https://github.com/albanD --- torch/csrc/StorageMethods.cpp | 5 +---- torch/csrc/autograd/utils/wrap_outputs.h | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 19f02b464530d0..a9d3f64f914550 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -49,9 +49,6 @@ static PyObject* THPStorage_nbytes(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_dataPtr(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - // PyLong_FromVoidPtr should not need to mutate the pointer in order - // to extract a new long object from it. - auto self_ = THPStorage_Unpack(self); // See Note [Invalid Python Storages] auto invalid = self_.data() == nullptr && @@ -59,7 +56,7 @@ static PyObject* THPStorage_dataPtr(PyObject* self, PyObject* noargs) { TORCH_CHECK( !invalid, "Attempted to access the data pointer on an invalid python storage.") - return PyLong_FromVoidPtr(self_.mutable_data()); + return torch::autograd::utils::wrap(self_.mutable_data()); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h index 5d1d17c597af32..72d7a6c76d7441 100644 --- a/torch/csrc/autograd/utils/wrap_outputs.h +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -47,7 +47,7 @@ inline PyObject* wrap(c10::complex value) { } inline PyObject* wrap(void* value) { - return THPUtils_packInt64(reinterpret_cast(value)); + return PyLong_FromVoidPtr(value); } inline PyObject* wrap(THPDtype* dtype) { From 3f2693f41f0c6038999071404f958958f5a2c179 Mon Sep 17 00:00:00 2001 From: CaoE Date: Sat, 14 Sep 2024 02:20:56 +0000 Subject: [PATCH 0866/1018] Use _amp_foreach_non_finite_check_and_unscale_ for CPU grads of ShardedGradScaler (#135232) Use `_amp_foreach_non_finite_check_and_unscale_` instead of fallback version for CPU grads of `ShardedGradScaler ` as `_amp_foreach_non_finite_check_and_unscale_ ` is supported on CPU https://github.com/pytorch/pytorch/pull/109281. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135232 Approved by: https://github.com/ezyang --- aten/src/ATen/native/cpu/PaddingKernel.cpp | 32 +++++++++---------- test/test_mps.py | 3 ++ .../_internal/common_methods_invocations.py | 9 ++---- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/cpu/PaddingKernel.cpp b/aten/src/ATen/native/cpu/PaddingKernel.cpp index 302346c4515c9b..1aabb8a3d50d26 100644 --- a/aten/src/ATen/native/cpu/PaddingKernel.cpp +++ b/aten/src/ATen/native/cpu/PaddingKernel.cpp @@ -486,7 +486,7 @@ void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, Int cpu_padding(output, input, param); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad1d", [&] { cpu_padding(output, input, param); }); @@ -496,7 +496,7 @@ void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, Int void reflection_pad1d_backward_kernel_impl( const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) { PaddingParams param{grad_input, grad_output, padding}; - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad1d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); @@ -513,14 +513,14 @@ void reflection_pad2d_kernel_impl(const Tensor& output, const Tensor& input, Int } else { switch (input.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad2d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad2d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -537,14 +537,14 @@ void reflection_pad2d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (grad_output.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad2d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad2d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); @@ -603,7 +603,7 @@ void reflection_pad3d_backward_kernel_impl( // replication padding void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) { PaddingParams param{input, output, padding}; - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,input.scalar_type(), "replication_pad1d", [&] { cpu_padding(output, input, param); }); @@ -612,7 +612,7 @@ void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, In void replication_pad1d_backward_kernel_impl( const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) { PaddingParams param{grad_input, grad_output, padding}; - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad1d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); @@ -622,14 +622,14 @@ void replication_pad2d_kernel_impl(const Tensor& output, const Tensor& input, In PaddingParams param{input, output, padding}; switch (input.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad2d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad2d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -645,14 +645,14 @@ void replication_pad2d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (grad_output.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad2d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad2d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); @@ -667,14 +667,14 @@ void replication_pad3d_kernel_impl(const Tensor& output, const Tensor& input, In PaddingParams param{input, output, padding}; switch (padding_memory_format_3d(input)) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad3d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast3d: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad3d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -690,14 +690,14 @@ void replication_pad3d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (padding_memory_format_3d(grad_output)) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad3d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast3d: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad3d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); diff --git a/test/test_mps.py b/test/test_mps.py index 2c3039a1602a77..1855dd3e75de56 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -12216,6 +12216,9 @@ def req_grad(t): cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) + if op.name == "nn.functional.pad" and op.variant_test_name == "replicate" and dtype == torch.float16: + atol = 1e-5 + rtol = 1.5e-3 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a174154349f8e4..5f4b98aaa5eaac 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15220,8 +15220,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='reflect', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), skips=( # Doesn't have a corresponding aten operator. @@ -15235,8 +15234,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='replicate', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), skips=( # Doesn't have a corresponding aten operator. @@ -15250,8 +15248,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='replicate_negative', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_nn_pad_replicate_negative, skips=( # Doesn't have a corresponding aten operator. From 8c3deddba6902a94026257809738fb66c1c89d10 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 14 Sep 2024 02:31:06 +0000 Subject: [PATCH 0867/1018] Revert "Use _amp_foreach_non_finite_check_and_unscale_ for CPU grads of ShardedGradScaler (#135232)" This reverts commit 51c52061339069a2162e921e5b464fad5a411522. Reverted https://github.com/pytorch/pytorch/pull/135232 on behalf of https://github.com/CaoE due to wrong commit ([comment](https://github.com/pytorch/pytorch/pull/135232#issuecomment-2350792806)) --- aten/src/ATen/native/cpu/PaddingKernel.cpp | 32 +++++++++---------- test/test_mps.py | 3 -- .../_internal/common_methods_invocations.py | 9 ++++-- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/cpu/PaddingKernel.cpp b/aten/src/ATen/native/cpu/PaddingKernel.cpp index 1aabb8a3d50d26..302346c4515c9b 100644 --- a/aten/src/ATen/native/cpu/PaddingKernel.cpp +++ b/aten/src/ATen/native/cpu/PaddingKernel.cpp @@ -486,7 +486,7 @@ void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, Int cpu_padding(output, input, param); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), "reflection_pad1d", [&] { cpu_padding(output, input, param); }); @@ -496,7 +496,7 @@ void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, Int void reflection_pad1d_backward_kernel_impl( const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) { PaddingParams param{grad_input, grad_output, padding}; - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), "reflection_pad1d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); @@ -513,14 +513,14 @@ void reflection_pad2d_kernel_impl(const Tensor& output, const Tensor& input, Int } else { switch (input.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), "reflection_pad2d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), "reflection_pad2d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -537,14 +537,14 @@ void reflection_pad2d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (grad_output.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), "reflection_pad2d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), "reflection_pad2d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); @@ -603,7 +603,7 @@ void reflection_pad3d_backward_kernel_impl( // replication padding void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) { PaddingParams param{input, output, padding}; - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), "replication_pad1d", [&] { cpu_padding(output, input, param); }); @@ -612,7 +612,7 @@ void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, In void replication_pad1d_backward_kernel_impl( const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) { PaddingParams param{grad_input, grad_output, padding}; - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), "replication_pad1d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); @@ -622,14 +622,14 @@ void replication_pad2d_kernel_impl(const Tensor& output, const Tensor& input, In PaddingParams param{input, output, padding}; switch (input.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), "replication_pad2d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), "replication_pad2d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -645,14 +645,14 @@ void replication_pad2d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (grad_output.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), "replication_pad2d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), "replication_pad2d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); @@ -667,14 +667,14 @@ void replication_pad3d_kernel_impl(const Tensor& output, const Tensor& input, In PaddingParams param{input, output, padding}; switch (padding_memory_format_3d(input)) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), "replication_pad3d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast3d: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), "replication_pad3d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -690,14 +690,14 @@ void replication_pad3d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (padding_memory_format_3d(grad_output)) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), "replication_pad3d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast3d: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), "replication_pad3d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); diff --git a/test/test_mps.py b/test/test_mps.py index 1855dd3e75de56..2c3039a1602a77 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -12216,9 +12216,6 @@ def req_grad(t): cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) - if op.name == "nn.functional.pad" and op.variant_test_name == "replicate" and dtype == torch.float16: - atol = 1e-5 - rtol = 1.5e-3 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5f4b98aaa5eaac..a174154349f8e4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15220,7 +15220,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='reflect', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + dtypes=all_types_and_complex_and(torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), skips=( # Doesn't have a corresponding aten operator. @@ -15234,7 +15235,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='replicate', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), skips=( # Doesn't have a corresponding aten operator. @@ -15248,7 +15250,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='replicate_negative', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_nn_pad_replicate_negative, skips=( # Doesn't have a corresponding aten operator. From bd71fee118c1832997ff05aa9f9eeb6347f6bcd4 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 13 Sep 2024 10:40:53 -0700 Subject: [PATCH 0868/1018] [Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732) For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732 Approved by: https://github.com/ydwu4 --- torch/_dynamo/backends/debugging.py | 12 +++++ torch/_higher_order_ops/cond.py | 20 +++++--- torch/_higher_order_ops/scan.py | 16 +++++-- torch/_higher_order_ops/while_loop.py | 21 +++++++-- torch/fx/experimental/proxy_tensor.py | 67 ++++++++++++++++----------- torch/nn/attention/flex_attention.py | 42 +++++++++++------ 6 files changed, 124 insertions(+), 54 deletions(-) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 87214f7d59774e..18784498862e39 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs): return gm.forward +def make_eager_backend_with_torch_function_mode(mode): + """Used to trace HOPs (cond and while) for eager exectution, the metadata + TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks + in the HOP, so we need to externally run this mode and not trace it.""" + + def fn(gm, fake_tensor_inputs, **kwargs): + with mode: + return gm.forward + + return fn + + @register_backend def eager_noexcept(gm, fake_tensor_inputs, **kwargs): if kwargs: diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index dee400d76f5964..0467e2899adc28 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -28,6 +28,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, @@ -129,6 +130,10 @@ def false_fn(x: torch.Tensor): if torch.compiler.is_dynamo_compiling(): return cond_op(pred, true_fn, false_fn, operands) + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + if isinstance(pred, (bool, int, float)): log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." @@ -169,12 +174,15 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(): - with torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_pre_dispatch_torch_function_mode(): - return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( - pred, true_fn, false_fn, operands - ) + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( + pred, true_fn, false_fn, operands + ) def create_fw_bw_graph_branches(true_fn, false_fn, *operands): diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index d6eecb0a9f29f9..d66cff067f6682 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -20,6 +20,7 @@ from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, track_tensor_tree, @@ -118,10 +119,19 @@ def _scan_op_wrapper(*args, **kwargs): return scan(*args, **kwargs) if not torch._dynamo.is_compiling(): + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)( - combine_fn, init, xs, dim=dim, reverse=reverse - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_scan_op_wrapper, backend=backend, fullgraph=True)( + combine_fn, init, xs, dim=dim, reverse=reverse + ) leaves_init, spec_init = pytree.tree_flatten(init) leaves_xs, spec_xs = pytree.tree_flatten(xs) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index d64f512399d9d5..f14321842f40b9 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -15,7 +15,11 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + ProxyTorchDispatchMode, + track_tensor_tree, +) class WhileLoopOp(HigherOrderOperator): @@ -113,6 +117,9 @@ def body_fn(iter, x): - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. """ + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. @@ -140,9 +147,15 @@ def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( - cond_fn, body_fn, carried_inputs, additional_inputs - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile( + _while_loop_op_wrapper, backend=backend, fullgraph=True + )(cond_fn, body_fn, carried_inputs, additional_inputs) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 5a5771d2ae8f7a..213672e216e6ce 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from contextlib import contextmanager, ExitStack, nullcontext +from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -1084,38 +1084,43 @@ def unwrap_proxy(self, e: T) -> object: return e -@contextmanager -def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: - from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode +def _make_temp_remove_mode_context_manager( + mode_ty: Type[TorchFunctionMode], +) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: + @contextmanager + def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - temp_elements = [] - pre_dispatch_mode = None + temp_elements = [] + removed_mode = None - while _len_torch_function_stack() > 0: - mode = _pop_mode() - if isinstance(mode, PreDispatchTorchFunctionMode): - pre_dispatch_mode = mode - break - else: - temp_elements.append(mode) + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, mode_ty): + removed_mode = mode + break + else: + temp_elements.append(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + for mode in reversed(temp_elements): + _push_mode(mode) - try: - yield + try: + yield removed_mode - finally: - if pre_dispatch_mode is not None: - count = len(temp_elements) - while count > 0: - mode = _pop_mode() - count -= 1 + finally: + if removed_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(removed_mode) - temp_elements.append(pre_dispatch_mode) + for mode in reversed(temp_elements): + _push_mode(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + return context_manager_fn @torch._disable_dynamo @@ -1230,6 +1235,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( + TorchFunctionMetadataMode +) + + # This mode is **only** used for pre_dispatch tracing. # In particular, we need to make sure that autograd/autocast API's # that do not desugar into dispatcher operators stay in the graph. @@ -1258,6 +1268,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( + PreDispatchTorchFunctionMode +) + + class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index bb04ce53d0fcab..2317dbac83bdda 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,6 +19,7 @@ ) from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input @@ -1033,6 +1034,10 @@ def score_mod( if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + # Dynamo is expecting a callable with "__code__" attribute. # We cannot directly pass hop to it. So we wrap it in a dummy function. def _flex_attention_hop_wrapper(*args, **kwargs): @@ -1041,18 +1046,25 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _set_compilation_env(): with torch._dynamo.utils.disable_cache_limit(): with _temp_remove_pre_dispatch_torch_function_mode(): - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend="eager", fullgraph=True - )( - query, - key, - value, - score_mod, - block_mask.as_tuple(), - scale, - kernel_options, - ) - if return_lse: - return out, lse * math.log(2) - else: - return out + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode( + metadata_mode + ) + else: + backend = "eager" + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out From 4bab849ea89b395a4f4a557e84588ca69818bcbb Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 13 Sep 2024 10:40:54 -0700 Subject: [PATCH 0869/1018] [Dynamo] Trace torch function modes entered outside of torch.compile (#133137) This PR adds initial tracing for torch function modes. Details: In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers. Previously landed: https://github.com/pytorch/pytorch/pull/133135 https://github.com/pytorch/pytorch/pull/133136 https://github.com/pytorch/pytorch/pull/133134 https://github.com/pytorch/pytorch/pull/133133 https://github.com/pytorch/pytorch/pull/133132 https://github.com/pytorch/pytorch/pull/133131 https://github.com/pytorch/pytorch/pull/133729 https://github.com/pytorch/pytorch/pull/133130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137 Approved by: https://github.com/jansel, https://github.com/zou3519 ghstack dependencies: #134732 --- test/dynamo/test_modes.py | 135 +++++++++ test/functorch/test_control_flow.py | 4 + .../pt2e/test_metadata_porting.py | 4 +- torch/_dynamo/convert_frame.py | 5 + torch/_dynamo/guards.py | 14 + torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 78 ++---- torch/_dynamo/trace_rules.py | 5 +- torch/_dynamo/utils.py | 10 +- torch/_dynamo/variables/ctx_manager.py | 2 + torch/_dynamo/variables/torch.py | 260 ++++++++++-------- torch/_dynamo/variables/torch_function.py | 129 +++++++-- torch/_dynamo/variables/user_defined.py | 7 - torch/_export/non_strict_utils.py | 6 +- 14 files changed, 457 insertions(+), 204 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index ee8a9b579ff124..bad5a788903440 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -14,6 +14,17 @@ from torch.utils._python_dispatch import TorchDispatchMode +class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -324,6 +335,130 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 2) + def test_nested_torch_function_mode(self): + mode_1_called = False + mode_2_called = False + + def reset_state(): + nonlocal mode_1_called + nonlocal mode_2_called + mode_1_called = False + mode_2_called = False + + ones = torch.ones(2, 2) + zeros = torch.zeros(2, 2) + + class TestMode1(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_1_called + + mode_1_called = True + + if func == torch.add: + return zeros + + return super().__torch_function__(func, types, args, kwargs) + + class TestMode2(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_2_called + + mode_2_called = True + + if func == torch.mul: + return ones + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + def fn_2(x): + return torch.mul(x, 3) + torch.add(x, 3) + + inp = torch.ones(2, 2) + 1 + + for fn_i in [fn, fn_2]: + fn_opt = torch.compile(fn_i, fullgraph=True) + with TestMode1(), TestMode2(): + expected = fn_i(inp), mode_1_called, mode_2_called + reset_state() + actual = fn_opt(inp), mode_1_called, mode_2_called + reset_state() + + self.assertEqual(expected, actual) + + def test_torch_function_mode_disable(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + with torch._C.DisableTorchFunctionSubclass(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + with torch._C.DisableTorchFunction(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_highest_priority(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 16740fdeddf3a6..c1fb51af8e2f90 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -26,6 +26,7 @@ parametrize, requires_cuda, run_tests, + skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, @@ -2882,6 +2883,7 @@ def f(fct, init, xs): gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x) @skipIfNoDynamoSupport + @skipIfCrossRef # Arg order changes with crossref def test_scan_simple_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -2988,6 +2990,7 @@ def f(x, y): self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -3250,6 +3253,7 @@ def test_while_loop_compile(self, backend, while_loop_test): self._check_compile(fn, inp, backend=backend) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] from torch._dynamo.testing import EagerAndRecordGraphs diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 27df58e159f3b6..21488e8cacdd42 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef class TestHelperModules: @@ -139,6 +139,8 @@ def _test_metadata_porting( self.assertEqual(v, node_tags[k]) return m + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_simple_metadata_porting(self): """ Model under test diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 4d1557e15a7d23..b3e4a7c268c0fe 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -605,6 +605,10 @@ def _compile( output: Optional[OutputGraph] = None tracer: Optional[InstructionTranslator] = None + tf_mode_stack: List[ + torch.overrides.TorchFunctionMode + ] = torch.overrides._get_current_function_mode_stack() + @preserve_global_state def transform( instructions: List[Instruction], code_options: Dict[str, object] @@ -618,6 +622,7 @@ def transform( locals, globals, builtins, + tf_mode_stack, code_options, compiler_fn, one_graph, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index b0bd273e96bbb5..4437ef4038190d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -98,6 +98,7 @@ ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, + TorchFunctionModeStackSource, TupleIteratorGetItemSource, TypeSource, UnspecializedBuiltinNNModuleSource, @@ -111,6 +112,7 @@ dict_keys_repr, get_custom_getattr, get_torch_function_mode_stack, + get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, @@ -314,6 +316,7 @@ def uninteresting_files(): "___dict_contains": lambda a, b: a in b, "___tuple_iterator_len": tuple_iterator_len, "___tuple_iterator_getitem": tuple_iterator_getitem, + "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -901,6 +904,15 @@ def get_guard_manager_from_source(self, source): ): assert base_guard_manager # to make mypy happy out = base_guard_manager + elif istype(source, TorchFunctionModeStackSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: get_torch_function_mode_stack_at( + source._get_index() + ), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -2214,6 +2226,8 @@ def __init__( self.output_graph = output_graph w_builder = None + # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing + # in case a set default device call was made in the graph. self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 6ed9d97f6a1210..c44f8d9e69abf5 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source): ind: int def name(self): - return "" + return f"___get_torch_function_mode_stack_at({self._get_index()})" def _get_index(self): from .variables.torch_function import TorchFunctionModeStackVariable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index ab92f82aa0f630..415179a5ed32a6 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -19,20 +19,7 @@ import types import typing import weakref -from typing import ( - Any, - Callable, - cast, - Deque, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union from unittest.mock import patch import torch @@ -72,14 +59,12 @@ GlobalWeakRefSource, LocalSource, Source, - TorchFunctionModeStackSource, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( counters, get_fake_value, get_instruction_source_311, - get_torch_function_mode_stack, graph_break_dup_warning_checker, istype, LazyString, @@ -120,11 +105,10 @@ ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable - - -if TYPE_CHECKING: - from .variables.torch_function import TorchFunctionModeVariable - +from .variables.torch_function import ( + SymbolicTorchFunctionState, + TorchFunctionModeVariable, +) from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -284,6 +268,10 @@ def resume_fn(self): return ReenterWith(self.stack_index) def exit(self, tx): + if hasattr(self, "graph_break") and isinstance( + self.with_context, TorchFunctionModeVariable + ): + return assert self.with_context is not None return self.with_context.exit(tx) @@ -652,7 +640,9 @@ def handle_graph_break( # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: assert b.with_context is not None - assert isinstance(b.with_context, ContextWrappingVariable) + assert isinstance( + b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) + ) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -728,7 +718,7 @@ class InstructionTranslatorBase( output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] + symbolic_torch_function_state: SymbolicTorchFunctionState stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -2548,7 +2538,7 @@ def __init__( code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, f_code: types.CodeType, export: bool, inline_depth: int, @@ -2563,7 +2553,7 @@ def __init__( self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals - self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack + self.symbolic_torch_function_state = symbolic_torch_function_state self.stack = [] # stack of variable names for tracking 3.13 closures self.name_stack: list[Any] = [] @@ -2652,6 +2642,7 @@ def __init__( f_locals, f_globals, f_builtins, + torch_function_mode_stack, code_options, compiler_fn, one_graph, @@ -2686,7 +2677,7 @@ def __init__( symbolic_locals={}, # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, - symbolic_torch_function_mode_stack=collections.deque(), + symbolic_torch_function_state=None, # type: ignore[arg-type] # set below f_code=f_code, export=export, inline_depth=0, @@ -2721,7 +2712,9 @@ def __init__( if k in f_locals } - self._init_torch_function_mode_stack() + self.symbolic_torch_function_state = SymbolicTorchFunctionState( + torch_function_mode_stack + ) self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] if export: @@ -2762,29 +2755,6 @@ def _throw_if_in_functorch(self): ) unimplemented(msg) - def _init_torch_function_mode_stack(self): - from .variables.torch_function import TorchFunctionModeStackVariable - - TorchFunctionModeStackVariable.reset() - - self.symbolic_torch_function_mode_stack: Deque[ - TorchFunctionModeVariable - ] = collections.deque() - # We want to retrieve all modes to properly reconstruct the stack if needed - py_stack = get_torch_function_mode_stack(filter_ignored=False) - - if py_stack: - has_device_context = isinstance( - py_stack[0], torch.utils._device.DeviceContext - ) - - for i, val in enumerate(py_stack): - self.symbolic_torch_function_mode_stack.append( - variables.LazyVariableTracker.create( - val, source=TorchFunctionModeStackSource(i) - ) - ) - def get_example_value(self, source: Source): if isinstance(source, LocalSource): return self.f_locals[source.local_name] @@ -3116,7 +3086,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3126,7 +3096,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3179,7 +3149,7 @@ def __init__( code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ) -> None: @@ -3196,7 +3166,7 @@ def __init__( f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, - symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, + symbolic_torch_function_state=symbolic_torch_function_state, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 391beb07cd600d..1a206ae339c2db 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3258,6 +3258,7 @@ def _module_dir(m: types.ModuleType): "torch.testing", "torch.utils._content_store", "torch.utils._contextlib", + "torch.utils._device", "torch.utils._foreach_utils", "torch.utils._python_dispatch", "torch.utils._pytree", @@ -3592,7 +3593,9 @@ def lookup_inner( if reasons is not None: reasons.add("func name is patched_init") return SkipFunctionVariable - elif name == "__torch_function__": + elif name == "__torch_function__" or ( + obj and obj.__name__ == "__torch_function__" + ): if reasons is not None: reasons.add("func name is __torch_function__") return UserFunctionVariable diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 6ff258ff5657c0..8130b6aa7371b2 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -63,7 +63,6 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( - _get_function_stack_at, _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, @@ -3087,7 +3086,9 @@ def is_parameter_freezing(): def get_torch_function_mode_stack(filter_ignored=True): from .variables.torch_function import IGNORED_MODES - stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] + stack = [ + get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) + ] if filter_ignored: stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] @@ -3107,6 +3108,11 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) +def clear_torch_function_mode_stack(): + for i in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index eebe90929b379a..8c4eb3dc4e715b 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -637,6 +637,8 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 + tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] + tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] tx.output.set_torch_function_state(values[0]) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 68508f001dae35..f6cb565ef021c0 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -156,6 +156,15 @@ bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) +@functools.lru_cache(None) +def get_overridable_functions(): + from itertools import chain + + from torch.overrides import get_overridable_functions as get_overridable_functions_ + + return set(chain(*get_overridable_functions_().values())) + + class BaseTorchVariable(VariableTracker): """common base for all torch.* functions, classes, modules and other things""" @@ -806,10 +815,10 @@ def handle_pop_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - if not tx.symbolic_torch_function_mode_stack: + if not tx.symbolic_torch_function_state.mode_stack: raise unimplemented("Popping from an empty torch function mode stack") TorchFunctionModeStackVariable.register_mutation(tx) - return tx.symbolic_torch_function_mode_stack.pop() + return tx.symbolic_torch_function_state.pop_torch_function_mode() @register(torch._C._push_on_torch_function_stack) def handle_push_torch_function( @@ -817,7 +826,7 @@ def handle_push_torch_function( ): assert len(args) == 1 and not kwargs TorchFunctionModeStackVariable.register_mutation(tx) - tx.symbolic_torch_function_mode_stack.append(args[0]) + tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) @@ -825,7 +834,9 @@ def handle_len_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) + return ConstantVariable.create( + len(tx.symbolic_torch_function_state.mode_stack) + ) @register(torch.set_default_device) def handle_set_default_device( @@ -857,6 +868,9 @@ def call_function( from . import ConstantVariable, SymNodeVariable, TensorVariable from .builder import wrap_fx_proxy + if self.torch_function_override_enabled(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -878,147 +892,144 @@ def call_function( if result: return result - if can_dispatch_torch_function(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - else: - any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) - all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args - ) - if ( - getattr(self.value, "__module__", "") == "torch" - and self.value.__name__ in bin_ops - and any_symints_or_symfloats - and all_ints_or_floats - ): - msg = f"""\ + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ - log.warning(msg) - unimplemented(msg) - - # TODO(voz): Replace w/ dynamic shape rewrite table. - # Ideally, we would be able to do this at ctor time, but alas we need a combination - # of value + args to determine this. - fn_ = self.value - if any_symints_or_symfloats: - torch_sym_op = f"_sym_{self.value.__name__}" - if getattr(self.value, "__module__", None) == "math" and hasattr( - torch, torch_sym_op - ): - fn_ = getattr(torch, torch_sym_op) - - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - *proxy_args_kwargs(args, kwargs), - ), - ) - - if ( - isinstance(tensor_variable, TensorVariable) - and "requires_grad" in kwargs - and kwargs["requires_grad"].as_python_constant() + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op ): - unimplemented( - """factory functions that return tensors that require grad are not supported. + fn_ = getattr(torch, torch_sym_op) + + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" - ) + ) - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items - ): - if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor.size != result_tensor.size - ): - # It's hard to get out variants with resizing on graph inputs work - # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape + out_tensor.source + and out_tensor in tx.output.graphargs + and isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor.size != result_tensor.size ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if ( + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out_shape != fake_tensor.shape + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( - "out= op was called where output tensor was non-contiguous" + "out= op was called where some of the output tensors were non-contiguous" ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where some of the output tensors were non-contiguous" - ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") + else: + unimplemented(f"out variant of {type(kwargs['out'])}") - return tensor_variable + return tensor_variable def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" @@ -1146,3 +1157,12 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad source ) return result + + def torch_function_override_enabled(self, tx, args, kwargs): + return ( + self.get_function() in get_overridable_functions() + or isinstance( + self.get_function(), + (torch._ops.OpOverload, torch._ops.OpOverloadPacket), + ) + ) and can_dispatch_torch_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4b3188507fe4b5..6e52cc688e0309 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,8 +1,11 @@ # mypy: ignore-errors +import collections +import contextlib import inspect -from typing import Dict, List, TYPE_CHECKING +from typing import Deque, Dict, List, TYPE_CHECKING +import torch._C import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import _get_overloaded_args, get_default_nowrap_functions @@ -15,6 +18,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import ContextWrappingVariable +from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable @@ -59,6 +63,67 @@ IGNORED_MODES = {DeviceContext} +class SymbolicTorchFunctionState: + def __init__(self, py_stack): + # This is annoyingly complicated because of how the torch function subclass + mode C API was designed + # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass + # These are their definitions: + # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered + # (if either are entered, this will be False) + # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR + # torch._C.DisableTorchFunction has been entered + # To disambiguate these and keep myself sane I added a C API to check whether all torch function + # concepts (modes and subclasses) are enabled. + # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate + # the stack length from the enablement state of torch function modes. + # This is important because now if a mode is pushed while dynamo is tracing, we know whether + # or not torch function modes are enabled and whether we should trace it. + self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() + + # This differs from the C API of the same name + # this will only be false iff we have entered torch._C.DisableTorchFunction + # and does not take into account the mode stack length, while the C API bundles these + # two concepts + self.torch_function_mode_enabled = ( + not torch._C._is_torch_function_all_disabled() + ) + + self.cur_mode = None + + TorchFunctionModeStackVariable.reset() + + self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() + + for i, val in enumerate(py_stack): + self.mode_stack.append( + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) + ) + + def in_torch_function_mode(self): + return len(self.mode_stack) > 0 + + def pop_torch_function_mode(self): + return self.mode_stack.pop() + + def push_torch_function_mode(self, mode_var): + self.mode_stack.append(mode_var) + + def call_torch_function_mode(self, tx, fn, types, args, kwargs): + with self._pop_mode_for_inlining() as cur_mode: + return cur_mode.call_torch_function(tx, fn, types, args, kwargs) + + @contextlib.contextmanager + def _pop_mode_for_inlining(self): + old_mode = self.cur_mode + self.cur_mode = self.pop_torch_function_mode() + try: + yield self.cur_mode + finally: + mode = self.cur_mode + self.cur_mode = old_mode + self.push_torch_function_mode(mode) + + class TorchFunctionModeStackVariable(VariableTracker): """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" @@ -88,19 +153,20 @@ def reset(cls): def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( - source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack + source=Source(), + symbolic_stack=tx.symbolic_torch_function_state.mode_stack, ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 - tx.symbolic_torch_function_mode_stack.insert( + stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) @@ -109,7 +175,7 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @@ -124,23 +190,39 @@ def get_mode_index(cls, ind): class TorchFunctionModeVariable(ContextWrappingVariable): - def __init__(self, value, **kwargs): + def __init__(self, value, source=None, **kwargs): super().__init__(value, **kwargs) self.value = value - - @staticmethod - def get_global_mangled_name(tx, val): - return get_safe_global_name( - tx, f"__torch_function_mode_{val.__class__.__name__}", val - ) + self.cm_obj = value # needed for BC with calling enter from CM code + self.source = source def reconstruct(self, codegen): - # We don't support locally created torch function modes yet + # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) - def _call_func(self, tx, values): - unimplemented("torch function mode context manager is not supported yet") + def module_name(self): + return self.value.__module__ + + def fn_name(self): + return type(self.value).__name__ + + def python_type(self): + return type(self.value) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + return call_torch_function( + tx, + self, + build_torch_function_fn(tx, self.value, self.source), + fn, + types, + args, + kwargs, + ) + + def _call_func(self, tx: "InstructionTranslator", values): + unimplemented("enter/exit for torch function mode NYI") def _get_all_args(args, kwargs): @@ -231,9 +313,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): - return tx.output.torch_function_enabled and any( + has_overridden_args = any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) + tf_state = tx.symbolic_torch_function_state + return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( + tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() + ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): @@ -245,11 +331,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): _get_subclass_type, ) + types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) + + if tx.symbolic_torch_function_state.in_torch_function_mode(): + res = tx.symbolic_torch_function_state.call_torch_function_mode( + tx, fn, types, args, kwargs + ) + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + return res + for arg in overloaded_args: res = arg.call_torch_function( tx, fn, - TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), + types, args, kwargs, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 34e1d0d10c9f72..d069d2e060a108 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,11 +82,6 @@ def is_forbidden_context_manager(ctx): from _pytest.python_api import RaisesContext from _pytest.recwarn import WarningsChecker - # TODO mlazos: Temporary to get this stack to pass - # remove in subsequent PR - from torch.overrides import BaseTorchFunctionMode - - f_ctxs.append(BaseTorchFunctionMode) f_ctxs.append(RaisesContext) f_ctxs.append(WarningsChecker) except ImportError: @@ -413,7 +408,6 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): - # import here to avoid an unfortunate circular dependency. from .ctx_manager import GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( @@ -421,7 +415,6 @@ def call_function( ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj - elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 5c7331659d26a0..ef15e5fea9e973 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: + if ( + not torch.compiler.is_dynamo_compiling() + and log.isEnabledFor(logging.DEBUG) + and config.extended_debug_current_loc + ): frame = _find_user_code_frame() if frame is not None: log.debug( From 04996be1671e3a4e23c00c869fa868392634c8d3 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 13 Sep 2024 10:40:54 -0700 Subject: [PATCH 0870/1018] [Dynamo] Support thread local setattr (#135443) In preparation for tracing through DeviceContext (https://github.com/pytorch/pytorch/blob/defb515306fc53ec62e92937a5a76fa5cbc05b84/torch/utils/_device.py#L66) This PR adds support for calling the setattr of thread local objects. These objects have a slots impl, and since this doesn't appear to have any side effects, we call this setattr impl when replaying mutations, since calling `object.__setattr__` on these objects results in a type error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135443 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137 --- test/dynamo/test_misc.py | 15 +++++++++++++++ torch/_dynamo/variables/user_defined.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 9d2bb374621134..96047d3436ebe1 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3380,6 +3380,21 @@ def fn4(x) -> None: self.assertTrue(same(obj41.y, obj42.y)) self.assertEqual(cnts.frame_count, 1) + def test_thread_local_setattr(self): + from threading import local + + loc = local() + + @torch.compile(fullgraph=True) + def fn(x, l): + l.x = x + return x + 1 + + x = torch.ones(2, 2) + fn(x, loc) + + self.assertTrue(loc.x is x) + def test_user_defined_class_name(self): class MyClassFoo: pass diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index d069d2e060a108..ad92277cf70ff4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -9,6 +9,7 @@ import itertools import random import sys +import threading import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING @@ -704,7 +705,7 @@ def call_method( if method is object.__init__: return ConstantVariable.create(None) - if is_standard_setattr(method): + if is_standard_setattr(method) or isinstance(self.value, threading.local): return self.method_setattr_standard(tx, *args, **kwargs) # [NOTE] OrderedDict, dict subtypes must always have source @@ -802,7 +803,7 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def needs_slow_setattr(self): return not is_standard_setattr( inspect.getattr_static(self.value, "__setattr__", None) - ) + ) and not isinstance(self.value, threading.local) def unpack_var_sequence(self, tx): if ( From e684dc67480245d905ec0d8e9c13d6cabd00720c Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 13 Sep 2024 10:40:55 -0700 Subject: [PATCH 0871/1018] [Dynamo] Simplify torch function mode stack guard (#135444) The semantics of ignored modes previously had edge cases, this eliminates these by in essence filtering any ignored modes out of both the ref stack and the current torch function mode stack. This is purely to fix complexity in #135422. The ignored modes handling will be removed in a future PR after https://github.com/pytorch/pytorch/pull/135422 lands, since we will then trace through DeviceContexts vs inserting them into the graph which needed these extra workarounds for correctness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135444 Approved by: https://github.com/anijain2305, https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443 --- test/dynamo/test_modes.py | 2 ++ torch/_dynamo/guards.py | 10 ++++---- torch/csrc/dynamo/guards.cpp | 44 +++++++++++++++++++++++++++++------- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index bad5a788903440..fa4c23fd320ddb 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -68,9 +68,11 @@ def tearDownClass(cls): def setUp(self): torch.set_default_device(None) + torch._dynamo.reset() def tearDown(self): torch.set_default_device(None) + torch._dynamo.reset() def _run_torch_function_mode_guard_test(self): class TestMode1(BaseTorchFunctionMode): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4437ef4038190d..1544bdd6dc697b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2663,12 +2663,14 @@ def make_torch_function_mode_stack_guard(intial_stack): def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - if len(cur_stack) != len(types): + + types_ = [ty for ty in types if ty not in IGNORED_MODES] + cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] + + if len(cur_stack_) != len(types_): return False - for ty, mode in zip(types, cur_stack): - if ty in IGNORED_MODES: - continue + for ty, mode in zip(types_, cur_stack_): if ty != type(mode): return False diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index afc3182c632640..771de44427b669 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2523,7 +2523,8 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref - this->_ref_stack.push_back(Py_TYPE(mode)); + auto type = Py_TYPE(mode); + this->_ref_stack.push_back(type); } len = PyList_Size(ignored_types.ptr()); @@ -2543,29 +2544,56 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface size_t ref_ind = 0; - int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - for (int64_t idx = 0; idx < len; idx++) { + int64_t idx = 0; + while ((idx < len) && (ref_ind < ref_stack_size)) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + bool act_ignored = this->_ignored_types.count(mode_type) > 0; + bool ref_ignored = + this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; // skip ignored types - if (this->_ignored_types.count(mode_type) > 0) { + if (act_ignored && ref_ignored) { + idx++; + ref_ind++; + continue; + } else if (ref_ignored) { + ref_ind++; + continue; + } else if (act_ignored) { + idx++; continue; } // if we already have more non-ignored modes than the ref stack // or if the mode doesn't match at the current index, return false - else if ( - (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || - mode_type != _ref_stack[ref_ind]) { + else if (mode_type != _ref_stack.at(ref_ind)) { return false; } ref_ind++; + idx++; + } + + for (; ref_ind < ref_stack_size; ref_ind++) { + if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { + return false; + } + } + + for (; idx < len; idx++) { + std::shared_ptr mode = + at::impl::PythonTorchFunctionTLS::get_stack_at(idx); + + PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + if (!(this->_ignored_types.count(mode_type) > 0)) { + return false; + } } - return ref_ind == this->_ref_stack.size(); + return ref_ind == ref_stack_size && idx == len; } private: From 54999a570b7ac8ce623afba3a0c0c3b4274eb90a Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 13 Sep 2024 10:40:55 -0700 Subject: [PATCH 0872/1018] [Dynamo] Trace enter/exit of TorchFunctionModes (#135422) This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode) Typically the bytecode for a context manager looks like this during a graph break: 1. graph call 2. enter context 3. unsupported code 4. exit context 5. resume call resume fn structure: 1. enter context 2. jump ... 3. exit context The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack). So for torch function modes the structure of our output code is this: 1. graph call 2. mutate tf mode stack to replay mutations 4. unsupported code 5. on exception restore stack 6. resume function Then our resume fn looks like this: 1. no-op enter torch function mode 2. jump 3. exit tf mode To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context). Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422 Approved by: https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443, #135444 --- test/dynamo/test_modes.py | 88 ++++++++++++++++++ torch/_dynamo/convert_frame.py | 6 +- torch/_dynamo/output_graph.py | 6 +- torch/_dynamo/polyfills/__init__.py | 20 ++++ torch/_dynamo/resume_execution.py | 104 +++++++++++++++++++++ torch/_dynamo/side_effects.py | 11 +++ torch/_dynamo/symbolic_convert.py | 30 ++++-- torch/_dynamo/testing.py | 1 + torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/builder.py | 20 ++-- torch/_dynamo/variables/ctx_manager.py | 12 +++ torch/_dynamo/variables/torch.py | 21 ++++- torch/_dynamo/variables/torch_function.py | 108 ++++++++++++++++++++-- torch/_dynamo/variables/user_defined.py | 14 ++- 14 files changed, 409 insertions(+), 34 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index fa4c23fd320ddb..63e06a515ed186 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -461,6 +461,94 @@ def fn(x): self.assertEqual(expected, actual) + def test_torch_function_mode_enter_exit(self): + def fn(x, y): + with TestMode(): + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn, fullgraph=True) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_graph_break(self): + def fn(x, y): + with TestMode(): + torch._dynamo.graph_break() + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_and_pop_graph_break(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_restore_on_exc(self): + @torch._dynamo.disable() + def err(): + raise RuntimeError("test") + + @torch.compile() + def fn(x): + with TestMode(): + x += 1 + err() + x += 2 + return x + + try: + fn(torch.ones(2, 2)) + except RuntimeError: + pass + self.assertEqual(_len_torch_function_stack(), 0) + + def test_torch_function_mode_and_pop_graph_break_mutation(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + z.y = 5 + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + o = torch.mul(o, z.y) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index b3e4a7c268c0fe..171af02e564b36 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -112,6 +112,7 @@ troubleshooting_url, write_record_to_file, ) +from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -210,15 +211,18 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() - exit_stack = contextlib.ExitStack() exit_stack.enter_context( torch.fx._symbolic_trace._maybe_revert_all_patches() ) + exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() + assert ( + torch._C._len_torch_function_stack() == 0 + ), "Torch function mode stack state changed while dynamo tracing, please report a bug" exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ba8cfa5a03dfba..0ea19e1cad6586 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,7 +78,6 @@ get_instruction_source_311, get_locals_to_steal, get_static_address_type, - get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -250,6 +249,7 @@ def __init__( local_scope: Scope, global_scope: Scope, f_code, + torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -368,7 +368,7 @@ def __init__( # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = get_torch_function_mode_stack() + self.torch_function_mode_stack = torch_function_mode_stack # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ def append_prefix_insts(): prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx) + block.exit(tx, is_graph_break=reason.graph_break) self.cleanup_graph() tx.prune_dead_locals() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 2b3f38920a0c2a..5b2812bc08c9e5 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -25,6 +25,26 @@ sys as sys, ) +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + def index(iterator, item, start=0, end=None): from itertools import islice diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 132e9e4081bceb..6f7db66514c487 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -6,6 +6,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( + add_push_null, create_call_function, create_call_method, create_dup_top, @@ -48,6 +49,109 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None + def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): + """ + Codegen based off of: + try: + (rest) + except: + (restore previous stack) + + """ + from .variables.torch_function import get_prev_stack_var_name + + except_jump_target = create_instruction( + "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" + ) + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally: List[Instruction] = [] + + if sys.version_info < (3, 11): + setup_finally.append( + create_instruction("SETUP_FINALLY", target=except_jump_target) + ) + else: + exn_tab_begin = create_instruction("NOP") + exn_tab_end = create_instruction("NOP") + exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_begin, + exn_tab_end, + except_jump_target, + self.stack_index + 1, + False, + ) + setup_finally.append(exn_tab_begin) + + def create_reset(): + insts = [ + create_instruction( + "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" + ), + create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), + ] + add_push_null(insts) + return [ + *insts, + create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()), + *create_call_function(1, False), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *create_reset(), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + create_instruction("END_FINALLY"), + cleanup_complete_jump_target, + ] + elif sys.version_info < (3, 11): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + cleanup_complete_jump_target, + ] + else: + finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) + finally_exn_tab_target = create_instruction("COPY", arg=3) + except_jump_target.exn_tab_entry = InstructionExnTabEntry( + except_jump_target, + finally_exn_tab_end, + finally_exn_tab_target, + self.stack_index + 2, + True, + ) + epilogue = [ + exn_tab_end, + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, # PUSH_EXC_INFO + create_instruction("POP_TOP"), + *create_reset(), + finally_exn_tab_end, + finally_exn_tab_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), # RERAISE 1 + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fc1bd976ff57df..fd5b54e4f7f7af 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -593,11 +593,22 @@ def codegen_update_mutated(self, cg: PyCodegen): elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): + # Needed in the finally block for stack restoration + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "get_torch_function_mode_stack" + ) + ) + cg.call_function(0, False) + name = variables.torch_function.get_prev_stack_var_name() + cg.code_options["co_varnames"] += (name,) + cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) + cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 415179a5ed32a6..61acaec995757a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -267,13 +267,12 @@ def resume_fn(self): else: return ReenterWith(self.stack_index) - def exit(self, tx): - if hasattr(self, "graph_break") and isinstance( - self.with_context, TorchFunctionModeVariable - ): - return + def exit(self, tx, is_graph_break): assert self.with_context is not None - return self.with_context.exit(tx) + if ( + is_graph_break and self.with_context.exit_on_graph_break() + ) or not is_graph_break: + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -639,10 +638,17 @@ def handle_graph_break( cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: + # Don't exit any modes we have entered, + # output bytecode will mutate the tf mode stack accordingly + if isinstance(b.with_context, TorchFunctionModeVariable): + cg.extend_output( + b.resume_fn().try_except_torch_function_mode( + cg.code_options, cleanup + ) + ) + continue assert b.with_context is not None - assert isinstance( - b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) - ) + assert isinstance(b.with_context, (ContextWrappingVariable)) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -2295,7 +2301,10 @@ def setup_or_before_with(self, inst): ): unimplemented(f"{inst.opname} {ctx}") - if isinstance(ctx, GenericContextWrappingVariable): + if ( + isinstance(ctx, GenericContextWrappingVariable) + and not ctx.supports_graph_breaks() + ): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2668,6 +2677,7 @@ def __init__( local_scope=f_locals, global_scope=f_globals, f_code=f_code, + torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 4922b521bada34..704a3889707233 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -163,6 +163,7 @@ def insert_nops(instructions, code_options): local_scope=locals(), global_scope=globals(), f_code=frame.f_code, + torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1a206ae339c2db..ed81b26191b4de 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -304,6 +304,7 @@ "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -2797,7 +2798,6 @@ "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", - "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5a8a9b613b899d..193e227267ec3b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -204,6 +204,7 @@ from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( @@ -1668,15 +1669,16 @@ def wrap_numpy_ndarray(self, value): # but warning is not the end of the world assert isinstance(value.base, np.nditer) - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides - - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 8c4eb3dc4e715b..e19c4e254c647e 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,6 +125,12 @@ def call_function( if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return True + class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -183,6 +189,12 @@ def exit(self, tx: "InstructionTranslator", *args): tx.generic_context_manager_depth -= 1 return x + def supports_graph_breaks(self): + return False + + def exit_on_graph_break(self): + return True + class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f6cb565ef021c0..c1e0dec0fbc415 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -162,7 +162,17 @@ def get_overridable_functions(): from torch.overrides import get_overridable_functions as get_overridable_functions_ - return set(chain(*get_overridable_functions_().values())) + funcs = set(chain(*get_overridable_functions_().values())) + more = { + torch.ones, + torch.ones_like, + torch.zeros, + torch.zeros_like, + torch.empty, + torch.full, + } + funcs.update(more) + return funcs class BaseTorchVariable(VariableTracker): @@ -838,6 +848,13 @@ def handle_len_torch_function( len(tx.symbolic_torch_function_state.mode_stack) ) + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + assert len(args) == 1 and not kwargs + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] + @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -855,7 +872,7 @@ def handle_set_default_device( else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return None + return ConstantVariable.create(None) return handlers diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 6e52cc688e0309..b7fd83a3d24f15 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -2,22 +2,35 @@ import collections import contextlib +import functools import inspect from typing import Deque, Dict, List, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import _get_overloaded_args, get_default_nowrap_functions +from torch.overrides import ( + _get_overloaded_args, + get_default_nowrap_functions, + TorchFunctionMode, +) from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import ContextWrappingVariable +from .ctx_manager import GenericContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -63,6 +76,39 @@ IGNORED_MODES = {DeviceContext} +@functools.lru_cache(None) +def get_prev_stack_var_name(): + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self): + self.stack = [] + + def __enter__(self): + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__(self, exc_type, exc_value, traceback): + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self): + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + class SymbolicTorchFunctionState: def __init__(self, py_stack): # This is annoyingly complicated because of how the torch function subclass + mode C API was designed @@ -189,9 +235,26 @@ def get_mode_index(cls, ind): return ind + cls.offset -class TorchFunctionModeVariable(ContextWrappingVariable): +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty): + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the funtion across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ + ) + def __init__(self, value, source=None, **kwargs): - super().__init__(value, **kwargs) + if value is not None: + super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source @@ -221,8 +284,39 @@ def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwar kwargs, ) - def _call_func(self, tx: "InstructionTranslator", values): - unimplemented("enter/exit for torch function mode NYI") + def enter(self, tx): + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen): + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return False def _get_all_args(args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ad92277cf70ff4..14499b4d2e4bf7 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -409,10 +409,22 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): + from torch.overrides import TorchFunctionMode + from .ctx_manager import GenericContextWrappingVariable + from .torch_function import TorchFunctionModeVariable + + if issubclass( + self.value, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode( + self.value + ): + var_cls = TorchFunctionModeVariable + else: + var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, GenericContextWrappingVariable, {} + self.source, self.value, var_cls, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj From 88fa33b156e14779677730498a88c50f1a43b127 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 13 Sep 2024 10:40:55 -0700 Subject: [PATCH 0873/1018] [Dynamo] Remove ignored modes workaround (#135502) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135502 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137, #135443, #135444, #135422 --- test/dynamo/test_modes.py | 65 ----------------------- torch/_dynamo/guards.py | 12 ++--- torch/_dynamo/utils.py | 10 +--- torch/_dynamo/variables/torch_function.py | 6 --- 4 files changed, 5 insertions(+), 88 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 63e06a515ed186..4d1f2bbea389ef 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -from unittest.mock import patch import torch import torch._dynamo.test_case @@ -107,70 +106,6 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 4) - def _run_ignored_mode_types_test(self): - class IgnoredMode(BaseTorchFunctionMode): - pass - - cnt = torch._dynamo.testing.CompileCounter() - - @torch.compile(backend=cnt.__call__, fullgraph=True) - def fn(x): - return x + 1 - - inp = torch.ones(2, 2) - - with patch( - "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} - ): - # initial compile - fn(inp) - - # no recompile, mode ignored - # note: the ref stack is length 0, and the stack we are checking against has length 2 - # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack - with IgnoredMode(), IgnoredMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 1) - - # recompile due to new mode on the stack - with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 2) - - # recompile - # tests both ref stack len > runtime stack len for the above guard check - # and ref stack len < runtime stack len for the initial zero mode case - with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 3) - - # no recompile - with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 3) - - # This is tricky, basically the ignored modes are baked into the guard - # IgnoredMode will be ignored forever by that guard. - # This is okay since we don't expect to be modifying IGNORED_MODES - # in the middle of execution except for the purposes of testing. - torch._dynamo.reset() - - with IgnoredMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 4) - - @torch._dynamo.config.patch("enable_cpp_guard_manager", False) - def test_torch_function_mode_guards_ignored_types_py(self): - self._run_ignored_mode_types_test() - - def test_torch_function_mode_guards_ignored_types_cpp(self): - self._run_ignored_mode_types_test() - @torch._dynamo.config.patch("enable_cpp_guard_manager", False) def test_torch_function_mode_guards_py(self): self._run_torch_function_mode_guard_test() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1544bdd6dc697b..66deeb3cbc0e79 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2344,15 +2344,13 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): ) if config.enable_cpp_guard_manager: - from .variables.torch_function import IGNORED_MODES - # Insert the global_state guard assert self.guard_manager # to make mypy happy self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(IGNORED_MODES), + list(), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list @@ -2659,18 +2657,14 @@ def is_recompiles_verbose_enabled(): # this will only be used if cpp guards are disabled def make_torch_function_mode_stack_guard(intial_stack): types = [type(x) for x in intial_stack] - from .variables.torch_function import IGNORED_MODES def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - types_ = [ty for ty in types if ty not in IGNORED_MODES] - cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] - - if len(cur_stack_) != len(types_): + if len(cur_stack) != len(types): return False - for ty, mode in zip(types_, cur_stack_): + for ty, mode in zip(types, cur_stack): if ty != type(mode): return False diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 8130b6aa7371b2..2d4a771166cfbe 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3083,16 +3083,10 @@ def is_parameter_freezing(): return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(filter_ignored=True): - from .variables.torch_function import IGNORED_MODES - - stack = [ +def get_torch_function_mode_stack(): + return [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] - if filter_ignored: - stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] - - return stack def get_torch_function_mode_stack_at(ind): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index b7fd83a3d24f15..ffb3d27d4d703b 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -69,12 +69,6 @@ if is_tensor_base_attr_getter(fn) ] -# Today set default device is placed in the graph and guarded on separately -# so we should not trace through it. In the future we can trace it once -# mode tracing is implemented and not put in the graph, but this is more -# of a BE project and can be evaluated later -IGNORED_MODES = {DeviceContext} - @functools.lru_cache(None) def get_prev_stack_var_name(): From aeae1d449e03e2c91e429e1d588323a4982f8c7c Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 13 Sep 2024 10:40:56 -0700 Subject: [PATCH 0874/1018] [Dynamo] Remove ignored modes from torch function mode stack guard (#135503) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135503 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137, #135443, #135444, #135422, #135502 --- torch/_C/_dynamo/guards.pyi | 2 +- torch/_dynamo/guards.py | 1 - torch/csrc/dynamo/guards.cpp | 69 +++++------------------------------- 3 files changed, 10 insertions(+), 62 deletions(-) diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index d6aab0d547f727..918d913068e6df 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -67,7 +67,7 @@ class GuardManager: ) -> None: ... def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_torch_function_mode_stack_guard( - self, initial_stack, ignored_types, verbose_code_parts: list[str] + self, initial_stack, verbose_code_parts: list[str] ) -> None: ... class RootGuardManager(GuardManager): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 66deeb3cbc0e79..3dcc9a032f208f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2350,7 +2350,6 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 771de44427b669..1d1cf2fb0c60a6 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2515,90 +2515,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { public: TORCH_FUNCTION_MODE_STACK( const py::list& initial_stack, - const py::list& ignored_types, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), - _ref_stack(), - _ignored_types() { + : LeafGuard(std::move(verbose_code_parts)), _ref_stack() { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref auto type = Py_TYPE(mode); this->_ref_stack.push_back(type); } - - len = PyList_Size(ignored_types.ptr()); - for (Py_ssize_t idx = 0; idx < len; idx++) { - PyObject* type_obj = - PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref - if (PyType_Check(type_obj) == 0) { - PyErr_SetString( - PyExc_TypeError, "ignored_types should contain a list of types"); - return; - } - PyTypeObject* type = (PyTypeObject*)type_obj; - this->_ignored_types.insert(type); - } } bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface - size_t ref_ind = 0; - const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - int64_t idx = 0; - while ((idx < len) && (ref_ind < ref_stack_size)) { - std::shared_ptr mode = - at::impl::PythonTorchFunctionTLS::get_stack_at(idx); - - PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - bool act_ignored = this->_ignored_types.count(mode_type) > 0; - bool ref_ignored = - this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; - // skip ignored types - if (act_ignored && ref_ignored) { - idx++; - ref_ind++; - continue; - } else if (ref_ignored) { - ref_ind++; - continue; - } else if (act_ignored) { - idx++; - continue; - } - // if we already have more non-ignored modes than the ref stack - // or if the mode doesn't match at the current index, return false - else if (mode_type != _ref_stack.at(ref_ind)) { - return false; - } - ref_ind++; - idx++; - } - - for (; ref_ind < ref_stack_size; ref_ind++) { - if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { - return false; - } + if (len != ref_stack_size) { + return false; } - for (; idx < len; idx++) { + for (int64_t idx = 0; (size_t)idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (!(this->_ignored_types.count(mode_type) > 0)) { + if (mode_type != _ref_stack.at(idx)) { return false; } } - return ref_ind == ref_stack_size && idx == len; + return true; } private: std::vector _ref_stack; - std::set _ignored_types; }; class TENSOR_MATCH : public LeafGuard { @@ -3763,7 +3713,7 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") - .def(py::init()) + .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( py_m, "DATA_PTR_MATCH") @@ -4000,10 +3950,9 @@ PyObject* torch_c_dynamo_guards_init() { "add_torch_function_mode_stack_guard", [](GuardManager& self, const py::list& initial_stack, - const py::list& ignored_types, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - initial_stack, ignored_types, std::move(verbose_code_parts))); + initial_stack, std::move(verbose_code_parts))); }) .def( "add_data_ptr_guard", From 274a5ae212cb0c139e3a15b8f136ea8fd457fd43 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Fri, 13 Sep 2024 09:49:05 -0700 Subject: [PATCH 0875/1018] [TCPStore] Remove deprecated constructor (#136004) While looking at TCPStore code again and found it confusing that we still keep the deprecated constructor for TCPStore in cpp while we don't expose it in python via pybind already. I checked both internal and external, all use cases in cpp (aside from unit test fixed in this PR) already moved to using option. So let's remove this legacy constructor to avoid confusion. Differential Revision: [D62653634](https://our.internmc.facebook.com/intern/diff/D62653634) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136004 Approved by: https://github.com/Skylion007, https://github.com/XilunWu --- test/cpp/c10d/TCPStoreTest.cpp | 11 ++++++----- torch/csrc/distributed/c10d/TCPStore.cpp | 19 +++---------------- torch/csrc/distributed/c10d/TCPStore.hpp | 8 -------- 3 files changed, 9 insertions(+), 29 deletions(-) diff --git a/test/cpp/c10d/TCPStoreTest.cpp b/test/cpp/c10d/TCPStoreTest.cpp index c0a7b2b64d601e..7351984f36c997 100644 --- a/test/cpp/c10d/TCPStoreTest.cpp +++ b/test/cpp/c10d/TCPStoreTest.cpp @@ -178,11 +178,12 @@ TEST(TCPStoreTest, testCleanShutdown) { auto serverTCPStore = std::make_unique( "127.0.0.1", - 0, - numWorkers, - true, - std::chrono::seconds(defaultTimeout), - /* wait */ false); + c10d::TCPStoreOptions{ + /* port */ 0, + /* isServer */ true, + numWorkers, + /* waitWorkers */ false, + /* timeout */ std::chrono::seconds(defaultTimeout)}); c10d::test::set(*serverTCPStore, "key", "val"); auto clientTCPStore = c10::make_intrusive( diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 2855ced299201c..68c5da982c2573 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -249,23 +249,10 @@ class SendBuffer { using detail::Socket; // TCPStore class methods -TCPStore::TCPStore( - const std::string& masterAddr, - std::uint16_t masterPort, - std::optional numWorkers, - bool isServer, - const std::chrono::milliseconds& timeout, - bool waitWorkers) - : TCPStore{ - masterAddr, - TCPStoreOptions{ - masterPort, - isServer, - numWorkers ? std::optional(*numWorkers) - : std::nullopt, - waitWorkers, - timeout}} {} +// Although we still allow multi-params in ctor in Python, that behavior is +// removed from cpp and we construct the opts implicitly for users in the pybind +// of TCPStore. TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) : Store{opts.timeout}, addr_{std::move(host)}, diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index b2839f62c7a3a1..37862fc25fac33 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -75,14 +75,6 @@ class TORCH_API TCPStore : public Store { explicit TCPStore(std::string host, const TCPStoreOptions& opts = {}); - [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore( - const std::string& masterAddr, - std::uint16_t masterPort, - std::optional numWorkers = std::nullopt, - bool isServer = false, - const std::chrono::milliseconds& timeout = kDefaultTimeout, - bool waitWorkers = true); - ~TCPStore() override; void set(const std::string& key, const std::vector& value) override; From af3506af91ed00d5ce6ab2501e3bfac57f84ab35 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Sat, 14 Sep 2024 05:40:10 +0000 Subject: [PATCH 0876/1018] [ROCm] Skip pointwise associative scan tests due to regression (#135995) https://github.com/pytorch/pytorch/pull/133012 caused a regression on ROCm causing pointwise scan tests to fail ``` ERROR: test_pointwise_associative_scan_tuple_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_tuple_reverse_False_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_complex_pytree_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_complex_pytree_reverse_False_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_binary_operator_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_binary_operator_reverse_False_combine_mode_pointwise_cuda ``` Skipping temporarily while triage is underway. Full log: https://ossci-raw-job-status.s3.amazonaws.com/log/30067645445 ``` File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1020, in call_function out = lowerings[target](*args, **kwargs) # type: ignore[index] File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 363, in wrapped out = decomp_fn(*args, **kwargs) File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 6245, in associative_scan raise RuntimeError("Unable to generate code for associative_scan op") torch._inductor.exc.LoweringException: RuntimeError: Unable to generate code for associative_scan op ``` NOTE: even "eager" backend fails ``` File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_higher_order_ops/associative_scan.py", line 338, in associative_scan_op_dense raise NotImplementedError("associative_scan is not implemented for eager") NotImplementedError: associative_scan is not implemented for eager ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135995 Approved by: https://github.com/malfet --- test/functorch/test_control_flow.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index c1fb51af8e2f90..cf639b6abcce4e 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1334,7 +1334,7 @@ def fwbw(map_op, f, x, y): unittest.skip, lambda params: ( params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") + and (params["device"] == torch.device("cpu") or torch.version.hip) ), ) def test_associative_scan_compile( @@ -1557,7 +1557,7 @@ def test_scan_dtype(self, reverse, compile_mode, device, dtype): unittest.skip, lambda params: ( params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") + and (params["device"] == torch.device("cpu") or torch.version.hip) ), ) def test_associative_scan_dim(self, combine_mode, reverse, device): @@ -1631,7 +1631,7 @@ def test_scan_dim(self, reverse, device): unittest.skip, lambda params: ( params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") + and (params["device"] == torch.device("cpu") or torch.version.hip) ), ) def test_associative_scan_binary_operator(self, combine_mode, reverse, device): @@ -1707,7 +1707,7 @@ def test_scan_binary_operator(self, reverse, device): unittest.skip, lambda params: ( params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") + and (params["device"] == torch.device("cpu") or torch.version.hip) ), ) def test_associative_scan_tuple(self, combine_mode, reverse, device): @@ -1804,7 +1804,7 @@ def fct_wrong_pytree(x, y): unittest.skip, lambda params: ( params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") + and (params["device"] == torch.device("cpu") or torch.version.hip) ), ) def test_associative_scan_complex_pytree(self, combine_mode, reverse, device): @@ -1912,7 +1912,7 @@ def test_scan_complex_pytree(self, reverse, device): unittest.skip, lambda params: ( params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") + and (params["device"] == torch.device("cpu") or torch.version.hip) ), ) def test_associative_scan_downstream_scan_matmul( @@ -1952,7 +1952,7 @@ def chain_fct(inp): unittest.skip, lambda params: ( params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") + and (params["device"] == torch.device("cpu") or torch.version.hip) ), ) def test_associative_scan_downstream_scan_scan( @@ -2004,7 +2004,7 @@ def chain_fct_same_dim(inp): unittest.skip, lambda params: ( params["combine_mode"] == "pointwise" - and params["device"] == torch.device("cpu") + and (params["device"] == torch.device("cpu") or torch.version.hip) ), ) def test_associative_scan_downstream_scan_scan_different_dim( From b634f52245f4d0ad1773f72fa22e30fd5f78dc67 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sat, 14 Sep 2024 06:09:36 +0000 Subject: [PATCH 0877/1018] [MPS] Add native im2col (#135706) It's called from `torch.unfold` and one of the few remaining vestiges in `MPSFallback.mm` Strongly inspired by CUDA implementation from https://github.com/pytorch/pytorch/blob/09519eb1959978af7f2f909acf768485772596ab/aten/src/ATen/native/cuda/im2col.cuh#L40-L61 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135706 Approved by: https://github.com/albanD --- aten/src/ATen/mps/MPSFallback.mm | 1 - aten/src/ATen/native/mps/OperationUtils.mm | 4 + aten/src/ATen/native/mps/operations/Im2Col.mm | 198 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 + 4 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 aten/src/ATen/native/mps/operations/Im2Col.mm diff --git a/aten/src/ATen/mps/MPSFallback.mm b/aten/src/ATen/mps/MPSFallback.mm index 26e4452eb7b3bd..1aafdbfd37a05f 100644 --- a/aten/src/ATen/mps/MPSFallback.mm +++ b/aten/src/ATen/mps/MPSFallback.mm @@ -88,7 +88,6 @@ static Tensor slow_conv2d_forward_mps(const Tensor& self, m.impl("embedding_renorm_", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("linalg_svd", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); - m.impl("im2col", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); // Used in preprocessing by nn.Unfold m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps); m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index d588e03afbb4f7..db469733fcf3b9 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -227,6 +227,10 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return "uchar"; case ScalarType::Bool: return "bool"; + case ScalarType::ComplexHalf: + return "half2"; + case ScalarType::ComplexFloat: + return "float2"; default: TORCH_CHECK(false, "Undefined type ", scalar_type); return "Undefined"; diff --git a/aten/src/ATen/native/mps/operations/Im2Col.mm b/aten/src/ATen/native/mps/operations/Im2Col.mm new file mode 100644 index 00000000000000..8cba89490f5683 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Im2Col.mm @@ -0,0 +1,198 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace at::native { +using namespace mps; +static MetalShaderLibrary lib(R"IM2COL_METAL( +// Heavily inspired by https://github.com/pytorch/pytorch/blob/09519eb19/aten/src/ATen/native/cuda/im2col.cuh#L51 +template +void im2col_kernel( + constant T * input, + device T * output, + uint2 kernel_size, + long2 input_offset, + long2 input_size, + long2 dilation, + ulong2 input_strides, + ulong output_stride) { + for (ulong i = 0; i < kernel_size.y; ++i) { + for (ulong j = 0; j < kernel_size.x; ++j) { + auto input_pos = input_offset + long2(j, i) * dilation; + if (input_pos.x < 0 || input_pos.y < 0 || input_pos.x >= input_size.x || input_pos.y >= input_size.y) { + *output = T(0); + } else { + auto offset = input_pos.x * input_strides.x + input_pos.y * input_strides.y; + *output = input[offset]; + } + output += output_stride; + } + } +} + +template +kernel void im2col( + constant T * inputData [[buffer(0)]], + device T * outputData [[buffer(1)]], + constant uint4 & kernel_dilation [[buffer(2)]], + constant int4 & padding_stride [[buffer(3)]], + constant ulong4 & input_strides [[buffer(4)]], + constant ulong4 & output_strides [[buffer(5)]], + constant long4 & input_sizes [[buffer(6)]], + uint3 thread_index [[thread_position_in_grid]]) { + // thread_index is (output_length, input_channels, input_batch) + const auto N = thread_index.z; + const auto C = thread_index.y; + const auto L = thread_index.x; + const auto output_width = output_strides.w; + const auto o_x = L % output_width; + const auto o_y = L / output_width; + auto i_x = o_x * padding_stride.z - padding_stride.x; + auto i_y = o_y * padding_stride.w - padding_stride.y; + ulong kernel_size = kernel_dilation.x * kernel_dilation.y; + outputData += N * output_strides.z + C * kernel_size * output_strides.y + L * output_strides.x; + inputData += N * input_strides.w + C * input_strides.z; + im2col_kernel(inputData, outputData, kernel_dilation.xy, long2(i_x, i_y), input_sizes.xy, long2(kernel_dilation.zw), input_strides.xy, output_strides.y); +} + +#define INSTANTIATE_IM2COL(DTYPE) \ +template \ +[[host_name("im2col_" #DTYPE)]] \ +kernel void im2col( \ + constant DTYPE * inputData [[buffer(0)]], \ + device DTYPE * outputData [[buffer(1)]], \ + constant uint4 & kernel_dilation [[buffer(2)]], \ + constant int4 & padding_stride [[buffer(3)]], \ + constant ulong4 & input_strides [[buffer(4)]], \ + constant ulong4 & output_strides [[buffer(5)]], \ + constant long4 & input_sizes [[buffer(6)]], \ + uint3 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_IM2COL(bool); +INSTANTIATE_IM2COL(float); +INSTANTIATE_IM2COL(float2); +INSTANTIATE_IM2COL(half); +INSTANTIATE_IM2COL(half2); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_IM2COL(bfloat); +#endif +)IM2COL_METAL"); + +namespace { +static void im2col_out_mps_template(Tensor& output, + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TORCH_CHECK(kernel_size.size() == 2, "It is expected kernel_size equals to 2, but got size ", kernel_size.size()); + + TORCH_CHECK(dilation.size() == 2, "It is expected dilation equals to 2, but got size ", dilation.size()); + + TORCH_CHECK(padding.size() == 2, "It is expected padding equals to 2, but got size ", padding.size()); + + TORCH_CHECK(stride.size() == 2, "It is expected stride equals to 2, but got size ", stride.size()); + + const auto kernel_height = kernel_size[0]; + int64_t kernel_width = kernel_size[1]; + int64_t dilation_height = dilation[0]; + int64_t dilation_width = dilation[1]; + int64_t pad_height = padding[0]; + int64_t pad_width = padding[1]; + int64_t stride_height = stride[0]; + int64_t stride_width = stride[1]; + + Tensor input = input_.contiguous(); + + bool batched_input = true; + + if (input.dim() == 3) { + batched_input = false; + input = input.unsqueeze(0); + } + + int64_t batch_size = input.size(0); + int64_t n_input_plane = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + + int64_t output_height = + (input_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / stride_height + 1; + int64_t output_width = (input_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1)) / stride_width + 1; + int64_t n_output_plane = n_input_plane * kernel_width * kernel_height; + int64_t output_length = output_height * output_width; + + output.resize_({batch_size, n_output_plane, output_length}); + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto im2colPSO = lib.getPipelineStateForFunc("im2col_" + mps::scalarToMetalTypeString(input)); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + std::array kernel_dilation = {static_cast(kernel_width), + static_cast(kernel_height), + static_cast(dilation_width), + static_cast(dilation_height)}; + std::array padding_stride = {static_cast(pad_width), + static_cast(pad_height), + static_cast(stride_width), + static_cast(stride_height)}; + std::array input_sizes = {input_width, input_height, n_input_plane, batch_size}; + std::array input_strides = {input.stride(3), input.stride(2), input.stride(1), input.stride(0)}; + std::array output_strides = {output.stride(2), output.stride(1), output.stride(0), output_width}; + getMPSProfiler().beginProfileKernel(im2colPSO, "im2col", {input, output}); + + if (getMPSProfiler().isCaptureEnabled()) { + getMPSProfiler().startCapture("im2col", stream); + } + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:im2colPSO]; + mtl_setBuffer(computeEncoder, input, 0); + mtl_setBuffer(computeEncoder, output, 1); + mtl_setBytes(computeEncoder, kernel_dilation, 2); + mtl_setBytes(computeEncoder, padding_stride, 3); + mtl_setBytes(computeEncoder, input_strides, 4); + mtl_setBytes(computeEncoder, output_strides, 5); + mtl_setBytes(computeEncoder, input_sizes, 6); + [computeEncoder dispatchThreads:MTLSizeMake(output_length, n_input_plane, batch_size) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + if (getMPSProfiler().isCapturing()) { + getMPSProfiler().stopCapture(stream); + } + getMPSProfiler().endProfileKernel(im2colPSO); + } + }); + if (!batched_input) { + output = output.squeeze(0); + } +} + +} // anonymous namespace +Tensor& im2col_out_mps(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + Tensor& output) { + im2col_out_mps_template(output, input, kernel_size, dilation, padding, stride); + return output; +} + +Tensor im2col_mps(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + im2col_out_mps_template(output, input, kernel_size, dilation, padding, stride); + return output; +} +} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8a68423495d607..e05aa43b17417f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13043,12 +13043,14 @@ dispatch: CPU: im2col_out_cpu CUDA: im2col_out_cuda + MPS: im2col_out_mps - func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor python_module: nn dispatch: CPU: im2col_cpu CUDA: im2col_cuda + MPS: im2col_mps - func: isfinite(Tensor self) -> Tensor variants: function, method From 90422d9110f3da2e47a5423ddc9c2bac6c2d7ada Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 13 Sep 2024 16:34:34 -0700 Subject: [PATCH 0878/1018] Disable garbage collection during compile_time_instructions count in benchmark base by default. (#135768) When we measure compile time instruction count, probably we do want in most cases to measure gc instructions disabling it here by default. if it is needed we can add an option to allow it, or someone can use the regular total instruction count instead of compile time instruction count. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135768 Approved by: https://github.com/ezyang, https://github.com/anijain2305 --- .../pr_time_benchmarks/benchmark_base.py | 47 +++++++++++-------- .../pr_time_benchmarks/benchmarks/add_loop.py | 3 -- .../benchmarks/update_hint_benchmark.py | 3 -- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py index d279922ea1aa69..0c0d0e7ecfa7b4 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py @@ -1,4 +1,5 @@ import csv +import gc from abc import ABC, abstractmethod from fbscribelogger import make_scribe_logger @@ -105,26 +106,32 @@ def _count_instructions(self): return min(results) def _count_compile_time_instructions(self): - print(f"collecting compile time instruction count for {self.name()}") - config.record_compile_time_instruction_count = True - - results = [] - for i in range(self._num_iterations): - self._prepare() - # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner - # hence this will only count instruction count spent in compile_inner. - CompileTimeInstructionCounter.clear() - self._work() - count = CompileTimeInstructionCounter.value() - if count == 0: - raise RuntimeError( - "compile time instruction count is 0, please check your benchmarks" - ) - print(f"compile time instruction count for iteration {i} is {count}") - results.append(count) - - config.record_compile_time_instruction_count = False - return min(results) + gc.disable() + + try: + print(f"collecting compile time instruction count for {self.name()}") + config.record_compile_time_instruction_count = True + + results = [] + for i in range(self._num_iterations): + self._prepare() + gc.collect() + # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner + # hence this will only count instruction count spent in compile_inner. + CompileTimeInstructionCounter.clear() + self._work() + count = CompileTimeInstructionCounter.value() + if count == 0: + raise RuntimeError( + "compile time instruction count is 0, please check your benchmarks" + ) + print(f"compile time instruction count for iteration {i} is {count}") + results.append(count) + + config.record_compile_time_instruction_count = False + return min(results) + finally: + gc.enable() def append_results(self, path): with open(path, "a", newline="") as csvfile: diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py index f56d12a6d1e410..f28d59f154ea28 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py @@ -1,4 +1,3 @@ -import gc import sys from benchmark_base import BenchmarkBase @@ -30,8 +29,6 @@ def _prepare_once(self): def _prepare(self): torch._dynamo.reset() - gc.collect() - gc.disable() def _work(self): @torch.compile(backend=self._backend, fullgraph=True, dynamic=self._dynamic) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py index d7dfa0615d25b7..7957836b6a9d1d 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py @@ -1,4 +1,3 @@ -import gc import sys from benchmark_base import BenchmarkBase @@ -25,8 +24,6 @@ def _prepare_once(self): def _prepare(self): torch._dynamo.reset() - gc.collect() - gc.disable() def _work(self): @torch.compile(fullgraph=True) From 3e6c90486f9f731fcbe9be62bf191d30fffde8f2 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Sat, 14 Sep 2024 06:16:10 +0000 Subject: [PATCH 0879/1018] [CI] Increase open file handles limit to 16K on MacOS (#136061) May be it will help with flaky failures tracked in https://github.com/pytorch/pytorch/issues/135885 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136061 Approved by: https://github.com/clee2000, https://github.com/kit1980, https://github.com/huydhn, https://github.com/ZainRizvi --- .ci/pytorch/macos-test.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 5fa57485a0a7a1..47707afdc2b14a 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -25,8 +25,9 @@ setup_test_python() { echo "Ninja version: $(ninja --version)" echo "Python version: $(which python) ($(python --version))" - # Increase default limit on open file handles from 256 to 1024 - ulimit -n 1024 + # Set the limit on open file handles to 16384 + # might help with intermittent compiler test failures + ulimit -n 16384 } test_python_all() { From 293eb3727e046d949f0a977b99ff29840c85ea95 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 13 Sep 2024 13:59:17 -0700 Subject: [PATCH 0880/1018] [Traceable FSDP2] Don't register RegisterPostBackwardFunction if user intends to use Traceable FSDP2, and assert that compiled autograd is not used when entering RegisterPostBackwardFunction (#135824) During enablement of Traceable FSDP2 on internal models, sometimes the user only applies torch.compile to some of the FSDP2 instances but not all of them. Such mixed usage pattern is not supported by compiled autograd. Here we try to catch and throw error at such usage pattern, so that the user can fix the usage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135824 Approved by: https://github.com/awgu --- .../_composable/fsdp/_fsdp_param_group.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 5f22e2d34b3229..e90863479a8d1e 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -478,9 +478,9 @@ def use_training_state(self, training_state: TrainingState): def _register_post_backward_hook( self, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - # Compile relies on `root_post_backward_callback` to call each + # Traceable FSDP2 relies on `root_post_backward_callback` to call each # `FSDPParamGroup.post_backward` - if ca.compiled_autograd_enabled: + if (not torch._dynamo.config.skip_fsdp_hooks) or ca.compiled_autograd_enabled: return args, kwargs if not torch.is_grad_enabled(): return args, kwargs @@ -634,13 +634,27 @@ def _get_param_module_infos( class RegisterPostBackwardFunction(torch.autograd.Function): + @staticmethod + def _assert_not_tracing_fsdp(): + if ca.compiled_autograd_enabled: + # TODO: Find a way to print the offending FSDP2 module. + msg = """\ +When Traceable FSDP2 is enabled, we rely on `root_post_backward_callback` to call +each `FSDPParamGroup.post_backward`, and we should not be calling into `RegisterPostBackwardFunction`. +If you are here, it means the forward part of this FSDP2 instance is not compiled, and you must also +compile the forward part if you want to use Traceable FSDP2.""" + torch._dynamo.comptime.comptime.print(msg) + raise RuntimeError(msg) + @staticmethod def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor): # All tensors in `inputs` should require gradient + RegisterPostBackwardFunction._assert_not_tracing_fsdp() ctx.param_group = param_group return inputs @staticmethod def backward(ctx, *grads: torch.Tensor): + RegisterPostBackwardFunction._assert_not_tracing_fsdp() ctx.param_group.post_backward() return (None,) + grads From 7b52775f4b7dd4886d2fcc2ce3d916b87500e04a Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 13 Sep 2024 14:08:17 -0700 Subject: [PATCH 0881/1018] Add higher order operator name to the cache bypass exception (#135876) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135876 Approved by: https://github.com/jamesjwu, https://github.com/zou3519 --- torch/_inductor/codecache.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 26a35a2d801c93..dd1be42e4c7443 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -538,7 +538,7 @@ def _reduce_tensor( # TODO: These tensors don't currently pickle, so we can't cache a # compiled graph containing them. Just fail now. If mkldnn tensors # get pickling support, we can remove this. - raise BypassFxGraphCache("mkldnn tensors unpickleable.") + raise BypassFxGraphCache("mkldnn tensors unpickleable") # Very large tensors could be expensive to copy to cpu and hash. Let's # at least report if we find slowness. @@ -569,7 +569,7 @@ def _reduce_unsupported(s: Any) -> NoReturn: See FxGraphCachePickler. Custom reducer to handle any objects that we don't support and therefore raise to bypass caching. """ - raise BypassFxGraphCache("Reduce unsupported.") + raise BypassFxGraphCache("Reduce unsupported") class FxGraphCachePickler(pickle.Pickler): @@ -607,7 +607,7 @@ def dumps(cls, obj: Any) -> bytes: # Some configs options are callables, e.g., post_grad_custom_pre_pass, # and may not pickle. log.warning("Can't pickle", exc_info=True) - raise BypassFxGraphCache("Config options may be unpickleable.") from e + raise BypassFxGraphCache("Config options may be unpickleable") from e return stream.getvalue() @classmethod @@ -1259,25 +1259,27 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: # Freezing can embed constants that wouldn't be static across runs. if config.freezing or config.aot_inductor.use_runtime_constant_folding: raise BypassFxGraphCache( - "Freezing may introduce constants that aren't static across runs." + "Freezing may introduce constants that aren't static across runs" ) # The treatment of guards in the caching implementation requires that # we have a shape env. if FxGraphCache._get_shape_env() is None: log.debug("fx graph cache no shape env") - raise BypassFxGraphCache("No shape env.") + raise BypassFxGraphCache("No shape env") # HigherOrderOperators should be handled on a case-by-case basis. # Currently, we just skip caching if we have any. # We also skip if there are any torchbind objects. for node in gm.graph.nodes: if isinstance(node.target, torch._ops.HigherOrderOperator): - raise BypassFxGraphCache("Can't cache HigherOrderOperators.") + raise BypassFxGraphCache( + f"Can't cache HigherOrderOperator: {node.target.name()}" + ) if node.op == "getattr" and isinstance( getattr(gm, node.target), torch._C.ScriptObject ): - raise BypassFxGraphCache("Can't cache torchbind objects.") + raise BypassFxGraphCache("Can't cache torchbind objects") @staticmethod def load( # type: ignore[no-untyped-def] From e9114254fd015fb294f3a2c11d58287ced759d9c Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 13 Sep 2024 17:49:59 -0700 Subject: [PATCH 0882/1018] [Traceable FSDP2][Partitioner] Must save AC output if output has a backward hook (#135727) If node is AC region output and has a backward hook on it, we intentionally choose to save it. This is to work around circular dependencies in Traceable FSDP2+AC. Example: ``` out = fully_shard(utils.checkpoint(module))(x) norm_out = layer_norm(out) ``` and there is a circular dependency: 1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`. 2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed. 3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad` -> circular dependency with (1)! Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency. ---- Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727 Approved by: https://github.com/Chillee --- test/dynamo/test_activation_checkpointing.py | 116 ++++++++++++++++++- torch/_dynamo/variables/tensor.py | 5 +- torch/_functorch/partitioners.py | 20 ++++ torch/fx/proxy.py | 1 + 4 files changed, 138 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 843f11b51c990f..c600ca88fd2a84 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -1,9 +1,11 @@ # Owner(s): ["module: dynamo"] +import contextlib import copy import functools import math import unittest # noqa: F811 from importlib import import_module +from typing import Set import torch import torch._dynamo.config @@ -83,6 +85,14 @@ def match_rng_op(node, op): return gm +def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]): + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph + return_node = list(graph.nodes)[-1] + assert return_node.target == "output" + for x in return_node.args[0]: + fwd_outputs.add(str(x)) + + class _InvalidContext: def __init__(self) -> None: pass @@ -126,18 +136,35 @@ def _custom_policy(ctx, func, *args, **kwargs): class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): - def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True): + def _validate( + self, + fn, + backend, + *args, + skip_check=False, + fullgraph=True, + compiled_autograd=False, + ): cloned_args = [] for arg in args: cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad)) + cloned_fn = copy.deepcopy(fn) + torch.manual_seed(0) expected = fn(*args) expected.sum().backward() torch.manual_seed(0) - result = torch.compile(fn, fullgraph=fullgraph, backend=backend)(*cloned_args) - result.sum().backward() + compiled_fn = torch.compile(cloned_fn, fullgraph=fullgraph, backend=backend) + ctx = contextlib.nullcontext() + if compiled_autograd: + ctx = torch._dynamo.compiled_autograd.enable( + lambda gm: torch.compile(gm, fullgraph=fullgraph, backend=backend) + ) + with ctx: + result = compiled_fn(*cloned_args) + result.sum().backward() if not skip_check: self.assertEqual( @@ -442,6 +469,89 @@ def fn(x): # rand decomps do not have have numerical results as eager self._validate(fn, backend, x, skip_check=True) + @torch._functorch.config.patch(recompute_views=True) + @torch._inductor.config.patch(fx_graph_cache=False) + def test_tags_must_save_tensor_that_has_backward_hook(self): + def my_post_forward_hook(submod, args, output): + output.register_hook(my_backward_hook) + return output + + def my_backward_hook(grad): + return grad + + class MySubmod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.matmul(x, x) + z = y * y + return z + + class MyMod(torch.nn.Module): + def __init__(self): + super().__init__() + self.submod = MySubmod() + self.norm = torch.nn.LayerNorm(4) + + def forward(self, x): + out = torch.utils.checkpoint.checkpoint( + self.submod, x, use_reentrant=False + ) + norm_out = self.norm(out) + return norm_out + + def _factory_fn(): + mod = MyMod() + x = torch.ones(4, 4, dtype=torch.float32, requires_grad=True) + backend = "inductor" + return mod, x, backend + + mod_no_hook, x, backend = _factory_fn() + mod_no_hook_fwd_outputs = set() + + with torch._inductor.config.patch( + post_grad_custom_pre_pass=functools.partial( + collect_fwd_graph_outputs, fwd_outputs=mod_no_hook_fwd_outputs + ) + ): + self._validate( + mod_no_hook, backend, x, fullgraph=True, compiled_autograd=True + ) + + mod_with_hook, x, backend = _factory_fn() + mod_with_hook.submod.register_forward_hook(my_post_forward_hook) + mod_with_hook_fwd_outputs = set() + + with torch._inductor.config.patch( + post_grad_custom_pre_pass=functools.partial( + collect_fwd_graph_outputs, fwd_outputs=mod_with_hook_fwd_outputs + ) + ): + self._validate( + mod_with_hook, backend, x, fullgraph=True, compiled_autograd=True + ) + + # If `z` has a backward hook, result of `z = y * y` should also be saved in addition to the usual saved tensors. + mod_no_hook_fwd_outputs_no_primal = { + x for x in mod_no_hook_fwd_outputs if not x.startswith("primals_") + } + mod_with_hook_fwd_outputs_no_primal = { + x for x in mod_with_hook_fwd_outputs if not x.startswith("primals_") + } + additional_saved_tensors = ( + mod_with_hook_fwd_outputs_no_primal - mod_no_hook_fwd_outputs_no_primal + ) + expected_additional_saved_tensors = {"mul"} + self.assertEqual( + additional_saved_tensors, + expected_additional_saved_tensors, + f""" +Expected additional saved tensors: {expected_additional_saved_tensors} but got: {additional_saved_tensors}. +Non-primal fwd outputs from model w/ backward hook: {mod_with_hook_fwd_outputs_no_primal}. +Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""", + ) + @requires_cuda def test_fallback(self): def gn(x, y): diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 088313bdbf551d..5cb02c077cbc30 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1006,12 +1006,15 @@ def _register_hook_trampoline(tensor, bw_state): from .builder import wrap_fx_proxy + self_proxy = self.as_proxy() + self_proxy.node.meta["has_backward_hook"] = True + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", _register_hook_trampoline, - (self.as_proxy(), bw_state_proxy), + (self_proxy, bw_state_proxy), {}, ), ) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 4263a2f2b10a33..81e2f297f6fd2c 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -784,6 +784,26 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: and user.meta["ac_graph_id"] > node.meta["ac_graph_id"] ): node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + if node.meta.get("has_backward_hook", False) and not any( + must_recompute(user) for user in node.users + ): + # If node is AC region output and has a backward hook on it, we intentionally choose to save it. + # This is to work around circular dependencies in Traceable FSDP2+AC. + # Example: + # ``` + # out = fully_shard(utils.checkpoint(module))(x) + # norm_out = layer_norm(out) + # ``` + # Here there is a circular dependency: + # 1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`. + # 2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) + # in order to be recomputed. + # 3. `out`'s backward hook, as is the case for all eager backward hooks, depends on `out_grad` + # -> circular dependency with (1)! + # + # Solution: check whether `out` has a backward hook, and if so, intentionally save `out` + # in forward graph outputs. With this, we can break the above circular dependency. + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE return joint_module diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index dc8ab7aa77e75e..86927595eac91e 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -97,6 +97,7 @@ def __exit__(self, *args): "original_aten", "recompute", "ac_graph_id", + "has_backward_hook", "from_node", "quantization_tag", # TODO deprecated "_numeric_debug_handle", # TODO deprecated From d6a5fdb02ac1fa5bea66cac8cf4c8791d7978b6a Mon Sep 17 00:00:00 2001 From: CaoE Date: Sat, 14 Sep 2024 09:53:15 +0000 Subject: [PATCH 0883/1018] Use _amp_foreach_non_finite_check_and_unscale_ for CPU grads of ShardedGradScaler (#135232) Use `_amp_foreach_non_finite_check_and_unscale_` instead of fallback version for CPU grads of `ShardedGradScaler ` as `_amp_foreach_non_finite_check_and_unscale_ ` is supported on CPU https://github.com/pytorch/pytorch/pull/109281. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135232 Approved by: https://github.com/ezyang --- torch/distributed/fsdp/sharded_grad_scaler.py | 49 +++---------------- 1 file changed, 6 insertions(+), 43 deletions(-) diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index c4f1f1dae4e82c..988ecb3533f5b9 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import logging from collections import abc, defaultdict -from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, overload, Tuple, Union import torch import torch.distributed as dist @@ -168,36 +168,6 @@ def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): return apply_scale(outputs) - def _foreach_non_finite_check_and_unscale_cpu_( - self, - grads: Sequence[torch.Tensor], - found_inf: torch.Tensor, - inv_scale: torch.Tensor, - ) -> None: - if len(grads) == 0: - return - assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor." - assert found_inf.numel() == 1, "found_inf must be a 1-element tensor." - - for grad in grads: - if grad.device.type != "cpu": - logger.error( - "tensor device is %s but was expected to be ``cpu``", - grad.device, - ) - raise ValueError( - "Gradients were found on a non-CPU device when" - " expected to be on CPU." - ) - if ( - torch.isinf(grad).any().item() is True - or torch.isnan(grad).any().item() is True - ): - found_inf.data = torch.tensor([1.0]) - break - else: - grad.data *= inv_scale.item() - def _unscale_grads_( self, optimizer: torch.optim.Optimizer, @@ -241,18 +211,11 @@ def _unscale_grads_( for device, per_dtype_grads in per_device_and_dtype_grads.items(): for grads in per_dtype_grads.values(): - if grads[0].device.type == "cpu": - self._foreach_non_finite_check_and_unscale_cpu_( - grads, - per_device_found_inf.get(device), - per_device_inv_scale.get(device), - ) - else: - torch._amp_foreach_non_finite_check_and_unscale_( - grads, - per_device_found_inf.get(device), - per_device_inv_scale.get(device), - ) + torch._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some # ranks may have no (non-zero sized) parameter shards, necessitating the # initialization of `per_device_found_inf._per_device_tensors` here From 41bf3bec9f801365cc68d31589fbd8a722ba7945 Mon Sep 17 00:00:00 2001 From: Zhenbin Lin Date: Sat, 14 Sep 2024 09:57:43 +0000 Subject: [PATCH 0884/1018] OpenReg: Fix issue when copying on the same device (#135956) Current copy gets wrong value when src and dst are both openreg. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135956 Approved by: https://github.com/albanD --- .../open_registration_extension/pytorch_openreg/_aten_impl.py | 1 + .../open_registration_extension/test/test_openreg.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py index 894e6ea10d129d..7103655185ba89 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -42,6 +42,7 @@ def _openreg_kernel_fallback(op, *args, **kwargs): if from_.device.type == to_.device.type: assert from_.device.type == "openreg" op = torch.ops.aten.copy_.default + args = to_, from_ # handled below as a regular copy elif from_.device.type == "openreg": args, _ = prepare_for_sending((from_,), {}) diff --git a/test/cpp_extensions/open_registration_extension/test/test_openreg.py b/test/cpp_extensions/open_registration_extension/test/test_openreg.py index 18484e61bef04b..27689c7559aca7 100644 --- a/test/cpp_extensions/open_registration_extension/test/test_openreg.py +++ b/test/cpp_extensions/open_registration_extension/test/test_openreg.py @@ -51,6 +51,10 @@ def test_cross_device_copy(self): b = a.to(device="openreg").add(2).to(device="cpu") self.assertEqual(b, a + 2) + def test_copy_same_device(self): + a = torch.ones(10, device="openreg").clone() + self.assertEqual(a, torch.ones(10, device="openreg")) + def test_data_dependent_output(self): cpu_a = torch.randn(10) a = cpu_a.to(device="openreg") From ae4e145b39e7f79ba33adf27b958b68a647f19dc Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 14 Sep 2024 10:02:54 +0000 Subject: [PATCH 0885/1018] Revert "[Dynamo] Remove ignored modes from torch function mode stack guard (#135503)" This reverts commit e77bd0ebd20e96990ccd40518e68bbcfe7fda855. Reverted https://github.com/pytorch/pytorch/pull/135503 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2350937008)) --- torch/_C/_dynamo/guards.pyi | 2 +- torch/_dynamo/guards.py | 1 + torch/csrc/dynamo/guards.cpp | 69 +++++++++++++++++++++++++++++++----- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 918d913068e6df..d6aab0d547f727 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -67,7 +67,7 @@ class GuardManager: ) -> None: ... def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_torch_function_mode_stack_guard( - self, initial_stack, verbose_code_parts: list[str] + self, initial_stack, ignored_types, verbose_code_parts: list[str] ) -> None: ... class RootGuardManager(GuardManager): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 3dcc9a032f208f..66deeb3cbc0e79 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2350,6 +2350,7 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, + list(), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 1d1cf2fb0c60a6..771de44427b669 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2515,40 +2515,90 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { public: TORCH_FUNCTION_MODE_STACK( const py::list& initial_stack, + const py::list& ignored_types, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), _ref_stack() { + : LeafGuard(std::move(verbose_code_parts)), + _ref_stack(), + _ignored_types() { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref auto type = Py_TYPE(mode); this->_ref_stack.push_back(type); } + + len = PyList_Size(ignored_types.ptr()); + for (Py_ssize_t idx = 0; idx < len; idx++) { + PyObject* type_obj = + PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref + if (PyType_Check(type_obj) == 0) { + PyErr_SetString( + PyExc_TypeError, "ignored_types should contain a list of types"); + return; + } + PyTypeObject* type = (PyTypeObject*)type_obj; + this->_ignored_types.insert(type); + } } bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface - const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); + size_t ref_ind = 0; + const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - if (len != ref_stack_size) { - return false; + int64_t idx = 0; + while ((idx < len) && (ref_ind < ref_stack_size)) { + std::shared_ptr mode = + at::impl::PythonTorchFunctionTLS::get_stack_at(idx); + + PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + bool act_ignored = this->_ignored_types.count(mode_type) > 0; + bool ref_ignored = + this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; + // skip ignored types + if (act_ignored && ref_ignored) { + idx++; + ref_ind++; + continue; + } else if (ref_ignored) { + ref_ind++; + continue; + } else if (act_ignored) { + idx++; + continue; + } + // if we already have more non-ignored modes than the ref stack + // or if the mode doesn't match at the current index, return false + else if (mode_type != _ref_stack.at(ref_ind)) { + return false; + } + ref_ind++; + idx++; } - for (int64_t idx = 0; (size_t)idx < len; idx++) { + for (; ref_ind < ref_stack_size; ref_ind++) { + if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { + return false; + } + } + + for (; idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (mode_type != _ref_stack.at(idx)) { + if (!(this->_ignored_types.count(mode_type) > 0)) { return false; } } - return true; + return ref_ind == ref_stack_size && idx == len; } private: std::vector _ref_stack; + std::set _ignored_types; }; class TENSOR_MATCH : public LeafGuard { @@ -3713,7 +3763,7 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") - .def(py::init()) + .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( py_m, "DATA_PTR_MATCH") @@ -3950,9 +4000,10 @@ PyObject* torch_c_dynamo_guards_init() { "add_torch_function_mode_stack_guard", [](GuardManager& self, const py::list& initial_stack, + const py::list& ignored_types, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - initial_stack, std::move(verbose_code_parts))); + initial_stack, ignored_types, std::move(verbose_code_parts))); }) .def( "add_data_ptr_guard", From 7831fd9c0c4a5fb5f3c17a5684e85c2429c8e5e2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 14 Sep 2024 10:02:55 +0000 Subject: [PATCH 0886/1018] Revert "[Dynamo] Remove ignored modes workaround (#135502)" This reverts commit 5c67cf180ee53d696f95d7c45dd99a35399e4450. Reverted https://github.com/pytorch/pytorch/pull/135502 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2350937008)) --- test/dynamo/test_modes.py | 65 +++++++++++++++++++++++ torch/_dynamo/guards.py | 12 +++-- torch/_dynamo/utils.py | 10 +++- torch/_dynamo/variables/torch_function.py | 6 +++ 4 files changed, 88 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 4d1f2bbea389ef..63e06a515ed186 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -1,4 +1,5 @@ # Owner(s): ["module: dynamo"] +from unittest.mock import patch import torch import torch._dynamo.test_case @@ -106,6 +107,70 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 4) + def _run_ignored_mode_types_test(self): + class IgnoredMode(BaseTorchFunctionMode): + pass + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt.__call__, fullgraph=True) + def fn(x): + return x + 1 + + inp = torch.ones(2, 2) + + with patch( + "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} + ): + # initial compile + fn(inp) + + # no recompile, mode ignored + # note: the ref stack is length 0, and the stack we are checking against has length 2 + # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack + with IgnoredMode(), IgnoredMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 1) + + # recompile due to new mode on the stack + with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 2) + + # recompile + # tests both ref stack len > runtime stack len for the above guard check + # and ref stack len < runtime stack len for the initial zero mode case + with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 3) + + # no recompile + with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 3) + + # This is tricky, basically the ignored modes are baked into the guard + # IgnoredMode will be ignored forever by that guard. + # This is okay since we don't expect to be modifying IGNORED_MODES + # in the middle of execution except for the purposes of testing. + torch._dynamo.reset() + + with IgnoredMode(): + fn(inp) + + self.assertEqual(cnt.frame_count, 4) + + @torch._dynamo.config.patch("enable_cpp_guard_manager", False) + def test_torch_function_mode_guards_ignored_types_py(self): + self._run_ignored_mode_types_test() + + def test_torch_function_mode_guards_ignored_types_cpp(self): + self._run_ignored_mode_types_test() + @torch._dynamo.config.patch("enable_cpp_guard_manager", False) def test_torch_function_mode_guards_py(self): self._run_torch_function_mode_guard_test() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 66deeb3cbc0e79..1544bdd6dc697b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2344,13 +2344,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): ) if config.enable_cpp_guard_manager: + from .variables.torch_function import IGNORED_MODES + # Insert the global_state guard assert self.guard_manager # to make mypy happy self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(), + list(IGNORED_MODES), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list @@ -2657,14 +2659,18 @@ def is_recompiles_verbose_enabled(): # this will only be used if cpp guards are disabled def make_torch_function_mode_stack_guard(intial_stack): types = [type(x) for x in intial_stack] + from .variables.torch_function import IGNORED_MODES def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - if len(cur_stack) != len(types): + types_ = [ty for ty in types if ty not in IGNORED_MODES] + cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] + + if len(cur_stack_) != len(types_): return False - for ty, mode in zip(types, cur_stack): + for ty, mode in zip(types_, cur_stack_): if ty != type(mode): return False diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 2d4a771166cfbe..8130b6aa7371b2 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3083,10 +3083,16 @@ def is_parameter_freezing(): return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(): - return [ +def get_torch_function_mode_stack(filter_ignored=True): + from .variables.torch_function import IGNORED_MODES + + stack = [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] + if filter_ignored: + stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] + + return stack def get_torch_function_mode_stack_at(ind): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index ffb3d27d4d703b..b7fd83a3d24f15 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -69,6 +69,12 @@ if is_tensor_base_attr_getter(fn) ] +# Today set default device is placed in the graph and guarded on separately +# so we should not trace through it. In the future we can trace it once +# mode tracing is implemented and not put in the graph, but this is more +# of a BE project and can be evaluated later +IGNORED_MODES = {DeviceContext} + @functools.lru_cache(None) def get_prev_stack_var_name(): From 3379c6116f55030a78754da3c7d282d8ae733c22 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 14 Sep 2024 10:02:55 +0000 Subject: [PATCH 0887/1018] Revert "[Dynamo] Trace enter/exit of TorchFunctionModes (#135422)" This reverts commit 7743149b2be4a9eba7e0997ccdc6abe552bec266. Reverted https://github.com/pytorch/pytorch/pull/135422 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2350937008)) --- test/dynamo/test_modes.py | 88 ------------------ torch/_dynamo/convert_frame.py | 6 +- torch/_dynamo/output_graph.py | 6 +- torch/_dynamo/polyfills/__init__.py | 20 ---- torch/_dynamo/resume_execution.py | 104 --------------------- torch/_dynamo/side_effects.py | 11 --- torch/_dynamo/symbolic_convert.py | 30 ++---- torch/_dynamo/testing.py | 1 - torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/builder.py | 20 ++-- torch/_dynamo/variables/ctx_manager.py | 12 --- torch/_dynamo/variables/torch.py | 21 +---- torch/_dynamo/variables/torch_function.py | 108 ++-------------------- torch/_dynamo/variables/user_defined.py | 14 +-- 14 files changed, 34 insertions(+), 409 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 63e06a515ed186..fa4c23fd320ddb 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -461,94 +461,6 @@ def fn(x): self.assertEqual(expected, actual) - def test_torch_function_mode_enter_exit(self): - def fn(x, y): - with TestMode(): - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn, fullgraph=True) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_graph_break(self): - def fn(x, y): - with TestMode(): - torch._dynamo.graph_break() - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_and_pop_graph_break(self): - def fn(x, y): - with TestMode(): - z = _pop_torch_function_stack() - torch._dynamo.graph_break() - _push_on_torch_function_stack(z) - o = torch.add(x, 3) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_restore_on_exc(self): - @torch._dynamo.disable() - def err(): - raise RuntimeError("test") - - @torch.compile() - def fn(x): - with TestMode(): - x += 1 - err() - x += 2 - return x - - try: - fn(torch.ones(2, 2)) - except RuntimeError: - pass - self.assertEqual(_len_torch_function_stack(), 0) - - def test_torch_function_mode_and_pop_graph_break_mutation(self): - def fn(x, y): - with TestMode(): - z = _pop_torch_function_stack() - z.y = 5 - torch._dynamo.graph_break() - _push_on_torch_function_stack(z) - o = torch.add(x, 3) - o = torch.mul(o, z.y) - - return torch.add(o, y) - - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) - fn_opt = torch.compile(fn) - - expected = fn(*inp) - actual = fn_opt(*inp) - - self.assertEqual(expected, actual) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 171af02e564b36..b3e4a7c268c0fe 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -112,7 +112,6 @@ troubleshooting_url, write_record_to_file, ) -from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -211,18 +210,15 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() + exit_stack = contextlib.ExitStack() exit_stack.enter_context( torch.fx._symbolic_trace._maybe_revert_all_patches() ) - exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() - assert ( - torch._C._len_torch_function_stack() == 0 - ), "Torch function mode stack state changed while dynamo tracing, please report a bug" exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 0ea19e1cad6586..ba8cfa5a03dfba 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,6 +78,7 @@ get_instruction_source_311, get_locals_to_steal, get_static_address_type, + get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -249,7 +250,6 @@ def __init__( local_scope: Scope, global_scope: Scope, f_code, - torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -368,7 +368,7 @@ def __init__( # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = torch_function_mode_stack + self.torch_function_mode_stack = get_torch_function_mode_stack() # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ def append_prefix_insts(): prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx, is_graph_break=reason.graph_break) + block.exit(tx) self.cleanup_graph() tx.prune_dead_locals() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 5b2812bc08c9e5..2b3f38920a0c2a 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -25,26 +25,6 @@ sys as sys, ) -from torch.overrides import BaseTorchFunctionMode - - -# These classes handle support for TorchFunctionModes across -# graph breaks -# Today the TorchFunctionMode enter (for the classes we support) -# simply pushes the mode onto the stack. Since after this occurs -# the stack is mutated, and we replay these mutations, we don't need -# any cleanup logic to be run once the graph break occurs, we simply replay -# these mutations to ensure at the graph break the torch function mode stack is correct -# and reconstruct the torch function mode stack normally -# when we compile the resume function on the other side of the break. -# However, to ensure we exit properly -# in the resume function, we need to re-enter the contexts as we do other contexts. -# These contexts do nothing on enter, but provide the correct exit logic to ensure -# the stack state is correct. -class NoEnterTorchFunctionMode(BaseTorchFunctionMode): - def __enter__(self): - pass - def index(iterator, item, start=0, end=None): from itertools import islice diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 6f7db66514c487..132e9e4081bceb 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -6,7 +6,6 @@ from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( - add_push_null, create_call_function, create_call_method, create_dup_top, @@ -49,109 +48,6 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None - def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): - """ - Codegen based off of: - try: - (rest) - except: - (restore previous stack) - - """ - from .variables.torch_function import get_prev_stack_var_name - - except_jump_target = create_instruction( - "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" - ) - cleanup_complete_jump_target = create_instruction("NOP") - - setup_finally: List[Instruction] = [] - - if sys.version_info < (3, 11): - setup_finally.append( - create_instruction("SETUP_FINALLY", target=except_jump_target) - ) - else: - exn_tab_begin = create_instruction("NOP") - exn_tab_end = create_instruction("NOP") - exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( - exn_tab_begin, - exn_tab_end, - except_jump_target, - self.stack_index + 1, - False, - ) - setup_finally.append(exn_tab_begin) - - def create_reset(): - insts = [ - create_instruction( - "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" - ), - create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), - ] - add_push_null(insts) - return [ - *insts, - create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()), - *create_call_function(1, False), - create_instruction("POP_TOP"), - ] - - if sys.version_info < (3, 9): - epilogue = [ - create_instruction("POP_BLOCK"), - create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), - except_jump_target, - *create_reset(), - create_instruction("POP_TOP"), - create_instruction("POP_TOP"), - create_instruction("POP_TOP"), - *create_reset(), - create_instruction("RAISE_VARARGS", argval=0), - create_instruction("POP_EXCEPT", argval=0), - create_instruction("END_FINALLY"), - cleanup_complete_jump_target, - ] - elif sys.version_info < (3, 11): - epilogue = [ - create_instruction("POP_BLOCK"), - create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), - except_jump_target, - create_instruction("POP_TOP"), - create_instruction("POP_TOP"), - create_instruction("POP_TOP"), - *create_reset(), - create_instruction("RAISE_VARARGS", argval=0), - create_instruction("POP_EXCEPT", argval=0), - cleanup_complete_jump_target, - ] - else: - finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) - finally_exn_tab_target = create_instruction("COPY", arg=3) - except_jump_target.exn_tab_entry = InstructionExnTabEntry( - except_jump_target, - finally_exn_tab_end, - finally_exn_tab_target, - self.stack_index + 2, - True, - ) - epilogue = [ - exn_tab_end, - create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), - except_jump_target, # PUSH_EXC_INFO - create_instruction("POP_TOP"), - *create_reset(), - finally_exn_tab_end, - finally_exn_tab_target, # COPY 3 - create_instruction("POP_EXCEPT"), - create_instruction("RERAISE", arg=1), # RERAISE 1 - cleanup_complete_jump_target, - ] - - cleanup[:] = epilogue + cleanup - return setup_finally - # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fd5b54e4f7f7af..fc1bd976ff57df 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -593,22 +593,11 @@ def codegen_update_mutated(self, cg: PyCodegen): elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): - # Needed in the finally block for stack restoration - cg.add_push_null( - lambda: cg.load_import_from( - utils.__name__, "get_torch_function_mode_stack" - ) - ) - cg.call_function(0, False) - name = variables.torch_function.get_prev_stack_var_name() - cg.code_options["co_varnames"] += (name,) - cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) - cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 61acaec995757a..415179a5ed32a6 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -267,12 +267,13 @@ def resume_fn(self): else: return ReenterWith(self.stack_index) - def exit(self, tx, is_graph_break): + def exit(self, tx): + if hasattr(self, "graph_break") and isinstance( + self.with_context, TorchFunctionModeVariable + ): + return assert self.with_context is not None - if ( - is_graph_break and self.with_context.exit_on_graph_break() - ) or not is_graph_break: - return self.with_context.exit(tx) + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -638,17 +639,10 @@ def handle_graph_break( cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: - # Don't exit any modes we have entered, - # output bytecode will mutate the tf mode stack accordingly - if isinstance(b.with_context, TorchFunctionModeVariable): - cg.extend_output( - b.resume_fn().try_except_torch_function_mode( - cg.code_options, cleanup - ) - ) - continue assert b.with_context is not None - assert isinstance(b.with_context, (ContextWrappingVariable)) + assert isinstance( + b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) + ) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -2301,10 +2295,7 @@ def setup_or_before_with(self, inst): ): unimplemented(f"{inst.opname} {ctx}") - if ( - isinstance(ctx, GenericContextWrappingVariable) - and not ctx.supports_graph_breaks() - ): + if isinstance(ctx, GenericContextWrappingVariable): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2677,7 +2668,6 @@ def __init__( local_scope=f_locals, global_scope=f_globals, f_code=f_code, - torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 704a3889707233..4922b521bada34 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -163,7 +163,6 @@ def insert_nops(instructions, code_options): local_scope=locals(), global_scope=globals(), f_code=frame.f_code, - torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index ed81b26191b4de..1a206ae339c2db 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -304,7 +304,6 @@ "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, - "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -2798,6 +2797,7 @@ "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", + "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 193e227267ec3b..5a8a9b613b899d 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -204,7 +204,6 @@ from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, - torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( @@ -1669,16 +1668,15 @@ def wrap_numpy_ndarray(self, value): # but warning is not the end of the world assert isinstance(value.base, np.nditer) - with torch_function_mode_stack_state_mgr.temp_restore_stack(): - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides - - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index e19c4e254c647e..8c4eb3dc4e715b 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,12 +125,6 @@ def call_function( if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) - def supports_graph_breaks(self): - return True - - def exit_on_graph_break(self): - return True - class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -189,12 +183,6 @@ def exit(self, tx: "InstructionTranslator", *args): tx.generic_context_manager_depth -= 1 return x - def supports_graph_breaks(self): - return False - - def exit_on_graph_break(self): - return True - class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c1e0dec0fbc415..f6cb565ef021c0 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -162,17 +162,7 @@ def get_overridable_functions(): from torch.overrides import get_overridable_functions as get_overridable_functions_ - funcs = set(chain(*get_overridable_functions_().values())) - more = { - torch.ones, - torch.ones_like, - torch.zeros, - torch.zeros_like, - torch.empty, - torch.full, - } - funcs.update(more) - return funcs + return set(chain(*get_overridable_functions_().values())) class BaseTorchVariable(VariableTracker): @@ -848,13 +838,6 @@ def handle_len_torch_function( len(tx.symbolic_torch_function_state.mode_stack) ) - @register(torch._C._get_function_stack_at) - def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): - assert len(args) == 1 and not kwargs - ind = args[0].as_python_constant() - assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) - return tx.symbolic_torch_function_state.mode_stack[ind] - @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -872,7 +855,7 @@ def handle_set_default_device( else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return ConstantVariable.create(None) + return None return handlers diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index b7fd83a3d24f15..6e52cc688e0309 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -2,35 +2,22 @@ import collections import contextlib -import functools import inspect from typing import Deque, Dict, List, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import ( - _get_overloaded_args, - get_default_nowrap_functions, - TorchFunctionMode, -) +from torch.overrides import _get_overloaded_args, get_default_nowrap_functions from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard -from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import ( - class_has_getattribute, - clear_torch_function_mode_stack, - get_safe_global_name, - has_torch_function, - is_tensor_base_attr_getter, - set_torch_function_mode_stack, -) +from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import GenericContextWrappingVariable +from .ctx_manager import ContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -76,39 +63,6 @@ IGNORED_MODES = {DeviceContext} -@functools.lru_cache(None) -def get_prev_stack_var_name(): - from ..bytecode_transformation import unique_id - - return unique_id("___prev_torch_function_mode_stack") - - -# Used to clear/restore the python torch function mode stack and temporarily restore it as needed -class TorchFunctionModeStackStateManager: - def __init__(self): - self.stack = [] - - def __enter__(self): - self.stack = torch.overrides._get_current_function_mode_stack() - clear_torch_function_mode_stack() - - def __exit__(self, exc_type, exc_value, traceback): - set_torch_function_mode_stack(self.stack) - self.stack = [] - - @contextlib.contextmanager - def temp_restore_stack(self): - prev = torch.overrides._get_current_function_mode_stack() - set_torch_function_mode_stack(self.stack) - try: - yield - finally: - set_torch_function_mode_stack(prev) - - -torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() - - class SymbolicTorchFunctionState: def __init__(self, py_stack): # This is annoyingly complicated because of how the torch function subclass + mode C API was designed @@ -235,26 +189,9 @@ def get_mode_index(cls, ind): return ind + cls.offset -class TorchFunctionModeVariable(GenericContextWrappingVariable): - @staticmethod - def is_supported_torch_function_mode(ty): - # Supported in this sense means we can support graph breaks under the - # context. - # We are able to trace custom modes but if there are graph breaks under them - # and they have a custom __enter__/__exit__ we don't handle this for the - # same reason we don't handle generic context managers: there may be side effects - # that are now affected by executing the funtion across two frames instead of one - # Today we support the enter/exit of the default TorchFunctionMode as well as - # DeviceContext (which is used for set_default_device) - return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( - not class_has_getattribute(ty) - and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ - and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ - ) - +class TorchFunctionModeVariable(ContextWrappingVariable): def __init__(self, value, source=None, **kwargs): - if value is not None: - super().__init__(value, **kwargs) + super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source @@ -284,39 +221,8 @@ def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwar kwargs, ) - def enter(self, tx): - from .torch import TorchInGraphFunctionVariable - - if isinstance(self.value, NoEnterTorchFunctionMode): - return ConstantVariable.create(None) - - TorchInGraphFunctionVariable( - torch._C._push_on_torch_function_stack - ).call_function(tx, [self], {}) - return ConstantVariable.create(None) - - def exit(self, tx: "InstructionTranslator", *args): - from .torch import TorchInGraphFunctionVariable - - TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( - tx, [], {} - ) - return ConstantVariable.create(None) - - def reconstruct_type(self, codegen): - ty = NoEnterTorchFunctionMode - codegen( - AttrSource( - codegen.tx.import_source(ty.__module__), - ty.__name__, - ) - ) - - def supports_graph_breaks(self): - return True - - def exit_on_graph_break(self): - return False + def _call_func(self, tx: "InstructionTranslator", values): + unimplemented("enter/exit for torch function mode NYI") def _get_all_args(args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 14499b4d2e4bf7..ad92277cf70ff4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -409,22 +409,10 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): - from torch.overrides import TorchFunctionMode - from .ctx_manager import GenericContextWrappingVariable - from .torch_function import TorchFunctionModeVariable - - if issubclass( - self.value, TorchFunctionMode - ) and TorchFunctionModeVariable.is_supported_torch_function_mode( - self.value - ): - var_cls = TorchFunctionModeVariable - else: - var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, var_cls, {} + self.source, self.value, GenericContextWrappingVariable, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj From 7af61fa5899d5d23c6236abf68f0e12b972c2269 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 14 Sep 2024 10:02:55 +0000 Subject: [PATCH 0888/1018] Revert "[Dynamo] Simplify torch function mode stack guard (#135444)" This reverts commit ce3c74f2744cbc134b95cf8bd53ae5e3fbc67c29. Reverted https://github.com/pytorch/pytorch/pull/135444 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2350937008)) --- test/dynamo/test_modes.py | 2 -- torch/_dynamo/guards.py | 10 ++++---- torch/csrc/dynamo/guards.cpp | 44 +++++++----------------------------- 3 files changed, 12 insertions(+), 44 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index fa4c23fd320ddb..bad5a788903440 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -68,11 +68,9 @@ def tearDownClass(cls): def setUp(self): torch.set_default_device(None) - torch._dynamo.reset() def tearDown(self): torch.set_default_device(None) - torch._dynamo.reset() def _run_torch_function_mode_guard_test(self): class TestMode1(BaseTorchFunctionMode): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1544bdd6dc697b..4437ef4038190d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2663,14 +2663,12 @@ def make_torch_function_mode_stack_guard(intial_stack): def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - - types_ = [ty for ty in types if ty not in IGNORED_MODES] - cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] - - if len(cur_stack_) != len(types_): + if len(cur_stack) != len(types): return False - for ty, mode in zip(types_, cur_stack_): + for ty, mode in zip(types, cur_stack): + if ty in IGNORED_MODES: + continue if ty != type(mode): return False diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 771de44427b669..afc3182c632640 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2523,8 +2523,7 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref - auto type = Py_TYPE(mode); - this->_ref_stack.push_back(type); + this->_ref_stack.push_back(Py_TYPE(mode)); } len = PyList_Size(ignored_types.ptr()); @@ -2544,56 +2543,29 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface size_t ref_ind = 0; - const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - int64_t idx = 0; - while ((idx < len) && (ref_ind < ref_stack_size)) { + for (int64_t idx = 0; idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - bool act_ignored = this->_ignored_types.count(mode_type) > 0; - bool ref_ignored = - this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; // skip ignored types - if (act_ignored && ref_ignored) { - idx++; - ref_ind++; - continue; - } else if (ref_ignored) { - ref_ind++; - continue; - } else if (act_ignored) { - idx++; + if (this->_ignored_types.count(mode_type) > 0) { continue; } // if we already have more non-ignored modes than the ref stack // or if the mode doesn't match at the current index, return false - else if (mode_type != _ref_stack.at(ref_ind)) { + else if ( + (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || + mode_type != _ref_stack[ref_ind]) { return false; } ref_ind++; - idx++; - } - - for (; ref_ind < ref_stack_size; ref_ind++) { - if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { - return false; - } - } - - for (; idx < len; idx++) { - std::shared_ptr mode = - at::impl::PythonTorchFunctionTLS::get_stack_at(idx); - - PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (!(this->_ignored_types.count(mode_type) > 0)) { - return false; - } } - return ref_ind == ref_stack_size && idx == len; + return ref_ind == this->_ref_stack.size(); } private: From 94cdd7f0e4f43effbe0c64cc1a220b47690b0cc1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 14 Sep 2024 10:02:55 +0000 Subject: [PATCH 0889/1018] Revert "[Dynamo] Support thread local setattr (#135443)" This reverts commit 149d0b716173787df4543186ff74b605aca54e3e. Reverted https://github.com/pytorch/pytorch/pull/135443 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2350937008)) --- test/dynamo/test_misc.py | 15 --------------- torch/_dynamo/variables/user_defined.py | 5 ++--- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 96047d3436ebe1..9d2bb374621134 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3380,21 +3380,6 @@ def fn4(x) -> None: self.assertTrue(same(obj41.y, obj42.y)) self.assertEqual(cnts.frame_count, 1) - def test_thread_local_setattr(self): - from threading import local - - loc = local() - - @torch.compile(fullgraph=True) - def fn(x, l): - l.x = x - return x + 1 - - x = torch.ones(2, 2) - fn(x, loc) - - self.assertTrue(loc.x is x) - def test_user_defined_class_name(self): class MyClassFoo: pass diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ad92277cf70ff4..d069d2e060a108 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -9,7 +9,6 @@ import itertools import random import sys -import threading import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING @@ -705,7 +704,7 @@ def call_method( if method is object.__init__: return ConstantVariable.create(None) - if is_standard_setattr(method) or isinstance(self.value, threading.local): + if is_standard_setattr(method): return self.method_setattr_standard(tx, *args, **kwargs) # [NOTE] OrderedDict, dict subtypes must always have source @@ -803,7 +802,7 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def needs_slow_setattr(self): return not is_standard_setattr( inspect.getattr_static(self.value, "__setattr__", None) - ) and not isinstance(self.value, threading.local) + ) def unpack_var_sequence(self, tx): if ( From 5253820352458300ed4adc9ad29edb290ea67daf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 14 Sep 2024 10:02:55 +0000 Subject: [PATCH 0890/1018] Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)" This reverts commit 4528777e034b157a8329d1879daf52290eea199a. Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2350937008)) --- test/dynamo/test_modes.py | 135 --------- test/functorch/test_control_flow.py | 4 - .../pt2e/test_metadata_porting.py | 4 +- torch/_dynamo/convert_frame.py | 5 - torch/_dynamo/guards.py | 14 - torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 78 ++++-- torch/_dynamo/trace_rules.py | 5 +- torch/_dynamo/utils.py | 10 +- torch/_dynamo/variables/ctx_manager.py | 2 - torch/_dynamo/variables/torch.py | 260 ++++++++---------- torch/_dynamo/variables/torch_function.py | 129 ++------- torch/_dynamo/variables/user_defined.py | 7 + torch/_export/non_strict_utils.py | 6 +- 14 files changed, 204 insertions(+), 457 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index bad5a788903440..ee8a9b579ff124 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -14,17 +14,6 @@ from torch.utils._python_dispatch import TorchDispatchMode -class TestMode(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - if func == torch.add: - return torch.zeros(2, 2) - - return super().__torch_function__(func, types, args, kwargs) - - class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -335,130 +324,6 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 2) - def test_nested_torch_function_mode(self): - mode_1_called = False - mode_2_called = False - - def reset_state(): - nonlocal mode_1_called - nonlocal mode_2_called - mode_1_called = False - mode_2_called = False - - ones = torch.ones(2, 2) - zeros = torch.zeros(2, 2) - - class TestMode1(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - nonlocal mode_1_called - - mode_1_called = True - - if func == torch.add: - return zeros - - return super().__torch_function__(func, types, args, kwargs) - - class TestMode2(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - nonlocal mode_2_called - - mode_2_called = True - - if func == torch.mul: - return ones - - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - def fn_2(x): - return torch.mul(x, 3) + torch.add(x, 3) - - inp = torch.ones(2, 2) + 1 - - for fn_i in [fn, fn_2]: - fn_opt = torch.compile(fn_i, fullgraph=True) - with TestMode1(), TestMode2(): - expected = fn_i(inp), mode_1_called, mode_2_called - reset_state() - actual = fn_opt(inp), mode_1_called, mode_2_called - reset_state() - - self.assertEqual(expected, actual) - - def test_torch_function_mode_disable(self): - class TestSubclass(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - if func == torch.add: - return torch.ones(2, 2) - return super().__torch_function__(func, types, args, kwargs) - - class TestMode(BaseTorchFunctionMode): - def __torch_function__(self, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - - if func == torch.add: - return torch.zeros(2, 2) - - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) - - fn_opt = torch.compile(fn, fullgraph=True) - with TestMode(), torch._dynamo.config.patch( - "traceable_tensor_subclasses", {TestSubclass} - ): - with torch._C.DisableTorchFunctionSubclass(): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - - with torch._C.DisableTorchFunction(): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - - def test_torch_function_mode_highest_priority(self): - class TestSubclass(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args, kwargs=None): - if not kwargs: - kwargs = {} - if func == torch.add: - return torch.ones(2, 2) - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return torch.add(x, 3) - - inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) - - fn_opt = torch.compile(fn, fullgraph=True) - with TestMode(), torch._dynamo.config.patch( - "traceable_tensor_subclasses", {TestSubclass} - ): - expected = fn(inp) - actual = fn_opt(inp) - - self.assertEqual(expected, actual) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index cf639b6abcce4e..d870b4b921af05 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -26,7 +26,6 @@ parametrize, requires_cuda, run_tests, - skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, @@ -2883,7 +2882,6 @@ def f(fct, init, xs): gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x) @skipIfNoDynamoSupport - @skipIfCrossRef # Arg order changes with crossref def test_scan_simple_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -2990,7 +2988,6 @@ def f(x, y): self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") - @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -3253,7 +3250,6 @@ def test_while_loop_compile(self, backend, while_loop_test): self._check_compile(fn, inp, backend=backend) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") - @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] from torch._dynamo.testing import EagerAndRecordGraphs diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 21488e8cacdd42..27df58e159f3b6 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef +from torch.testing._internal.common_utils import IS_WINDOWS class TestHelperModules: @@ -139,8 +139,6 @@ def _test_metadata_porting( self.assertEqual(v, node_tags[k]) return m - @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack - # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_simple_metadata_porting(self): """ Model under test diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index b3e4a7c268c0fe..4d1557e15a7d23 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -605,10 +605,6 @@ def _compile( output: Optional[OutputGraph] = None tracer: Optional[InstructionTranslator] = None - tf_mode_stack: List[ - torch.overrides.TorchFunctionMode - ] = torch.overrides._get_current_function_mode_stack() - @preserve_global_state def transform( instructions: List[Instruction], code_options: Dict[str, object] @@ -622,7 +618,6 @@ def transform( locals, globals, builtins, - tf_mode_stack, code_options, compiler_fn, one_graph, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4437ef4038190d..b0bd273e96bbb5 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -98,7 +98,6 @@ ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, - TorchFunctionModeStackSource, TupleIteratorGetItemSource, TypeSource, UnspecializedBuiltinNNModuleSource, @@ -112,7 +111,6 @@ dict_keys_repr, get_custom_getattr, get_torch_function_mode_stack, - get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, @@ -316,7 +314,6 @@ def uninteresting_files(): "___dict_contains": lambda a, b: a in b, "___tuple_iterator_len": tuple_iterator_len, "___tuple_iterator_getitem": tuple_iterator_getitem, - "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -904,15 +901,6 @@ def get_guard_manager_from_source(self, source): ): assert base_guard_manager # to make mypy happy out = base_guard_manager - elif istype(source, TorchFunctionModeStackSource): - out = root_guard_manager.lambda_manager( - python_lambda=lambda _: get_torch_function_mode_stack_at( - source._get_index() - ), - source=source_name, - example_value=example_value, - guard_manager_enum=guard_manager_enum, - ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -2226,8 +2214,6 @@ def __init__( self.output_graph = output_graph w_builder = None - # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing - # in case a set default device call was made in the graph. self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index c44f8d9e69abf5..6ed9d97f6a1210 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source): ind: int def name(self): - return f"___get_torch_function_mode_stack_at({self._get_index()})" + return "" def _get_index(self): from .variables.torch_function import TorchFunctionModeStackVariable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 415179a5ed32a6..ab92f82aa0f630 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -19,7 +19,20 @@ import types import typing import weakref -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union +from typing import ( + Any, + Callable, + cast, + Deque, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) from unittest.mock import patch import torch @@ -59,12 +72,14 @@ GlobalWeakRefSource, LocalSource, Source, + TorchFunctionModeStackSource, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( counters, get_fake_value, get_instruction_source_311, + get_torch_function_mode_stack, graph_break_dup_warning_checker, istype, LazyString, @@ -105,10 +120,11 @@ ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable -from .variables.torch_function import ( - SymbolicTorchFunctionState, - TorchFunctionModeVariable, -) + + +if TYPE_CHECKING: + from .variables.torch_function import TorchFunctionModeVariable + from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -268,10 +284,6 @@ def resume_fn(self): return ReenterWith(self.stack_index) def exit(self, tx): - if hasattr(self, "graph_break") and isinstance( - self.with_context, TorchFunctionModeVariable - ): - return assert self.with_context is not None return self.with_context.exit(tx) @@ -640,9 +652,7 @@ def handle_graph_break( # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: assert b.with_context is not None - assert isinstance( - b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) - ) + assert isinstance(b.with_context, ContextWrappingVariable) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -718,7 +728,7 @@ class InstructionTranslatorBase( output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] - symbolic_torch_function_state: SymbolicTorchFunctionState + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -2538,7 +2548,7 @@ def __init__( code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_state: SymbolicTorchFunctionState, + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], f_code: types.CodeType, export: bool, inline_depth: int, @@ -2553,7 +2563,7 @@ def __init__( self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals - self.symbolic_torch_function_state = symbolic_torch_function_state + self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack self.stack = [] # stack of variable names for tracking 3.13 closures self.name_stack: list[Any] = [] @@ -2642,7 +2652,6 @@ def __init__( f_locals, f_globals, f_builtins, - torch_function_mode_stack, code_options, compiler_fn, one_graph, @@ -2677,7 +2686,7 @@ def __init__( symbolic_locals={}, # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, - symbolic_torch_function_state=None, # type: ignore[arg-type] # set below + symbolic_torch_function_mode_stack=collections.deque(), f_code=f_code, export=export, inline_depth=0, @@ -2712,9 +2721,7 @@ def __init__( if k in f_locals } - self.symbolic_torch_function_state = SymbolicTorchFunctionState( - torch_function_mode_stack - ) + self._init_torch_function_mode_stack() self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] if export: @@ -2755,6 +2762,29 @@ def _throw_if_in_functorch(self): ) unimplemented(msg) + def _init_torch_function_mode_stack(self): + from .variables.torch_function import TorchFunctionModeStackVariable + + TorchFunctionModeStackVariable.reset() + + self.symbolic_torch_function_mode_stack: Deque[ + TorchFunctionModeVariable + ] = collections.deque() + # We want to retrieve all modes to properly reconstruct the stack if needed + py_stack = get_torch_function_mode_stack(filter_ignored=False) + + if py_stack: + has_device_context = isinstance( + py_stack[0], torch.utils._device.DeviceContext + ) + + for i, val in enumerate(py_stack): + self.symbolic_torch_function_mode_stack.append( + variables.LazyVariableTracker.create( + val, source=TorchFunctionModeStackSource(i) + ) + ) + def get_example_value(self, source: Source): if isinstance(source, LocalSource): return self.f_locals[source.local_name] @@ -3086,7 +3116,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_state, + parent.symbolic_torch_function_mode_stack, closure_cells, func, ) @@ -3096,7 +3126,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_state, + parent.symbolic_torch_function_mode_stack, closure_cells, func, ) @@ -3149,7 +3179,7 @@ def __init__( code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_state: SymbolicTorchFunctionState, + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ) -> None: @@ -3166,7 +3196,7 @@ def __init__( f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, - symbolic_torch_function_state=symbolic_torch_function_state, + symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1a206ae339c2db..391beb07cd600d 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3258,7 +3258,6 @@ def _module_dir(m: types.ModuleType): "torch.testing", "torch.utils._content_store", "torch.utils._contextlib", - "torch.utils._device", "torch.utils._foreach_utils", "torch.utils._python_dispatch", "torch.utils._pytree", @@ -3593,9 +3592,7 @@ def lookup_inner( if reasons is not None: reasons.add("func name is patched_init") return SkipFunctionVariable - elif name == "__torch_function__" or ( - obj and obj.__name__ == "__torch_function__" - ): + elif name == "__torch_function__": if reasons is not None: reasons.add("func name is __torch_function__") return UserFunctionVariable diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 8130b6aa7371b2..6ff258ff5657c0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -63,6 +63,7 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( + _get_function_stack_at, _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, @@ -3086,9 +3087,7 @@ def is_parameter_freezing(): def get_torch_function_mode_stack(filter_ignored=True): from .variables.torch_function import IGNORED_MODES - stack = [ - get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) - ] + stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] if filter_ignored: stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] @@ -3108,11 +3107,6 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) -def clear_torch_function_mode_stack(): - for i in range(_len_torch_function_stack()): - _pop_torch_function_stack() - - def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 8c4eb3dc4e715b..eebe90929b379a 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -637,8 +637,6 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 - tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] - tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] tx.output.set_torch_function_state(values[0]) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f6cb565ef021c0..68508f001dae35 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -156,15 +156,6 @@ bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) -@functools.lru_cache(None) -def get_overridable_functions(): - from itertools import chain - - from torch.overrides import get_overridable_functions as get_overridable_functions_ - - return set(chain(*get_overridable_functions_().values())) - - class BaseTorchVariable(VariableTracker): """common base for all torch.* functions, classes, modules and other things""" @@ -815,10 +806,10 @@ def handle_pop_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - if not tx.symbolic_torch_function_state.mode_stack: + if not tx.symbolic_torch_function_mode_stack: raise unimplemented("Popping from an empty torch function mode stack") TorchFunctionModeStackVariable.register_mutation(tx) - return tx.symbolic_torch_function_state.pop_torch_function_mode() + return tx.symbolic_torch_function_mode_stack.pop() @register(torch._C._push_on_torch_function_stack) def handle_push_torch_function( @@ -826,7 +817,7 @@ def handle_push_torch_function( ): assert len(args) == 1 and not kwargs TorchFunctionModeStackVariable.register_mutation(tx) - tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) + tx.symbolic_torch_function_mode_stack.append(args[0]) return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) @@ -834,9 +825,7 @@ def handle_len_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - return ConstantVariable.create( - len(tx.symbolic_torch_function_state.mode_stack) - ) + return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) @register(torch.set_default_device) def handle_set_default_device( @@ -868,9 +857,6 @@ def call_function( from . import ConstantVariable, SymNodeVariable, TensorVariable from .builder import wrap_fx_proxy - if self.torch_function_override_enabled(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -892,144 +878,147 @@ def call_function( if result: return result - any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + if can_dispatch_torch_function(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + else: + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) - all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args - ) - if ( - getattr(self.value, "__module__", "") == "torch" - and self.value.__name__ in bin_ops - and any_symints_or_symfloats - and all_ints_or_floats - ): - msg = f"""\ + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ - log.warning(msg) - unimplemented(msg) - - # TODO(voz): Replace w/ dynamic shape rewrite table. - # Ideally, we would be able to do this at ctor time, but alas we need a combination - # of value + args to determine this. - fn_ = self.value - if any_symints_or_symfloats: - torch_sym_op = f"_sym_{self.value.__name__}" - if getattr(self.value, "__module__", None) == "math" and hasattr( - torch, torch_sym_op - ): - fn_ = getattr(torch, torch_sym_op) - - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - *proxy_args_kwargs(args, kwargs), - ), - ) + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op + ): + fn_ = getattr(torch, torch_sym_op) + + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) - if ( - isinstance(tensor_variable, TensorVariable) - and "requires_grad" in kwargs - and kwargs["requires_grad"].as_python_constant() - ): - unimplemented( - """factory functions that return tensors that require grad are not supported. + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" - ) + ) - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items - ): + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): + if ( + out_tensor.source + and out_tensor in tx.output.graphargs + and isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor.size != result_tensor.size + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor.size != result_tensor.size + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out_shape != fake_tensor.shape ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] - if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape - ): - # It's hard to get out variants with resizing on graph inputs work - # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where output tensor was non-contiguous" - ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( - "out= op was called where some of the output tensors were non-contiguous" + "out= op was called where output tensor was non-contiguous" ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where some of the output tensors were non-contiguous" + ) + else: + unimplemented(f"out variant of {type(kwargs['out'])}") - return tensor_variable + return tensor_variable def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" @@ -1157,12 +1146,3 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad source ) return result - - def torch_function_override_enabled(self, tx, args, kwargs): - return ( - self.get_function() in get_overridable_functions() - or isinstance( - self.get_function(), - (torch._ops.OpOverload, torch._ops.OpOverloadPacket), - ) - ) and can_dispatch_torch_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 6e52cc688e0309..4b3188507fe4b5 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,11 +1,8 @@ # mypy: ignore-errors -import collections -import contextlib import inspect -from typing import Deque, Dict, List, TYPE_CHECKING +from typing import Dict, List, TYPE_CHECKING -import torch._C import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import _get_overloaded_args, get_default_nowrap_functions @@ -18,7 +15,6 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import ContextWrappingVariable -from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable @@ -63,67 +59,6 @@ IGNORED_MODES = {DeviceContext} -class SymbolicTorchFunctionState: - def __init__(self, py_stack): - # This is annoyingly complicated because of how the torch function subclass + mode C API was designed - # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass - # These are their definitions: - # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered - # (if either are entered, this will be False) - # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR - # torch._C.DisableTorchFunction has been entered - # To disambiguate these and keep myself sane I added a C API to check whether all torch function - # concepts (modes and subclasses) are enabled. - # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate - # the stack length from the enablement state of torch function modes. - # This is important because now if a mode is pushed while dynamo is tracing, we know whether - # or not torch function modes are enabled and whether we should trace it. - self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() - - # This differs from the C API of the same name - # this will only be false iff we have entered torch._C.DisableTorchFunction - # and does not take into account the mode stack length, while the C API bundles these - # two concepts - self.torch_function_mode_enabled = ( - not torch._C._is_torch_function_all_disabled() - ) - - self.cur_mode = None - - TorchFunctionModeStackVariable.reset() - - self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() - - for i, val in enumerate(py_stack): - self.mode_stack.append( - LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) - ) - - def in_torch_function_mode(self): - return len(self.mode_stack) > 0 - - def pop_torch_function_mode(self): - return self.mode_stack.pop() - - def push_torch_function_mode(self, mode_var): - self.mode_stack.append(mode_var) - - def call_torch_function_mode(self, tx, fn, types, args, kwargs): - with self._pop_mode_for_inlining() as cur_mode: - return cur_mode.call_torch_function(tx, fn, types, args, kwargs) - - @contextlib.contextmanager - def _pop_mode_for_inlining(self): - old_mode = self.cur_mode - self.cur_mode = self.pop_torch_function_mode() - try: - yield self.cur_mode - finally: - mode = self.cur_mode - self.cur_mode = old_mode - self.push_torch_function_mode(mode) - - class TorchFunctionModeStackVariable(VariableTracker): """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" @@ -153,20 +88,19 @@ def reset(cls): def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( - source=Source(), - symbolic_stack=tx.symbolic_torch_function_state.mode_stack, + source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_state.mode_stack + stack = tx.symbolic_torch_function_mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 - stack.insert( + tx.symbolic_torch_function_mode_stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) @@ -175,7 +109,7 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_state.mode_stack + stack = tx.symbolic_torch_function_mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @@ -190,39 +124,23 @@ def get_mode_index(cls, ind): class TorchFunctionModeVariable(ContextWrappingVariable): - def __init__(self, value, source=None, **kwargs): + def __init__(self, value, **kwargs): super().__init__(value, **kwargs) self.value = value - self.cm_obj = value # needed for BC with calling enter from CM code - self.source = source + + @staticmethod + def get_global_mangled_name(tx, val): + return get_safe_global_name( + tx, f"__torch_function_mode_{val.__class__.__name__}", val + ) def reconstruct(self, codegen): - # This shouldn't be called unless we have a source + # We don't support locally created torch function modes yet assert self.source self.source.reconstruct(codegen) - def module_name(self): - return self.value.__module__ - - def fn_name(self): - return type(self.value).__name__ - - def python_type(self): - return type(self.value) - - def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): - return call_torch_function( - tx, - self, - build_torch_function_fn(tx, self.value, self.source), - fn, - types, - args, - kwargs, - ) - - def _call_func(self, tx: "InstructionTranslator", values): - unimplemented("enter/exit for torch function mode NYI") + def _call_func(self, tx, values): + unimplemented("torch function mode context manager is not supported yet") def _get_all_args(args, kwargs): @@ -313,13 +231,9 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): - has_overridden_args = any( + return tx.output.torch_function_enabled and any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) - tf_state = tx.symbolic_torch_function_state - return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( - tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() - ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): @@ -331,20 +245,11 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): _get_subclass_type, ) - types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) - - if tx.symbolic_torch_function_state.in_torch_function_mode(): - res = tx.symbolic_torch_function_state.call_torch_function_mode( - tx, fn, types, args, kwargs - ) - if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): - return res - for arg in overloaded_args: res = arg.call_torch_function( tx, fn, - types, + TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), args, kwargs, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index d069d2e060a108..34e1d0d10c9f72 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,6 +82,11 @@ def is_forbidden_context_manager(ctx): from _pytest.python_api import RaisesContext from _pytest.recwarn import WarningsChecker + # TODO mlazos: Temporary to get this stack to pass + # remove in subsequent PR + from torch.overrides import BaseTorchFunctionMode + + f_ctxs.append(BaseTorchFunctionMode) f_ctxs.append(RaisesContext) f_ctxs.append(WarningsChecker) except ImportError: @@ -408,6 +413,7 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): + # import here to avoid an unfortunate circular dependency. from .ctx_manager import GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( @@ -415,6 +421,7 @@ def call_function( ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj + elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index ef15e5fea9e973..5c7331659d26a0 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -506,11 +506,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if ( - not torch.compiler.is_dynamo_compiling() - and log.isEnabledFor(logging.DEBUG) - and config.extended_debug_current_loc - ): + if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: frame = _find_user_code_frame() if frame is not None: log.debug( From 5b5901da4c493e6d307a25b7cc9f40e9fbdfb300 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 14 Sep 2024 10:02:55 +0000 Subject: [PATCH 0891/1018] Revert "[Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732)" This reverts commit 731b178b56c83966d6e8cdfb0015d22d8f91b4d2. Reverted https://github.com/pytorch/pytorch/pull/134732 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2350937008)) --- torch/_dynamo/backends/debugging.py | 12 ----- torch/_higher_order_ops/cond.py | 20 +++----- torch/_higher_order_ops/scan.py | 16 ++----- torch/_higher_order_ops/while_loop.py | 21 ++------- torch/fx/experimental/proxy_tensor.py | 67 +++++++++++---------------- torch/nn/attention/flex_attention.py | 42 ++++++----------- 6 files changed, 54 insertions(+), 124 deletions(-) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 18784498862e39..87214f7d59774e 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -31,18 +31,6 @@ def eager(gm, fake_tensor_inputs, **kwargs): return gm.forward -def make_eager_backend_with_torch_function_mode(mode): - """Used to trace HOPs (cond and while) for eager exectution, the metadata - TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks - in the HOP, so we need to externally run this mode and not trace it.""" - - def fn(gm, fake_tensor_inputs, **kwargs): - with mode: - return gm.forward - - return fn - - @register_backend def eager_noexcept(gm, fake_tensor_inputs, **kwargs): if kwargs: diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 0467e2899adc28..dee400d76f5964 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -28,7 +28,6 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, @@ -130,10 +129,6 @@ def false_fn(x: torch.Tensor): if torch.compiler.is_dynamo_compiling(): return cond_op(pred, true_fn, false_fn, operands) - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) - if isinstance(pred, (bool, int, float)): log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." @@ -174,15 +169,12 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode(metadata_mode) - else: - backend = "eager" - return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( - pred, true_fn, false_fn, operands - ) + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_pre_dispatch_torch_function_mode(): + return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( + pred, true_fn, false_fn, operands + ) def create_fw_bw_graph_branches(true_fn, false_fn, *operands): diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index d66cff067f6682..d6eecb0a9f29f9 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -20,7 +20,6 @@ from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, track_tensor_tree, @@ -119,19 +118,10 @@ def _scan_op_wrapper(*args, **kwargs): return scan(*args, **kwargs) if not torch._dynamo.is_compiling(): - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) - with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode(metadata_mode) - else: - backend = "eager" - return torch.compile(_scan_op_wrapper, backend=backend, fullgraph=True)( - combine_fn, init, xs, dim=dim, reverse=reverse - ) + return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)( + combine_fn, init, xs, dim=dim, reverse=reverse + ) leaves_init, spec_init = pytree.tree_flatten(init) leaves_xs, spec_xs = pytree.tree_flatten(xs) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index f14321842f40b9..d64f512399d9d5 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -15,11 +15,7 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, - ProxyTorchDispatchMode, - track_tensor_tree, -) +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree class WhileLoopOp(HigherOrderOperator): @@ -117,9 +113,6 @@ def body_fn(iter, x): - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. """ - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. @@ -147,15 +140,9 @@ def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode(metadata_mode) - else: - backend = "eager" - return torch.compile( - _while_loop_op_wrapper, backend=backend, fullgraph=True - )(cond_fn, body_fn, carried_inputs, additional_inputs) + return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( + cond_fn, body_fn, carried_inputs, additional_inputs + ) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 213672e216e6ce..5a5771d2ae8f7a 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext +from contextlib import contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -1084,43 +1084,38 @@ def unwrap_proxy(self, e: T) -> object: return e -def _make_temp_remove_mode_context_manager( - mode_ty: Type[TorchFunctionMode], -) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: - @contextmanager - def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: - from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - - temp_elements = [] - removed_mode = None +@contextmanager +def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - while _len_torch_function_stack() > 0: - mode = _pop_mode() - if isinstance(mode, mode_ty): - removed_mode = mode - break - else: - temp_elements.append(mode) + temp_elements = [] + pre_dispatch_mode = None - for mode in reversed(temp_elements): - _push_mode(mode) + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, PreDispatchTorchFunctionMode): + pre_dispatch_mode = mode + break + else: + temp_elements.append(mode) - try: - yield removed_mode + for mode in reversed(temp_elements): + _push_mode(mode) - finally: - if removed_mode is not None: - count = len(temp_elements) - while count > 0: - mode = _pop_mode() - count -= 1 + try: + yield - temp_elements.append(removed_mode) + finally: + if pre_dispatch_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 - for mode in reversed(temp_elements): - _push_mode(mode) + temp_elements.append(pre_dispatch_mode) - return context_manager_fn + for mode in reversed(temp_elements): + _push_mode(mode) @torch._disable_dynamo @@ -1235,11 +1230,6 @@ def __torch_function__( return func(*args, **kwargs) -_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( - TorchFunctionMetadataMode -) - - # This mode is **only** used for pre_dispatch tracing. # In particular, we need to make sure that autograd/autocast API's # that do not desugar into dispatcher operators stay in the graph. @@ -1268,11 +1258,6 @@ def __torch_function__( return func(*args, **kwargs) -_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( - PreDispatchTorchFunctionMode -) - - class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 2317dbac83bdda..bb04ce53d0fcab 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,7 +19,6 @@ ) from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( - _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input @@ -1034,10 +1033,6 @@ def score_mod( if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") - from torch._dynamo.backends.debugging import ( - make_eager_backend_with_torch_function_mode, - ) - # Dynamo is expecting a callable with "__code__" attribute. # We cannot directly pass hop to it. So we wrap it in a dummy function. def _flex_attention_hop_wrapper(*args, **kwargs): @@ -1046,25 +1041,18 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _set_compilation_env(): with torch._dynamo.utils.disable_cache_limit(): with _temp_remove_pre_dispatch_torch_function_mode(): - with _temp_remove_metadata_torch_function_mode() as metadata_mode: - if metadata_mode: - backend = make_eager_backend_with_torch_function_mode( - metadata_mode - ) - else: - backend = "eager" - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend=backend, fullgraph=True - )( - query, - key, - value, - score_mod, - block_mask.as_tuple(), - scale, - kernel_options, - ) - if return_lse: - return out, lse * math.log(2) - else: - return out + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend="eager", fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out From 28fe3fa6a575b55a726aa6cb6242ac68525246b3 Mon Sep 17 00:00:00 2001 From: CaoE Date: Sat, 14 Sep 2024 14:18:55 +0000 Subject: [PATCH 0892/1018] Add Half support for reflection and replication padding on CPU (#135931) Fixes #135680 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135931 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cpu/PaddingKernel.cpp | 32 +++++++++---------- test/test_mps.py | 7 ++++ .../_internal/common_methods_invocations.py | 9 ++---- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/cpu/PaddingKernel.cpp b/aten/src/ATen/native/cpu/PaddingKernel.cpp index 302346c4515c9b..1aabb8a3d50d26 100644 --- a/aten/src/ATen/native/cpu/PaddingKernel.cpp +++ b/aten/src/ATen/native/cpu/PaddingKernel.cpp @@ -486,7 +486,7 @@ void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, Int cpu_padding(output, input, param); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad1d", [&] { cpu_padding(output, input, param); }); @@ -496,7 +496,7 @@ void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, Int void reflection_pad1d_backward_kernel_impl( const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) { PaddingParams param{grad_input, grad_output, padding}; - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad1d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); @@ -513,14 +513,14 @@ void reflection_pad2d_kernel_impl(const Tensor& output, const Tensor& input, Int } else { switch (input.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad2d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad2d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -537,14 +537,14 @@ void reflection_pad2d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (grad_output.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad2d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad2d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); @@ -603,7 +603,7 @@ void reflection_pad3d_backward_kernel_impl( // replication padding void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) { PaddingParams param{input, output, padding}; - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,input.scalar_type(), "replication_pad1d", [&] { cpu_padding(output, input, param); }); @@ -612,7 +612,7 @@ void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, In void replication_pad1d_backward_kernel_impl( const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) { PaddingParams param{grad_input, grad_output, padding}; - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad1d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); @@ -622,14 +622,14 @@ void replication_pad2d_kernel_impl(const Tensor& output, const Tensor& input, In PaddingParams param{input, output, padding}; switch (input.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad2d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad2d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -645,14 +645,14 @@ void replication_pad2d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (grad_output.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad2d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad2d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); @@ -667,14 +667,14 @@ void replication_pad3d_kernel_impl(const Tensor& output, const Tensor& input, In PaddingParams param{input, output, padding}; switch (padding_memory_format_3d(input)) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad3d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast3d: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad3d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -690,14 +690,14 @@ void replication_pad3d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (padding_memory_format_3d(grad_output)) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad3d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast3d: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad3d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); diff --git a/test/test_mps.py b/test/test_mps.py index 2c3039a1602a77..32adede8dddca0 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -12216,6 +12216,13 @@ def req_grad(t): cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) + if ( + op.name == "nn.functional.pad" + and op.variant_test_name in ["replicate", "reflect"] + and dtype == torch.float16 + ): + atol = 1e-5 + rtol = 1.5e-3 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a174154349f8e4..5f4b98aaa5eaac 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15220,8 +15220,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='reflect', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), skips=( # Doesn't have a corresponding aten operator. @@ -15235,8 +15234,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='replicate', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), skips=( # Doesn't have a corresponding aten operator. @@ -15250,8 +15248,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='replicate_negative', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_nn_pad_replicate_negative, skips=( # Doesn't have a corresponding aten operator. From 9ef8f627f2431b42c8c2104a8394fe818a8cb584 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 13 Sep 2024 13:54:26 -0700 Subject: [PATCH 0893/1018] [3.13] fix 3.13 pickle error in torch/package (#136049) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136049 Approved by: https://github.com/albanD ghstack dependencies: #136034 --- torch/package/_package_pickler.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index 7845ffe39a2a2e..8856ad6c37ccf0 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -8,7 +8,6 @@ EXT2, EXT4, GLOBAL, - Pickler, PicklingError, STACK_GLOBAL, ) @@ -18,7 +17,18 @@ from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer -class PackagePickler(_Pickler): +class _PyTorchLegacyPickler(_Pickler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._persistent_id = None + + def persistent_id(self, obj): + if self._persistent_id is None: + return super().persistent_id(obj) + return self._persistent_id(obj) + + +class PackagePickler(_PyTorchLegacyPickler): """Package-aware pickler. This behaves the same as a normal pickler, except it uses an `Importer` @@ -113,6 +123,6 @@ def create_pickler(data_buf, importer, protocol=4): if importer is sys_importer: # if we are using the normal import library system, then # we can use the C implementation of pickle which is faster - return Pickler(data_buf, protocol=protocol) + return _PyTorchLegacyPickler(data_buf, protocol=protocol) else: return PackagePickler(importer, data_buf, protocol=protocol) From d3eb38acb1b77df3389ff271093f7206dce34faa Mon Sep 17 00:00:00 2001 From: Suresh Babu Kolla Date: Sat, 14 Sep 2024 16:35:22 +0000 Subject: [PATCH 0894/1018] [Pytorch] Consolidate Strobelight compile time profiler between OSS and fbcode (#135953) Summary: Move towards consolidating strobelight profiler implementations between OSS and fbcode. This change is a first step towards that. - Created a new function to abstract out compile time profiling enablement. This function allows profiler to switch between different function profilers (e.g. Thrift based or CLI based) - Both OSS and Fbcode now use one compile time profiler in torch/_strobelight Test Plan: Tested OSS with following commands: ``` python torch/_strobelight/examples/compile_time_profile_example.py python torch/_strobelight/examples/cli_function_profiler_example.py TORCH_COMPILE_STROBELIGHT=TRUE TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python benchmarks/dynamo/huggingface.py --ci --accuracy --timing --explain --inductor --device cuda --training --amp --only XLNetLMHeadModel ``` See test commands for fbcode in comments. Differential Revision: D62444551 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135953 Approved by: https://github.com/laithsakka --- .../examples/compile_time_profile_example.py | 5 ++--- torch/_utils_internal.py | 20 +++++++++---------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/torch/_strobelight/examples/compile_time_profile_example.py b/torch/_strobelight/examples/compile_time_profile_example.py index f158d00a29f929..26f659a1dbf229 100644 --- a/torch/_strobelight/examples/compile_time_profile_example.py +++ b/torch/_strobelight/examples/compile_time_profile_example.py @@ -1,11 +1,10 @@ # mypy: allow-untyped-defs import torch -from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler +from torch._utils_internal import enable_compiletime_strobelight if __name__ == "__main__": - # You can pass TORCH_COMPILE_STROBELIGHT=True instead. - StrobelightCompileTimeProfiler.enable() + enable_compiletime_strobelight() def fn(x, y, z): return x * y + z diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index f254217452061f..0c99601392f3d2 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -12,17 +12,6 @@ log = logging.getLogger(__name__) -if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): - import shutil - - if not shutil.which("strobeclient"): - log.info( - "TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine." - ) - else: - log.info("Strobelight profiler is enabled via environment variable") - StrobelightCompileTimeProfiler.enable() - # this arbitrary-looking assortment of functionality is provided here # to have a central place for overrideable behavior. The motivating # use is the FB build environment, where this source file is replaced @@ -75,6 +64,15 @@ def throw_abstract_impl_not_imported_error(opname, module, context): ) +def enable_compiletime_strobelight(): + StrobelightCompileTimeProfiler.enable() + + +if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): + log.info("Strobelight profiler is enabled via environment variable") + enable_compiletime_strobelight() + + # NB! This treats "skip" kwarg specially!! def compile_time_strobelight_meta(phase_name): def compile_time_strobelight_meta_inner(function): From f448c8292dc3aebc42ac4c8a23aa778db66ce568 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 13 Sep 2024 16:31:26 -0700 Subject: [PATCH 0895/1018] Only keep ListOfLinears module in basic_modules_benchmarks and add gpu version. (#135730) All of the previous benchmarks are similar, ListOfLinears should be representative enough. I copied the previous benchmarks from unit tests without an intention, was just trying to create a large number of benchmarks to better observe noise. This PR keeps only one, we can add more as we see value and regressions in the future. Also this diff adds a GPU version. ``` collecting compile time instruction count for basic_modules_ListOfLinears_eager compile time instruction count for iteration 0 is 6479525851 compile time instruction count for iteration 1 is 1024432680 compile time instruction count for iteration 2 is 1019417317 compile time instruction count for iteration 3 is 1013603566 compile time instruction count for iteration 4 is 1008853980 compile time instruction count for iteration 5 is 1009541481 compile time instruction count for iteration 6 is 1005025533 compile time instruction count for iteration 7 is 1004116323 compile time instruction count for iteration 8 is 1000828633 compile time instruction count for iteration 9 is 999788323 collecting compile time instruction count for basic_modules_ListOfLinears_inductor compile time instruction count for iteration 0 is 40837529730 compile time instruction count for iteration 1 is 18411921909 compile time instruction count for iteration 2 is 18383665161 compile time instruction count for iteration 3 is 18348983522 compile time instruction count for iteration 4 is 18349276590 compile time instruction count for iteration 5 is 18353046274 compile time instruction count for iteration 6 is 18346818581 compile time instruction count for iteration 7 is 18340057998 compile time instruction count for iteration 8 is 18331267320 compile time instruction count for iteration 9 is 18328381338 collecting compile time instruction count for basic_modules_ListOfLinears_inductor_gpu compile time instruction count for iteration 0 is 15408870979 compile time instruction count for iteration 1 is 10949520859 compile time instruction count for iteration 2 is 11058786167 compile time instruction count for iteration 3 is 11003606719 compile time instruction count for iteration 4 is 10896406770 compile time instruction count for iteration 5 is 10982875189 compile time instruction count for iteration 6 is 10931848275 compile time instruction count for iteration 7 is 10956345008 compile time instruction count for iteration 8 is 11045384499 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135730 Approved by: https://github.com/ezyang, https://github.com/anijain2305 --- .../pr_time_benchmarks/benchmark_runner.sh | 9 ++ .../benchmarks/basic_modules_benchmarks.py | 128 +++--------------- .../log_benchmarking_time.py | 17 +++ 3 files changed, 42 insertions(+), 112 deletions(-) create mode 100644 benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh index 66ea3f0e0d8fd2..a5cf04173358c5 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh @@ -18,8 +18,17 @@ output_file=$1 # Set the directory of Python programs python_programs_dir=$2 # Loop through all files in the directory of Python programs + +start=`date +%s` + for file in $python_programs_dir/*.py do # Execute the Python program and append the output to the output file python $file $output_file done +end=`date +%s` + +runtime=$((end-start)) +echo "total time to run benchmarks is:" +echo $runtime +python benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py $runtime diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py index be545cdedd185e..65588cb3719dea 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from torch._inductor.utils import fresh_inductor_cache @@ -20,139 +19,44 @@ def forward(self, x): return x -class BasicModule(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - self.scale = torch.randn(1, 10) - - def forward(self, x): - return F.relu(self.linear1(x)) * self.scale - - -class ModuleForwardHasGraphBreak(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.layer1 = BasicModule() - self.layer2 = BasicModule() - self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) - self.layer4 = torch.nn.ModuleList( - [ - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - ] - ) - self.layer5 = torch.nn.ModuleDict( - { - "0": torch.nn.Linear(10, 10), - } - ) - self.scale = torch.randn(1, 10) - - def forward(self, x): - """ - This is used to test if the results of functions like `named_parameters` - can be reconstructed correctly after graph break. - - https://github.com/pytorch/torchdynamo/issues/1931 - """ - x = self.layer1(x) - params1 = dict(self.named_parameters()) - params2 = list(self.parameters()) - buffers1 = dict(self.named_buffers()) - buffers2 = list(self.buffers()) - modules1 = dict(self.named_modules()) - modules2 = list(self.modules()) - torch._dynamo.graph_break() - y = modules2 - y = modules1 - y = buffers2 - y = buffers1 - y = params2 - y = params1 - x = ( - self.layer2(x) - + y["layer3.1.linear1.weight"] - + y["layer4.2.weight"] - + y["layer5.0.weight"] - ) - return x * self.scale - - -class SequentialWithDuplicatedModule(torch.nn.Module): - # Sequential module(self.layer) contains three duplicated ReLU module. - def __init__(self) -> None: - super().__init__() - self.relu = torch.nn.ReLU() - self.layer = torch.nn.Sequential( - torch.nn.Linear(10, 20), - self.relu, - torch.nn.Linear(20, 20), - self.relu, - torch.nn.Linear(20, 10), - self.relu, - ) - - def forward(self, x): - return self.layer(x) - - -class ModuleComparison(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.layer0 = torch.nn.Linear(10, 10) - self.layer1 = torch.nn.Linear(10, 10) - self.layer2 = torch.nn.Linear(10, 10) - - @property - def encoder_layers(self): - return [self.layer0, self.layer1, self.layer2] - - def forward(self, x): - for layer in self.encoder_layers: - output = layer(x) - if layer is None or layer == self.layer0: - output = F.relu6(output) - else: - output = F.relu(output) - return output - - class Benchmark(BenchmarkBase): - def __init__(self, ModuleClass, backend): + def __init__(self, ModuleClass, backend, is_gpu=False, dynamic=False): self.ModuleClass = ModuleClass self.backend = backend self._name = ModuleClass.__name__ + self._is_gpu = is_gpu + self._dynamic = dynamic def name(self): - return f"basic_modules_{self._name}_{self.backend}" + prefix = f"basic_modules_{self._name}_{self.backend}" + if self._dynamic: + prefix += "_dynamic" + if self._is_gpu: + prefix += "_gpu" + return prefix def _prepare_once(self): self.m = self.ModuleClass() - self.input = torch.ones(10) + torch.set_float32_matmul_precision("high") + self.input = torch.ones(10, device="cuda" if self._is_gpu else "cpu") def _prepare(self): torch._dynamo.reset() def _work(self): with fresh_inductor_cache(): - opt_m = torch.compile(backend=self.backend)(self.m) + opt_m = torch.compile(backend=self.backend, dynamic=self._dynamic)( + self.m.cuda() if self._is_gpu else self.m + ) opt_m(self.input) def main(): result_path = sys.argv[1] benchmarks = [ - Benchmark(ListOfLinears, "inductor"), Benchmark(ListOfLinears, "eager"), - Benchmark(ModuleForwardHasGraphBreak, "inductor"), - Benchmark(ModuleForwardHasGraphBreak, "eager"), - Benchmark(SequentialWithDuplicatedModule, "inductor"), - Benchmark(SequentialWithDuplicatedModule, "eager"), - Benchmark(ModuleComparison, "inductor"), - Benchmark(ModuleComparison, "eager"), + Benchmark(ListOfLinears, "inductor"), + Benchmark(ListOfLinears, "inductor", is_gpu=True), ] for b in benchmarks: b.enable_compile_time_instruction_count().collect_all().append_results( diff --git a/benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py b/benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py new file mode 100644 index 00000000000000..d011758c23ffef --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py @@ -0,0 +1,17 @@ +import json +import sys + +import torch._logging.scribe as scribe + + +def main(): + duration = int(sys.argv[1]) + scribe.open_source_signpost( + subsystem="pr_time_benchmarks", + name="duration", + parameters=json.dumps(duration), + ) + + +if __name__ == "__main__": + main() From 7380d470357ba1e113bdfbf16b6a76272dadab87 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 13 Sep 2024 13:01:19 -0700 Subject: [PATCH 0896/1018] [AOTI] Fix a fallback op returning None issue (#135997) Summary: Fixes https://github.com/pytorch/pytorch/issues/135781. In some cases, a fallback can return None in the place of a tensor. Differential Revision: [D62659039](https://our.internmc.facebook.com/intern/diff/D62659039) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135997 Approved by: https://github.com/chenyang78 --- test/inductor/test_cpu_cpp_wrapper.py | 1 - torch/_inductor/codecache.py | 4 +++- torch/_inductor/codegen/cpp_wrapper_cpu.py | 18 ++++++++++++------ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 50a21411a28144..91ecebd5ca11ec 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -96,7 +96,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): f"{test_name}_dynamic_shapes" ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False) skip_list = [ - "test_multihead_attention_cpu", *[ func for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index dd1be42e4c7443..3cd20e9f86df6e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2051,7 +2051,9 @@ def convert_arg(arg: Any) -> Any: assert callable(func), op + " can not be loaded through custom_op_wrapper" result = func(*converted_args) if isinstance(result, (list, tuple)): - for r in result: + # unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only + result = [torch.tensor([]) if r is None else r for r in result] + for i, r in enumerate(result): assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] else: diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index f3d4e53dc9110c..c009246243f3cf 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2122,8 +2122,7 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed( def extract_output_name(out): if out is None: - # Because out is not a MultiOutput, we assume the kernel returns a single output - return [buf_name] + return None elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): return out.get_name() elif isinstance(out, (list, tuple)): @@ -2134,9 +2133,13 @@ def extract_output_name(out): # output_args has the same pytree structure as outputs output_args = None if config.abi_compatible: - output_args = extract_output_name(outputs) - if isinstance(output_args, str): - output_args = [output_args] + if outputs is None: + # outputs is not specified, the default is to write to buf_name + output_args = [buf_name] + else: + output_args = extract_output_name(outputs) + if isinstance(output_args, str): + output_args = [output_args] if V.graph.aot_mode and config.abi_compatible: assert op_overload is not None @@ -2347,13 +2350,16 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( else: # result is a tuple of tensors for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue lines += f""" {output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));""" declarations_before_scope = [ f"RAIIAtenTensorHandle {output_arg};" - for idx, output_arg in enumerate(output_args) + for output_arg in output_args + if output_arg is not None ] scope_gil_acquire = self.generate_scoped_gil_acquire( declarations_before_scope, lines From ff59b62f59743910b23aa6ccecd28f8611201466 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 14 Sep 2024 03:23:34 -0700 Subject: [PATCH 0897/1018] [Dynamo] Use custom backend to reenter metadata tf mode when tracing while/cond (#134732) For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732 Approved by: https://github.com/ydwu4 --- torch/_dynamo/backends/debugging.py | 12 +++++ torch/_higher_order_ops/cond.py | 20 +++++--- torch/_higher_order_ops/scan.py | 16 +++++-- torch/_higher_order_ops/while_loop.py | 21 +++++++-- torch/fx/experimental/proxy_tensor.py | 67 ++++++++++++++++----------- torch/nn/attention/flex_attention.py | 42 +++++++++++------ 6 files changed, 124 insertions(+), 54 deletions(-) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 87214f7d59774e..18784498862e39 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs): return gm.forward +def make_eager_backend_with_torch_function_mode(mode): + """Used to trace HOPs (cond and while) for eager exectution, the metadata + TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks + in the HOP, so we need to externally run this mode and not trace it.""" + + def fn(gm, fake_tensor_inputs, **kwargs): + with mode: + return gm.forward + + return fn + + @register_backend def eager_noexcept(gm, fake_tensor_inputs, **kwargs): if kwargs: diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index dee400d76f5964..0467e2899adc28 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -28,6 +28,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, @@ -129,6 +130,10 @@ def false_fn(x: torch.Tensor): if torch.compiler.is_dynamo_compiling(): return cond_op(pred, true_fn, false_fn, operands) + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + if isinstance(pred, (bool, int, float)): log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." @@ -169,12 +174,15 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(): - with torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_pre_dispatch_torch_function_mode(): - return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( - pred, true_fn, false_fn, operands - ) + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( + pred, true_fn, false_fn, operands + ) def create_fw_bw_graph_branches(true_fn, false_fn, *operands): diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index d6eecb0a9f29f9..d66cff067f6682 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -20,6 +20,7 @@ from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, track_tensor_tree, @@ -118,10 +119,19 @@ def _scan_op_wrapper(*args, **kwargs): return scan(*args, **kwargs) if not torch._dynamo.is_compiling(): + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)( - combine_fn, init, xs, dim=dim, reverse=reverse - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_scan_op_wrapper, backend=backend, fullgraph=True)( + combine_fn, init, xs, dim=dim, reverse=reverse + ) leaves_init, spec_init = pytree.tree_flatten(init) leaves_xs, spec_xs = pytree.tree_flatten(xs) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index d64f512399d9d5..f14321842f40b9 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -15,7 +15,11 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + ProxyTorchDispatchMode, + track_tensor_tree, +) class WhileLoopOp(HigherOrderOperator): @@ -113,6 +117,9 @@ def body_fn(iter, x): - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. """ + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. @@ -140,9 +147,15 @@ def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( - cond_fn, body_fn, carried_inputs, additional_inputs - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile( + _while_loop_op_wrapper, backend=backend, fullgraph=True + )(cond_fn, body_fn, carried_inputs, additional_inputs) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 5a5771d2ae8f7a..213672e216e6ce 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from contextlib import contextmanager, ExitStack, nullcontext +from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -1084,38 +1084,43 @@ def unwrap_proxy(self, e: T) -> object: return e -@contextmanager -def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: - from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode +def _make_temp_remove_mode_context_manager( + mode_ty: Type[TorchFunctionMode], +) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: + @contextmanager + def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - temp_elements = [] - pre_dispatch_mode = None + temp_elements = [] + removed_mode = None - while _len_torch_function_stack() > 0: - mode = _pop_mode() - if isinstance(mode, PreDispatchTorchFunctionMode): - pre_dispatch_mode = mode - break - else: - temp_elements.append(mode) + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, mode_ty): + removed_mode = mode + break + else: + temp_elements.append(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + for mode in reversed(temp_elements): + _push_mode(mode) - try: - yield + try: + yield removed_mode - finally: - if pre_dispatch_mode is not None: - count = len(temp_elements) - while count > 0: - mode = _pop_mode() - count -= 1 + finally: + if removed_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(removed_mode) - temp_elements.append(pre_dispatch_mode) + for mode in reversed(temp_elements): + _push_mode(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + return context_manager_fn @torch._disable_dynamo @@ -1230,6 +1235,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( + TorchFunctionMetadataMode +) + + # This mode is **only** used for pre_dispatch tracing. # In particular, we need to make sure that autograd/autocast API's # that do not desugar into dispatcher operators stay in the graph. @@ -1258,6 +1268,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( + PreDispatchTorchFunctionMode +) + + class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index bb04ce53d0fcab..2317dbac83bdda 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,6 +19,7 @@ ) from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input @@ -1033,6 +1034,10 @@ def score_mod( if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + # Dynamo is expecting a callable with "__code__" attribute. # We cannot directly pass hop to it. So we wrap it in a dummy function. def _flex_attention_hop_wrapper(*args, **kwargs): @@ -1041,18 +1046,25 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _set_compilation_env(): with torch._dynamo.utils.disable_cache_limit(): with _temp_remove_pre_dispatch_torch_function_mode(): - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend="eager", fullgraph=True - )( - query, - key, - value, - score_mod, - block_mask.as_tuple(), - scale, - kernel_options, - ) - if return_lse: - return out, lse * math.log(2) - else: - return out + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode( + metadata_mode + ) + else: + backend = "eager" + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out From 659bee2813094dc001b63607e486c426d7df5618 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 14 Sep 2024 03:23:35 -0700 Subject: [PATCH 0898/1018] [Dynamo] Trace torch function modes entered outside of torch.compile (#133137) This PR adds initial tracing for torch function modes. Details: In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call. This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers. Previously landed: https://github.com/pytorch/pytorch/pull/133135 https://github.com/pytorch/pytorch/pull/133136 https://github.com/pytorch/pytorch/pull/133134 https://github.com/pytorch/pytorch/pull/133133 https://github.com/pytorch/pytorch/pull/133132 https://github.com/pytorch/pytorch/pull/133131 https://github.com/pytorch/pytorch/pull/133729 https://github.com/pytorch/pytorch/pull/133130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137 Approved by: https://github.com/jansel, https://github.com/zou3519 ghstack dependencies: #134732 --- test/dynamo/test_modes.py | 135 +++++++++ test/functorch/test_control_flow.py | 4 + .../pt2e/test_metadata_porting.py | 4 +- .../pt2e/test_numeric_debugger.py | 4 +- torch/_dynamo/convert_frame.py | 5 + torch/_dynamo/guards.py | 14 + torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 78 ++---- torch/_dynamo/trace_rules.py | 5 +- torch/_dynamo/utils.py | 10 +- torch/_dynamo/variables/ctx_manager.py | 2 + torch/_dynamo/variables/torch.py | 260 ++++++++++-------- torch/_dynamo/variables/torch_function.py | 129 +++++++-- torch/_dynamo/variables/user_defined.py | 7 - torch/_export/non_strict_utils.py | 6 +- 15 files changed, 460 insertions(+), 205 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index ee8a9b579ff124..bad5a788903440 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -14,6 +14,17 @@ from torch.utils._python_dispatch import TorchDispatchMode +class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -324,6 +335,130 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 2) + def test_nested_torch_function_mode(self): + mode_1_called = False + mode_2_called = False + + def reset_state(): + nonlocal mode_1_called + nonlocal mode_2_called + mode_1_called = False + mode_2_called = False + + ones = torch.ones(2, 2) + zeros = torch.zeros(2, 2) + + class TestMode1(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_1_called + + mode_1_called = True + + if func == torch.add: + return zeros + + return super().__torch_function__(func, types, args, kwargs) + + class TestMode2(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_2_called + + mode_2_called = True + + if func == torch.mul: + return ones + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + def fn_2(x): + return torch.mul(x, 3) + torch.add(x, 3) + + inp = torch.ones(2, 2) + 1 + + for fn_i in [fn, fn_2]: + fn_opt = torch.compile(fn_i, fullgraph=True) + with TestMode1(), TestMode2(): + expected = fn_i(inp), mode_1_called, mode_2_called + reset_state() + actual = fn_opt(inp), mode_1_called, mode_2_called + reset_state() + + self.assertEqual(expected, actual) + + def test_torch_function_mode_disable(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + with torch._C.DisableTorchFunctionSubclass(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + with torch._C.DisableTorchFunction(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_highest_priority(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index d870b4b921af05..cf639b6abcce4e 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -26,6 +26,7 @@ parametrize, requires_cuda, run_tests, + skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, @@ -2882,6 +2883,7 @@ def f(fct, init, xs): gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x) @skipIfNoDynamoSupport + @skipIfCrossRef # Arg order changes with crossref def test_scan_simple_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -2988,6 +2990,7 @@ def f(x, y): self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -3250,6 +3253,7 @@ def test_while_loop_compile(self, backend, while_loop_test): self._check_compile(fn, inp, backend=backend) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] from torch._dynamo.testing import EagerAndRecordGraphs diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 27df58e159f3b6..21488e8cacdd42 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef class TestHelperModules: @@ -139,6 +139,8 @@ def _test_metadata_porting( self.assertEqual(v, node_tags[k]) return m + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_simple_metadata_porting(self): """ Model under test diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 85b5e9137bdb5c..7808eb892579e9 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -22,7 +22,7 @@ ) from torch.export import export_for_training from torch.testing._internal.common_quantization import TestHelperModules -from torch.testing._internal.common_utils import IS_WINDOWS, TestCase +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]: @@ -117,6 +117,8 @@ def test_deepcopy_preserve_handle(self): self.assertEqual(debug_handle_map, debug_handle_map_ref) + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 4d1557e15a7d23..b3e4a7c268c0fe 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -605,6 +605,10 @@ def _compile( output: Optional[OutputGraph] = None tracer: Optional[InstructionTranslator] = None + tf_mode_stack: List[ + torch.overrides.TorchFunctionMode + ] = torch.overrides._get_current_function_mode_stack() + @preserve_global_state def transform( instructions: List[Instruction], code_options: Dict[str, object] @@ -618,6 +622,7 @@ def transform( locals, globals, builtins, + tf_mode_stack, code_options, compiler_fn, one_graph, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index b0bd273e96bbb5..4437ef4038190d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -98,6 +98,7 @@ ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, + TorchFunctionModeStackSource, TupleIteratorGetItemSource, TypeSource, UnspecializedBuiltinNNModuleSource, @@ -111,6 +112,7 @@ dict_keys_repr, get_custom_getattr, get_torch_function_mode_stack, + get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, @@ -314,6 +316,7 @@ def uninteresting_files(): "___dict_contains": lambda a, b: a in b, "___tuple_iterator_len": tuple_iterator_len, "___tuple_iterator_getitem": tuple_iterator_getitem, + "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -901,6 +904,15 @@ def get_guard_manager_from_source(self, source): ): assert base_guard_manager # to make mypy happy out = base_guard_manager + elif istype(source, TorchFunctionModeStackSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: get_torch_function_mode_stack_at( + source._get_index() + ), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -2214,6 +2226,8 @@ def __init__( self.output_graph = output_graph w_builder = None + # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing + # in case a set default device call was made in the graph. self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 6ed9d97f6a1210..c44f8d9e69abf5 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source): ind: int def name(self): - return "" + return f"___get_torch_function_mode_stack_at({self._get_index()})" def _get_index(self): from .variables.torch_function import TorchFunctionModeStackVariable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index ab92f82aa0f630..415179a5ed32a6 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -19,20 +19,7 @@ import types import typing import weakref -from typing import ( - Any, - Callable, - cast, - Deque, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union from unittest.mock import patch import torch @@ -72,14 +59,12 @@ GlobalWeakRefSource, LocalSource, Source, - TorchFunctionModeStackSource, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( counters, get_fake_value, get_instruction_source_311, - get_torch_function_mode_stack, graph_break_dup_warning_checker, istype, LazyString, @@ -120,11 +105,10 @@ ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable - - -if TYPE_CHECKING: - from .variables.torch_function import TorchFunctionModeVariable - +from .variables.torch_function import ( + SymbolicTorchFunctionState, + TorchFunctionModeVariable, +) from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -284,6 +268,10 @@ def resume_fn(self): return ReenterWith(self.stack_index) def exit(self, tx): + if hasattr(self, "graph_break") and isinstance( + self.with_context, TorchFunctionModeVariable + ): + return assert self.with_context is not None return self.with_context.exit(tx) @@ -652,7 +640,9 @@ def handle_graph_break( # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: assert b.with_context is not None - assert isinstance(b.with_context, ContextWrappingVariable) + assert isinstance( + b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) + ) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -728,7 +718,7 @@ class InstructionTranslatorBase( output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] + symbolic_torch_function_state: SymbolicTorchFunctionState stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -2548,7 +2538,7 @@ def __init__( code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, f_code: types.CodeType, export: bool, inline_depth: int, @@ -2563,7 +2553,7 @@ def __init__( self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals - self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack + self.symbolic_torch_function_state = symbolic_torch_function_state self.stack = [] # stack of variable names for tracking 3.13 closures self.name_stack: list[Any] = [] @@ -2652,6 +2642,7 @@ def __init__( f_locals, f_globals, f_builtins, + torch_function_mode_stack, code_options, compiler_fn, one_graph, @@ -2686,7 +2677,7 @@ def __init__( symbolic_locals={}, # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, - symbolic_torch_function_mode_stack=collections.deque(), + symbolic_torch_function_state=None, # type: ignore[arg-type] # set below f_code=f_code, export=export, inline_depth=0, @@ -2721,7 +2712,9 @@ def __init__( if k in f_locals } - self._init_torch_function_mode_stack() + self.symbolic_torch_function_state = SymbolicTorchFunctionState( + torch_function_mode_stack + ) self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] if export: @@ -2762,29 +2755,6 @@ def _throw_if_in_functorch(self): ) unimplemented(msg) - def _init_torch_function_mode_stack(self): - from .variables.torch_function import TorchFunctionModeStackVariable - - TorchFunctionModeStackVariable.reset() - - self.symbolic_torch_function_mode_stack: Deque[ - TorchFunctionModeVariable - ] = collections.deque() - # We want to retrieve all modes to properly reconstruct the stack if needed - py_stack = get_torch_function_mode_stack(filter_ignored=False) - - if py_stack: - has_device_context = isinstance( - py_stack[0], torch.utils._device.DeviceContext - ) - - for i, val in enumerate(py_stack): - self.symbolic_torch_function_mode_stack.append( - variables.LazyVariableTracker.create( - val, source=TorchFunctionModeStackSource(i) - ) - ) - def get_example_value(self, source: Source): if isinstance(source, LocalSource): return self.f_locals[source.local_name] @@ -3116,7 +3086,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3126,7 +3096,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3179,7 +3149,7 @@ def __init__( code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ) -> None: @@ -3196,7 +3166,7 @@ def __init__( f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, - symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, + symbolic_torch_function_state=symbolic_torch_function_state, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 391beb07cd600d..1a206ae339c2db 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3258,6 +3258,7 @@ def _module_dir(m: types.ModuleType): "torch.testing", "torch.utils._content_store", "torch.utils._contextlib", + "torch.utils._device", "torch.utils._foreach_utils", "torch.utils._python_dispatch", "torch.utils._pytree", @@ -3592,7 +3593,9 @@ def lookup_inner( if reasons is not None: reasons.add("func name is patched_init") return SkipFunctionVariable - elif name == "__torch_function__": + elif name == "__torch_function__" or ( + obj and obj.__name__ == "__torch_function__" + ): if reasons is not None: reasons.add("func name is __torch_function__") return UserFunctionVariable diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 6ff258ff5657c0..8130b6aa7371b2 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -63,7 +63,6 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( - _get_function_stack_at, _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, @@ -3087,7 +3086,9 @@ def is_parameter_freezing(): def get_torch_function_mode_stack(filter_ignored=True): from .variables.torch_function import IGNORED_MODES - stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] + stack = [ + get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) + ] if filter_ignored: stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] @@ -3107,6 +3108,11 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) +def clear_torch_function_mode_stack(): + for i in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index eebe90929b379a..8c4eb3dc4e715b 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -637,6 +637,8 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 + tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] + tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] tx.output.set_torch_function_state(values[0]) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 68508f001dae35..f6cb565ef021c0 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -156,6 +156,15 @@ bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) +@functools.lru_cache(None) +def get_overridable_functions(): + from itertools import chain + + from torch.overrides import get_overridable_functions as get_overridable_functions_ + + return set(chain(*get_overridable_functions_().values())) + + class BaseTorchVariable(VariableTracker): """common base for all torch.* functions, classes, modules and other things""" @@ -806,10 +815,10 @@ def handle_pop_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - if not tx.symbolic_torch_function_mode_stack: + if not tx.symbolic_torch_function_state.mode_stack: raise unimplemented("Popping from an empty torch function mode stack") TorchFunctionModeStackVariable.register_mutation(tx) - return tx.symbolic_torch_function_mode_stack.pop() + return tx.symbolic_torch_function_state.pop_torch_function_mode() @register(torch._C._push_on_torch_function_stack) def handle_push_torch_function( @@ -817,7 +826,7 @@ def handle_push_torch_function( ): assert len(args) == 1 and not kwargs TorchFunctionModeStackVariable.register_mutation(tx) - tx.symbolic_torch_function_mode_stack.append(args[0]) + tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) @@ -825,7 +834,9 @@ def handle_len_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) + return ConstantVariable.create( + len(tx.symbolic_torch_function_state.mode_stack) + ) @register(torch.set_default_device) def handle_set_default_device( @@ -857,6 +868,9 @@ def call_function( from . import ConstantVariable, SymNodeVariable, TensorVariable from .builder import wrap_fx_proxy + if self.torch_function_override_enabled(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -878,147 +892,144 @@ def call_function( if result: return result - if can_dispatch_torch_function(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - else: - any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) - all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args - ) - if ( - getattr(self.value, "__module__", "") == "torch" - and self.value.__name__ in bin_ops - and any_symints_or_symfloats - and all_ints_or_floats - ): - msg = f"""\ + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ - log.warning(msg) - unimplemented(msg) - - # TODO(voz): Replace w/ dynamic shape rewrite table. - # Ideally, we would be able to do this at ctor time, but alas we need a combination - # of value + args to determine this. - fn_ = self.value - if any_symints_or_symfloats: - torch_sym_op = f"_sym_{self.value.__name__}" - if getattr(self.value, "__module__", None) == "math" and hasattr( - torch, torch_sym_op - ): - fn_ = getattr(torch, torch_sym_op) - - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - *proxy_args_kwargs(args, kwargs), - ), - ) - - if ( - isinstance(tensor_variable, TensorVariable) - and "requires_grad" in kwargs - and kwargs["requires_grad"].as_python_constant() + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op ): - unimplemented( - """factory functions that return tensors that require grad are not supported. + fn_ = getattr(torch, torch_sym_op) + + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" - ) + ) - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items - ): - if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor.size != result_tensor.size - ): - # It's hard to get out variants with resizing on graph inputs work - # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape + out_tensor.source + and out_tensor in tx.output.graphargs + and isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor.size != result_tensor.size ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if ( + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out_shape != fake_tensor.shape + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( - "out= op was called where output tensor was non-contiguous" + "out= op was called where some of the output tensors were non-contiguous" ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where some of the output tensors were non-contiguous" - ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") + else: + unimplemented(f"out variant of {type(kwargs['out'])}") - return tensor_variable + return tensor_variable def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" @@ -1146,3 +1157,12 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad source ) return result + + def torch_function_override_enabled(self, tx, args, kwargs): + return ( + self.get_function() in get_overridable_functions() + or isinstance( + self.get_function(), + (torch._ops.OpOverload, torch._ops.OpOverloadPacket), + ) + ) and can_dispatch_torch_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4b3188507fe4b5..6e52cc688e0309 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,8 +1,11 @@ # mypy: ignore-errors +import collections +import contextlib import inspect -from typing import Dict, List, TYPE_CHECKING +from typing import Deque, Dict, List, TYPE_CHECKING +import torch._C import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import _get_overloaded_args, get_default_nowrap_functions @@ -15,6 +18,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import ContextWrappingVariable +from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable @@ -59,6 +63,67 @@ IGNORED_MODES = {DeviceContext} +class SymbolicTorchFunctionState: + def __init__(self, py_stack): + # This is annoyingly complicated because of how the torch function subclass + mode C API was designed + # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass + # These are their definitions: + # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered + # (if either are entered, this will be False) + # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR + # torch._C.DisableTorchFunction has been entered + # To disambiguate these and keep myself sane I added a C API to check whether all torch function + # concepts (modes and subclasses) are enabled. + # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate + # the stack length from the enablement state of torch function modes. + # This is important because now if a mode is pushed while dynamo is tracing, we know whether + # or not torch function modes are enabled and whether we should trace it. + self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() + + # This differs from the C API of the same name + # this will only be false iff we have entered torch._C.DisableTorchFunction + # and does not take into account the mode stack length, while the C API bundles these + # two concepts + self.torch_function_mode_enabled = ( + not torch._C._is_torch_function_all_disabled() + ) + + self.cur_mode = None + + TorchFunctionModeStackVariable.reset() + + self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() + + for i, val in enumerate(py_stack): + self.mode_stack.append( + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) + ) + + def in_torch_function_mode(self): + return len(self.mode_stack) > 0 + + def pop_torch_function_mode(self): + return self.mode_stack.pop() + + def push_torch_function_mode(self, mode_var): + self.mode_stack.append(mode_var) + + def call_torch_function_mode(self, tx, fn, types, args, kwargs): + with self._pop_mode_for_inlining() as cur_mode: + return cur_mode.call_torch_function(tx, fn, types, args, kwargs) + + @contextlib.contextmanager + def _pop_mode_for_inlining(self): + old_mode = self.cur_mode + self.cur_mode = self.pop_torch_function_mode() + try: + yield self.cur_mode + finally: + mode = self.cur_mode + self.cur_mode = old_mode + self.push_torch_function_mode(mode) + + class TorchFunctionModeStackVariable(VariableTracker): """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" @@ -88,19 +153,20 @@ def reset(cls): def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( - source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack + source=Source(), + symbolic_stack=tx.symbolic_torch_function_state.mode_stack, ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 - tx.symbolic_torch_function_mode_stack.insert( + stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) @@ -109,7 +175,7 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @@ -124,23 +190,39 @@ def get_mode_index(cls, ind): class TorchFunctionModeVariable(ContextWrappingVariable): - def __init__(self, value, **kwargs): + def __init__(self, value, source=None, **kwargs): super().__init__(value, **kwargs) self.value = value - - @staticmethod - def get_global_mangled_name(tx, val): - return get_safe_global_name( - tx, f"__torch_function_mode_{val.__class__.__name__}", val - ) + self.cm_obj = value # needed for BC with calling enter from CM code + self.source = source def reconstruct(self, codegen): - # We don't support locally created torch function modes yet + # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) - def _call_func(self, tx, values): - unimplemented("torch function mode context manager is not supported yet") + def module_name(self): + return self.value.__module__ + + def fn_name(self): + return type(self.value).__name__ + + def python_type(self): + return type(self.value) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + return call_torch_function( + tx, + self, + build_torch_function_fn(tx, self.value, self.source), + fn, + types, + args, + kwargs, + ) + + def _call_func(self, tx: "InstructionTranslator", values): + unimplemented("enter/exit for torch function mode NYI") def _get_all_args(args, kwargs): @@ -231,9 +313,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): - return tx.output.torch_function_enabled and any( + has_overridden_args = any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) + tf_state = tx.symbolic_torch_function_state + return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( + tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() + ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): @@ -245,11 +331,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): _get_subclass_type, ) + types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) + + if tx.symbolic_torch_function_state.in_torch_function_mode(): + res = tx.symbolic_torch_function_state.call_torch_function_mode( + tx, fn, types, args, kwargs + ) + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + return res + for arg in overloaded_args: res = arg.call_torch_function( tx, fn, - TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), + types, args, kwargs, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 34e1d0d10c9f72..d069d2e060a108 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,11 +82,6 @@ def is_forbidden_context_manager(ctx): from _pytest.python_api import RaisesContext from _pytest.recwarn import WarningsChecker - # TODO mlazos: Temporary to get this stack to pass - # remove in subsequent PR - from torch.overrides import BaseTorchFunctionMode - - f_ctxs.append(BaseTorchFunctionMode) f_ctxs.append(RaisesContext) f_ctxs.append(WarningsChecker) except ImportError: @@ -413,7 +408,6 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): - # import here to avoid an unfortunate circular dependency. from .ctx_manager import GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( @@ -421,7 +415,6 @@ def call_function( ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj - elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 5c7331659d26a0..ef15e5fea9e973 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: + if ( + not torch.compiler.is_dynamo_compiling() + and log.isEnabledFor(logging.DEBUG) + and config.extended_debug_current_loc + ): frame = _find_user_code_frame() if frame is not None: log.debug( From a72c29acab6c7cc38fa1c5e6f2e8b395e6d314e8 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 14 Sep 2024 03:23:35 -0700 Subject: [PATCH 0899/1018] [Dynamo] Support thread local setattr (#135443) In preparation for tracing through DeviceContext (https://github.com/pytorch/pytorch/blob/defb515306fc53ec62e92937a5a76fa5cbc05b84/torch/utils/_device.py#L66) This PR adds support for calling the setattr of thread local objects. These objects have a slots impl, and since this doesn't appear to have any side effects, we call this setattr impl when replaying mutations, since calling `object.__setattr__` on these objects results in a type error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135443 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137 --- test/dynamo/test_misc.py | 15 +++++++++++++++ torch/_dynamo/variables/user_defined.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 9d2bb374621134..96047d3436ebe1 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3380,6 +3380,21 @@ def fn4(x) -> None: self.assertTrue(same(obj41.y, obj42.y)) self.assertEqual(cnts.frame_count, 1) + def test_thread_local_setattr(self): + from threading import local + + loc = local() + + @torch.compile(fullgraph=True) + def fn(x, l): + l.x = x + return x + 1 + + x = torch.ones(2, 2) + fn(x, loc) + + self.assertTrue(loc.x is x) + def test_user_defined_class_name(self): class MyClassFoo: pass diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index d069d2e060a108..ad92277cf70ff4 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -9,6 +9,7 @@ import itertools import random import sys +import threading import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING @@ -704,7 +705,7 @@ def call_method( if method is object.__init__: return ConstantVariable.create(None) - if is_standard_setattr(method): + if is_standard_setattr(method) or isinstance(self.value, threading.local): return self.method_setattr_standard(tx, *args, **kwargs) # [NOTE] OrderedDict, dict subtypes must always have source @@ -802,7 +803,7 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def needs_slow_setattr(self): return not is_standard_setattr( inspect.getattr_static(self.value, "__setattr__", None) - ) + ) and not isinstance(self.value, threading.local) def unpack_var_sequence(self, tx): if ( From 0371158ec325bd3b5da086e015ba95960101e8c8 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 14 Sep 2024 03:23:36 -0700 Subject: [PATCH 0900/1018] [Dynamo] Simplify torch function mode stack guard (#135444) The semantics of ignored modes previously had edge cases, this eliminates these by in essence filtering any ignored modes out of both the ref stack and the current torch function mode stack. This is purely to fix complexity in #135422. The ignored modes handling will be removed in a future PR after https://github.com/pytorch/pytorch/pull/135422 lands, since we will then trace through DeviceContexts vs inserting them into the graph which needed these extra workarounds for correctness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135444 Approved by: https://github.com/anijain2305, https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443 --- test/dynamo/test_modes.py | 2 ++ torch/_dynamo/guards.py | 10 ++++---- torch/csrc/dynamo/guards.cpp | 44 +++++++++++++++++++++++++++++------- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index bad5a788903440..fa4c23fd320ddb 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -68,9 +68,11 @@ def tearDownClass(cls): def setUp(self): torch.set_default_device(None) + torch._dynamo.reset() def tearDown(self): torch.set_default_device(None) + torch._dynamo.reset() def _run_torch_function_mode_guard_test(self): class TestMode1(BaseTorchFunctionMode): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4437ef4038190d..1544bdd6dc697b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2663,12 +2663,14 @@ def make_torch_function_mode_stack_guard(intial_stack): def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - if len(cur_stack) != len(types): + + types_ = [ty for ty in types if ty not in IGNORED_MODES] + cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] + + if len(cur_stack_) != len(types_): return False - for ty, mode in zip(types, cur_stack): - if ty in IGNORED_MODES: - continue + for ty, mode in zip(types_, cur_stack_): if ty != type(mode): return False diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index afc3182c632640..771de44427b669 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2523,7 +2523,8 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref - this->_ref_stack.push_back(Py_TYPE(mode)); + auto type = Py_TYPE(mode); + this->_ref_stack.push_back(type); } len = PyList_Size(ignored_types.ptr()); @@ -2543,29 +2544,56 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface size_t ref_ind = 0; - int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - for (int64_t idx = 0; idx < len; idx++) { + int64_t idx = 0; + while ((idx < len) && (ref_ind < ref_stack_size)) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + bool act_ignored = this->_ignored_types.count(mode_type) > 0; + bool ref_ignored = + this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; // skip ignored types - if (this->_ignored_types.count(mode_type) > 0) { + if (act_ignored && ref_ignored) { + idx++; + ref_ind++; + continue; + } else if (ref_ignored) { + ref_ind++; + continue; + } else if (act_ignored) { + idx++; continue; } // if we already have more non-ignored modes than the ref stack // or if the mode doesn't match at the current index, return false - else if ( - (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || - mode_type != _ref_stack[ref_ind]) { + else if (mode_type != _ref_stack.at(ref_ind)) { return false; } ref_ind++; + idx++; + } + + for (; ref_ind < ref_stack_size; ref_ind++) { + if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { + return false; + } + } + + for (; idx < len; idx++) { + std::shared_ptr mode = + at::impl::PythonTorchFunctionTLS::get_stack_at(idx); + + PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); + if (!(this->_ignored_types.count(mode_type) > 0)) { + return false; + } } - return ref_ind == this->_ref_stack.size(); + return ref_ind == ref_stack_size && idx == len; } private: From a87bf5b88d2602459ff2f52126e03e940ba6aa84 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 14 Sep 2024 03:23:36 -0700 Subject: [PATCH 0901/1018] [Dynamo] Trace enter/exit of TorchFunctionModes (#135422) This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode) Typically the bytecode for a context manager looks like this during a graph break: 1. graph call 2. enter context 3. unsupported code 4. exit context 5. resume call resume fn structure: 1. enter context 2. jump ... 3. exit context The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack). So for torch function modes the structure of our output code is this: 1. graph call 2. mutate tf mode stack to replay mutations 4. unsupported code 5. on exception restore stack 6. resume function Then our resume fn looks like this: 1. no-op enter torch function mode 2. jump 3. exit tf mode To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context). Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422 Approved by: https://github.com/williamwen42 ghstack dependencies: #134732, #133137, #135443, #135444 --- test/dynamo/test_modes.py | 88 ++++++++++++++++++ torch/_dynamo/convert_frame.py | 6 +- torch/_dynamo/output_graph.py | 6 +- torch/_dynamo/polyfills/__init__.py | 20 ++++ torch/_dynamo/resume_execution.py | 104 +++++++++++++++++++++ torch/_dynamo/side_effects.py | 11 +++ torch/_dynamo/symbolic_convert.py | 30 ++++-- torch/_dynamo/testing.py | 1 + torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/builder.py | 20 ++-- torch/_dynamo/variables/ctx_manager.py | 12 +++ torch/_dynamo/variables/torch.py | 21 ++++- torch/_dynamo/variables/torch_function.py | 108 ++++++++++++++++++++-- torch/_dynamo/variables/user_defined.py | 14 ++- 14 files changed, 409 insertions(+), 34 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index fa4c23fd320ddb..63e06a515ed186 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -461,6 +461,94 @@ def fn(x): self.assertEqual(expected, actual) + def test_torch_function_mode_enter_exit(self): + def fn(x, y): + with TestMode(): + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn, fullgraph=True) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_graph_break(self): + def fn(x, y): + with TestMode(): + torch._dynamo.graph_break() + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_and_pop_graph_break(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_restore_on_exc(self): + @torch._dynamo.disable() + def err(): + raise RuntimeError("test") + + @torch.compile() + def fn(x): + with TestMode(): + x += 1 + err() + x += 2 + return x + + try: + fn(torch.ones(2, 2)) + except RuntimeError: + pass + self.assertEqual(_len_torch_function_stack(), 0) + + def test_torch_function_mode_and_pop_graph_break_mutation(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + z.y = 5 + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + o = torch.mul(o, z.y) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index b3e4a7c268c0fe..171af02e564b36 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -112,6 +112,7 @@ troubleshooting_url, write_record_to_file, ) +from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -210,15 +211,18 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() - exit_stack = contextlib.ExitStack() exit_stack.enter_context( torch.fx._symbolic_trace._maybe_revert_all_patches() ) + exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() + assert ( + torch._C._len_torch_function_stack() == 0 + ), "Torch function mode stack state changed while dynamo tracing, please report a bug" exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ba8cfa5a03dfba..0ea19e1cad6586 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,7 +78,6 @@ get_instruction_source_311, get_locals_to_steal, get_static_address_type, - get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -250,6 +249,7 @@ def __init__( local_scope: Scope, global_scope: Scope, f_code, + torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -368,7 +368,7 @@ def __init__( # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = get_torch_function_mode_stack() + self.torch_function_mode_stack = torch_function_mode_stack # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ def append_prefix_insts(): prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx) + block.exit(tx, is_graph_break=reason.graph_break) self.cleanup_graph() tx.prune_dead_locals() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 2b3f38920a0c2a..5b2812bc08c9e5 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -25,6 +25,26 @@ sys as sys, ) +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + def index(iterator, item, start=0, end=None): from itertools import islice diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 132e9e4081bceb..6f7db66514c487 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -6,6 +6,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( + add_push_null, create_call_function, create_call_method, create_dup_top, @@ -48,6 +49,109 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None + def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): + """ + Codegen based off of: + try: + (rest) + except: + (restore previous stack) + + """ + from .variables.torch_function import get_prev_stack_var_name + + except_jump_target = create_instruction( + "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" + ) + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally: List[Instruction] = [] + + if sys.version_info < (3, 11): + setup_finally.append( + create_instruction("SETUP_FINALLY", target=except_jump_target) + ) + else: + exn_tab_begin = create_instruction("NOP") + exn_tab_end = create_instruction("NOP") + exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_begin, + exn_tab_end, + except_jump_target, + self.stack_index + 1, + False, + ) + setup_finally.append(exn_tab_begin) + + def create_reset(): + insts = [ + create_instruction( + "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" + ), + create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), + ] + add_push_null(insts) + return [ + *insts, + create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()), + *create_call_function(1, False), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *create_reset(), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + create_instruction("END_FINALLY"), + cleanup_complete_jump_target, + ] + elif sys.version_info < (3, 11): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + cleanup_complete_jump_target, + ] + else: + finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) + finally_exn_tab_target = create_instruction("COPY", arg=3) + except_jump_target.exn_tab_entry = InstructionExnTabEntry( + except_jump_target, + finally_exn_tab_end, + finally_exn_tab_target, + self.stack_index + 2, + True, + ) + epilogue = [ + exn_tab_end, + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, # PUSH_EXC_INFO + create_instruction("POP_TOP"), + *create_reset(), + finally_exn_tab_end, + finally_exn_tab_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), # RERAISE 1 + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fc1bd976ff57df..fd5b54e4f7f7af 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -593,11 +593,22 @@ def codegen_update_mutated(self, cg: PyCodegen): elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): + # Needed in the finally block for stack restoration + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "get_torch_function_mode_stack" + ) + ) + cg.call_function(0, False) + name = variables.torch_function.get_prev_stack_var_name() + cg.code_options["co_varnames"] += (name,) + cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) + cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 415179a5ed32a6..61acaec995757a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -267,13 +267,12 @@ def resume_fn(self): else: return ReenterWith(self.stack_index) - def exit(self, tx): - if hasattr(self, "graph_break") and isinstance( - self.with_context, TorchFunctionModeVariable - ): - return + def exit(self, tx, is_graph_break): assert self.with_context is not None - return self.with_context.exit(tx) + if ( + is_graph_break and self.with_context.exit_on_graph_break() + ) or not is_graph_break: + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -639,10 +638,17 @@ def handle_graph_break( cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: + # Don't exit any modes we have entered, + # output bytecode will mutate the tf mode stack accordingly + if isinstance(b.with_context, TorchFunctionModeVariable): + cg.extend_output( + b.resume_fn().try_except_torch_function_mode( + cg.code_options, cleanup + ) + ) + continue assert b.with_context is not None - assert isinstance( - b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable) - ) + assert isinstance(b.with_context, (ContextWrappingVariable)) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -2295,7 +2301,10 @@ def setup_or_before_with(self, inst): ): unimplemented(f"{inst.opname} {ctx}") - if isinstance(ctx, GenericContextWrappingVariable): + if ( + isinstance(ctx, GenericContextWrappingVariable) + and not ctx.supports_graph_breaks() + ): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2668,6 +2677,7 @@ def __init__( local_scope=f_locals, global_scope=f_globals, f_code=f_code, + torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 4922b521bada34..704a3889707233 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -163,6 +163,7 @@ def insert_nops(instructions, code_options): local_scope=locals(), global_scope=globals(), f_code=frame.f_code, + torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1a206ae339c2db..ed81b26191b4de 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -304,6 +304,7 @@ "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -2797,7 +2798,6 @@ "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", - "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5a8a9b613b899d..193e227267ec3b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -204,6 +204,7 @@ from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( @@ -1668,15 +1669,16 @@ def wrap_numpy_ndarray(self, value): # but warning is not the end of the world assert isinstance(value.base, np.nditer) - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides - - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 8c4eb3dc4e715b..e19c4e254c647e 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,6 +125,12 @@ def call_function( if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return True + class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -183,6 +189,12 @@ def exit(self, tx: "InstructionTranslator", *args): tx.generic_context_manager_depth -= 1 return x + def supports_graph_breaks(self): + return False + + def exit_on_graph_break(self): + return True + class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f6cb565ef021c0..c1e0dec0fbc415 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -162,7 +162,17 @@ def get_overridable_functions(): from torch.overrides import get_overridable_functions as get_overridable_functions_ - return set(chain(*get_overridable_functions_().values())) + funcs = set(chain(*get_overridable_functions_().values())) + more = { + torch.ones, + torch.ones_like, + torch.zeros, + torch.zeros_like, + torch.empty, + torch.full, + } + funcs.update(more) + return funcs class BaseTorchVariable(VariableTracker): @@ -838,6 +848,13 @@ def handle_len_torch_function( len(tx.symbolic_torch_function_state.mode_stack) ) + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + assert len(args) == 1 and not kwargs + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] + @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -855,7 +872,7 @@ def handle_set_default_device( else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return None + return ConstantVariable.create(None) return handlers diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 6e52cc688e0309..b7fd83a3d24f15 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -2,22 +2,35 @@ import collections import contextlib +import functools import inspect from typing import Deque, Dict, List, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import _get_overloaded_args, get_default_nowrap_functions +from torch.overrides import ( + _get_overloaded_args, + get_default_nowrap_functions, + TorchFunctionMode, +) from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import ContextWrappingVariable +from .ctx_manager import GenericContextWrappingVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -63,6 +76,39 @@ IGNORED_MODES = {DeviceContext} +@functools.lru_cache(None) +def get_prev_stack_var_name(): + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self): + self.stack = [] + + def __enter__(self): + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__(self, exc_type, exc_value, traceback): + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self): + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + class SymbolicTorchFunctionState: def __init__(self, py_stack): # This is annoyingly complicated because of how the torch function subclass + mode C API was designed @@ -189,9 +235,26 @@ def get_mode_index(cls, ind): return ind + cls.offset -class TorchFunctionModeVariable(ContextWrappingVariable): +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty): + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the funtion across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ + ) + def __init__(self, value, source=None, **kwargs): - super().__init__(value, **kwargs) + if value is not None: + super().__init__(value, **kwargs) self.value = value self.cm_obj = value # needed for BC with calling enter from CM code self.source = source @@ -221,8 +284,39 @@ def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwar kwargs, ) - def _call_func(self, tx: "InstructionTranslator", values): - unimplemented("enter/exit for torch function mode NYI") + def enter(self, tx): + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen): + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return False def _get_all_args(args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ad92277cf70ff4..14499b4d2e4bf7 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -409,10 +409,22 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): + from torch.overrides import TorchFunctionMode + from .ctx_manager import GenericContextWrappingVariable + from .torch_function import TorchFunctionModeVariable + + if issubclass( + self.value, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode( + self.value + ): + var_cls = TorchFunctionModeVariable + else: + var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, GenericContextWrappingVariable, {} + self.source, self.value, var_cls, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj From 4019089e79a62b4132f2708aa28ec41ba08841b1 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 14 Sep 2024 03:23:36 -0700 Subject: [PATCH 0902/1018] [Dynamo] Remove ignored modes workaround (#135502) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135502 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137, #135443, #135444, #135422 --- test/dynamo/test_modes.py | 65 ----------------------- torch/_dynamo/guards.py | 12 ++--- torch/_dynamo/utils.py | 10 +--- torch/_dynamo/variables/torch_function.py | 6 --- 4 files changed, 5 insertions(+), 88 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 63e06a515ed186..4d1f2bbea389ef 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -from unittest.mock import patch import torch import torch._dynamo.test_case @@ -107,70 +106,6 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 4) - def _run_ignored_mode_types_test(self): - class IgnoredMode(BaseTorchFunctionMode): - pass - - cnt = torch._dynamo.testing.CompileCounter() - - @torch.compile(backend=cnt.__call__, fullgraph=True) - def fn(x): - return x + 1 - - inp = torch.ones(2, 2) - - with patch( - "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} - ): - # initial compile - fn(inp) - - # no recompile, mode ignored - # note: the ref stack is length 0, and the stack we are checking against has length 2 - # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack - with IgnoredMode(), IgnoredMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 1) - - # recompile due to new mode on the stack - with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 2) - - # recompile - # tests both ref stack len > runtime stack len for the above guard check - # and ref stack len < runtime stack len for the initial zero mode case - with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 3) - - # no recompile - with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 3) - - # This is tricky, basically the ignored modes are baked into the guard - # IgnoredMode will be ignored forever by that guard. - # This is okay since we don't expect to be modifying IGNORED_MODES - # in the middle of execution except for the purposes of testing. - torch._dynamo.reset() - - with IgnoredMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 4) - - @torch._dynamo.config.patch("enable_cpp_guard_manager", False) - def test_torch_function_mode_guards_ignored_types_py(self): - self._run_ignored_mode_types_test() - - def test_torch_function_mode_guards_ignored_types_cpp(self): - self._run_ignored_mode_types_test() - @torch._dynamo.config.patch("enable_cpp_guard_manager", False) def test_torch_function_mode_guards_py(self): self._run_torch_function_mode_guard_test() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1544bdd6dc697b..66deeb3cbc0e79 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2344,15 +2344,13 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): ) if config.enable_cpp_guard_manager: - from .variables.torch_function import IGNORED_MODES - # Insert the global_state guard assert self.guard_manager # to make mypy happy self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(IGNORED_MODES), + list(), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list @@ -2659,18 +2657,14 @@ def is_recompiles_verbose_enabled(): # this will only be used if cpp guards are disabled def make_torch_function_mode_stack_guard(intial_stack): types = [type(x) for x in intial_stack] - from .variables.torch_function import IGNORED_MODES def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() - types_ = [ty for ty in types if ty not in IGNORED_MODES] - cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES] - - if len(cur_stack_) != len(types_): + if len(cur_stack) != len(types): return False - for ty, mode in zip(types_, cur_stack_): + for ty, mode in zip(types, cur_stack): if ty != type(mode): return False diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 8130b6aa7371b2..2d4a771166cfbe 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3083,16 +3083,10 @@ def is_parameter_freezing(): return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(filter_ignored=True): - from .variables.torch_function import IGNORED_MODES - - stack = [ +def get_torch_function_mode_stack(): + return [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] - if filter_ignored: - stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] - - return stack def get_torch_function_mode_stack_at(ind): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index b7fd83a3d24f15..ffb3d27d4d703b 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -69,12 +69,6 @@ if is_tensor_base_attr_getter(fn) ] -# Today set default device is placed in the graph and guarded on separately -# so we should not trace through it. In the future we can trace it once -# mode tracing is implemented and not put in the graph, but this is more -# of a BE project and can be evaluated later -IGNORED_MODES = {DeviceContext} - @functools.lru_cache(None) def get_prev_stack_var_name(): From 23664698380a34a01dd822e7936e2f106b48cc99 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 14 Sep 2024 03:23:37 -0700 Subject: [PATCH 0903/1018] [Dynamo] Remove ignored modes from torch function mode stack guard (#135503) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135503 Approved by: https://github.com/anijain2305 ghstack dependencies: #134732, #133137, #135443, #135444, #135422, #135502 --- torch/_C/_dynamo/guards.pyi | 2 +- torch/_dynamo/guards.py | 1 - torch/csrc/dynamo/guards.cpp | 69 +++++------------------------------- 3 files changed, 10 insertions(+), 62 deletions(-) diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index d6aab0d547f727..918d913068e6df 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -67,7 +67,7 @@ class GuardManager: ) -> None: ... def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_torch_function_mode_stack_guard( - self, initial_stack, ignored_types, verbose_code_parts: list[str] + self, initial_stack, verbose_code_parts: list[str] ) -> None: ... class RootGuardManager(GuardManager): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 66deeb3cbc0e79..3dcc9a032f208f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2350,7 +2350,6 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 771de44427b669..1d1cf2fb0c60a6 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2515,90 +2515,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { public: TORCH_FUNCTION_MODE_STACK( const py::list& initial_stack, - const py::list& ignored_types, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), - _ref_stack(), - _ignored_types() { + : LeafGuard(std::move(verbose_code_parts)), _ref_stack() { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref auto type = Py_TYPE(mode); this->_ref_stack.push_back(type); } - - len = PyList_Size(ignored_types.ptr()); - for (Py_ssize_t idx = 0; idx < len; idx++) { - PyObject* type_obj = - PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref - if (PyType_Check(type_obj) == 0) { - PyErr_SetString( - PyExc_TypeError, "ignored_types should contain a list of types"); - return; - } - PyTypeObject* type = (PyTypeObject*)type_obj; - this->_ignored_types.insert(type); - } } bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface - size_t ref_ind = 0; - const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - int64_t idx = 0; - while ((idx < len) && (ref_ind < ref_stack_size)) { - std::shared_ptr mode = - at::impl::PythonTorchFunctionTLS::get_stack_at(idx); - - PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - bool act_ignored = this->_ignored_types.count(mode_type) > 0; - bool ref_ignored = - this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0; - // skip ignored types - if (act_ignored && ref_ignored) { - idx++; - ref_ind++; - continue; - } else if (ref_ignored) { - ref_ind++; - continue; - } else if (act_ignored) { - idx++; - continue; - } - // if we already have more non-ignored modes than the ref stack - // or if the mode doesn't match at the current index, return false - else if (mode_type != _ref_stack.at(ref_ind)) { - return false; - } - ref_ind++; - idx++; - } - - for (; ref_ind < ref_stack_size; ref_ind++) { - if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) { - return false; - } + if (len != ref_stack_size) { + return false; } - for (; idx < len; idx++) { + for (int64_t idx = 0; (size_t)idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - if (!(this->_ignored_types.count(mode_type) > 0)) { + if (mode_type != _ref_stack.at(idx)) { return false; } } - return ref_ind == ref_stack_size && idx == len; + return true; } private: std::vector _ref_stack; - std::set _ignored_types; }; class TENSOR_MATCH : public LeafGuard { @@ -3763,7 +3713,7 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") - .def(py::init()) + .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( py_m, "DATA_PTR_MATCH") @@ -4000,10 +3950,9 @@ PyObject* torch_c_dynamo_guards_init() { "add_torch_function_mode_stack_guard", [](GuardManager& self, const py::list& initial_stack, - const py::list& ignored_types, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - initial_stack, ignored_types, std::move(verbose_code_parts))); + initial_stack, std::move(verbose_code_parts))); }) .def( "add_data_ptr_guard", From 7056c5bb479ee3c26050b0500be75f0daeac84ad Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sat, 14 Sep 2024 20:48:42 +0000 Subject: [PATCH 0904/1018] [BE][Ez]: Update pybind11 to 2.13.6. Exposes new conduit cross-compat API (#136087) Updates pybind11 submodule. The major patchnote is an experimental new function that is added to all pybind11 objects that will make them more compatible across pybind11 version, settings, and frameworks (such as nanobind) called cpp_conduit. No code changes needed on our end except to update Pull Request resolved: https://github.com/pytorch/pytorch/pull/136087 Approved by: https://github.com/malfet --- third_party/pybind11 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/pybind11 b/third_party/pybind11 index 7c33cdc2d39c7b..a2e59f0e706540 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 7c33cdc2d39c7b99a122579f53bc94c8eb3332ff +Subproject commit a2e59f0e7065404b44dfe92a28aca47ba1378dc4 From 6d4db5bc5213d18515b97ce7ed8a55a2882266ce Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 13 Sep 2024 20:12:05 -0700 Subject: [PATCH 0905/1018] [dynamo] Fix support for classmethod(property(...)) (#134968) Fixes #134451 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134968 Approved by: https://github.com/yanboliang --- test/dynamo/test_repros.py | 79 ++++++++++++++++++++++++- torch/_dynamo/variables/constant.py | 2 + torch/_dynamo/variables/user_defined.py | 4 ++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e1e1dd1837bef4..bd661990402622 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -14,13 +14,14 @@ import itertools import os import random +import sys import unittest import warnings import weakref from abc import ABC from collections import namedtuple from copy import deepcopy -from enum import Enum +from enum import Enum, IntEnum from functools import wraps from typing import Any, Dict, Iterator, List, Tuple from unittest import mock @@ -4546,6 +4547,82 @@ def f(*args): f(*args) self.assertEqual(num_compiles, 1) + @unittest.skipIf(sys.version_info < (3, 9), "requires python 3.9+") + def test_issue134451(self): + class BoundingBox2DIndex(IntEnum): + _X = 0 + _Y = 1 + _HEADING = 2 + _LENGTH = 3 + _WIDTH = 4 + + @classmethod + def size(cls): + return 5 + + @classmethod + @property + def X(cls): + return cls._X + + @classmethod + @property + def Y(cls): + return cls._Y + + @classmethod + @property + def HEADING(cls): + return cls._HEADING + + @classmethod + @property + def LENGTH(cls): + return cls._LENGTH + + @classmethod + @property + def WIDTH(cls): + return cls._WIDTH + + @classmethod + @property + def POINT(cls): + # assumes X, Y have subsequent indices + return slice(cls._X, cls._Y + 1) + + @classmethod + @property + def STATE_SE2(cls): + # assumes X, Y, HEADING have subsequent indices + return slice(cls._X, cls._HEADING + 1) + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self._mlp_states = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, BoundingBox2DIndex.size()), + ) + + def forward(self, x): + agent_states = self._mlp_states(x) + agent_states[..., BoundingBox2DIndex.POINT] = ( + agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 + ) + agent_states[..., BoundingBox2DIndex.HEADING] = ( + agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi + ) + return agent_states + + model = SimpleModel().eval() + input_tensor = torch.randn(1, 10, dtype=torch.float32) + opt = torch.compile(model.eval(), backend="eager", fullgraph=True) + actual = opt(input_tensor) + expected = model(input_tensor) + self.assertEqual(actual, expected) + def test_invalid_seq_unpack(self): def myfn(arg): (a, b) = arg diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 65de0aab6ef3c7..de357cf8094f3a 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -222,6 +222,8 @@ def create(cls, cls_type, value_vt, options): unimplemented("Enum variable is constructed with non constant values") def as_proxy(self): + if isinstance(self.value, int): + return int(self.value) # convert IntEnum to a normal int return self.value def __str__(self) -> str: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 14499b4d2e4bf7..32057c838d74c8 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -193,6 +193,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke else: return SourcelessBuilder.create(tx, func) elif isinstance(obj, classmethod): + if isinstance(obj.__func__, property): + return variables.UserFunctionVariable(obj.__func__.fget).call_function( + tx, [self], {} + ) return variables.UserMethodVariable(obj.__func__, self, source=source) elif isinstance(obj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static(dict, "fromkeys") From d01f579c5f491d1e5eb18b779edc79aadce6a912 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sat, 14 Sep 2024 21:40:34 +0000 Subject: [PATCH 0906/1018] [BE]: Update mypy to 1.11.2 (#133816) Updates mypy to 1.11.1 to improve type inference Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816 Approved by: https://github.com/ezyang --- .ci/docker/requirements-ci.txt | 2 +- .lintrunner.toml | 2 +- torch/_dynamo/repro/after_aot.py | 2 +- torch/_dynamo/repro/after_dynamo.py | 9 +- torch/_dynamo/resume_execution.py | 14 +-- torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/utils.py | 2 + torch/_export/passes/constant_folding.py | 2 +- torch/_higher_order_ops/flex_attention.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 3 +- .../codegen/cuda_combined_scheduling.py | 2 +- torch/_inductor/dependencies.py | 10 +- torch/_inductor/kernel/mm_common.py | 4 +- torch/_inductor/remote_cache.py | 4 +- torch/_inductor/runtime/triton_heuristics.py | 4 +- torch/_inductor/scheduler.py | 4 +- torch/_library/infer_schema.py | 2 +- torch/_subclasses/functional_tensor.py | 4 +- torch/_tensor.py | 2 +- torch/ao/nn/intrinsic/modules/fused.py | 4 +- .../ao/nn/intrinsic/qat/modules/conv_fused.py | 93 +++++++++---------- .../intrinsic/quantized/modules/conv_add.py | 4 +- torch/ao/nn/qat/modules/conv.py | 18 ++-- torch/ao/nn/quantized/dynamic/modules/conv.py | 25 ++--- torch/ao/nn/quantized/modules/conv.py | 41 ++++---- torch/ao/ns/_numeric_suite.py | 2 +- torch/ao/ns/_numeric_suite_fx.py | 2 +- torch/ao/ns/fx/mappings.py | 6 +- .../data_sparsifier/base_data_sparsifier.py | 6 +- .../data_sparsifier/data_norm_sparsifier.py | 2 +- .../_experimental/pruner/FPGM_pruner.py | 4 +- .../sparsifier/nearly_diagonal_sparsifier.py | 4 +- .../sparsifier/weight_norm_sparsifier.py | 2 +- torch/ao/quantization/__init__.py | 2 +- torch/ao/quantization/_correct_bias.py | 2 +- .../ao/quantization/experimental/observer.py | 2 +- .../quantizer/x86_inductor_quantizer.py | 2 +- torch/distributed/_functional_collectives.py | 4 +- .../distributed/_shard/sharded_tensor/api.py | 2 +- .../algorithms/_comm_hooks/default_hooks.py | 2 + .../optim/post_localSGD_optimizer.py | 2 +- torch/distributed/pipelining/_IR.py | 5 +- torch/distributed/tensor/_api.py | 2 +- torch/distributed/tensor/_shards_wrapper.py | 2 +- .../tensor/parallel/input_reshard.py | 1 + torch/distributions/constraints.py | 2 +- torch/distributions/von_mises.py | 2 +- torch/masked/maskedtensor/core.py | 4 +- torch/multiprocessing/reductions.py | 2 +- torch/nested/_internal/nested_tensor.py | 2 +- torch/nn/attention/bias.py | 2 +- torch/optim/lr_scheduler.py | 4 +- torch/sparse/semi_structured.py | 4 +- torch/utils/_sympy/functions.py | 22 ++--- torch/utils/weak.py | 2 +- 56 files changed, 185 insertions(+), 180 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 88ecc0d1a7edf3..bd48d696c9230b 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -90,7 +90,7 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==1.10.0 +mypy==1.11.2 # Pin MyPy version because new errors are likely to appear with each release #Description: linter #Pinned versions: 1.10.0 diff --git a/.lintrunner.toml b/.lintrunner.toml index 4789991736be17..9b43b382809212 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -139,7 +139,7 @@ init_command = [ 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.2.1', - 'mypy==1.10.0', + 'mypy==1.11.2', 'sympy==1.12.1 ; python_version == "3.8"', 'sympy==1.13.0 ; python_version >= "3.9"', 'types-requests==2.27.25', diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index a21a20cdf2e611..19aa69d084db36 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -521,7 +521,7 @@ def repro_common(options, mod, load_args): def repro_minifier_query(options, mod, load_args): mod, args = repro_common(options, mod, load_args) fail_fn = functools.partial( - ACCURACY_FAILS[options.accuracy], check_str=options.check_str + ACCURACY_FAILS[options.accuracy], check_str=options.check_str # type: ignore[call-arg] ) if fail_fn(mod, args): sys.exit(1) diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index b1bb950d11c11b..214480b02c9332 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -12,6 +12,7 @@ import torch import torch.fx as fx +from torch._dynamo.backends.registry import CompiledFn from torch._dynamo.debug_utils import ( AccuracyError, backend_accuracy_fails, @@ -271,8 +272,10 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -@register_debug_backend -def dynamo_minifier_backend(gm, example_inputs, compiler_name): +@register_debug_backend # type: ignore[arg-type] +def dynamo_minifier_backend( + gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn +): from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) @@ -311,7 +314,7 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name): return gm -@register_debug_backend +@register_debug_backend # type: ignore[arg-type] def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): from functorch.compile import minifier diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 6f7db66514c487..2013b3992eb54f 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -471,14 +471,14 @@ def generate( code, lineno, offset: int, - setup_fn_target_offsets: Tuple[int], # only used in Python 3.11+ + setup_fn_target_offsets: Tuple[int, ...], # only used in Python 3.11+ nstack: int, - argnames: Tuple[str], - argnames_null: Tuple[str], - setup_fns: Tuple[ReenterWith], - stack_ctx_vars: Tuple[int, Tuple[Any]], - argnames_ctx_vars: Tuple[str, Tuple[Any]], - null_idxes: Tuple[int], + argnames: Tuple[str, ...], + argnames_null: Tuple[str, ...], + setup_fns: Tuple[ReenterWith, ...], + stack_ctx_vars: Tuple[Tuple[int, Tuple[Any]], ...], + argnames_ctx_vars: Tuple[Tuple[str, Tuple[Any]], ...], + null_idxes: Tuple[int, ...], ) -> types.CodeType: assert offset is not None assert not ( diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index c44f8d9e69abf5..9febc69f42cda2 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -272,7 +272,7 @@ def guard_source(self): def name(self): return f"" - def make_guard(self): + def make_guard(self, fn): raise NotImplementedError def is_ephemeral(self): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 61acaec995757a..4a83b455a1252b 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3198,7 +3198,7 @@ def fake_mode(self): def run_ctx_mgr(self): return TracingContext.current_frame(self.parent.frame_summary()) - def STORE_DEREF(self, inst): + def STORE_DEREF(self, inst): # type: ignore[override] if inst.argval in self.closure_cells: cell = self.closure_cells[inst.argval] val = self.pop() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 2d4a771166cfbe..bed3564ec15ce0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2915,6 +2915,8 @@ def is_frozen_dataclass(value): not object_has_getattribute(value) and not class_has_getattribute(value) and is_dataclass(value) + and hasattr(value, "__dataclass_params__") + and hasattr(value.__dataclass_params__, "frozen") and value.__dataclass_params__.frozen ) diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index b1491ca5d47946..083cd69a970df6 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -195,7 +195,7 @@ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: self.node_replacements[node] = tensor - def run(self): + def run(self): # type: ignore[override] env = {} for n in self.module.graph.find_nodes(op="placeholder"): env[n] = self.unknown_value diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index b28f657b9d5d8f..ed99ae7beca12a 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -77,7 +77,7 @@ class TransformGetItemToIndex(TorchFunctionMode): # scalar and create a view. We do not want that behavior in this case, so we # use this torchfunctionmode to override that behavior for score_mod # wherever we're running it. - def __torch_function__(self, func, types, args, kwargs=None): + def __torch_function__(self, func, types, args=(), kwargs=None): if func == torch.Tensor.__getitem__: index_args = pytree.tree_leaves(args[1]) if all(isinstance(x, torch.Tensor) for x in index_args): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index c009246243f3cf..4e88a6f91f7e22 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1480,7 +1480,8 @@ def generate_profiler_mark_wrapper_call(self, stack): 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' ) - def write_triton_header_once(self): + @cache_on_self + def write_triton_header_once(self) -> None: pass def generate_start_graph(self): diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 90971045379656..9273eeb5f9824f 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -30,7 +30,7 @@ def __init__(self, scheduler: Scheduler) -> None: self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) - def get_backend_features(self, device): + def get_backend_features(self, device): # type:ignore[override] return self._triton_scheduling.get_backend_features(device) def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 95a48281977c87..643bf686bd8f19 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -578,22 +578,22 @@ def extract_read_writes( repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)} name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} for entry in fn.memory_usage[MemoryUsageType.LOAD]: - inner.load(entry.buffer_name, name_to_index[entry.index_name]) + inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type] for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: - inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type] for entry in fn.memory_usage[MemoryUsageType.STORE]: inner.store( - entry.buffer_name, name_to_index[entry.index_name], None, entry.mode + entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type] ) for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: inner.store_reduction( - entry.buffer_name, name_to_index[entry.index_name], None + entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type] ) for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: inner.index_expr(name_to_index[entry.index_name], None) for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: inner.bucketize( - None, entry.buffer_name, name_to_index[entry.index_name], None, None + None, entry.buffer_name, name_to_index[entry.index_name], None, None # type: ignore[arg-type] ) # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped else: diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 0c66da4ca4f314..21ba6c1e215dbe 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -2,7 +2,7 @@ import functools import itertools import logging -from typing import cast, List, Tuple +from typing import cast, Sequence, Tuple import sympy @@ -28,7 +28,7 @@ def filtered_configs( m: int, n: int, k: int, - configs: List[Tuple[int, int, int, int, int]], + configs: Sequence[Tuple[int, int, int, int, int]], has_int8_tensor=False, ): """Heuristic to shrink configs when they are bigger than the input size""" diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 14963a1bf5d435..056e24f1b4e26c 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -101,8 +101,8 @@ def put(self, key: str, value: _T) -> None: self._put(key, value, sample) self._log_sample(sample) - def _decode(self, data: _U, sample: Optional[Sample]) -> _T: - return self.serde.decode(data) + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] + return self.serde.decode(data) # type: ignore[arg-type] def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U return self.serde.encode(value) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 658dcbf33d99e5..a7a2bdce32da64 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -827,7 +827,7 @@ def benchmark_one_config(config): ) return config2launcher.get(best_config) - def run(self, *args, grid, stream, **kwargs): + def run(self, *args, grid, stream, **kwargs): # type:ignore[override] if len(self.launchers) != 1: if len(self.launchers) == 0: start_time = time.time_ns() @@ -958,7 +958,7 @@ def __init__(self, *args, regex_filter="", with_profiler=False, **kwargs): super().__init__(*args, **kwargs) self.cached = None - def run(self, *args, grid, stream): + def run(self, *args, grid, stream): # type: ignore[override] possible_names = _find_names(self) kernel_name = f"{max(possible_names, key=len)}" if not re.match(self.regex_filter, kernel_name): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5f8d35cbcab5eb..3d2676e0b2e000 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2857,9 +2857,7 @@ def has_shared_data_after_reordering_loop( return False # Pick the largest buffer to guide the loop reordering - numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[ - 0 - ] + numel, lhs_dep, rhs_dep = max(candidates, key=lambda x: x[0]) if lhs_dep.num_vars != rhs_dep.num_vars: # this can happen due to we don't merge loops. diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index b2eeb24521d382..9845a64874803a 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -268,4 +268,4 @@ def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.L elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type] return typing.List[type_args[0]] # type: ignore[valid-type] else: - return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc] + return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value] diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index bec1ebc559c328..cd5bfa655ea0b5 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -221,7 +221,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" ) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"FunctionalTensor({repr(self.elem)})" @staticmethod @@ -302,7 +302,7 @@ def cuda(self, device=None, *args, **kwargs): long = _conversion_method_template(dtype=torch.int64) # TODO(sparse-team): fixes #133174 but can we do without the relay? - def to_dense(self): + def to_dense(self): # type: ignore[override] return self.elem.to_dense() @property diff --git a/torch/_tensor.py b/torch/_tensor.py index 7d8010081abc77..e22c6a92d92d0f 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1353,7 +1353,7 @@ def align_to(self, *names): [name for name in names if not is_ellipsis(name)], ellipsis_idx ) - def unflatten(self, dim, sizes): + def unflatten(self, dim, sizes): # type: ignore[override] r""" unflatten(dim, sizes) -> Tensor diff --git a/torch/ao/nn/intrinsic/modules/fused.py b/torch/ao/nn/intrinsic/modules/fused.py index 010b9a701a3353..c7ae2ce3319d3c 100644 --- a/torch/ao/nn/intrinsic/modules/fused.py +++ b/torch/ao/nn/intrinsic/modules/fused.py @@ -228,7 +228,7 @@ def __init__(self, conv, add): super().__init__(conv) self.add = add - def forward(self, x1, x2): + def forward(self, x1, x2): # type: ignore[override] return self.add(self[0](x1), x2) @@ -241,5 +241,5 @@ def __init__(self, conv, add, relu): self.add = add self.relu = relu - def forward(self, x1, x2): + def forward(self, x1, x2): # type: ignore[override] return self.relu(self.add(self[0](x1), x2)) diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 3c700b30f24720..6ceee89757ca75 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import math -from typing import TypeVar +from typing import ClassVar, Optional, Type import torch import torch.ao.nn.intrinsic as nni @@ -33,12 +33,9 @@ } -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) - - class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): _version = 2 - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -365,7 +362,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert mod.qconfig, "Input float module must have a valid qconfig" qconfig = mod.qconfig - conv, bn = mod[0], mod[1] + conv, bn = mod[0], mod[1] # type: ignore[index] qat_convbn = cls( conv.in_channels, conv.out_channels, @@ -434,7 +431,7 @@ def to_float(self): return conv -class ConvBn1d(_ConvBnNd, nn.Conv1d): +class ConvBn1d(_ConvBnNd, nn.Conv1d): # type: ignore[misc] r""" A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, attached with FakeQuantize modules for weight, @@ -451,10 +448,10 @@ class ConvBn1d(_ConvBnNd, nn.Conv1d): weight_fake_quant: fake quant module for weight """ - _FLOAT_BN_MODULE = nn.BatchNorm1d - _FLOAT_RELU_MODULE: None = None - _FLOAT_MODULE = nni.ConvBn1d - _FLOAT_CONV_MODULE = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d def __init__( self, @@ -520,12 +517,12 @@ class ConvBnReLU1d(ConvBn1d): """ # base class defines _FLOAT_MODULE as "ConvBn1d" - _FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv1d - _FLOAT_BN_MODULE = nn.BatchNorm1d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBnReLU1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU1d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d def __init__( self, @@ -585,10 +582,10 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv1d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, @@ -631,7 +628,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) -class ConvBn2d(_ConvBnNd, nn.Conv2d): +class ConvBn2d(_ConvBnNd, nn.Conv2d): # type: ignore[misc] r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, @@ -648,10 +645,10 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBn2d - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE = nn.BatchNorm2d - _FLOAT_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -717,12 +714,12 @@ class ConvBnReLU2d(ConvBn2d): """ # base class defines _FLOAT_MODULE as "ConvBn2d" - _FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE = nn.BatchNorm2d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm2d]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU # type: ignore[assignment,misc] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU2d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU2d]]] = nni.ConvReLU2d def __init__( self, @@ -782,10 +779,10 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, @@ -828,7 +825,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) -class ConvBn3d(_ConvBnNd, nn.Conv3d): +class ConvBn3d(_ConvBnNd, nn.Conv3d): # type: ignore[misc] r""" A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, attached with FakeQuantize modules for weight, @@ -845,10 +842,10 @@ class ConvBn3d(_ConvBnNd, nn.Conv3d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBn3d - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE = nn.BatchNorm3d - _FLOAT_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -913,12 +910,12 @@ class ConvBnReLU3d(ConvBn3d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE = nn.BatchNorm3d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm3d]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.ReLU]]] = nn.ReLU # type: ignore[assignment, misc] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU3d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU3d]]] = nni.ConvReLU3d def __init__( self, @@ -980,10 +977,10 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py index 66299394091b7c..0d1b7e01f4479f 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -49,7 +49,7 @@ def __init__( dtype=dtype, ) - def forward(self, input, extra_input): + def forward(self, input, extra_input): # type: ignore[override] # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: @@ -117,7 +117,7 @@ def __init__( dtype=dtype, ) - def forward(self, input, extra_input): + def forward(self, input, extra_input): # type: ignore[override] # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 3be5a7cec8162e..49b559136bbc6e 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Tuple, TypeVar, Union +from typing import ClassVar, Tuple, Type, Union import torch import torch.nn as nn @@ -10,11 +10,9 @@ __all__ = ["Conv1d", "Conv2d", "Conv3d"] -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) - class _ConvNd(nn.modules.conv._ConvNd): - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -136,8 +134,8 @@ class Conv1d(_ConvNd, nn.Conv1d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv1d - _FLOAT_CONV_MODULE = nn.Conv1d + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d def __init__( self, @@ -197,8 +195,8 @@ class Conv2d(_ConvNd, nn.Conv2d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv2d - _FLOAT_CONV_MODULE = nn.Conv2d + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d def __init__( self, @@ -261,8 +259,8 @@ class Conv3d(_ConvNd, nn.Conv3d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv3d - _FLOAT_CONV_MODULE = nn.Conv3d + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d def __init__( self, diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index fa86dcdd93e17d..f379e487f65de8 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -2,6 +2,7 @@ r"""Dynamically quantized convolution modules.""" import warnings +from typing import ClassVar, Optional, Type import torch import torch.ao.nn.quantized as nnq @@ -47,9 +48,9 @@ class Conv1d(nnq.Conv1d): """ - _FLOAT_MODULE = nn.Conv1d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -132,9 +133,9 @@ class Conv2d(nnq.Conv2d): >>> output = m(input) """ - _FLOAT_MODULE = nn.Conv2d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -216,9 +217,9 @@ class Conv3d(nnq.Conv3d): >>> output = m(input) """ - _FLOAT_MODULE = nn.Conv3d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -308,7 +309,7 @@ class ConvTranspose1d(nnq.ConvTranspose1d): torch.Size([1, 16, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose1d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d def __init__( self, @@ -390,7 +391,7 @@ class ConvTranspose2d(nnq.ConvTranspose2d): torch.Size([1, 16, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose2d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d def __init__( self, @@ -472,7 +473,7 @@ class ConvTranspose3d(nnq.ConvTranspose3d): torch.Size([1, 16, 12, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose3d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d def __init__( self, diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 48ca18028acf94..0ef2f3af1dcd44 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Quantized convolution modules.""" -from typing import List, Optional, TypeVar +from typing import ClassVar, List, Optional, Type import torch import torch.ao.nn.intrinsic as nni @@ -386,11 +386,11 @@ class Conv1d(_ConvNd): """ - _FLOAT_MODULE = nn.Conv1d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d - _NNI_CONV_RELU_MODULE = nni.ConvReLU1d - _NNI_CONV_ADD_MODULE: None = None - _NNI_CONV_ADD_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -518,11 +518,11 @@ class Conv2d(_ConvNd): >>> output = m(q_input) """ - _FLOAT_MODULE = nn.Conv2d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d - _NNI_CONV_RELU_MODULE = nni.ConvReLU2d - _NNI_CONV_ADD_MODULE = nni.ConvAdd2d - _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU2d + _NNI_CONV_ADD_MODULE: ClassVar[Type[nni.ConvAdd2d]] = nni.ConvAdd2d + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Type[nni.ConvAddReLU2d]] = nni.ConvAddReLU2d def __init__( self, @@ -647,11 +647,11 @@ class Conv3d(_ConvNd): >>> output = m(q_input) """ - _FLOAT_MODULE = nn.Conv3d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d - _NNI_CONV_RELU_MODULE = nni.ConvReLU3d - _NNI_CONV_ADD_MODULE: None = None - _NNI_CONV_ADD_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn3d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU3d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -740,11 +740,10 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): # === Transposed Convolutions === -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) class _ConvTransposeNd(_ConvNd): - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -914,7 +913,7 @@ class ConvTranspose1d(_ConvTransposeNd): torch.Size([1, 16, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose1d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d def __init__( self, @@ -1037,7 +1036,7 @@ class ConvTranspose2d(_ConvTransposeNd): torch.Size([1, 16, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose2d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d def __init__( self, @@ -1162,7 +1161,7 @@ class ConvTranspose3d(_ConvTransposeNd): torch.Size([1, 16, 12, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose3d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d def __init__( self, diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 9a80b4af6f4794..9e4cf94303c9f1 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -198,7 +198,7 @@ def __init__(self): self.stats["float"] = [] self.stats["quantized"] = [] - def forward(self, x, y): + def forward(self, x, y): # type: ignore[override] # fmt: off """ """ # blank docblock to make autodoc happy diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index 3cea146b8031fd..da79b03ce88089 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -261,7 +261,7 @@ def __init__(self, *args, **kwargs): self.comparisons = [] # precalculated comparisons function - def forward(self, x, x_ref): + def forward(self, x, x_ref): # type: ignore[override] # fmt: off """ """ # blank docblock to make autodoc happy diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index 875af8c875b0a4..56784dd1780308 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -410,7 +410,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: # Add function swaps from default lowering path # - for source, ( + for source, ( # type:ignore[assignment] target1, target2, ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): @@ -422,7 +422,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: _lower_to_native_backend.QBIN_RELU_OP_MAPPING, quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, ): - for source, target in source_to_target.items(): + for source, target in source_to_target.items(): # type:ignore[assignment] new_connections.append((source, target)) # @@ -432,7 +432,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: for source_to_target in ( quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, ): - for source, target in source_to_target.items(): + for source, target in source_to_target.items(): # type:ignore[assignment] new_connections.append((source, target)) # add the new connections from backend_config diff --git a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py index b81ee658e55318..75d86737d832d5 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -71,7 +71,7 @@ def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults # add data with default config here [self.add_data(name, data, **self.defaults) for name, data in data_list] - def prepare(self): + def prepare(self, model, config): raise NotImplementedError("this function is undefined for this class") def _extract_weight(self, data): @@ -266,7 +266,7 @@ def __getstate__(self): "_container": self._container.state_dict(), } - def __repr__(self): + def __repr__(self): # type:ignore[override] format_string = self.__class__.__name__ + " (" for name, sparse_args in self.data_groups.items(): format_string += "\n" @@ -299,7 +299,7 @@ def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): self._container, name, leave_parametrized=leave_parametrized ) - def step(self): + def step(self): # type:ignore[override] if not self.enable_mask_update: return with torch.no_grad(): diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index 34444fa79df0fe..ec867205f64ab1 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -156,7 +156,7 @@ def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): ] # squeeze only the first 2 dimension return mask - def update_mask( + def update_mask( # type: ignore[override] self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs ): values_per_block = reduce(operator.mul, sparse_block_shape) diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index 54177272ce4f2a..3da27ba38df55b 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -75,7 +75,9 @@ def _compute_distance(self, t): return distance - def update_mask(self, module, tensor_name, sparsity_level, **kwargs): + def update_mask( # type: ignore[override] + self, module, tensor_name, sparsity_level, **kwargs + ): tensor_weight = getattr(module, tensor_name) mask = getattr(module.parametrizations, tensor_name)[0].mask diff --git a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py index 47c567ec6d93eb..a4d42ea803289c 100644 --- a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py +++ b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -33,7 +33,9 @@ def __init__(self, nearliness: int = 1): defaults = {"nearliness": nearliness} super().__init__(defaults=defaults) - def update_mask(self, module, tensor_name, nearliness, **kwargs): + def update_mask( # type:ignore[override] + self, module, tensor_name, nearliness, **kwargs + ): mask = getattr(module.parametrizations, tensor_name)[0].mask mask.data = torch.zeros_like(mask) if nearliness <= 0: diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index 3c2eef0dd15c79..a25b7ffbca61cb 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -208,7 +208,7 @@ def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None) mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() return mask - def update_mask( + def update_mask( # type: ignore[call-override, override] self, module, tensor_name, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 840d2715c7717f..503755d383bb93 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -216,5 +216,5 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return x - def calculate_qparams(self): + def calculate_qparams(self): # type:ignore[override] return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/torch/ao/quantization/_correct_bias.py b/torch/ao/quantization/_correct_bias.py index 4be37467de1818..03d259a0ad26ca 100644 --- a/torch/ao/quantization/_correct_bias.py +++ b/torch/ao/quantization/_correct_bias.py @@ -61,7 +61,7 @@ def __init__(self): self.float_sum = None self.quant_sum = None - def forward(self, x, y): + def forward(self, x, y): # type: ignore[override] """Compute the average of quantized and floating-point data from modules. The inputs x,y are output data from the quantized and floating-point modules. diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 631ec11b1669d1..1cf27c97054c1e 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -34,7 +34,7 @@ def __init__(self, b, k, dtype=torch.quint8) -> None: # min_val and max_val are optional args to override # the min_val and max_val observed by forward - def calculate_qparams(self, signed): + def calculate_qparams(self, signed): # type:ignore[override] return self._calculate_qparams(signed, self.min_val, self.max_val) r""" Calculates nonuniform quantization parameters according to APoT paper: diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 574af30a7159b9..6042fd2ee5adb9 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -225,7 +225,7 @@ def _map_module_function_to_aten_operator_type(): ), ) for map_item in map_list: - module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload] + module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] return module_function_to_aten_operator diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 77b962bf3df47c..4127885dccc1f6 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -600,7 +600,7 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): elem = inner_tensors["elem"] return AsyncCollectiveTensor(elem) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"AsyncCollectiveTensor({self.trigger_wait()})" def trigger_wait(self): @@ -653,7 +653,7 @@ def wrap(e: torch.Tensor): return out - def numpy(self): + def numpy(self): # type: ignore[override] return self.wait().numpy() diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 68df582cd5145e..d50160ca8ecc31 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1200,7 +1200,7 @@ def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: def __hash__(self): return id(self) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"ShardedTensor({self._metadata})" @dataclass diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index 0acafd6868d3bb..032b1979e576ce 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -168,6 +168,7 @@ def fp16_compress_hook( output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. """ fp16_hook = functools.partial(_low_precision_hook, torch.float16) + assert output is not None return fp16_hook(state, grad, output) @@ -189,4 +190,5 @@ def bf16_compress_hook( output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. """ bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16) + assert output is not None return bf16_hook(state, grad, output) diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index db65856e32adad..3c0027d1124073 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -96,7 +96,7 @@ def load_state_dict(self, state_dict): ) self.averager.step = 0 - def step(self): + def step(self): # type: ignore[override] r""" Performs a single optimization step (parameter update). """ diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 036ad6cf8621eb..3010bccd377c94 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -364,7 +364,7 @@ def __init__(self, module, garbage_collect_values=True): super().__init__(module, garbage_collect_values) self.value_remap = {} - def run(self, *args, initial_env=None): + def run(self, *args, initial_env=None): # type: ignore[override] self.value_remap = {} return super().run(*args, initial_env=initial_env) @@ -932,8 +932,7 @@ def move_param_to_callee( if node.op == "get_attr": # get_attr might get access deeper level attribute fqn = scope + "." + node.target if scope else node.target - if fqn in unused_attributes: # used, remove it - unused_attributes.remove(fqn) + unused_attributes.discard(fqn) for _name, _submod in _mod.named_children(): stack.append((scope + "." + _name if scope else _name, _submod)) # delete unused attributes diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 583aa20c1dc188..8684d7b0cafa0f 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -283,7 +283,7 @@ def __new__( # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. # pyre-fixme[3]: Return type must be annotated. - def __repr__(self): + def __repr__(self): # type: ignore[override] # TODO: consider all_gather the local tensors for better debugging return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" diff --git a/torch/distributed/tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py index de396473b77c91..df8c7d09e38a4a 100644 --- a/torch/distributed/tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -309,7 +309,7 @@ def __hash__(self): # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. # pyre-fixme[3]: Return type must be annotated. - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" def __str__(self) -> str: diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 630c287cae88f5..401ca3d3fdaf69 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -42,6 +42,7 @@ def input_reshard( cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None: + assert input_reshard_dim is not None saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks( partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim), partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim), diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 3c510bd32abc62..e6f730b123afd1 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -204,7 +204,7 @@ def __init__( self._is_discrete = is_discrete self._event_dim = event_dim - def __call__(self, fn): + def __call__(self, fn): # type: ignore[override] """ Support for syntax to customize static attributes:: diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index bd8fa87f2619a4..a4d403383d9cb0 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -177,7 +177,7 @@ def sample(self, sample_shape=torch.Size()): self._loc, self._concentration, self._proposal_r, x ).to(self.loc.dtype) - def expand(self, batch_shape): + def expand(self, batch_shape, _instance=None): try: return super().expand(batch_shape) except NotImplementedError: diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 69850df81c65a1..366cf45eb2d50f 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -258,7 +258,7 @@ def _set_data_mask(self, data, mask): self._masked_mask = mask self._validate_members() - def __repr__(self): + def __repr__(self): # type: ignore[override] formatter = "{0:8.4f}" if self.dim() == 0: scalar_data = self.get_data().item() @@ -350,7 +350,7 @@ def get_mask(self): def is_sparse_coo(self): return self.layout == torch.sparse_coo - def is_sparse_csr(self): + def is_sparse_csr(self): # type: ignore[override] return self.layout == torch.sparse_csr # Update later to support more sparse layouts diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index fa0818571a93c0..c02ad3dc2362ba 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -74,7 +74,7 @@ def __init__(self) -> None: def _after_fork(self): self.lock = threading.Lock() - def get(self, key): + def get(self, key): # type: ignore[override] with self.lock: return dict.get(self, key) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 76f8d5a29e2d8f..d39eb12d919c87 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -209,7 +209,7 @@ def _max_seqlen(self): def _min_seqlen(self): return self._get_min_seqlen() - def __repr__(self): + def __repr__(self): # type: ignore[override] # We should implement this in torch/_tensor_str.py instead grad_fn_str = ( f", requires_grad={self.requires_grad}" if self.requires_grad else "" diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 2fed60030a6898..da7acb957d96d7 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -289,7 +289,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): ) return cls._dispatch(*args, **kwargs) - def __repr__(self): + def __repr__(self): # type:ignore[override] return self._materialize().__repr__() diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 57dcbd85a83164..820a57f3e0b756 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -910,7 +910,7 @@ def __init__( self._last_lr = schedulers[0].get_last_lr() - def step(self): + def step(self): # type: ignore[override] """Perform a step.""" self.last_epoch += 1 idx = bisect_right(self._milestones, self.last_epoch) @@ -1179,7 +1179,7 @@ def __init__( group["lr"] for group in self._schedulers[-1].optimizer.param_groups ] - def step(self): + def step(self): # type: ignore[override] """Perform a step.""" for scheduler in self._schedulers: scheduler.step() diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 0017a10e6771bb..5e8a632633e09a 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -295,7 +295,7 @@ def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: else: return dense_input - def to_dense(self): + def to_dense(self): # type:ignore[override] col = self.shape[-1] return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) @@ -420,7 +420,7 @@ def from_dense( requires_grad=original_tensor.requires_grad, ) - def to_dense(self): + def to_dense(self): # type: ignore[override] assert self.meta is not None and self.packed is not None return ( sparse_semi_structured_to_dense_cutlass( diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 4924fff6dea6be..ed597402d8b637 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -310,11 +310,11 @@ def eval(cls, base, divisor, modulus): if isinstance(base, FloorDiv): return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] p, q = self.args[:2] return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined] @@ -329,14 +329,14 @@ class Where(sympy.Function): def _eval_is_integer(self): return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return ( True if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] else None ) - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] @classmethod @@ -397,7 +397,7 @@ def eval(cls, p, q): return S.Zero # NB: args[1] for PythonMod - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return True if self.args[1].is_positive else None # type: ignore[attr-defined] def _eval_is_nonpositive(self): @@ -823,13 +823,13 @@ class Max(MinMaxBase, Application): # type: ignore[misc] zero = S.Infinity identity = S.NegativeInfinity - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] - def _eval_is_negative(self): + def _eval_is_negative(self): # type:ignore[override] return fuzzy_and(a.is_negative for a in self.args) @@ -841,13 +841,13 @@ class Min(MinMaxBase, Application): # type: ignore[misc] zero = S.NegativeInfinity identity = S.Infinity - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] - def _eval_is_negative(self): + def _eval_is_negative(self): # type:ignore[override] return fuzzy_or(a.is_negative for a in self.args) diff --git a/torch/utils/weak.py b/torch/utils/weak.py index cc272a7f26375a..f729ff06489ffe 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -263,7 +263,7 @@ def pop(self, key, *args): def setdefault(self, key, default=None): return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED - def update(self, dict=None, **kwargs): + def update(self, dict=None, **kwargs): # type: ignore[override] d = self.data if dict is not None: if not hasattr(dict, "items"): From 41bd750278fee0e8dd76c28cfe30f855e43ec92d Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Fri, 13 Sep 2024 16:33:38 -0700 Subject: [PATCH 0907/1018] Add Triton CPU as an Inductor backend (#133408) The goal is to use Inductor-generated kernels to stress test the new Triton CPU backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133408 Approved by: https://github.com/jansel --- .ci/docker/build.sh | 1 + .../fsdp/test_fully_shard_compile.py | 22 +-- .../fully_shard/test_fully_shard_compile.py | 4 +- .../test_replicate_with_compiler.py | 14 +- .../_tensor/test_dtensor_compile.py | 10 +- .../fsdp/test_fsdp_use_orig_params.py | 4 +- .../tensor/parallel/test_micro_pipeline_tp.py | 18 +- .../test_c10d_functional_native.py | 28 +-- .../test_compute_comm_reordering.py | 14 +- test/distributed/test_dynamo_distributed.py | 50 +++--- test/distributed/test_functional_api.py | 8 +- test/distributed/test_inductor_collectives.py | 38 ++-- test/inductor/test_cuda_repro.py | 4 +- test/inductor/test_memory_planning.py | 3 +- test/inductor/test_triton_cpu_backend.py | 34 ++++ test/test_sparse_semi_structured.py | 8 +- torch/_dynamo/device_interface.py | 49 +++++ torch/_inductor/autotune_process.py | 72 ++++---- torch/_inductor/codegen/common.py | 7 +- torch/_inductor/codegen/cpp.py | 2 +- .../_inductor/codegen/cpp_template_kernel.py | 4 +- .../codegen/cpu_device_op_overrides.py | 26 +++ torch/_inductor/codegen/halide.py | 1 + .../codegen/rocm/rocm_benchmark_request.py | 8 +- torch/_inductor/codegen/triton.py | 7 +- torch/_inductor/codegen/wrapper.py | 168 +++++++++--------- torch/_inductor/config.py | 2 +- torch/_inductor/runtime/benchmarking.py | 2 +- torch/_inductor/runtime/triton_helpers.py | 15 ++ torch/_inductor/runtime/triton_heuristics.py | 3 + torch/_inductor/scheduler.py | 3 +- torch/_inductor/select_algorithm.py | 17 +- torch/_inductor/utils.py | 8 +- torch/utils/_triton.py | 16 +- 34 files changed, 422 insertions(+), 248 deletions(-) create mode 100644 test/inductor/test_triton_cpu_backend.py create mode 100644 torch/_inductor/codegen/cpu_device_op_overrides.py diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 61f7acfe7633fe..8aafbb89d811ec 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -379,6 +379,7 @@ case "$image" in GCC_VERSION=11 CONDA_CMAKE=yes HALIDE=yes + TRITON=yes ;; pytorch-linux-focal-linter) # TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627. diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 873fc4ba8bfaa2..2f038c713ff225 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -30,7 +30,7 @@ ModelArgs, Transformer, ) -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU log = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def _is_fallback_op_in_snodes(snodes, op): class TestFullyShardCompileCompute(FSDPTest): - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_disable_compiling_hooks(self): self.run_subtests( @@ -472,14 +472,14 @@ def input_creation_fn(): return model_init_fn, input_creation_fn @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_simple_mlp_fullgraph_backend_aot_eager(self): self._test_traceable_fsdp( *self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True ) @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self): self._test_traceable_fsdp( *self._create_simple_mlp_factory_fns(), @@ -488,7 +488,7 @@ def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self): ) @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_simple_mlp_fullgraph_backend_inductor(self): self._test_traceable_fsdp( *self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True @@ -556,7 +556,7 @@ def input_creation_fn(): return model_init_fn, input_creation_fn @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_aot_eager(self): for fullgraph in [True, False]: self._test_traceable_fsdp( @@ -566,7 +566,7 @@ def test_nested_fully_shard_backend_aot_eager(self): ) @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): for fullgraph in [True, False]: self._test_traceable_fsdp( @@ -576,7 +576,7 @@ def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): ) @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_inductor(self): for fullgraph in [True, False]: with self._reinplace_all_gather_with_optional_checks( @@ -736,7 +736,7 @@ def _sdpa_with_graph_break(orig_fn, fullgraph, *args, **kwargs): ) @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_transformer_backend_aot_eager(self): for fullgraph, all_requires_grad in itertools.product( [True, False], [True, False] @@ -753,7 +753,7 @@ def test_transformer_backend_aot_eager(self): ) @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout has worse accuracy after decomp, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_aot_eager_decomp_partition(self): @@ -770,7 +770,7 @@ def test_transformer_backend_aot_eager_decomp_partition(self): ) @skipIfRocm - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_inductor(self): diff --git a/test/distributed/_composable/fully_shard/test_fully_shard_compile.py b/test/distributed/_composable/fully_shard/test_fully_shard_compile.py index 05af94ff11e06f..1ad5b3be8462c2 100644 --- a/test/distributed/_composable/fully_shard/test_fully_shard_compile.py +++ b/test/distributed/_composable/fully_shard/test_fully_shard_compile.py @@ -18,7 +18,7 @@ TransformerWithSharedParams, ) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU if not dist.is_available(): @@ -38,7 +38,7 @@ class TestCompile(FSDPTest): def world_size(self) -> int: return torch.cuda.device_count() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_compile(self): self.run_subtests( diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 14d9cc47271f59..d3a519e8e6f93f 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -33,7 +33,7 @@ ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils.checkpoint import checkpoint @@ -216,21 +216,21 @@ def test_compile_cpu_no_sync(self): ] self._test_compile(use_gpu=False, no_sync=True) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) @torch._inductor.config.patch(reorder_for_locality=False) def test_compile_gpu(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=False) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) @torch._inductor.config.patch(reorder_for_locality=False) def test_compile_gpu_ac(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=True) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) def test_compile_bf16(self): @@ -244,7 +244,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: self._test_compile(use_gpu=True, no_sync=False, setup_func=setup) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) def test_compile_fp16(self): @@ -261,7 +261,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True ) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) def test_compile_backward_only(self): @@ -385,7 +385,7 @@ def setUp(self): def tearDown(self): dist.destroy_process_group() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm def test_ddp_tp(self): ref_model = Net() diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index eaef0716ba30dc..60a12d1935f103 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -46,7 +46,7 @@ with_comms, ) from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils.checkpoint import checkpoint @@ -420,7 +420,7 @@ def fn(x): tmp_dt._local_tensor.stride(), tmp_dt_fake._local_tensor.stride() ) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent(self): # Partial -> Shard on an unbalanced tensor results in: # - A contiguous DTensor @@ -496,7 +496,7 @@ def fw_hook(module, inp, out): out_test = opt_mod(dt) self.assertEqual(out_ref, out_test) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_dtensor_different_gradient_placement(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -628,7 +628,7 @@ def forward(self, primals_1): return (sin_1, primals_1, wait_tensor)""", ) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_dtensor_partial_placement_graph_output(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -646,7 +646,7 @@ def fn(x): out_dt = torch.matmul(tmp_dt, y_dt) out_dt.sum().backward() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(1) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index 33fae8dea8a561..e19a522e157002 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -43,7 +43,7 @@ TEST_WITH_DEV_DBG_ASAN, TestCase, ) -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU if not dist.is_available(): @@ -218,7 +218,7 @@ def _get_sharding_strategy_from_str( raise ValueError(f"Invalid string: {sharding_strategy_str}") return sharding_strategy - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_fsdp_compile(self): self.run_subtests( diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index 951e77188364c3..3a4b2eb946fc51 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -37,7 +37,7 @@ ) from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU def _make_post_grad_fx(f, *inps): @@ -78,7 +78,7 @@ def setUp(self): def tearDown(self): dist.destroy_process_group() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_find_all_gather_patterns(self): group = dist.group.WORLD @@ -129,7 +129,7 @@ def func(inp: torch.Tensor) -> torch.Tensor: torch.ops.aten.view.dtype, ) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_find_reduce_scatter_patterns(self): group = dist.group.WORLD @@ -168,7 +168,7 @@ def func(inp: torch.Tensor) -> torch.Tensor: self.assertEqual(reduce_scatters[1].reduce_op, "avg") self.assertEqual(reduce_scatters[1].scatter_dim, 1) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_get_unexposed_collectives(self): group = dist.group.WORLD @@ -193,7 +193,7 @@ def func(inp: torch.Tensor) -> torch.Tensor: ["all_gather_into_tensor", "reduce_scatter_tensor"], ) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("gather_dim", [0, 1, 2]) @fresh_inductor_cache() @@ -231,7 +231,7 @@ def func(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor: self.assertNotIn("all_gather_into_tensor", code) @runOnRocmArch(MI300_ARCH) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("gather_dim", [0, 1, 2]) @fresh_inductor_cache() @@ -299,7 +299,7 @@ def func( self.assertIn("fused_all_gather_scaled_matmul", code) self.assertNotIn("all_gather_into_tensor", code) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("scatter_dim", [0, 1, 2]) @fresh_inductor_cache() @@ -328,7 +328,7 @@ def func(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: self.assertNotIn("reduce_scatter_tensor", code) @runOnRocmArch(MI300_ARCH) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("scatter_dim", [0, 1, 2]) @fresh_inductor_cache() @@ -381,7 +381,7 @@ def func( self.assertIn("fused_scaled_matmul_reduce_scatter", code) self.assertNotIn("reduce_scatter_tensor", code) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("shard_dim", [0, 1]) @fresh_inductor_cache() def test_dtensor_seq_par(self, shard_dim: int): diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 619c4b7887f7f1..0bea36967b880e 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -28,7 +28,7 @@ TestCase, ) from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU def load_test_module(name): @@ -218,7 +218,7 @@ def test_all_gather_into_tensor_single(self) -> None: assert output.eq(expect).all() assert output.completed - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) # https://github.com/pytorch/pytorch/issues/126338 def test_inductor_dtypeview_memory_leak(self): @@ -434,7 +434,7 @@ def wait(self, _): torch.ops._c10d_functional.wait_tensor(tensor) self.assertTrue(wait_called) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @fresh_inductor_cache() def test_threading(self): @@ -498,7 +498,7 @@ def setUp(self): def tearDown(self): dist.destroy_process_group() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_reduce_single(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -535,7 +535,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_reduce_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: @@ -581,7 +581,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_inplace_op_on_view(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -608,7 +608,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: .run(code) ) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reuse_buffer_after_inplace_collective(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -643,7 +643,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_gather_into_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -670,7 +670,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_gather_into_tensor_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: @@ -704,7 +704,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reduce_scatter_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -730,7 +730,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reduce_scatter_tensor_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: @@ -766,7 +766,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_to_all_single(self): def _tolist_with_constrain_as_size(tensor): @@ -814,7 +814,7 @@ def func( .run(code) ) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_broadcast(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -850,7 +850,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_ranks_and_tag(self): def func(arg: torch.Tensor) -> torch.Tensor: diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index db0d6c8becbff0..bb0e54484b94e5 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -28,7 +28,7 @@ DynamoDistributedMultiProcTestCase, requires_nccl, ) -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU def get_snode_runtime_for_reorder_compute_test(snode): @@ -92,7 +92,7 @@ def world_size(self) -> int: # works around issue with skipif<2 and workers with unpredictable #s gpu return 2 - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -131,7 +131,7 @@ def func(a): correct = func(inputs) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -178,7 +178,7 @@ def func(a): correct = func(inputs) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -231,7 +231,7 @@ def func(a, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -283,7 +283,7 @@ def func(a, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -340,7 +340,7 @@ def func(a, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object( diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index d6357678c94f97..1cf98a4d8b4047 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -46,7 +46,7 @@ skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import requires_cuda -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU def reset_rng_state(): @@ -325,7 +325,7 @@ def run_hf_bert_ddp(self, model, inputs, backend): class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase): - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): @@ -528,7 +528,7 @@ def _test_hf_bert_ddp_inductor(self, static_graph): @skip_if_lt_x_gpu(2) @import_transformers_or_skip() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=True, enable_compiler_collectives=True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): @@ -536,7 +536,7 @@ def test_hf_bert_ddp_inductor(self): @skip_if_lt_x_gpu(2) @import_transformers_or_skip() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=True, enable_compiler_collectives=True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor_static_graph(self): @@ -561,7 +561,7 @@ def test_hf_bert_ddp_aot_eager_static_graph(self): self._test_hf_bert_aot_eager(static_graph=True) @skip_if_lt_x_gpu(2) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=False, enable_compiler_collectives=True) def test_ddp_activation_checkpointing(self): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -676,7 +676,7 @@ def test_fsdp_unspecialized_forced_getattr_inline(self): @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_fsdp_inductor(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) @@ -701,7 +701,7 @@ def test_fsdp_inductor(self): @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_fsdp_activation_checkpointing(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): model, inputs = get_toy_model_for_activation_checkpointing( @@ -722,7 +722,7 @@ def test_fsdp_activation_checkpointing(self): ) @import_transformers_or_skip() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) @@ -767,7 +767,7 @@ def apply_fsdp(model, wrap_policy): self.assertTrue(same(correct_results, opt_results)) @import_transformers_or_skip() - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) @@ -815,7 +815,7 @@ def test_hf_bert_fsdp_activation_checkpointing(self): ) self.assertTrue(same(correct_results, opt_results)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_tensor(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -860,7 +860,7 @@ def B(s): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_scalar(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -888,7 +888,7 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_speculation_divergence(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -921,7 +921,7 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_graph_break_empty_graph_still_collective(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -955,7 +955,7 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_dim_mismatch(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -984,7 +984,7 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_missing_source(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -1006,7 +1006,7 @@ def f(rank, xs): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_scalar_missing_source(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -1028,7 +1028,7 @@ def f(rank, xs): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_type_mismatch(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -1062,7 +1062,7 @@ def f(x): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", False) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) def test_asymmetric_compilation(self): @@ -1113,7 +1113,7 @@ def f(x): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", True) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) @patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10) @@ -1203,7 +1203,7 @@ def test_ddp_baseline_aot_eager(self): outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_inductor(self): from torch.nn.parallel import DistributedDataParallel as DDP @@ -1299,7 +1299,7 @@ def opt_fn(inputs): self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons)) @patch.object(config, "optimize_ddp", True) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor(self): assert config.optimize_ddp """ @@ -1368,18 +1368,18 @@ def opt_fn(inputs): opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_layout_optimizations_training(self): self._test_graph_split_inductor_layout_optimizations_impl( contextlib.nullcontext ) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_layout_optimizations_inference(self): self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad) @patch.object(config, "optimize_ddp", True) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_transpose(self): assert config.optimize_ddp @@ -1470,7 +1470,7 @@ def opt_fn(inputs): self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 3) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_empty_graph_inductor(self): def fn(): get_world_size = torch.distributed.distributed_c10d.get_world_size() diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index c4972a46405862..2624ce87089f37 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -14,7 +14,7 @@ from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU if not dist.is_available(): @@ -564,7 +564,7 @@ def test_all_to_all_single_split_sizes_none(self): expected = torch.cat(expected) self.assertEqual(y, expected) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @requires_nccl() @with_comms() def test_tracing(self): @@ -574,7 +574,7 @@ def allreduce(t, pg): compiled_allreduce = torch.compile(allreduce, fullgraph=True) compiled_allreduce(torch.randn(8, device=self.device), self.process_group) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_tracing_with_fakepg(self): exit_if_lt_x_gpu(self.world_size) @@ -590,7 +590,7 @@ def allreduce(t, pg): ) allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @requires_nccl() @with_comms() def test_tracing_with_dce_code(self): diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index a19183425d3907..216643e59cdec5 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -29,7 +29,7 @@ parametrize, requires_cuda, ) -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import HAS_GPU def _tolist_with_constrain_as_size(tensor): @@ -58,7 +58,7 @@ def world_size(self) -> int: # works around issue with skipif<2 and workers with unpredictable #s gpu return 2 - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_broadcast_inductor(self): """ @@ -90,7 +90,7 @@ def compile(func, example_inputs): compiled_out = compiled_func(*inputs) self.assertTrue(same(eager_out, compiled_out)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_allreduce_inductor(self): """ @@ -123,7 +123,7 @@ def compile(func, example_inputs): inductor_out = compiled_matmul_cat_col(*inputs) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_allreduce_inductor_cudagraph_trees(self): """ @@ -169,7 +169,7 @@ def test_c10d_functional_tagged_pt2_compliant(self): op = torch.ops.c10d_functional.all_reduce.default self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_eager_allreduce_inductor_wait(self): def eager_func(a, b, c, d, *, tag, ranks, group_size): @@ -208,7 +208,7 @@ def compile(func, example_inputs): print(f"inductor_out, {inductor_out}") self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_inductor_allreduce_eager_wait(self): def inductor_func(a, b, c, d, *, tag, ranks, group_size): @@ -243,7 +243,7 @@ def compile(func, example_inputs): ) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) def test_allreduce_input_buffer_reuse(self): @@ -261,7 +261,7 @@ def func(a, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_permute_tensor(self): def func(tensor, src_dst_pairs, *, tag, ranks, group_size): @@ -287,7 +287,7 @@ def func(tensor, src_dst_pairs, *, tag, ranks, group_size): self.assertEqual(out, expected) self.assertEqual(correct, expected) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) def test_allgather_output_buffer_reuse(self): @@ -311,7 +311,7 @@ def forward(self, x, world_size, tag, ranks, group_size): correct = model(inp, self.world_size, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_allgather_contiguous_input(self): class Model(torch.nn.Module): @@ -335,7 +335,7 @@ def forward(self, x, world_size, tag, ranks, group_size): correct = model(inp, self.world_size, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_allgather_into_tensor_inductor(self): """ @@ -366,7 +366,7 @@ def compile(func, example_inputs): inductor_out = compiled_matmul_cat_col(*inputs) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_reduce_scatter_tensor_inductor(self): def example(a, b, *, tag, ranks, group_size): @@ -393,7 +393,7 @@ def compile(func, example_inputs): inductor_out = compiled_fn(*inputs) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) def test_all_to_all_single_inductor(self): @@ -462,7 +462,7 @@ def example( inductor_out = compiled_fn(*inputs, **trs) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_all_to_all_single_inductor_split_sizes_none(self): def example(inp, *, tag, ranks, group_size): @@ -518,7 +518,7 @@ def get_world_trs(self, world_size=1): "group_size": world_size, } - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch(debug=True) def test_inductor_single_op(self): def func(inp, *, tag, ranks, group_size): @@ -547,7 +547,7 @@ def func(inp, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch(debug=True) def test_inductor_steal_buffer(self): """ @@ -582,7 +582,7 @@ def func(inp, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) def test_inductor_doesnt_mutate_shared(self): """ @@ -1030,7 +1030,7 @@ def test_meta(self): out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs()) self.assertEqual(x.size(), out.size()) - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) def test_inductor_all_gather_coalesced(self): """ @@ -1076,7 +1076,7 @@ def func(inp, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) assert same(out, correct), f"{out} va {correct}" - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) def test_inductor_reduce_scatter_coalesced(self): """ diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 5172801572b3fc..b493ce92539a2f 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1319,7 +1319,9 @@ def fn(x, y): self.assertEqual(expect, actual) # Expect the code iterates in contiguous order, and is not tiled - kernel_code = "\n".join(code[0].split("\n")[60:74]) + lines = code[0].split("\n") + start = lines.index("@triton.jit") + kernel_code = "\n".join(lines[start : start + 14]) self.assertExpectedInline( kernel_code, """\ diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index d3e07670492123..fd8bfd5c82bcc1 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -22,10 +22,9 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_cpp_code from torch.export import Dim -from torch.utils._triton import has_triton -@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") +@unittest.skipIf(not HAS_CUDA, "Inductor+gpu needs triton and CUDA") @config.patch(memory_planning=True) class TestMemoryPlanning(TestCase): def _generate(self, *, device): diff --git a/test/inductor/test_triton_cpu_backend.py b/test/inductor/test_triton_cpu_backend.py new file mode 100644 index 00000000000000..8be4d85caea6a6 --- /dev/null +++ b/test/inductor/test_triton_cpu_backend.py @@ -0,0 +1,34 @@ +# Owner(s): ["module: inductor"] +from torch._inductor import config +from torch._inductor.test_case import run_tests +from torch.testing._internal.inductor_utils import HAS_CPU +from torch.utils._triton import has_triton + + +try: + from . import test_torchinductor +except ImportError: + import test_torchinductor + +if has_triton(): + import triton + + TRITON_HAS_CPU = "cpu" in triton.backends.backends +else: + TRITON_HAS_CPU = False + + +if HAS_CPU and TRITON_HAS_CPU: + + @config.patch(cpu_backend="triton") + class SweepInputsCpuTritonTest(test_torchinductor.SweepInputsCpuTest): + pass + + @config.patch(cpu_backend="triton") + class CpuTritonTests(test_torchinductor.CpuTests): + pass + + +if __name__ == "__main__": + if HAS_CPU and TRITON_HAS_CPU: + run_tests(needs="filelock") diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 0f7c96065bcdd6..6bac38723d3013 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -38,9 +38,9 @@ IS_WINDOWS, ) -import pytest +from torch.testing._internal.inductor_utils import HAS_GPU -from torch.utils._triton import has_triton +import pytest SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict() @@ -981,7 +981,7 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): torch.backends.cuda.matmul.allow_tf32 = orig - @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @inference_dtypes def test_conversions(self, device, dtype): @@ -1009,7 +1009,7 @@ def run_test(r, c, device, dtype): for r, c in shapes: run_test(r, c, device, dtype) - @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @inference_dtypes def test_conversions_all_patterns(self, device, dtype): r, c = 32, 128 diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 00975cda550818..ad58a11bbd5642 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs import inspect +import time +from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union import torch @@ -295,6 +297,51 @@ def is_bf16_supported(including_emulation: bool = False) -> bool: return torch.xpu.is_bf16_supported() +@dataclass +class CpuDeviceProperties: + multi_processor_count: int + + +class CpuInterface(DeviceInterface): + class Event(_EventBase): + def __init__(self, enable_timing=True): + self.time = 0.0 + + def elapsed_time(self, end_event) -> float: + return (end_event.time - self.time) * 1000 + + def record(self): + self.time = time.perf_counter() + + @staticmethod + def is_available() -> bool: + return True + + @staticmethod + def get_compute_capability(device: _device_t = None) -> str: + return "" + + @staticmethod + def get_raw_stream(device_idx) -> int: + return 0 + + @staticmethod + def current_device(): + return 0 + + @staticmethod + def synchronize(device: _device_t = None): + pass + + class Worker: + @staticmethod + def get_device_properties(device: _device_t = None): + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + return CpuDeviceProperties(cpu_count) + + device_interfaces: Dict[str, Type[DeviceInterface]] = {} _device_initialized = False @@ -333,4 +380,6 @@ def init_device_reg(): for i in range(torch.xpu.device_count()): register_interface_for_device(f"xpu:{i}", XpuInterface) + register_interface_for_device("cpu", CpuInterface) + _device_initialized = True diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 94efbaf4e32fdd..8dcda8fa3d7dca 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -574,7 +574,7 @@ def benchmark( return self.value -class GPUDeviceBenchmarkRequest(BenchmarkRequest): +class GPUDeviceBenchmarkMixin: def do_bench( self, fn, @@ -601,7 +601,17 @@ def do_bench( return out -class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest): +class CPUDeviceBenchmarkMixin: + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + return benchmarker.benchmark_cpu(fn) + + +class TritonBenchmarkRequest(BenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! def __init__( @@ -646,28 +656,22 @@ def make_run_fn( if "warmup" in inspect.signature(run_method).parameters: warmup_arg["warmup"] = False - from torch._C import _cuda_getCurrentRawStream as get_raw_stream - - if torch.version.hip and self.matrix_instr_nonkdim != 0: - return functools.partial( - run_method, - *input_tensors, - output_tensor, - *self.extra_args, - grid=self.grid, - **warmup_arg, - stream=get_raw_stream(self.output_tensor_meta.device.index), - ) + if output_tensor.device.type == "cpu": + stream = 0 else: - return functools.partial( - run_method, - *input_tensors, - output_tensor, - *self.extra_args, - grid=self.grid, - **warmup_arg, - stream=get_raw_stream(self.output_tensor_meta.device.index), - ) + from torch._C import _cuda_getCurrentRawStream as get_raw_stream + + stream = get_raw_stream(self.output_tensor_meta.device.index) + + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + stream=stream, + ) def precompile(self): mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) @@ -677,7 +681,15 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" -class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest): +class TritonGPUBenchmarkRequest(GPUDeviceBenchmarkMixin, TritonBenchmarkRequest): + pass + + +class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest): + pass + + +class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! @@ -794,17 +806,7 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" -class CPUDeviceBenchmarkRequest(BenchmarkRequest): - def do_bench( - self, - fn, - *input_tensors: torch.Tensor, - output_tensor: Optional[torch.Tensor] = None, - ) -> float: - return benchmarker.benchmark_cpu(fn) - - -class CppBenchmarkRequest(CPUDeviceBenchmarkRequest): +class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put Tensors in here! diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index efbef7f461c8c2..3902d41b97ce3d 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -239,7 +239,11 @@ def init_backend_registration(): from .wrapper import WrapperCodeGen if get_scheduling_for_device("cpu") is None: - cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling} + cpu_backends = { + "cpp": CppScheduling, + "halide": HalideScheduling, + "triton": TritonScheduling, + } register_backend_for_device( "cpu", lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs), @@ -301,6 +305,7 @@ def get_device_op_overrides(device: str): assert isinstance(device, str) if not device_op_overrides_dict.keys(): + from . import cpu_device_op_overrides # noqa: F401 from .cuda import device_op_overrides # noqa: F401 from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 8bbe8023a312af..e4011322fbe88b 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4643,7 +4643,7 @@ def codegen_group(self, name=None) -> str: def call_kernel(self, wrapper, kernel_name): _, call_args, arg_types = self.args.cpp_argdefs() wrapper.generate_kernel_call( - kernel_name, call_args, gpu=False, arg_types=arg_types + kernel_name, call_args, gpu=False, triton=False, arg_types=arg_types ) diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index e72a895dc44d59..7ac125c76cc6fb 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -107,7 +107,9 @@ def hook(): def call_kernel(self, name: str, node: ir.CppTemplateBuffer): wrapper = V.graph.wrapper_code _, call_args, arg_types = self.args.cpp_argdefs() - wrapper.generate_kernel_call(name, call_args, gpu=False, arg_types=arg_types) + wrapper.generate_kernel_call( + name, call_args, triton=False, gpu=False, arg_types=arg_types + ) def dtype(self, node: ir.Buffer) -> str: return DTYPE_TO_CPP[node.get_dtype()] diff --git a/torch/_inductor/codegen/cpu_device_op_overrides.py b/torch/_inductor/codegen/cpu_device_op_overrides.py new file mode 100644 index 00000000000000..1944c0e6beb81e --- /dev/null +++ b/torch/_inductor/codegen/cpu_device_op_overrides.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +from textwrap import dedent + +from .common import DeviceOpOverrides, register_device_op_overrides + + +class CpuDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return dedent( + """ + def get_raw_stream(_): + return 0 + """ + ) + + def set_device(self, device_idx): + return "pass" + + def synchronize(self): + return "pass" + + def device_guard(self, device_idx): + return "pass" + + +register_device_op_overrides("cpu", CpuDeviceOpOverrides()) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 337aa544b0d10a..bbf6866fdc3818 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -1639,6 +1639,7 @@ def call_kernel(self, name: str, node=None): name, call_args, gpu=False, # grid/stream is handled internally in halide + triton=False, ) def generate_assert(self, check): diff --git a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py index a70f45b7033d6d..bdd68d9a2e6ba6 100644 --- a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py +++ b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -7,14 +7,18 @@ from typing import Any, Callable, Iterable, List, Optional, Union import torch -from torch._inductor.autotune_process import GPUDeviceBenchmarkRequest, TensorMeta +from torch._inductor.autotune_process import ( + BenchmarkRequest, + GPUDeviceBenchmarkMixin, + TensorMeta, +) from torch._inductor.codecache import DLLWrapper, ROCmCodeCache log = logging.getLogger(__name__) -class ROCmBenchmarkRequest(GPUDeviceBenchmarkRequest): +class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 69daf2cad522a6..a85fa4172829fb 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2602,6 +2602,11 @@ def codegen_kernel(self, name=None): if name is None: code.splice(gen_common_triton_imports()) + device_type = V.graph.scheduler.get_current_device_or_throw().type + if device_type == "cpu": + code.splice("triton_helpers.set_driver_to_cpu()") + else: + code.splice("triton_helpers.set_driver_to_gpu()") if config.benchmark_kernel: code.splice(self.imports_for_benchmark_kernel()) @@ -2834,7 +2839,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): call_args, grid, current_device.index, - gpu=True, + gpu=current_device.type != "cpu", triton=True, arg_types=arg_types, grid_fn=self._get_grid_fn(), diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 348f1a2a7bc673..75cbec0c8610a1 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1669,99 +1669,97 @@ def generate_kernel_call( gpu: Defines whether the backend is GPU. Otherwise the backend is CPU. - triton: Defines whether the GPU backend uses Triton for codegen. - Otherwise it uses the CUDA language for codegen. - Only valid when gpu == True. + triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True, + and C++ when gpu=False. """ - if gpu: - device_index, call_args_str = self.prepare_triton_kernel_call( - device_index, call_args + if not (triton or gpu): + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + return + + device_index, call_args_str = self.prepare_triton_kernel_call( + device_index, call_args + ) + call_args_str = ", ".join(call_args_str) + stream_name = self.write_get_raw_stream(device_index, V.graph) + + if not triton: + stream_ptr = f"c_void_p({stream_name})" + self.writeline( + f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" ) - call_args_str = ", ".join(call_args_str) - stream_name = self.write_get_raw_stream(device_index, V.graph) - if triton: - self.write_triton_header_once() - if grid is None: - grid_str = grid_fn - else: - grid_str = ", ".join(self._grid_dim_str(item) for item in grid) - if grid_extra_kwargs: - grid_str = f"{grid_str}, {grid_extra_kwargs}" - grid_str = f"{grid_fn}({grid_str})" - # add debug printer code for triton kernel calls at (jit) inductor level - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.set_printer_args( - call_args, kernel_name, arg_types, None - ) - with debug_printer_manager: - self.writeline( - f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" - ) - if ( - config.triton.autotune_at_compile_time - and kernel_name not in self.kernel_autotune_names - ): - # Create example args for autotune in a separate epilogue - assert arg_types is not None and len(call_args) == len( - arg_types - ), "call_args and arg_types do not match" - - tensor_args = {} - all_args = [] - if raw_args is None: - # create a dummy raw_args for uniform behavior in the following loop - raw_args = [None] * len(call_args) - else: - assert len(raw_args) == len( - call_args - ), "call_args and raw_args do not match" + return - for i, (arg, arg_type, raw_arg) in enumerate( - zip(call_args, arg_types, raw_args) - ): - key = None - if isinstance(arg, str) and "=" in str(arg): - # arg may be passed in a kwarg style, and then we need to extract its value - key, arg = arg.split("=") - - if isinstance(arg_type, torch_dtype): - if arg not in tensor_args: - arg_str = self.generate_example_arg_value( - arg, arg_type, raw_arg, i - ) - tensor_args[arg] = arg_str - else: - arg_str = tensor_args[arg] - else: - arg_str = self.generate_example_arg_value( - arg, arg_type, raw_arg, i - ) - all_args.append(arg_str if key is None else f"{key}={arg_str}") + self.write_triton_header_once() + if grid is None: + grid_str = grid_fn + else: + grid_str = ", ".join(self._grid_dim_str(item) for item in grid) + if grid_extra_kwargs: + grid_str = f"{grid_str}, {grid_extra_kwargs}" + grid_str = f"{grid_fn}({grid_str})" + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) + with debug_printer_manager: + self.writeline( + f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + ) + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Create example args for autotune in a separate epilogue + assert arg_types is not None and len(call_args) == len( + arg_types + ), "call_args and arg_types do not match" + + tensor_args = {} + all_args = [] + if raw_args is None: + # create a dummy raw_args for uniform behavior in the following loop + raw_args = [None] * len(call_args) + else: + assert len(raw_args) == len( + call_args + ), "call_args and raw_args do not match" - if grid is None: - grid_str = grid_fn - else: - grid_str = ", ".join( - self.generate_example_arg_value(g, type(g)) for g in grid + for i, (arg, arg_type, raw_arg) in enumerate( + zip(call_args, arg_types, raw_args) + ): + key = None + if isinstance(arg, str) and "=" in str(arg): + # arg may be passed in a kwarg style, and then we need to extract its value + key, arg = arg.split("=") + + if isinstance(arg_type, torch_dtype): + if arg not in tensor_args: + arg_str = self.generate_example_arg_value( + arg, arg_type, raw_arg, i ) - if grid_extra_kwargs: - grid_str = f"{grid_str}, {grid_extra_kwargs}" - grid_str = f"{grid_fn}({grid_str})" + tensor_args[arg] = arg_str + else: + arg_str = tensor_args[arg] + else: + arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i) + all_args.append(arg_str if key is None else f"{key}={arg_str}") - self.kernel_autotune_calls.writeline( - f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})" - ) - self.kernel_autotune_calls.writeline( - f"del {', '.join(arg for arg in tensor_args.values())}\n", - ) - self.kernel_autotune_names.add(kernel_name) + if grid is None: + grid_str = grid_fn else: - stream_ptr = f"c_void_p({stream_name})" - self.writeline( - f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" + grid_str = ", ".join( + self.generate_example_arg_value(g, type(g)) for g in grid ) - else: - self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + if grid_extra_kwargs: + grid_str = f"{grid_str}, {grid_extra_kwargs}" + grid_str = f"{grid_fn}({grid_str})" + + self.kernel_autotune_calls.writeline( + f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})" + ) + self.kernel_autotune_calls.writeline( + f"del {', '.join(arg for arg in tensor_args.values())}\n", + ) + self.kernel_autotune_names.add(kernel_name) def writeline(self, line): self.lines.append(line) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 1d247bcc5865c9..15c3446099229f 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1136,7 +1136,7 @@ class rocm: use_preselected_instances: bool = False -# Backend to use for CPU codegen either "cpp" or "halide" (experimental) +# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) cpu_backend = "cpp" # Backend to use for CUDA codegen either "triton" or "halide" (experimental) diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 44c222c3f936ef..8d521c0d714f61 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -77,7 +77,7 @@ def __init__(self: Self) -> None: def benchmark( self: Self, fn: Callable[..., Any], - fn_args: Tuple[Any], + fn_args: Tuple[Any, ...], fn_kwargs: Dict[str, Any], **kwargs: Any, ) -> float: diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 1a8ff2ba2408d7..37f17924331912 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -31,6 +31,21 @@ def _log2(x): raise NotImplementedError +def set_driver_to_cpu(): + if backend := triton.backends.backends.get("cpu", None): + triton.runtime.driver.set_active(backend.driver()) + return + raise RuntimeError("Could not find an active CPU backend") + + +def set_driver_to_gpu(): + for name, backend in triton.backends.backends.items(): + if backend.driver.is_active() and name != "cpu": + triton.runtime.driver.set_active(backend.driver()) + return + raise RuntimeError("Could not find an active GPU backend") + + @triton.jit def promote_to_tensor(x): # Addition promotes to tensor for us diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a7a2bdce32da64..d011acdba1310d 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -673,6 +673,9 @@ def kernel_call(): return do_bench_using_profiling(kernel_call, warmup=10, rep=40) + if self.device_props.type == "cpu": + return benchmarker.benchmark_cpu(kernel_call) + return benchmarker.benchmark_gpu(kernel_call, rep=40) def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 3d2676e0b2e000..480ee83cef4536 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3447,12 +3447,11 @@ def _codegen(self) -> None: self.current_device.type ): V.graph.wrapper_code.codegen_device_guard_exit() + self.current_device = device if device_need_guard(device.type): assert device.index is not None, "device should have an index" V.graph.wrapper_code.codegen_device_guard_enter(device.index) - self.current_device = device - self.buffer_names_to_free.update(node.last_usage) if node.is_template(): diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 9514f04476611f..a4e46acceef727 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -15,7 +15,7 @@ from collections import namedtuple from concurrent.futures import as_completed, ThreadPoolExecutor from io import StringIO -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from unittest.mock import patch import sympy @@ -27,7 +27,12 @@ from torch._dynamo.utils import counters, identity, preserve_rng_state from . import config, ir -from .autotune_process import TensorMeta, TritonBenchmarkRequest +from .autotune_process import ( + TensorMeta, + TritonBenchmarkRequest, + TritonCPUBenchmarkRequest, + TritonGPUBenchmarkRequest, +) from .codecache import code_hash, PersistentCache, PyCodeCache from .codegen.common import IndentedBuffer, KernelTemplate from .codegen.triton import ( @@ -575,6 +580,7 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}", arg_types=arg_types, triton_meta=self.triton_meta, + gpu="cpu" not in V.graph.device_types, ) @@ -744,7 +750,12 @@ def make_kernel_render(out_node): ), kwargs, ) - bmreq = TritonBenchmarkRequest( + bmreq_cls: Type[TritonBenchmarkRequest] + if layout.device.type == "cpu": + bmreq_cls = TritonCPUBenchmarkRequest + else: + bmreq_cls = TritonGPUBenchmarkRequest + bmreq = bmreq_cls( module_path=mod.__file__, module_cache_key=mod.key, kernel_name=kernel_name, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 1004817d6e7dbc..82aa6a2ac001e7 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1099,7 +1099,13 @@ def use_triton_template(layout, *, enable_int32=False, enable_float8=False): if enable_float8: layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2]) return ( - _use_template_for_cuda(layout, layout_dtypes) + ( + ( + layout.device.type == "cuda" + and _use_template_for_cuda(layout, layout_dtypes) + ) + or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) + ) and _use_autotune_backend("TRITON") and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) ) diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 92b42320c3c668..3cc977c8483823 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -17,15 +17,27 @@ def has_triton_package() -> bool: @functools.lru_cache(None) def has_triton() -> bool: + if not has_triton_package(): + return False + from torch._dynamo.device_interface import get_interface_for_device def cuda_extra_check(device_interface): return device_interface.Worker.get_device_properties().major >= 7 + def cpu_extra_check(device_interface): + import triton.backends + + return "cpu" in triton.backends.backends + def _return_true(device_interface): return True - triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true} + triton_supported_devices = { + "cuda": cuda_extra_check, + "xpu": _return_true, + "cpu": cpu_extra_check, + } def is_device_compatible_with_triton(): for device, extra_check in triton_supported_devices.items(): @@ -34,7 +46,7 @@ def is_device_compatible_with_triton(): return True return False - return is_device_compatible_with_triton() and has_triton_package() + return is_device_compatible_with_triton() @functools.lru_cache(None) From 4d548a6518c634fdbda8677f11033fbc57fb3d3e Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Fri, 13 Sep 2024 16:33:39 -0700 Subject: [PATCH 0908/1018] Add CI for Triton CPU backend (#135342) Where possible, I have marked failing tests (which we intend to fix or triage) as `@xfail_if_triton_cpu`. This will help us track progress of the Triton CPU backend over time. Tests that I don't think we need to address, or that are flaky, have been marked as skips. Successful CI run: https://github.com/pytorch/pytorch/actions/runs/10822238062/job/30028284549 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135342 Approved by: https://github.com/jansel ghstack dependencies: #133408 --- .ci/docker/build.sh | 8 + .ci/docker/ci_commit_pins/triton-cpu.txt | 1 + .ci/docker/common/install_triton.sh | 8 +- .ci/docker/ubuntu/Dockerfile | 7 + .ci/pytorch/test.sh | 7 + .github/workflows/inductor.yml | 22 +++ test/inductor/test_torchinductor.py | 182 ++++++++++++++++------ test/inductor/test_triton_cpu_backend.py | 12 +- torch/testing/_internal/inductor_utils.py | 1 + 9 files changed, 197 insertions(+), 51 deletions(-) create mode 100644 .ci/docker/ci_commit_pins/triton-cpu.txt diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 8aafbb89d811ec..9e0abc37ab8578 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -381,6 +381,13 @@ case "$image" in HALIDE=yes TRITON=yes ;; + pytorch-linux-jammy-py3.12-triton-cpu) + CUDA_VERSION=12.4 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=11 + CONDA_CMAKE=yes + TRITON_CPU=yes + ;; pytorch-linux-focal-linter) # TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627. # We will need to update mypy version eventually, but that's for another day. The task @@ -510,6 +517,7 @@ docker build \ --build-arg "UCC_COMMIT=${UCC_COMMIT}" \ --build-arg "CONDA_CMAKE=${CONDA_CMAKE}" \ --build-arg "TRITON=${TRITON}" \ + --build-arg "TRITON_CPU=${TRITON_CPU}" \ --build-arg "ONNX=${ONNX}" \ --build-arg "DOCS=${DOCS}" \ --build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \ diff --git a/.ci/docker/ci_commit_pins/triton-cpu.txt b/.ci/docker/ci_commit_pins/triton-cpu.txt new file mode 100644 index 00000000000000..36e9eea06792a7 --- /dev/null +++ b/.ci/docker/ci_commit_pins/triton-cpu.txt @@ -0,0 +1 @@ +6a333f1b05671f6fada4ba7bbfae4a02a9d96f4f diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index 6e5fb8839c686e..cb2d1edc71c995 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -15,8 +15,11 @@ conda_reinstall() { if [ -n "${XPU_VERSION}" ]; then TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton" TRITON_TEXT_FILE="triton-xpu" +elif [ -n "${TRITON_CPU}" ]; then + TRITON_REPO="https://github.com/triton-lang/triton-cpu" + TRITON_TEXT_FILE="triton-cpu" else - TRITON_REPO="https://github.com/openai/triton" + TRITON_REPO="https://github.com/triton-lang/triton" TRITON_TEXT_FILE="triton" fi @@ -44,9 +47,10 @@ chown -R jenkins /var/lib/jenkins/triton chgrp -R jenkins /var/lib/jenkins/triton pushd /var/lib/jenkins/ -as_jenkins git clone ${TRITON_REPO} triton +as_jenkins git clone --recursive ${TRITON_REPO} triton cd triton as_jenkins git checkout ${TRITON_PINNED_COMMIT} +as_jenkins git submodule update --init --recursive cd python # TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 80410be3cbc69f..1be2fccca41922 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -147,6 +147,13 @@ COPY ci_commit_pins/triton.txt triton.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton.txt +ARG TRITON_CPU +COPY ./common/install_triton.sh install_triton.sh +COPY ./common/common_utils.sh common_utils.sh +COPY ci_commit_pins/triton-cpu.txt triton-cpu.txt +RUN if [ -n "${TRITON_CPU}" ]; then bash ./install_triton.sh; fi +RUN rm install_triton.sh common_utils.sh triton-cpu.txt + ARG EXECUTORCH # Build and install executorch COPY ./common/install_executorch.sh install_executorch.sh diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9d2f2246ad7509..594c5e4ef69252 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -607,6 +607,11 @@ test_inductor_halide() { assert_git_not_dirty } +test_inductor_triton_cpu() { + python test/run_test.py --include inductor/test_triton_cpu_backend.py --verbose + assert_git_not_dirty +} + test_dynamo_benchmark() { # Usage: test_dynamo_benchmark huggingface 0 TEST_REPORTS_DIR=$(pwd)/test/test-reports @@ -1433,6 +1438,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then test_inductor_distributed elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then test_inductor_halide +elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then + test_inductor_triton_cpu elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then test_inductor_micro_benchmark elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 88ffe090fd8897..0c630878fea538 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -118,6 +118,28 @@ jobs: docker-image: ${{ needs.linux-jammy-cpu-py3_12-inductor-halide-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cpu-py3_12-inductor-halide-build.outputs.test-matrix }} + linux-jammy-cpu-py3_12-inductor-triton-cpu-build: + name: linux-jammy-cpu-py3.12-gcc11-inductor-triton-cpu + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-py3.12-gcc11 + docker-image-name: pytorch-linux-jammy-py3.12-triton-cpu + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + test-matrix: | + { include: [ + { config: "inductor-triton-cpu", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + ]} + + linux-jammy-cpu-py3_12-inductor-triton-cpu-test: + name: linux-jammy-cpu-py3.12-gcc11-inductor-triton-cpu + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_12-inductor-triton-cpu-build + with: + build-environment: linux-jammy-py3.12-gcc11 + docker-image: ${{ needs.linux-jammy-cpu-py3_12-inductor-triton-cpu-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_12-inductor-triton-cpu-build.outputs.test-matrix }} + linux-focal-cuda12_4-py3_10-gcc9-inductor-build: # Should be synced with the one in inductor-periodic.yml but this only runs inductor_timm name: cuda12.4-py3.10-gcc9-sm86 diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index f2240ff64922d4..f2c333bac132be 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -738,12 +738,6 @@ def is_cpp_backend(device): return getattr(device, "type", device) == "cpu" and config.cpu_backend == "cpp" -def is_halide_backend(device): - if getattr(device, "type", device) == "cpu": - return config.cpu_backend == "halide" - return config.cuda_backend == "halide" - - def skip_if_halide(fn): @functools.wraps(fn) def wrapper(self): @@ -754,6 +748,42 @@ def wrapper(self): return wrapper +def is_halide_backend(device): + if getattr(device, "type", device) == "cpu": + return config.cpu_backend == "halide" + return config.cuda_backend == "halide" + + +def is_triton_cpu_backend(device): + return getattr(device, "type", device) == "cpu" and config.cpu_backend == "triton" + + +def skip_if_triton_cpu(fn): + import types + + reason = "Triton CPU not supported" + + def decorator(fn): + @functools.wraps(fn) + def wrapper(self): + if is_triton_cpu_backend(self.device): + raise unittest.SkipTest(reason) + return fn(self) + + return wrapper + + if isinstance(fn, types.FunctionType): + return decorator(fn) + else: + reason = fn + return decorator + + +def xfail_if_triton_cpu(fn): + fn._expected_failure_triton_cpu = True + return fn + + def skip_if_gpu_halide(fn): @functools.wraps(fn) def wrapper(self): @@ -793,6 +823,7 @@ def fn(a, b): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti + @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_dtype_device_layout(self): ns = "aten" @@ -835,6 +866,7 @@ def test_aoti_eager_dtype_device_layout(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti + @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_support_out(self): ns = "aten" @@ -889,6 +921,7 @@ def test_aoti_eager_support_out(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti + @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_support_str(self): ns = "aten" @@ -928,6 +961,7 @@ def test_aoti_eager_support_str(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti + @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_cache_hit(self): ns = "aten" @@ -971,6 +1005,7 @@ def test_aoti_eager_cache_hit(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti + @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_with_persistent_cache(self): def fn(a): @@ -1017,6 +1052,7 @@ def fn(a): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti + @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_with_scalar(self): namespace_name = "aten" @@ -1089,6 +1125,7 @@ def test_aoti_eager_with_scalar(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti + @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_override_registration(self): namespace_name = "aten" @@ -1247,6 +1284,7 @@ def fn(a): self.common(fn, (torch.randn(17),)) + @xfail_if_triton_cpu def test_angle(self): def fn(a, b, c): return torch.angle(a), torch.angle(b), torch.angle(c) @@ -1436,6 +1474,7 @@ def nested(x, repeats): actual = nested_opt(*example_inputs) self.assertEqual(expect, actual) + @xfail_if_triton_cpu def test_index_propagation_flip(self): def flip(x): i = torch.arange(x.size(0) - 1, -1, -1, device=x.device) @@ -1521,7 +1560,7 @@ def test( fn_opt = torch.compile(fn) if is_halide_backend(self.device): pass # no device asserts in halide - elif self.device == "cpu": + elif self.device == "cpu" and not is_triton_cpu_backend(self.device): _, code = run_and_get_cpp_code(fn_opt, *inps) self.assertTrue((") ? (" in code or "blendv" in code) is has_wrapping) self.assertTrue(("TORCH_CHECK" in code) is has_assert) @@ -1617,6 +1656,7 @@ def constant_propagation_neg(a): vectorize=False, # There's no loop to vectorize! ) + @xfail_if_triton_cpu def test_computed_buffer_inlining(self): def flip(x): idx = torch.arange(x.size(0) - 1, -1, -1, device=x.device) @@ -2224,6 +2264,7 @@ def make_tensor(shape): ) self.assertEqual(cfn(inp), fn(inp)) + @xfail_if_triton_cpu def test_logcumsumexp(self): def fn(x): return x.logcumsumexp(0), x.logcumsumexp(1) @@ -2271,6 +2312,7 @@ def fn(a): self.common(fn, (torch.randint(4, (4,)),)) @skip_if_gpu_halide + @xfail_if_triton_cpu def test_dist(self): def fn(a, b): return ( @@ -2520,6 +2562,7 @@ def fn(a, b): self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) + @xfail_if_triton_cpu def test_round(self): def fn(a, b): return torch.round(a), torch.round(b + 1), torch.round(a, decimals=2) @@ -2531,6 +2574,7 @@ def fn(a, b): # with *100 we are always getting a number exactly at .5 which we don't do right in half self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 10)) + @xfail_if_triton_cpu def test_round_correctness(self): if self.device == "cuda": raise unittest.SkipTest("need to debug tl.libdevice on A100/V100") @@ -2544,6 +2588,7 @@ def fn(a): check_lowp=False, ) + @xfail_if_triton_cpu def test_builtins_round(self): def fn(x, i): return x[: round(i / 2 + 1)] + round(i / 2) @@ -2555,6 +2600,7 @@ def fn(x, i): for i in range(1, 6): self.assertEqual(cfn(x, i), fn(x, i)) + @xfail_if_triton_cpu def test_builtins_round_float_ndigits_pos(self): def fn(x, i): return x + round(i / 2 * 123.4567, 1) @@ -2567,6 +2613,7 @@ def fn(x, i): with torch.no_grad(): self.assertEqual(cfn(x, i), fn(x, i)) + @xfail_if_triton_cpu def test_builtins_round_float_ndigits_zero(self): def fn(x, i): return x + round(i / 2 * 123.4567, 0) @@ -2579,6 +2626,7 @@ def fn(x, i): with torch.no_grad(): self.assertEqual(cfn(x, i), fn(x, i)) + @xfail_if_triton_cpu def test_builtins_round_float_ndigits_neg(self): def fn(x, i): return x + round(i / 2 * 123.4567, -1) @@ -2730,6 +2778,7 @@ def fn(a, b): (torch.ones([8, 8], dtype=torch.bool), torch.randint(-100, -1, [8, 8])), ) + @skip_if_triton_cpu # divide by zero; cannot xfail because it crashes process def test_div7(self): def fn(a, b): return ( @@ -2768,6 +2817,7 @@ def fn(x): self.common(fn, (torch.randn(8),)) + @skip_if_triton_cpu # divide by zero; cannot xfail because it crashes process def test_div_zero_dim(self): def fn(a, b): return ( @@ -2794,6 +2844,7 @@ def fn(a, b): ), ) + @skip_if_triton_cpu # divide by zero; cannot xfail because it crashes process def test_div_prim(self): def fn(a, b): return (torch.ops.prims.div(a, b),) @@ -3947,6 +3998,7 @@ def fn(a): ) @requires_gpu() + @xfail_if_triton_cpu def test_multi_device(self): def fn(x): x = x + 1 @@ -4344,6 +4396,7 @@ def run_weights_sharing_model(m, inp): thread.join() @unittest.skipIf(config.is_fbcode(), "fbcode triton error, needs debugging") + @skip_if_triton_cpu("Flaky on Triton CPU") @skip_if_gpu_halide # https://github.com/halide/Halide/issues/8311 def test_adaptive_avg_pool2d_low_prec(self): class Model(torch.nn.Module): @@ -4773,6 +4826,7 @@ def fn(x): ) @skip_if_halide # lgamma not implemented + @xfail_if_triton_cpu def test_lgamma(self): def fn(x): return aten.lgamma(x) + 2, aten.cos(x + 1) @@ -5312,6 +5366,7 @@ def fn(x): (torch.randn([16, 16]),), ) + @xfail_if_triton_cpu def test_pow2(self): def fn(x): return aten.pow(1000, x), aten.pow(x, 1000) @@ -5361,6 +5416,7 @@ def fn(x, y): ), ) + @xfail_if_triton_cpu def test_pow_symfloat(self): def fn(x): r = math.sqrt(x.size(0)) @@ -5869,6 +5925,7 @@ def fn(x): self.common(fn, (torch.randn([1, 2, 6, 6]),)) + @xfail_if_triton_cpu def test_fmod(self): def fn(a, b): return torch.fmod(a, b), torch.fmod(3.0 * a, b) - 2.0 @@ -5876,6 +5933,7 @@ def fn(a, b): shape = [1, 2, 6, 6] self.common(fn, (torch.randn(shape), torch.randn(shape))) + @xfail_if_triton_cpu def test_fmod_zero_dim(self): def fn(a, b): return (torch.fmod(a, b),) @@ -7296,6 +7354,7 @@ def fn(x): actual_out = compiled_fn(view) self.assertEqual(reference_out.stride(), actual_out.stride()) + @xfail_if_triton_cpu def test_like_channels_last(self): def foo(): randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32) @@ -7405,6 +7464,7 @@ def fn(a, b): self.common(fn, [a, b]) @with_tf32_off + @xfail_if_triton_cpu def test_slice_scatter_reinplace(self): class M(nn.Module): def __init__(self, device): @@ -8182,6 +8242,7 @@ def fn(x): self.assertFalse(torch.allclose(a0, a1)) @requires_gpu() + @skip_if_triton_cpu("Flaky on Triton CPU") def test_like_rands3(self): # rand_like with `device` which is different from `x.device` def test_like_rands_on_different_device(device1, device2): @@ -8543,6 +8604,7 @@ def fn(a, b): ) assertGeneratedKernelCountEqual(self, 0) + @xfail_if_triton_cpu @config.patch(search_autotune_cache=False) def test_mm_views(self): def fn(a, b): @@ -8630,6 +8692,7 @@ def check(r, g): self.assertTrue(same(r2, r3)) self.assertTrue(same(g2, g3)) + @xfail_if_triton_cpu @config.patch(search_autotune_cache=False) def test_dropout3(self): m = torch.nn.Sequential( @@ -9417,6 +9480,7 @@ def fn(a, b): ], ) + @xfail_if_triton_cpu def test_index_dynamic_shapes(self): # Repro from vision_maskrcnn def fn(arg0_1): @@ -9589,6 +9653,7 @@ def forward(self, arg0_1, arg1_1): self.assertEqual(inductor_out, eager_out) @skipIfRocm + @xfail_if_triton_cpu def test_require_stride_expanded(self): def forward(arg6, arg7, arg16): convolution = torch.ops.aten.convolution( @@ -9858,6 +9923,7 @@ def fn(a, b): self.common(fn, (torch.rand(1), torch.rand(2))) + @xfail_if_triton_cpu def test_view_on_aliased(self): # https://github.com/pytorch/pytorch/issues/96728 def fn1(a, b): @@ -9929,6 +9995,7 @@ def fn(x): self.common(fn, (torch.ones(1, 1, 13, dtype=dtype),)) @unittest.skipIf(not HAS_CPU, "requires C++ compiler") + @xfail_if_triton_cpu # bf16 @skip_if_halide # bf16 def test_data_type_propogation(self): from torch._dynamo.utils import detect_fake_mode @@ -10229,6 +10296,7 @@ def fn(x, y): opt_fn = torch._dynamo.optimize("inductor")(fn) same(fn(x, y), opt_fn(x_clone, y)) + @xfail_if_triton_cpu def test_erfc(self): def fn(x): return torch.erfc(x) @@ -10236,6 +10304,7 @@ def fn(x): self.common(fn, (torch.randn(8, 8),)) @skip_if_halide # erfinv not implemented + @xfail_if_triton_cpu def test_erfinv(self): def fn(x): return torch.erfinv(x) @@ -10619,6 +10688,7 @@ def fn(x): self.common(fn, (inp,), check_lowp=False) @requires_gpu() + @xfail_if_triton_cpu @config.patch(implicit_fallbacks=True) def test_mutable_custom_op_fixed_layout2(self): with torch.library._scoped_library("mylib", "DEF") as lib: @@ -10940,6 +11010,7 @@ def fn(x): @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") # We only support dtypeview for abi_conpatible aoti + @skip_if_triton_cpu("Compile time crash in Triton CPU CI") @torch._inductor.config.patch(abi_compatible=True) def test_dtypeview(self): if TEST_WITH_ASAN: @@ -11016,6 +11087,7 @@ def fn(x, x2): _, code = run_and_get_code(fn, x, x2) FileCheck().check("aten.view.dtype(reinterpret_tensor").run(code[0]) + @xfail_if_triton_cpu @requires_gpu() def test_scalar_cpu_tensor_arg(self): def fn(x, y): @@ -11050,6 +11122,7 @@ def fn(x): @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") @skip_if_gpu_halide # https://github.com/halide/Halide/issues/8311 + @xfail_if_triton_cpu def test_bfloat16_to_int16(self): def fn(a, b): x = a + b @@ -11152,47 +11225,60 @@ def test_pointwise(self, name, op): # _cuda not implemented for Half check_lowp = False - if is_halide_backend(self.device) and name in ( - "erfinv", - "airy_ai", - "bessel_j0", - "bessel_j1", - "bessel_y0", - "bessel_y1", - "chebyshev_polynomial_t", - "chebyshev_polynomial_u", - "chebyshev_polynomial_v", - "chebyshev_polynomial_w", - "digamma", - "gammainc", - "gammaincc", - "gammaln", - "hermite_polynomial_h", - "hermite_polynomial_he", - "i0", - "i0e", - "i1", - "i1e", - "laguerre_polynomial_l", - "legendre_polynomial_p", - "modified_bessel_i0", - "modified_bessel_i1", - "modified_bessel_k0", - "modified_bessel_k1", - "multigammaln", - "ndtri", - "polygamma", - "psi", - "scaled_modified_bessel_k0", - "scaled_modified_bessel_k1", - "shifted_chebyshev_polynomial_t", - "shifted_chebyshev_polynomial_u", - "shifted_chebyshev_polynomial_v", - "shifted_chebyshev_polynomial_w", - "spherical_bessel_j0", - "zeta", + if ( + is_halide_backend(self.device) + or is_triton_cpu_backend(self.device) + and name + in ( + "erfinv", + "airy_ai", + "bessel_j0", + "bessel_j1", + "bessel_y0", + "bessel_y1", + "chebyshev_polynomial_t", + "chebyshev_polynomial_u", + "chebyshev_polynomial_v", + "chebyshev_polynomial_w", + "digamma", + "gammainc", + "gammaincc", + "gammaln", + "hermite_polynomial_h", + "hermite_polynomial_he", + "i0", + "i0e", + "i1", + "i1e", + "laguerre_polynomial_l", + "legendre_polynomial_p", + "modified_bessel_i0", + "modified_bessel_i1", + "modified_bessel_k0", + "modified_bessel_k1", + "multigammaln", + "ndtri", + "polygamma", + "psi", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", + "shifted_chebyshev_polynomial_t", + "shifted_chebyshev_polynomial_u", + "shifted_chebyshev_polynomial_v", + "shifted_chebyshev_polynomial_w", + "spherical_bessel_j0", + "zeta", + ) ): - raise unittest.SkipTest(f"halide does not support {name}") + raise unittest.SkipTest(f"Halide & Triton CPU do not support {name}") + + if is_triton_cpu_backend(self.device) and name in [ + "erfc", + "erfcx", + "round", + "log_ndtr", + ]: + raise unittest.SkipTest(f"Triton CPU does not support {name}") if name in {"gammainc", "gammaincc"}: args = ( @@ -11314,6 +11400,7 @@ def test_generate_rand_fp8(self): t = rand_strided((2, 3), (3, 1), device=self.device, dtype=torch.float8_e4m3fn) self.assertTrue(t.dtype is torch.float8_e4m3fn) + @skip_if_triton_cpu("Triton CPU: Cannot xfail because it crashes process") def test_large_grid(self): # https://github.com/pytorch/pytorch/issues/123210 def fn(primals_5): @@ -11368,6 +11455,7 @@ def forward(): self.common(forward, ()) + @xfail_if_triton_cpu def test_flip_cat(self): def forward(unsqueeze, unsqueeze_1): cat_1 = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1) diff --git a/test/inductor/test_triton_cpu_backend.py b/test/inductor/test_triton_cpu_backend.py index 8be4d85caea6a6..cb738163b3d120 100644 --- a/test/inductor/test_triton_cpu_backend.py +++ b/test/inductor/test_triton_cpu_backend.py @@ -25,8 +25,16 @@ class SweepInputsCpuTritonTest(test_torchinductor.SweepInputsCpuTest): pass @config.patch(cpu_backend="triton") - class CpuTritonTests(test_torchinductor.CpuTests): - pass + class CpuTritonTests(test_torchinductor.TestCase): + common = test_torchinductor.check_model + device = "cpu" + + test_torchinductor.copy_tests( + test_torchinductor.CommonTemplate, + CpuTritonTests, + "cpu", + xfail_prop="_expected_failure_triton_cpu", + ) if __name__ == "__main__": diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 00f9ba0dd2ee29..d8e86651fa9a56 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -74,6 +74,7 @@ def _check_has_dynamic_shape( def skipDeviceIf(cond, msg, *, device): if cond: def decorate_fn(fn): + @functools.wraps(fn) def inner(self, *args, **kwargs): if not hasattr(self, "device"): warn_msg = "Expect the test class to have attribute device but not found. " From c0325213003ec960b9821149237ce1afabc0d2fe Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Sat, 14 Sep 2024 06:44:48 +0000 Subject: [PATCH 0909/1018] Fix dividing Mul by factor (#136079) Fixes https://github.com/pytorch/pytorch/issues/136032 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136079 Approved by: https://github.com/ezyang --- torch/fx/experimental/symbolic_shapes.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 68701a2080e5df..e5ed7ac1bd27a2 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -447,14 +447,20 @@ def div_by_factor(x, factor): if x.is_Integer: return x / factor elif x.is_Mul: - return sympy.Mul._from_args((x.args[0] / factor, *x.args[1:]), is_commutative=x.is_commutative) + if x.args[0] != factor: + args = (x.args[0] / factor, *x.args[1:]) + else: + # Mul._from_args require a canonical list of args + # so we remove the first arg (x.args[0] / factor) if it was 1 + args = x.args[1:] + return sympy.Mul._from_args(args, is_commutative=x.is_commutative) if expr.is_Add: atoms = expr.args factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) if factor == 1: return expr - atoms = [x / factor for x in atoms] + atoms = [div_by_factor(x, factor) for x in atoms] return sympy.Add._from_args(atoms) elif expr.is_Integer: return sympy.One From 75925aee8145bfc89a898975c429521981fe22a9 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 13 Sep 2024 17:08:10 +0000 Subject: [PATCH 0910/1018] use a fast expand algorithm (#135999) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135999 Approved by: https://github.com/ezyang --- torch/fx/experimental/symbolic_shapes.py | 55 +++++++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e5ed7ac1bd27a2..580d3ed741e3e4 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1414,13 +1414,64 @@ def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) + +def _expandsums(args: List[sympy.Expr]) -> Tuple[sympy.Expr, bool]: + adds, other = [], [] + for arg in args: + if arg.is_Add: + adds.append(arg) + else: + other.append(arg) + + result = [sympy.Mul(*other)] + for add in adds: + result = [a * b for a, b in itertools.product(result, add.args)] + + result = sympy.Add(*result) + return result, len(adds) > 1 or (len(adds) > 0 and len(other) > 0) + + +def _fast_expand(expr: sympy.Expr) -> sympy.Expr: + # The expand algorithm in sympy is slow due to all the features is supports + # For eg: e^(-x)*(x-1)/(x+1) is expanded to (x-1)/(e^x + e^x*x) if x is + # positive and (e^(-x)*x-e^(-x))/(x+1) if x is negative. We do not implement + # such features here to avoid expensive checks. We also make sure that we + # only re-create the objects if any of the args changed to avoid expensive + # checks when re-creating objects. + new_args = [_fast_expand(arg) for arg in expr.args] + if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): + return _fast_expand(expr.func(*new_args)) + + if expr.is_Pow: + base, exp = expr.args + if exp.is_Integer: + if exp > 1: + return sympy.expand_multinomial(expr, deep=False) + elif exp < 0: + return 1 / sympy.expand_multinomial(1 / expr, deep=False) + elif expr.is_Mul: + num, den = [], [] + for arg in expr.args: + if arg.is_Pow and arg.args[1] == -1: + den.append(1 / arg) + else: + num.append(arg) + + num, num_changed = _expandsums(num) + den, den_changed = _expandsums(den) + if num_changed or den_changed: + return num / den + + return expr + + @lru_cache(256) def safe_expand(r): if hasattr(r, 'expand'): try: - return sympy.expand(r) + return _fast_expand(r) except RecursionError: - log.warning("RecursionError in sympy.expand(%s)", r) + log.warning("RecursionError in _fast_expand(%s)", r) return r else: return r From 8b5957fffec526e83d9d0986478fd64eda5a6517 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Sat, 14 Sep 2024 12:45:52 -0700 Subject: [PATCH 0911/1018] Add TensorReferenceAnalysis and some tests (#135886) Split out and modified from https://github.com/pytorch/pytorch/pull/130228. There were a bunch of subtle bugs eg. sometimes we need to use torch.ops.aten.{operator}.Tensor vs other times using torch.ops.aten.{operator}.default. Or in the case of pow we need to use Tensor_Tensor. I figured it'd be easier to split out adding TensorReferenceAnalysis and add some tests and do the actual integration in a separate diff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135886 Approved by: https://github.com/ezyang --- test/test_sympy_utils.py | 186 ++++++++++++++++++++++++----- torch/utils/_sympy/reference.py | 200 ++++++++++++++++++++++++++++++++ 2 files changed, 356 insertions(+), 30 deletions(-) diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 6be74e6493ba87..81ed1126dcbb27 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -1,30 +1,35 @@ # Owner(s): ["oncall: pt2"] +import functools import itertools import math import sys +from typing import Callable, List, Tuple, Type import sympy -from typing import Callable, List, Tuple, Type + +import torch +import torch.fx as fx +from sympy.core.relational import is_ge, is_gt, is_le, is_lt from torch.testing._internal.common_device_type import skipIf from torch.testing._internal.common_utils import ( - TEST_Z3, instantiate_parametrized_tests, parametrize, run_tests, + TEST_Z3, TestCase, ) from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd -from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve -from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges -from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis from torch.utils._sympy.interp import sympy_interp -from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity -from sympy.core.relational import is_ge, is_le, is_gt, is_lt -import functools -import torch.fx as fx - +from torch.utils._sympy.reference import ( + PythonReferenceAnalysis, + ReferenceAnalysis, + TensorReferenceAnalysis, +) +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve +from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges UNARY_OPS = [ @@ -39,10 +44,18 @@ "ceil", ] BINARY_OPS = [ - "truediv", "floordiv", + "truediv", + "floordiv", # "truncdiv", # TODO # NB: pow is float_pow - "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" + "add", + "mul", + "sub", + "pow", + "pow_by_natural", + "minimum", + "maximum", + "mod", ] UNARY_BOOL_OPS = ["not_"] @@ -152,12 +165,12 @@ def test_int_infinity(self): self.assertIs(-(-int_oo), int_oo) # noqa: B002 self.assertIs(abs(int_oo), int_oo) self.assertIs(abs(-int_oo), int_oo) - self.assertIs(int_oo ** 2, int_oo) + self.assertIs(int_oo**2, int_oo) self.assertIs((-int_oo) ** 2, int_oo) self.assertIs((-int_oo) ** 3, -int_oo) - self.assertEqual(int_oo ** -1, 0) + self.assertEqual(int_oo**-1, 0) self.assertEqual((-int_oo) ** -1, 0) - self.assertIs(int_oo ** int_oo, int_oo) + self.assertIs(int_oo**int_oo, int_oo) self.assertTrue(int_oo == int_oo) self.assertFalse(int_oo != int_oo) self.assertTrue(-int_oo == -int_oo) @@ -325,7 +338,9 @@ def test_binary_ref_range(self, fn): class TestSympyInterp(TestCase): - @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) + @parametrize( + "fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS + ) def test_interp(self, fn): # SymPy does not implement truncation for Expressions if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): @@ -335,8 +350,8 @@ def test_interp(self, fn): if fn == "pow_by_natural": is_integer = True - x = sympy.Dummy('x', integer=is_integer) - y = sympy.Dummy('y', integer=is_integer) + x = sympy.Dummy("x", integer=is_integer) + y = sympy.Dummy("y", integer=is_integer) vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: @@ -358,10 +373,14 @@ def test_interp(self, fn): ref_r = getattr(ReferenceAnalysis, fn)(*sargs) # Yes, I know this is a longwinded way of saying xreplace; the # point is to test sympy_interp - r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr) + r = sympy_interp( + ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr + ) self.assertEqual(ref_r, r) - @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) + @parametrize( + "fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS + ) def test_python_interp_fx(self, fn): # These never show up from symbolic_shapes if fn in ("log", "exp"): @@ -383,8 +402,8 @@ def test_python_interp_fx(self, fn): if fn == "pow_by_natural": is_integer = True - x = sympy.Dummy('x', integer=is_integer) - y = sympy.Dummy('y', integer=is_integer) + x = sympy.Dummy("x", integer=is_integer) + y = sympy.Dummy("y", integer=is_integer) symbols = [x] if arity == 2: @@ -411,26 +430,125 @@ def test_python_interp_fx(self, fn): sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) if arity == 1: + def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + return sympy_interp( + PythonReferenceAnalysis, {x: px}, sympy_expr + ) + else: + def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + return sympy_interp( + PythonReferenceAnalysis, {x: px, y: py}, sympy_expr + ) gm = fx.symbolic_trace(trace_f) self.assertEqual( - sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), - gm(*args) + sympy_interp( + PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr + ), + gm(*args), ) + @parametrize( + "fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS + ) + def test_tensor_interp(self, fn): + # Skip operations not implemented or not applicable for tensors + if fn in ("div", "truncdiv", "int_truediv", "mod", "round_decimal"): + return + + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Symbol("x", integer=is_integer) + y = sympy.Symbol("y", integer=is_integer) + + vals = CONSTANTS + if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: + vals = [True, False] + + arity = 1 + if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: + arity = 2 + + symbols = [x] + if arity == 2: + symbols = [x, y] + + for args in itertools.product(vals, repeat=arity): + if arity == 1 and not valid_unary(fn, *args): + continue + elif arity == 2 and not valid_binary(fn, *args): + continue + + with self.subTest(args=args): + tensor_args = [ + torch.tensor( + a, dtype=torch.double if isinstance(a, float) else torch.int64 + ) + for a in args + ] + + try: + tensor_fn = getattr(TensorReferenceAnalysis, fn) + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + direct_result = tensor_fn(*tensor_args) + interp_result = sympy_interp( + TensorReferenceAnalysis, + dict(zip(symbols, tensor_args)), + sympy_expr, + ) + + # Ensure both results are of the same dtype for comparison + if direct_result.dtype != interp_result.dtype: + if ( + direct_result.dtype == torch.bool + or interp_result.dtype == torch.bool + ): + direct_result = direct_result.to(torch.bool) + interp_result = interp_result.to(torch.bool) + else: + direct_result = direct_result.to(torch.double) + interp_result = interp_result.to(torch.double) + + self.assertTrue( + torch.allclose( + direct_result, interp_result, rtol=1e-5, atol=1e-8 + ), + f"Mismatch for {fn}{args}: direct={direct_result}, interp={interp_result}", + ) + + if fn in UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS: + self.assertEqual(direct_result.dtype, torch.bool) + self.assertEqual(interp_result.dtype, torch.bool) + + if fn in ( + "floor_to_int", + "ceil_to_int", + "round_to_int", + "trunc_to_int", + ): + self.assertEqual(direct_result.dtype, torch.int64) + self.assertEqual(interp_result.dtype, torch.int64) + + except NotImplementedError: + print(f"Operation {fn} not implemented for TensorReferenceAnalysis") + except Exception as e: + self.fail(f"Unexpected error for {fn}{args}: {str(e)}") + def type_name_fn(type: Type) -> str: return type.__name__ + def parametrize_relational_types(*types): def wrapper(f: Callable): return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f) + return wrapper @@ -492,7 +610,13 @@ def test_noop_rhs(self, op): self.assertEqual(r_expr, mirror(rhs, lhs)) self.assertEqual(r_rhs, lhs) - def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs): + def _test_cases( + self, + cases: List[Tuple[sympy.Basic, sympy.Basic]], + thing: sympy.Basic, + op: Type[sympy.Rel], + **kwargs, + ): for source, expected in cases: r = try_solve(source, thing, **kwargs) @@ -563,7 +687,7 @@ def test_multiplication_division_inequality(self, op): @parametrize_relational_types() def test_floordiv(self, op): - from sympy import Eq, Ne, Gt, Ge, Lt, Le + from sympy import Eq, Ge, Gt, Le, Lt, Ne a, b, c = sympy.symbols("a b c") pos = sympy.Symbol("pos", positive=True) @@ -603,10 +727,12 @@ def test_floordiv(self, op): r_op = op self._test_cases([special_case, *cases], a, r_op) - self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False) + self._test_cases( + [(special_case[0], None), *cases], a, r_op, floordiv_inequality=False + ) def test_floordiv_eq_simplify(self): - from sympy import Eq, Lt, Le + from sympy import Eq, Le, Lt a = sympy.Symbol("a", positive=True, integer=True) diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 11fcded3452408..f845484db489e1 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import math import operator +from typing import Union import sympy @@ -281,3 +282,202 @@ def round_to_int(a, dtype): @staticmethod def round_decimal(a, b): return round(a, ndigits=b) + + +def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.ops.aten._to_copy(x, dtype=dtype) + + +# Suppose we have some int/float arguments. This diagram commutes: +# +# int/float -- PythonReferenceAnalysis.op --> int/float +# | | +# | | +# torch.tensor(..., dtype=torch.int64/torch.float64) +# | | +# V V +# Tensor -- TensorReferenceAnalysis.op --> Tensor +# +# NB: int before and after must be representable in int64 (we will +# insert guards accordingly.) +# +# This is guaranteed to be FX traceable with OpOverloads only. +class TensorReferenceAnalysis: + # NB: This is actually dead, because with Proxy tracing the factory + # function isn't traced correctly. Here for completeness. + @staticmethod + def constant(c, dtype): + d: Union[int, float, bool] + if dtype is torch.int64: + d = int(c) + elif dtype is torch.double: + d = float(c) + elif dtype is torch.bool: + d = bool(c) + else: + raise AssertionError(f"unrecognized dtype {dtype}") + return torch.ops.aten.scalar_tensor.default(d, dtype=dtype) + + @staticmethod + def or_(a, b): + return torch.ops.aten.logical_or.default(a, b) + + @staticmethod + def and_(a, b): + return torch.ops.aten.logical_and.default(a, b) + + @staticmethod + def eq(a, b): + return torch.ops.aten.eq.Tensor(a, b) + + @classmethod + def ne(cls, a, b): + return torch.ops.aten.ne.Tensor(a, b) + + @staticmethod + def lt(a, b): + return torch.ops.aten.lt.Tensor(a, b) + + @staticmethod + def gt(a, b): + return torch.ops.aten.gt.Tensor(a, b) + + @staticmethod + def le(a, b): + return torch.ops.aten.le.Tensor(a, b) + + @staticmethod + def ge(a, b): + return torch.ops.aten.ge.Tensor(a, b) + + @staticmethod + def not_(a): + return torch.ops.aten.logical_not.default(a) + + @staticmethod + def reciprocal(x): + return torch.ops.aten.reciprocal.default(x) + + @staticmethod + def square(x): + # TODO: maybe composite implicit autograd doesn't work here? + return torch.ops.aten.square.default(x) + + @staticmethod + def trunc_to_int(x, dtype): + return _to_dtype(torch.ops.aten.trunc.default(x), dtype) + + @staticmethod + def ceil_to_int(x, dtype): + return _to_dtype(torch.ops.aten.ceil.default(x), dtype) + + @staticmethod + def floor_to_int(x, dtype): + return _to_dtype(torch.ops.aten.floor.default(x), dtype) + + @staticmethod + def floor(x): + return torch.ops.aten.floor.default(x) + + @staticmethod + def ceil(x): + return torch.ops.aten.ceil.default(x) + + @staticmethod + def to_dtype(x, dtype): + return _to_dtype(x, dtype) + + @staticmethod + def mod(x, y): + # TODO: https://github.com/pytorch/pytorch/pull/133654 + raise NotImplementedError( + "no C-style modulus operation available from frontend atm" + ) + + @staticmethod + def abs(x): + return torch.ops.aten.abs.default(x) + + @staticmethod + def neg(x): + return torch.ops.aten.neg.default(x) + + @staticmethod + def truediv(a, b): + return torch.ops.aten.true_divide.Tensor(a, b) + + @staticmethod + def int_truediv(a, b): + raise NotImplementedError( + "Python int truediv difficult to implement in PyTorch atm" + ) + + # TODO: This is wrong, CPython has a custom implementation of true + # division that results in higher precision when the floats are + # sufficiently large. Short term fix: add a guard here + return torch.ops.aten.true_divide.default( + _to_dtype(a, torch.float64), _to_dtype(b, torch.float64) + ) + + @staticmethod + def floordiv(a, b): + return torch.ops.aten.floor_divide(a, b) + + @staticmethod + def truncdiv(a, b): + raise NotImplementedError( + "no C-style truncdiv operation available from frontend atm" + ) + + @staticmethod + def add(a, b): + return torch.ops.aten.add.Tensor(a, b) + + @staticmethod + def mul(a, b): + return torch.ops.aten.mul.Tensor(a, b) + + @staticmethod + def sub(a, b): + return torch.ops.aten.sub.Tensor(a, b) + + @staticmethod + def exp(x): + return torch.ops.aten.exp.default(x) + + @staticmethod + def log(x): + return torch.ops.aten.log.default(x) + + @staticmethod + def sqrt(x): + return torch.ops.aten.sqrt.default(x) + + @staticmethod + def pow(a, b): + return torch.ops.aten.pow.Tensor_Tensor(a, b) + + @staticmethod + def pow_by_natural(a, b): + # NB: pow handles int x int fine + return torch.ops.aten.pow.Tensor_Tensor(a, b) + + @staticmethod + def minimum(a, b): + return torch.ops.aten.minimum.default(a, b) + + @staticmethod + def maximum(a, b): + return torch.ops.aten.maximum.default(a, b) + + @staticmethod + def round_to_int(a, dtype): + return torch.ops.aten.round.default(a) + + @staticmethod + def round_decimal(a, b): + raise NotImplementedError( + "round decimal doesn't support Tensor second argument atm" + ) + + # return torch.ops.aten.round.decimals(a, b) From 1199f5f43b2e6c428ae7510b8b81380732dccb7a Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 12 Sep 2024 16:42:34 +0000 Subject: [PATCH 0912/1018] Optimize dict reconstruct to not codegen untouched values (#134876) PR changes how `reconstruct` is done for a ConstDict. As of today, it works as follow: (1) codegen(...) each pair of key/value (2) create a new dictionary to hold the new items (3) clear the original dictionary (4) update the original dict with the one created in (2) We do a micro optimization in the generated bytecode to: - Only codegen the items that changed. - Only clear the original dictionary if a key was removed. Fixes: #133487 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134876 Approved by: https://github.com/zou3519 --- test/dynamo/test_reconstruct.py | 257 +++++++++++++++++++++++++++++ torch/_dynamo/side_effects.py | 23 ++- torch/_dynamo/variables/builder.py | 4 +- torch/_dynamo/variables/dicts.py | 41 ++++- 4 files changed, 314 insertions(+), 11 deletions(-) create mode 100644 test/dynamo/test_reconstruct.py diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py new file mode 100644 index 00000000000000..18a18b6151b527 --- /dev/null +++ b/test/dynamo/test_reconstruct.py @@ -0,0 +1,257 @@ +# Owner(s): ["module: dynamo"] + +import contextlib +import dis +import unittest +from typing import List + +import torch +import torch._dynamo.test_case + + +def _filter_instructions(instructions, opname): + return list(filter(lambda x: x.opname == opname, instructions)) + + +class ReconstructTest(torch._dynamo.test_case.TestCase): + @contextlib.contextmanager + def register_bytecode_hook(self, check_fn): + def hook(code, out_code): + check_fn(list(dis.get_instructions(out_code))) + return code + + torch._dynamo.reset() + handle = torch._dynamo.convert_frame.register_bytecode_hook(hook) + try: + yield + finally: + handle.remove() + + def test_ConstDict_optimize_reconstruct(self): + """ + Emit code to reconstruct only the key that changed + """ + + def hook(instructions: List[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 1) + # reconstruct only d[40] + self.assertEqual(build_map[0].argval, 1) + + def f(d, t): + d[1] = t + d[40] = t + 1 + + t = torch.randn(3, 4) + d = {1: t} + d_opt = d.copy() + f(d, t) + + with self.register_bytecode_hook(hook): + opt_f = torch._dynamo.optimize("eager", nopython=True)(f) + opt_f(d_opt, t) + self.assertEqual(d, d_opt) + + def test_ConstDict_pop_reconstruct(self): + """ + If something is pop'ed from the dict, we reconstruct everything + """ + + def hook(instructions: List[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 1) + # reconstruct everything + self.assertEqual(build_map[0].argval, 2) + + def f(d, t): + d.pop(2) + d[40] = t + 1 + + t = torch.randn(3, 4) + d = {1: t, 2: t + 1} + d_opt = d.copy() + + f(d, t) + + with self.register_bytecode_hook(hook): + opt_f = torch._dynamo.optimize("eager", nopython=True)(f) + opt_f(d_opt, t) + self.assertEqual(d, d_opt) + + @unittest.expectedFailure + def test_ConstDict_popitem_reconstruct(self): + """ + If something is pop'ed from the dict, we reconstruct everything + """ + + def hook(instructions: List[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 1) + # reconstruct everything + self.assertEqual(build_map[0].argval, 1) + + def f(d, t): + d.popitem() + + t = torch.randn(3, 4) + d = {1: t, 2: t + 1} + d_opt = d.copy() + + f(d, t) + + with self.register_bytecode_hook(hook): + opt_f = torch._dynamo.optimize("eager", nopython=True)(f) + opt_f(d_opt, t) + self.assertEqual(d, d_opt) + + def test_ConstDict_popitem_reconstruct_graph_break(self): + """ + If something is pop'ed from the dict, we reconstruct everything. + Calling dict.popitem will graph break. + """ + + def f(d, t): + d.popitem() + + t = torch.randn(3, 4) + d = {1: t, 2: t + 1} + d_opt = d.copy() + + f(d, t) + + opt_f = torch.compile(backend="eager")(f) + opt_f(d_opt, t) + self.assertEqual(d, d_opt) + + def test_ConstDict_del_reconstruct(self): + """ + If something is deleted from the dict, we reconstruct everything + """ + + def hook(instructions: List[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 1) + # reconstruct everything + self.assertEqual(build_map[0].argval, 2) + + def f(d, t): + del d[2] + d[40] = t + 1 + + t = torch.randn(3, 4) + d = {1: t, 2: t + 1} + d_opt = d.copy() + + f(d, t) + + with self.register_bytecode_hook(hook): + opt_f = torch._dynamo.optimize("eager", nopython=True)(f) + opt_f(d_opt, t) + self.assertEqual(d, d_opt) + + def test_ConstDict_get_reconstruct(self): + """ + dict.get shouldn't affect anything + """ + + def hook(instructions: List[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 1) + self.assertEqual(build_map[0].argval, 1) + load_const = _filter_instructions(instructions, "LOAD_CONST") + self.assertNotIn(123, load_const) + + def f(d, t): + d[456] = d.get(456) + t + + t = torch.randn(3, 4) + d = {123: t, 456: t + 1} + d_opt = d.copy() + + f(d, t) + + with self.register_bytecode_hook(hook): + opt_f = torch._dynamo.optimize("eager", nopython=True)(f) + opt_f(d_opt, t) + self.assertEqual(d, d_opt) + + def test_ConstDict_clear_reconstruct(self): + """ + If dict.clear() is used, we reconstruct everything + """ + + def hook(instructions: List[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 1) + # reconstruct everything + self.assertEqual(build_map[0].argval, 1) + + def f(d, t): + d.clear() + d[3] = t + 3 + + t = torch.randn(3, 4) + d = {1: t, 2: t + 1} + d_opt = d.copy() + + f(d, t) + + with self.register_bytecode_hook(hook): + opt_f = torch._dynamo.optimize("eager", nopython=True)(f) + opt_f(d_opt, t) + self.assertEqual(d, d_opt) + + def test_create_dict_reconstruct(self): + """ + If dict is created inside a function, everything needs to be reconstructed + """ + + def hook(instructions: List[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 1) + # reconstruct everything + self.assertEqual(build_map[0].argval, 2) + + def f(t): + return {1: t, 2: t + 1} + + t = torch.randn(3, 4) + d = f(t) + + with self.register_bytecode_hook(hook): + opt_f = torch._dynamo.optimize("eager", nopython=True)(f) + d_opt = opt_f(t) + self.assertEqual(d, d_opt) + + def test_functional_call_reconstruct(self): + """ + PyTorch shouldn't codegen any key/value when functional_call is used + """ + + def hook(instructions: List[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 1) + # don't reconstruct anything + self.assertEqual(build_map[0].argval, 0) + + m = torch.nn.Linear(3, 3) + new_bias = torch.randn(3) + new_weight = torch.randn(3, 3) + + def fn(new_weight, new_bias, x): + return torch.func.functional_call( + m, {"weight": new_weight, "bias": new_bias}, x + ) + + x = torch.randn(2, 3) + expected = torch.nn.functional.linear(x, new_weight, new_bias) + with self.register_bytecode_hook(hook): + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + got = opt_fn(new_weight, new_bias, x) + self.assertEqual(expected, got) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fd5b54e4f7f7af..fb5782853784b6 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -575,21 +575,36 @@ def codegen_update_mutated(self, cg: PyCodegen): ) elif isinstance(var, variables.ConstDictVariable): + # Reconstruct works as follow: + # (1) codegen(...) each pair of key/value + # (2) create a new dictionary with the pairs of key/values above + # (3) clear the original dictionary + # + only if a key was removed from the input dict + # (4) update the original dictionary with the dict created in (2) + cg(var.mutable_local.source) # type: ignore[attr-defined] cg.load_method("update") cg(var, allow_cache=False) - cg(var.mutable_local.source) # type: ignore[attr-defined] - cg.load_method("clear") + if var.should_reconstruct_all: + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.load_method("clear") suffixes.append( [ - *create_call_method(0), # clear - create_instruction("POP_TOP"), *create_call_method(1), # update create_instruction("POP_TOP"), ] ) + + if var.should_reconstruct_all: + suffixes.append( + [ + *create_call_method(0), # clear + create_instruction("POP_TOP"), + ] + ) + elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 193e227267ec3b..41c5f4c004df98 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -640,7 +640,9 @@ def build_key_value(i, k, v): source=self.source, ) else: - result = ConstDictVariable(result, type(value), source=self.source) + result = ConstDictVariable( + result, user_cls=type(value), source=self.source + ) return self.set_source_and_track_mutable(value, result) elif isinstance(value, torch.nn.Module): diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index e8323ec6f70d76..4626f68bb44c06 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -14,7 +14,7 @@ from ..eval_frame import skip_code from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard -from ..source import AttrSource, GetItemSource +from ..source import AttrSource, GetItemSource, is_from_local_source from ..utils import dict_keys, dict_values, istype, specialize_symnode from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -128,8 +128,18 @@ def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: return Hashable._eq_impl(self.underlying_value, other) def __init__( - self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs + self, + items: Dict[VariableTracker, VariableTracker], + user_cls=dict, + **kwargs, ) -> None: + # .clone() pass these arguments in kwargs but they're recreated a few + # lines below + if "original_items" in kwargs: + kwargs.pop("original_items") + if "should_reconstruct_all" in kwargs: + kwargs.pop("should_reconstruct_all") + super().__init__(**kwargs) Hashable = ConstDictVariable._HashableTracker @@ -145,6 +155,10 @@ def make_hashable(key): return key if isinstance(key, Hashable) else Hashable(key) self.items = {make_hashable(x): v for x, v in items.items()} + # need to reconstruct everything if the dictionary is an intermediate value + # or if a pop/delitem was executed + self.should_reconstruct_all = not is_from_local_source(self.source) + self.original_items = items.copy() self.user_cls = user_cls def as_proxy(self): @@ -189,6 +203,9 @@ def len(self): ] ) + def _maybe_realize(self, item): + return item.realize() if item else item + def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: @@ -201,20 +218,29 @@ def reconstruct(self, codegen): ) ) # instructions to build the dict keys and values + num_args = 0 for key, value in self.items.items(): - codegen(key.vt) - codegen(value) + # We can safely call realize() here as it won't introduce any new guards + is_new_item = ( + self._maybe_realize(self.original_items.get(key.vt)) != value.realize() + ) + + if is_new_item or self.should_reconstruct_all: + codegen(key.vt) + codegen(value) + num_args += 1 + # BUILD_MAP and calling collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: codegen.extend_output( [ - create_instruction("BUILD_MAP", arg=len(self.items)), + create_instruction("BUILD_MAP", arg=num_args), *create_call_function(1, False), ] ) # BUILD_MAP only if user_cls is dict else: - codegen.append_output(create_instruction("BUILD_MAP", arg=len(self.items))) + codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) def getitem_const_raise_exception_if_absent( self, tx: "InstructionTranslator", arg: VariableTracker @@ -288,6 +314,7 @@ def call_method( self.items[Hashable(args[0])] = args[1] return ConstantVariable.create(None) elif name == "__delitem__" and arg_hashable and self.mutable_local: + self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.__delitem__(Hashable(args[0])) return ConstantVariable.create(None) @@ -298,9 +325,11 @@ def call_method( else: return args[1] elif name == "pop" and arg_hashable and self.mutable_local: + self.should_reconstruct_all = True tx.output.side_effects.mutation(self) return self.items.pop(Hashable(args[0])) elif name == "clear": + self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.clear() return ConstantVariable.create(None) From d9185e1ee6f6aebfa521c63191054216cf940bf2 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 12 Sep 2024 22:32:08 -0700 Subject: [PATCH 0913/1018] SKIP llama for dynamic size testing (#135960) Running Torchbench llama with dynamic size failed with ``` File "/localdisk/leslie/torch_inductor_community/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4182, in produce_guards raise ConstraintViolationError( torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['inputs'][0].size()[0])! For more information, run with TORCH_LOGS="+dynamic". - Not all values of RelaxedUnspecConstraint(L['inputs'][0].size()[0]) are valid because L['inputs'][0].size()[0] was inferred to be a constant (32). ``` Skip this model for marking dynamic dim. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135960 Approved by: https://github.com/ezyang --- benchmarks/dynamo/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 80e7d74cef7701..81ace9e9c1d197 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -158,6 +158,7 @@ class CI(NamedTuple): "detectron2_fasterrcnn_r_50_fpn", "hf_T5_generate", "Reformer", + "llama", }.union(INTERNAL_CI_SKIP_DYNAMIC_BATCH_ONLY) # These models currently fail accuracy with eager Adam optimizer From 1054a5f6fbe4e290e7ebc60d704ab71a7d060420 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 14 Sep 2024 14:56:30 -0700 Subject: [PATCH 0914/1018] [Traceable FSDP2] Ignore FSDP2 forward hook side-effects in AC; Support FSDP2 + AC (#134997) > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134997 Approved by: https://github.com/zou3519 ghstack dependencies: #135727 --- .../fsdp/test_fully_shard_compile.py | 106 +++++++++++++----- torch/_dynamo/output_graph.py | 10 +- torch/_dynamo/side_effects.py | 26 +++++ torch/_dynamo/variables/functions.py | 15 ++- torch/_dynamo/variables/higher_order_ops.py | 32 +++++- 5 files changed, 157 insertions(+), 32 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 2f038c713ff225..1bf26d1b67b3e6 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -154,25 +154,64 @@ def f(x): torch.compile(f, backend="aot_eager")(x) self.assertEqual(x, ref_x) - def _assert_no_aliased_graph_inputs(self, graph: torch.fx.Graph) -> None: + def _assert_no_aliased_unsharded_params_in_graph_inputs( + self, model, graph: torch.fx.Graph + ) -> None: + # FSDP2 unsharded params are mutated in the graph without going through functionalization. + # Therefore, we want to make sure they don't have aliases in the graph inputs, to make it easier + # for us to do the replacement of unsharded params with the all-gathered temporary buffer directly + # in downstream users in the graph. storage_id_to_graph_inputs = defaultdict(list) + unsharded_param_graph_inputs = set() for node in graph.nodes: - if node.op == "placeholder" and isinstance( - node.meta.get("val", None), torch.Tensor + if ( + node.op == "call_function" + and node.target + in [ + torch.ops.inductor.resize_storage_bytes_.default, + torch.ops.fsdp.copy_.default, + ] + and node.args[0].op == "placeholder" ): - storage_id_to_graph_inputs[ - id(node.meta["val"].untyped_storage()) - ].append(node) - no_aliased_graph_inputs = True + unsharded_param_graph_inputs.add(node.args[0]) + assert len(unsharded_param_graph_inputs) > 0 + assert len(unsharded_param_graph_inputs) == len( + list(model.parameters()) + ), """\ +Expected all model parameters to be wrapped by FSDP2 and +have their unsharded version as graph input, but it's not true! +""" + no_aliased_unsharded_params_in_graph_inputs = True err_msg = "" for aliased_graph_inputs in storage_id_to_graph_inputs.values(): - if len(aliased_graph_inputs) > 1: - no_aliased_graph_inputs = False + if len(aliased_graph_inputs) > 1 and any( + x in unsharded_param_graph_inputs for x in aliased_graph_inputs + ): + no_aliased_unsharded_params_in_graph_inputs = False err_msg += f"""\n -Found aliased graph inputs: {aliased_graph_inputs}, +Found aliased unsharded param in graph inputs: {aliased_graph_inputs}, val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, """ - self.assertTrue(no_aliased_graph_inputs, err_msg) + self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg) + + def _remove_fsdp2_unsharded_param_graph_input_usage_with_optional_checks( + self, model, fullgraph + ): + def _run_with_checks(graph, orig_fn): + self._assert_no_aliased_unsharded_params_in_graph_inputs(model, graph) + orig_fn(graph) + + if fullgraph: + return mock.patch.object( + comms, + "remove_fsdp2_unsharded_param_graph_input_usage", + functools.partial( + _run_with_checks, + orig_fn=comms.remove_fsdp2_unsharded_param_graph_input_usage, + ), + ) + else: + return contextlib.nullcontext() def _check_fsdp_copy_and_resize_ops_count_in_graph( self, @@ -359,7 +398,11 @@ def inductor_code_check_fsdp_reduce_scatter( return file_check def _test_traceable_fsdp( - self, model_init_fn, input_creation_fn, backend, fullgraph + self, + model_init_fn, + input_creation_fn, + backend, + fullgraph, ): def compiler_fn(compiled_autograd_backend): def _fn(gm): @@ -401,13 +444,18 @@ def test_compiled(): # FSDP2 does lazy init using 1st run, so run it once to init using eager mode run_iters(model, optim, n_iter=1) - model_compiled = torch.compile(model, backend=backend, fullgraph=fullgraph) - res = run_iters( - model_compiled, - optim, - compiled_autograd_backend=backend, - ) - return res + with self._remove_fsdp2_unsharded_param_graph_input_usage_with_optional_checks( + model, fullgraph + ): + model_compiled = torch.compile( + model, backend=backend, fullgraph=fullgraph + ) + res = run_iters( + model_compiled, + optim, + compiled_autograd_backend=backend, + ) + return res def test_eager(): model, optim = model_init_fn() @@ -421,7 +469,8 @@ def test_eager(): inline_inbuilt_nn_modules=True, skip_fsdp_hooks=False, ), torch._functorch.config.patch( - recompute_views=True, cse=False + recompute_views=True, + cse=False, ), torch._inductor.config.patch( reorder_for_compute_comm_overlap=True, reorder_for_compute_comm_overlap_passes=[ @@ -429,9 +478,6 @@ def test_eager(): "raise_comms", "reorder_compute_for_overlap", ], - post_grad_custom_pre_pass=self._assert_no_aliased_graph_inputs - if fullgraph - else None, ): losses_compiled = test_compiled() losses_eager = test_eager() @@ -677,7 +723,9 @@ def test_nested_fully_shard_backend_inductor(self): "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", ) - def _create_transformer_factory_fns(self, all_requires_grad): + def _create_transformer_factory_fns( + self, all_requires_grad, *, activation_checkpoint=False + ): seq_len = 16 vocab_size = 8 n_layers = 3 @@ -689,6 +737,7 @@ def model_init_fn(): model_args = ModelArgs( vocab_size=vocab_size, n_layers=n_layers, + checkpoint_activations=activation_checkpoint, ) model = Transformer(model_args) if not all_requires_grad: @@ -775,9 +824,11 @@ def test_transformer_backend_aot_eager_decomp_partition(self): @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_inductor(self): # TODO: enable fullgraph=False case - for fullgraph, all_requires_grad in itertools.product([True], [True, False]): + for fullgraph, all_requires_grad, activation_checkpoint in itertools.product( + [True], [True, False], [True, False] + ): log.warning( - f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}" # noqa: G004, G001 + f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001 ) with self._maybe_add_graph_break_to_sdpa( fullgraph @@ -802,7 +853,8 @@ def test_transformer_backend_inductor(self): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( *self._create_transformer_factory_fns( - all_requires_grad=all_requires_grad + all_requires_grad=all_requires_grad, + activation_checkpoint=activation_checkpoint, ), "inductor", fullgraph=fullgraph, diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 0ea19e1cad6586..76be81a088c3c9 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -326,7 +326,7 @@ def __init__( ] = collections.defaultdict(list) # Stores the full fqn of a param or buffer to the relevant source. self.param_name_to_source: Optional[Dict[str, Source]] = {} - self.side_effects = SideEffects() + self.side_effects = SideEffects(self) # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL # and LOAD_ATTR for same python objects free. self.variable_tracker_cache = VariableTrackerCache() @@ -1834,6 +1834,14 @@ def __init__( # Dicts maintain the order of args for the HigherOrderOperator call. self.lifted_freevars = {} self.prev_inst = None + # True if this tracer is currently tracing into torch.utils.checkpoint + # as part of speculate_subgraph. + self.under_activation_checkpoint = False + # True if we want to allow side-effects (doesn't throw error on their existence) + # during this tracer's tracing of torch.utils.checkpoint (via speculate_subgraph). + # Only safe if we know for sure that *NOT* replaying these side-effects during + # backward recomputation of the checkpoint region doesn't affect its correctness. + self.allow_side_effects_under_checkpoint = False self._cur_code = None self._orig_gm_meta = None diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index fb5782853784b6..3b83eb17d407f4 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs +import contextlib import functools import inspect import warnings +import weakref from collections.abc import MutableMapping from typing import Any, Dict, List, Optional, Type, Union @@ -79,6 +81,7 @@ class SideEffects: def __init__( self, + output_graph, id_to_variable=None, store_attr_mutations=None, keepalive=None, @@ -86,6 +89,7 @@ def __init__( tensor_hooks=None, ): super().__init__() + self.output_graph_weakref = weakref.ref(output_graph) self.id_to_variable = id_to_variable or {} self.store_attr_mutations = store_attr_mutations or {} self.keepalive = keepalive or [] @@ -130,6 +134,7 @@ def diff(self, other: "SideEffects") -> Optional[str]: def clone(self): """Create a shallow copy""" return self.__class__( + output_graph=self.output_graph_weakref(), id_to_variable=dict(self.id_to_variable), store_attr_mutations={ k: dict(v) for k, v in self.store_attr_mutations.items() @@ -145,6 +150,14 @@ def __contains__(self, item): def __getitem__(self, item): return self.id_to_variable[id(item)] + def should_allow_side_effects_under_checkpoint(self): + output_graph = self.output_graph_weakref() + return ( + output_graph + and output_graph.current_tx.output.current_tracer.under_activation_checkpoint + and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint + ) + def check_allowed_side_effect(self, item): from torch._dynamo.variables.misc import AutogradFunctionContextVariable @@ -152,6 +165,8 @@ def check_allowed_side_effect(self, item): # These are benign. if isinstance(item, AutogradFunctionContextVariable): return True + if self.should_allow_side_effects_under_checkpoint(): + return True if not is_side_effect_safe(item.mutable_local): unimplemented( "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" @@ -725,3 +740,14 @@ def is_empty(self): def clear(self): self.keepalive.clear() self.id_to_variable.clear() + + +@contextlib.contextmanager +def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: ignore[name-defined] # noqa: F821 + assert tx.output.current_tracer.under_activation_checkpoint + orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint + try: + tx.output.current_tracer.allow_side_effects_under_checkpoint = True + yield + finally: + tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index c7324960fb989d..6c1e4a13c94595 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -322,7 +322,20 @@ def call_function( return invoke_and_store_as_constant( tx, self.fn, self.get_name(), args, kwargs ) - + if ( + tx.output.current_tracer.under_activation_checkpoint + and not tx.output.current_tracer.allow_side_effects_under_checkpoint + ): + try: + from torch.distributed._composable.fsdp._fsdp_state import FSDPState + except Exception: + FSDPState = None + if FSDPState is not None and self.fn in [ + FSDPState._pre_forward, + FSDPState._post_forward, + ]: + with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): + return super().call_function(tx, args, kwargs) return super().call_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 0a384246736736..c8d82b15134461 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -71,6 +71,16 @@ def dynamo_enable_grad(tx: "InstructionTranslator", enable=True): GradModeVariable.create(tx, org_value, initialized=True) +@contextlib.contextmanager +def dynamo_under_activation_checkpoint(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.under_activation_checkpoint + try: + tx.output.current_tracer.under_activation_checkpoint = True + yield + finally: + tx.output.current_tracer.under_activation_checkpoint = orig_val + + def only_consist_of(var, types, allow_none=False): if isinstance(var, types): return True @@ -388,6 +398,7 @@ def speculate_subgraph( set_subgraph_inputs="automatic", restore_side_effects=True, should_flatten_outputs=False, + under_activation_checkpoint=False, # Pass in an originating tracer - this is needed for preserving context # across fwd-bwd for autograd.Function tracer=None, @@ -439,6 +450,11 @@ def speculate_subgraph( if enable_grad is not None else contextlib.nullcontext() ) + checkpoint_ctx = ( + dynamo_under_activation_checkpoint(tx) + if under_activation_checkpoint + else contextlib.nullcontext() + ) # For handling side effects, we can make an argument that we don't # have to do anything here. The side effects infra does a good job @@ -458,7 +474,7 @@ def speculate_subgraph( if restore_side_effects: prev_side_effects = tx.output.side_effects.clone() - with autograd_ctx: + with autograd_ctx, checkpoint_ctx: output = f.call_function(tx, args, sub_kwargs) if restore_side_effects: @@ -1504,7 +1520,12 @@ def call_function( class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): def create_wrapped_node( - self, tx: "InstructionTranslator", args, kwargs, description + self, + tx: "InstructionTranslator", + args, + kwargs, + description, + under_activation_checkpoint=False, ): # See NOTE [HigherOrderOperator tracing design] for more details @@ -1520,6 +1541,7 @@ def create_wrapped_node( description, source_target=self.value, should_flatten_outputs=True, + under_activation_checkpoint=under_activation_checkpoint, ) body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) @@ -1856,7 +1878,11 @@ def call_function( treespec, checkpointed_gmod, ) = self.create_wrapped_node( - tx, args, gmod_kwargs, "torch.utils.checkpoint.checkpoint" + tx, + args, + gmod_kwargs, + "torch.utils.checkpoint.checkpoint", + under_activation_checkpoint=True, ) if context_fn is not None: checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn From e63e794788b32505448135b2113971a1c2d0af51 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Sun, 15 Sep 2024 04:31:34 +0000 Subject: [PATCH 0915/1018] [audio hash update] update the pinned audio hash (#136106) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136106 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index d4276f969ac8ca..c835e7c2838715 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -97ed7b36b7a741253d4e41e4da3c901d83294503 +ba696ea3dfec4cbe693bf06a84c75dc196077f5b From 343d15308b8ce91817f52d140c3872fb882433f8 Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 15 Sep 2024 05:28:19 +0000 Subject: [PATCH 0916/1018] Fix redundant move warnings by g++ (#134987) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134987 Approved by: https://github.com/ezyang --- torch/csrc/jit/frontend/tree_views.h | 4 ++++ torch/csrc/jit/python/pybind_utils.h | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index d7e758e34e38b8..9f3a926fe5a908 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -1262,7 +1262,11 @@ inline Expr pep604union_to_union(const Expr& expr) { expr.range(), Var::create(expr.range(), Ident::create(expr.range(), "Union")), List::create(expr.range(), members)); +#if defined(__clang__) return std::move(synthesised_union); +#else + return synthesised_union; +#endif } } // namespace torch::jit diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 96a6f23d6dad54..eee1cf05b1201c 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -1073,7 +1073,11 @@ inline py::object createPyObjectForStack(Stack&& stack) { return_values[ret] = toPyObject(std::move(stack[ret])); } +#if defined(__clang__) return std::move(return_values); +#else + return return_values; +#endif } // TODO: Remove once we clean up the GraphExecutor usage. From c509f72329dae380d70df7651b217c89cc367955 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 15 Sep 2024 05:32:38 +0000 Subject: [PATCH 0917/1018] Revert "[Pytorch] Consolidate Strobelight compile time profiler between OSS and fbcode (#135953)" This reverts commit b8637503c036abb898f6b880b325aeffe6f09c03. Reverted https://github.com/pytorch/pytorch/pull/135953 on behalf of https://github.com/kollasb due to Broke internal module factory compatibility, revert from Phabricator failed ([comment](https://github.com/pytorch/pytorch/pull/135953#issuecomment-2351381777)) --- .../examples/compile_time_profile_example.py | 5 +++-- torch/_utils_internal.py | 20 ++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/torch/_strobelight/examples/compile_time_profile_example.py b/torch/_strobelight/examples/compile_time_profile_example.py index 26f659a1dbf229..f158d00a29f929 100644 --- a/torch/_strobelight/examples/compile_time_profile_example.py +++ b/torch/_strobelight/examples/compile_time_profile_example.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs import torch -from torch._utils_internal import enable_compiletime_strobelight +from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler if __name__ == "__main__": - enable_compiletime_strobelight() + # You can pass TORCH_COMPILE_STROBELIGHT=True instead. + StrobelightCompileTimeProfiler.enable() def fn(x, y, z): return x * y + z diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 0c99601392f3d2..f254217452061f 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -12,6 +12,17 @@ log = logging.getLogger(__name__) +if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): + import shutil + + if not shutil.which("strobeclient"): + log.info( + "TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine." + ) + else: + log.info("Strobelight profiler is enabled via environment variable") + StrobelightCompileTimeProfiler.enable() + # this arbitrary-looking assortment of functionality is provided here # to have a central place for overrideable behavior. The motivating # use is the FB build environment, where this source file is replaced @@ -64,15 +75,6 @@ def throw_abstract_impl_not_imported_error(opname, module, context): ) -def enable_compiletime_strobelight(): - StrobelightCompileTimeProfiler.enable() - - -if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): - log.info("Strobelight profiler is enabled via environment variable") - enable_compiletime_strobelight() - - # NB! This treats "skip" kwarg specially!! def compile_time_strobelight_meta(phase_name): def compile_time_strobelight_meta_inner(function): From 4bc7d452caff09219dcf532cc4af9696d08f6f6b Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 13 Sep 2024 22:09:00 -0700 Subject: [PATCH 0918/1018] Deprecate _preserve_ops and consolidate with decomp_table (#135080) In this PR, we deprecate _preserve_ops feature in run_decomposition API. We can't kill this API completely because Executorch team depends on it. As the syncing between two repos is non-trivial, I just leave this argument as deprecated for now. In the next PR, i will immediately remove it. After this PR, run_decompositions will only decompose what's inside the decomp table and preserve the rest by default. Note that this feature is only rolled out to OSS for now. Old code path is protected under IS_FBCODE flag. Differential Revision: [D62163161](https://our.internmc.facebook.com/intern/diff/D62163161/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135080 Approved by: https://github.com/justinchuby, https://github.com/avikchaudhuri, https://github.com/bdhirsh --- test/export/test_export.py | 300 ++++++++++++++---- .../test_export_training_ir_to_run_decomp.py | 18 +- test/test_decomp.py | 4 +- torch/_decomp/__init__.py | 190 ++++++++++- torch/export/experimental/__init__.py | 4 +- torch/export/exported_program.py | 214 ++++++++----- torch/onnx/_internal/exporter/_decomp.py | 14 +- torch/onnx/_internal/exporter/_fx_passes.py | 6 +- 8 files changed, 593 insertions(+), 157 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index cfd9ef0d46f18e..e2711e120ea602 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -18,7 +18,11 @@ import torch.nn.functional as F from functorch.experimental.control_flow import cond, map from torch import Tensor -from torch._decomp import get_decompositions +from torch._decomp import ( + _decomp_table_to_post_autograd_aten, + core_aten_decompositions, + get_decompositions, +) from torch._dynamo.test_case import TestCase from torch._dynamo.testing import normalize_gm from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse @@ -1068,14 +1072,17 @@ def forward(self, x): x = self.linear(x) return torch.ops.aten.chunk.default(x, 3, 0) - gm = ( - torch.export.export( - Foo(), - (torch.randn(3, 3),), + ep = torch.export.export(Foo(), (torch.randn(3, 3),)) + if IS_FBCODE: + ep = ep.run_decompositions( + {}, _preserve_ops=(torch.ops.aten.linear.default,) ) - .run_decompositions({}, _preserve_ops=(torch.ops.aten.linear.default,)) - .graph_module - ) + else: + decomp_table = _decomp_table_to_post_autograd_aten() + del decomp_table[torch.ops.aten.linear.default] + ep = ep.run_decompositions(decomp_table) + + gm = ep.graph_module # linear is CompositeImplicitAutograd functional op so we should preserve it # chunk is CompositeImplicitAutograd non-functional op we decompose. self.assertExpectedInline( @@ -1436,6 +1443,7 @@ def forward(self, x, y): "dy - 6 = 6" not in exc.args[0] ) # don't suggest fix for non-root dim + @unittest.skip("See https://github.com/pytorch/pytorch/issues/135759") def test_keep_composite_ops_invalid(self): class Foo(torch.nn.Module): def __init__(self) -> None: @@ -1446,32 +1454,29 @@ def forward(self, x): x = self.linear(x) return torch.ops.aten.chunk.default(x, 3, 0) - with self.assertRaisesRegex( - RuntimeError, "aten.chunk.default is a mutating/aliasing op" - ): + def _(*args, **kwargs): + return NotImplemented + + with self.assertWarnsRegex(UserWarning, "The op aten.chunk.default"): _ = torch.export.export( Foo(), (torch.randn(3, 3),), - ).run_decompositions({}, _preserve_ops=(torch.ops.aten.chunk.default,)) + ).run_decompositions({torch.ops.aten.chunk.default: _}) - with self.assertRaisesRegex( - RuntimeError, "aten.sym_size.default is a metadata query function" - ): + with self.assertWarnsRegex(UserWarning, "The op aten.sym_size.default"): _ = torch.export.export( Foo(), (torch.randn(3, 3),), - ).run_decompositions({}, _preserve_ops=(torch.ops.aten.sym_size.default,)) + ).run_decompositions({torch.ops.aten.sym_size.default: _}) - with self.assertRaisesRegex( - RuntimeError, - "We can't detect aten.native_batch_norm.default as a functional op statically", + with self.assertWarnsRegex( + UserWarning, + "The op aten.native_batch_norm.default", ): _ = torch.export.export( Foo(), (torch.randn(3, 3),), - ).run_decompositions( - {}, _preserve_ops=(torch.ops.aten.native_batch_norm.default,) - ) + ).run_decompositions({torch.ops.aten.native_batch_norm.default: _}) def test_keep_composite_ops_linear_convd(self): class MyLinear(torch.nn.Module): @@ -1499,10 +1504,14 @@ def forward(self, x, y): ep = torch.export.export( Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) ) - ep_has_linear_convd = ep.run_decompositions( - decomp_table={}, - _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, - ) + if IS_FBCODE: + ep_has_linear_convd = ep.run_decompositions( + {}, + _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, + ) + else: + ep_has_linear_convd = ep.run_decompositions({}) + self.assertExpectedInline( str(ep_has_linear_convd.graph_module.code).strip(), """\ @@ -1516,13 +1525,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_ return (add,)""", ) - ep_has_convd = ep.run_decompositions( - decomp_table=None, - _preserve_ops=[ - torch.ops.aten.conv2d.default, - torch.ops.aten.conv1d.default, - ], - ) + if IS_FBCODE: + ep_has_convd = ep.run_decompositions( + _preserve_ops=( + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + ) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + del decomp_table[torch.ops.aten.conv1d.default] + + ep_has_convd = ep.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ @@ -1538,10 +1553,15 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_ add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) + if IS_FBCODE: + ep_has_convd = ep_has_convd.run_decompositions( + _preserve_ops=(torch.ops.aten.conv2d.default,) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] - ep_has_convd = ep_has_convd.run_decompositions( - decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] - ) + ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ @@ -1584,10 +1604,16 @@ def forward(self, x, y): ep = torch.export.export_for_training( Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) ) - ep_has_linear_convd = ep.run_decompositions( - decomp_table={}, - _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, - ) + + if IS_FBCODE: + ep_has_linear_convd = ep.run_decompositions( + {}, + _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, + ) + else: + ep_has_linear_convd = ep.run_decompositions( + decomp_table={}, + ) self.assertExpectedInline( str(ep_has_linear_convd.graph_module.code).strip(), @@ -1602,13 +1628,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_ return (add,)""", ) - ep_has_convd = ep.run_decompositions( - decomp_table=None, - _preserve_ops=[ - torch.ops.aten.conv2d.default, - torch.ops.aten.conv1d.default, - ], - ) + if IS_FBCODE: + ep_has_convd = ep.run_decompositions( + _preserve_ops=( + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + ) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + del decomp_table[torch.ops.aten.conv1d.default] + + ep_has_convd = ep.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), @@ -1626,9 +1658,14 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_ return (add,)""", ) - ep_has_convd = ep_has_convd.run_decompositions( - decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] - ) + if IS_FBCODE: + ep_has_convd = ep_has_convd.run_decompositions( + _preserve_ops=(torch.ops.aten.conv2d.default,) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), @@ -1646,6 +1683,57 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_ return (add,)""", ) + @unittest.skip("See https://github.com/pytorch/pytorch/issues/135759") + def test_error_when_passing_mutating_primitive_op(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x.sin() + + ep = export(Foo(), (torch.ones(3, 3),)) + with self.assertWarnsRegex( + UserWarning, + "The op aten.index_put_.default", + ): + ep.run_decompositions({torch.ops.aten.index_put_.default: None}) + + def test_if_post_autograd_op_preserved(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x.sin() + x.sum() + + ep = export(Foo(), (torch.ones(3, 3),)) + if IS_FBCODE: + ep_preserve_sum = ep.run_decompositions( + _preserve_ops=(torch.ops.aten.sum.default,) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.sum.default] + ep_preserve_sum = ep.run_decompositions(decomp_table) + + # Even though we are decomposing to core aten which should make + # sum into sum.dim_IntList, we explicitly marked it to not do that. + self.assertExpectedInline( + str(ep_preserve_sum.graph_module.code).strip(), + """\ +def forward(self, x): + sin = torch.ops.aten.sin.default(x) + sum_1 = torch.ops.aten.sum.default(x); x = None + add = torch.ops.aten.add.Tensor(sin, sum_1); sin = sum_1 = None + return (add,)""", + ) + + ep_no_preserve_sum = ep.run_decompositions() + self.assertExpectedInline( + str(ep_no_preserve_sum.graph_module.code).strip(), + """\ +def forward(self, x): + sin = torch.ops.aten.sin.default(x) + sum_1 = torch.ops.aten.sum.dim_IntList(x, []); x = None + add = torch.ops.aten.add.Tensor(sin, sum_1); sin = sum_1 = None + return (add,)""", + ) + def test_set_grad_empty(self): class M(torch.nn.Module): def forward(self, x): @@ -4674,6 +4762,84 @@ def forward(self, x): self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp))) self.assertEqual(id(state_dict), id(ep.state_dict)) + @unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode") + def test_export_decomp_torture_case_1(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin(x) + + inp = (torch.randn(5, 10),) + m = M() + ep = export(m, inp) + + def custom_decomp_callable(x, weight, bias): + return x + bias + + decomp_table = core_aten_decompositions() + decomp_table[torch.ops.aten.linear.default] = custom_decomp_callable + core_aten_ep = ep.run_decompositions(decomp_table) + self.assertExpectedInline( + str(core_aten_ep.graph_module.code).strip(), + """\ +def forward(self, p_lin_weight, p_lin_bias, x): + add = torch.ops.aten.add.Tensor(x, p_lin_bias); x = p_lin_bias = None + return (add,)""", + ) + + @unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode") + def test_export_decomp_torture_case_2(self): + class MyLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.randn(20, 98) + self.bias = torch.randn(20) + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight, self.bias) + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(16, 33, 3) + self.conv1d = torch.nn.Conv1d(16, 33, 3) + self.linear = MyLinear() + + def forward(self, x, y): + x_conv = self.conv(x) + y_conv_1d = self.conv1d(y) + x_linear = self.linear(x_conv) + return x_linear.cos() + y_conv_1d.sum() + + ep = export(Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))) + ep_has_linear_convd = ep.run_decompositions(decomp_table={}) + + def _decompose_linear_custom(x, weight, bias): + return torch.matmul(x, weight.T) + 2 * bias + + ep_decompose_linear = ep_has_linear_convd.run_decompositions( + decomp_table={torch.ops.aten.linear.default: _decompose_linear_custom} + ) + + self.assertExpectedInline( + str(ep_decompose_linear.graph_module.code).strip(), + """\ +def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y): + conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None + permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None + matmul = torch.ops.aten.matmul.default(conv2d, permute); conv2d = permute = None + mul = torch.ops.aten.mul.Tensor(c_linear_bias, 2); c_linear_bias = None + add = torch.ops.aten.add.Tensor(matmul, mul); matmul = mul = None + cos = torch.ops.aten.cos.default(add); add = None + sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None + add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None + return (add_1,)""", + ) + def test_export_decomps_dynamic(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -5244,7 +5410,7 @@ def forward( eps, ), ) - ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) + ep.run_decompositions() self.assertEqual( ep.module()( input, weight, bias, running_mean, running_var, training, momentum, eps @@ -5474,7 +5640,7 @@ def forward(self, t, dim, index, src, **kwargs): output = model(t, dim, index, src) ep = torch.export.export(model, args=(t, dim, index, src)) - ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) + ep = ep.run_decompositions() self.assertEqual(ep.module()(t, dim, index, src), output) def test_fqn(self): @@ -8041,10 +8207,15 @@ def forward(self, x): ep.graph_module.code ) - ep = ep.run_decompositions( - decomp_table=get_decompositions([torch.ops.aten.elu.default]), - _preserve_ops=[torch.ops.aten.elu.default], - ) + if IS_FBCODE: + ep = ep.run_decompositions(_preserve_ops=(torch.ops.aten.elu.default,)) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.elu.default] + + ep = ep.run_decompositions( + decomp_table=decomp_table, + ) FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run( ep.graph_module.code ) @@ -8066,12 +8237,17 @@ def forward(self, x): "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True ).run(ep.graph_module.code) - decomp_table = get_decompositions([torch.ops.aten.upsample_bilinear2d.vec]) - ep = ep.run_decompositions( - decomp_table=decomp_table, - _preserve_ops=[torch.ops.aten.upsample_bilinear2d.vec], - ) - assert torch.ops.aten.upsample_bilinear2d.vec in decomp_table + if IS_FBCODE: + ep = ep.run_decompositions( + _preserve_ops=(torch.ops.aten.upsample_bilinear2d.vec,) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.upsample_bilinear2d.vec] + ep = ep.run_decompositions( + decomp_table=decomp_table, + ) + FileCheck().check_count( "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True ).run(ep.graph_module.code) diff --git a/test/export/test_export_training_ir_to_run_decomp.py b/test/export/test_export_training_ir_to_run_decomp.py index 3cab41c6d095c5..b1168f54bb2277 100644 --- a/test/export/test_export_training_ir_to_run_decomp.py +++ b/test/export/test_export_training_ir_to_run_decomp.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: export"] import torch +from torch.testing._internal.common_utils import IS_FBCODE try: @@ -15,9 +16,11 @@ def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs): ep = torch.export.export_for_training(*args, **kwargs) - return ep.run_decompositions( - {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY - ) + if IS_FBCODE: + return ep.run_decompositions( + {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY + ) + return ep.run_decompositions({}) def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs): @@ -25,9 +28,12 @@ def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs): ep = torch.export.export_for_training(*args, **kwargs) else: ep = torch.export.export_for_training(*args, **kwargs, strict=False) - return ep.run_decompositions( - {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY - ) + + if IS_FBCODE: + return ep.run_decompositions( + {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY + ) + return ep.run_decompositions({}) def make_dynamic_cls(cls, strict): diff --git a/test/test_decomp.py b/test/test_decomp.py index d458892851ef9f..7ab3859454ff6b 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -10,7 +10,7 @@ import torch._inductor.decomposition import torch.autograd from torch import Tensor -from torch._decomp import core_aten_decompositions, decomposition_table +from torch._decomp import _is_cia_op, core_aten_decompositions, decomposition_table from torch._dispatch.python import enable_python_dispatcher from torch._ops import DispatchKey from torch.testing import make_tensor @@ -62,7 +62,7 @@ def overload_to_aten_name(op): core_decomposition_names = { overload_to_aten_name(k) for k in core_aten_decompositions() - if isinstance(k, torch._ops.OpOverload) + if isinstance(k, torch._ops.OpOverload) and not _is_cia_op(k) } _decomp_test_ops = [ op diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 46ffa579aec564..a22289b75c4012 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -1,15 +1,26 @@ # mypy: allow-untyped-defs import inspect from collections import defaultdict -from functools import wraps +from functools import lru_cache, partial, wraps from itertools import chain -from typing import Callable, Dict, List, Sequence, TypeVar, Union +from typing import ( + Callable, + Dict, + FrozenSet, + List, + Optional, + Sequence, + Set, + TypeVar, + Union, +) from typing_extensions import ParamSpec import torch import torch.library -from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket +from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket from torch._prims_common import CustomOutParamAnnotation +from torch._subclasses.functional_tensor import FunctionalTensor from torch.utils import _pytree as pytree @@ -20,6 +31,8 @@ "register_decomposition", "get_decompositions", "core_aten_decompositions", + "_decomp_table_to_post_autograd_aten", + "_special_op_to_preserve_cia", ] _T = TypeVar("_T") @@ -250,13 +263,184 @@ def remove_decompositions( import torch._refs +# Our strategy for deciding if we can preserve a op is following: +# 1. The op should be known statically that it is functional +# 2. If it is maybe aliasing, we decompose because we must know if an op +# is mutating or aliasing. +# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor +# decomp part. (https://github.com/pytorch/pytorch/issues/129431) +def _check_valid_to_preserve(op_overload): + if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: + return False + if op_overload in FunctionalTensor.metadata_fns: + return False + + alias_info = len( + [i for i in op_overload._schema.arguments if i.alias_info is not None] + ) + + is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable + + if is_mutating_or_aliasing: + return False + + if not torch._C._dispatch_has_kernel(op_overload.name()): + return False + + return True + + +def _is_cia_op(op: "OpOverload") -> bool: + return ( + torch._C._dispatch_has_kernel_for_dispatch_key( + op.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ) + or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels + ) + + +@lru_cache(maxsize=1) +def _collect_all_valid_cia_ops() -> Set["OperatorBase"]: + """ + This is an util function that gets the all CIA functional ops. + + The algorithm is in 2 steps: + 1. We first query C++ dispatcher to get the list of CIA ops + and then we call getattr on torch.ops.aten to lazily populate + them. + + 2. Sometimes, handful of ops have CIA registered in python dispatcher + but not on the C++ side, these can't be caught at the first step. + So we walk again to get the final list. + + Note that the output of this function should never be modified + """ + # First step to lazily populate torch.ops.aten + cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key( + "CompositeImplicitAutograd" + ) + # Ignore quantized namespace ops + cia_ops = [name[6:] for name in cia_ops if name.startswith("aten::")] + # Materialize all CIA ops first + for op in cia_ops: + split_list = op.split(".") + # Sometime overload could be missing + assert len(split_list) == 1 or len(split_list) == 2 + op_name = split_list[0] + op_overload_name = "default" + if len(split_list) == 2: + op_overload_name = split_list[1] + + _ = getattr(getattr(torch.ops.aten, op_name), op_overload_name) + + # Second step to finally compile the list of all valid ops + cia_ops = set() + for op in torch.ops.aten: + op_packet = getattr(torch.ops.aten, op) + for overload in op_packet.overloads(): + op_overload = getattr(op_packet, overload) + if _check_valid_to_preserve(op_overload) and _is_cia_op(op_overload): + cia_ops.add(op_overload) + return cia_ops + + +def _get_decomp_for_cia(op): + # [NOTE] Seperating out func.decompose + # Ideally we should be able to just register func.decompose but + # we can't as this decomp is gonna be registered to the py_impl. + # As a result it will infinitely recurse. So we first check if the op + # has py_impl entry for CIA and if it is we use that first. If not, + # we register C++ query to py_impl. + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey): + return op.py_kernels[dk] + + def _special_op_to_decompose_cia(*args, **kwargs): + kernel = kwargs["kernel"] + del kwargs["kernel"] + # Can't call kernel.decompose due to infinite recursion as + # we register this kernel to py_impl directly + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if torch._C._dispatch_has_kernel_for_dispatch_key( + kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + return kernel._op_dk(dk, *args, **kwargs) + else: + raise AssertionError( + f"Expected {kernel} to have CompositeImplicitAutograd kernel" + ) + + return partial(_special_op_to_decompose_cia, kernel=op) + + # See NOTE [Core ATen Ops] # # list was copied from torch/_inductor/decomposition.py # excluding decompositions that results in prim ops # Resulting opset of decomposition is core aten ops def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: + decomp_table = _core_aten_decompositions_post_autograd() + + # If it is fbcode change, we return the old decomposition list + from torch._inductor import config + + if config.is_fbcode(): + return decomp_table + + aten = torch.ops.aten + + # We are deleting custom decomp in core_aten_decomp + # for CIA ops but it should be fine technically + # because this table is only meant to be used in export context + # in which we really carefully control the decomp behaviour + # In any case, C++ decomps should be preferred + cia_ops_that_should_be_removed = [ + aten.all.dimname, + aten.index_add.dimname, + aten.index_copy.dimname, + aten.index_fill.Dimname_Scalar, + aten.index_fill.Dimname_Tensor, + aten.norm.names_ScalarOpt_dim_dtype, + aten.norm.names_ScalarOpt_dim, + aten.silu_backward.default, + aten.std.default, + aten.std.dim, + aten.std.names_dim, + aten.std.correction_names, + aten.std_mean.default, + aten.std_mean.dim, + aten.std_mean.names_dim, + aten.std_mean.correction_names, + aten.upsample_bilinear2d.vec, + aten.upsample_trilinear3d.vec, + ] + + for k in list(decomp_table.keys()): + if k in cia_ops_that_should_be_removed: + del decomp_table[k] + + for op in _collect_all_valid_cia_ops(): + decomp_table[op] = _get_decomp_for_cia(op) + return decomp_table + + +# This table is a stop-gap table which replicates +# the old behaviour of post-dispatch IR. +# This table contains all functional CIA ops mapping +# to their default decomp. In old export, this will +# be decomposed implicitly. +def _decomp_table_to_post_autograd_aten(): + decomp_table = {} + for k in _collect_all_valid_cia_ops(): + decomp_table[k] = _get_decomp_for_cia(k) + return decomp_table + + +def _core_aten_decompositions_post_autograd() -> ( + Dict[torch._ops.OperatorBase, Callable] +): aten = torch.ops.aten + # TODO Delete all mutating or CIA ops from this list return get_decompositions( [ aten.addcdiv, diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index c501a5b9953583..d830ed6dc65837 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -57,8 +57,8 @@ def _export_forward_backward( ep = _decompose_exported_program( ep, - decomp_table=core_aten_decompositions(), - _preserve_ops=(), # type: ignore[arg-type] + cia_to_decomp={}, + python_decomp_table=core_aten_decompositions(), joint_loss_index=joint_loss_index, ) gm, new_graph_signature = _copy_graph_module_and_signature(ep) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index b9c71bdc91deca..c5ee650a21a96e 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -54,7 +54,6 @@ from torch._export.verifier import Verifier from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import unset_fake_temporarily -from torch._subclasses.functional_tensor import FunctionalTensor from torch.export._tree_utils import is_equivalent, reorder_kwargs from torch.fx._compatibility import compatibility from torch.fx.passes.infra.pass_base import PassResult @@ -174,7 +173,7 @@ def _register_cia_to_meta(*args, **kwargs): @contextmanager -def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True): +def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): # This function overrides CompositeImplicitAutograd decomp for # functional composite ops that user specified. Ideally we want to not-decompose # ALL composite ops but today's C++ functinalization relies on @@ -192,48 +191,7 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True saved_tables = {} patched_ops = set() - removed_decomps = {} - for op_overload in ops_to_preserve: - # Our strategy for deciding if we can preserve CIA is following: - # 1. The op should be known statically that it is functional - # 2. If it is maybe aliasing, we decompose because we must know if an op - # is mutating or aliasing. - # TODO (tmanlaibaatar) make this utility function and share it with functional_tensor - # decomp part. (https://github.com/pytorch/pytorch/issues/129431) - def assert_valid_to_preserve(op_overload): - if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: - raise RuntimeError( - f"We can't detect {op_overload} as a functional op statically, so we can't preserve it" - ) - if op_overload in FunctionalTensor.metadata_fns: - raise RuntimeError( - f"{op_overload} is a metadata query function, " - "it will be preserved implicitly in our tracing system. " - "Please file an issue on github if you see otherwise" - ) - - alias_info = len( - [i for i in op_overload._schema.arguments if i.alias_info is not None] - ) - - is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable - - if is_mutating_or_aliasing: - raise RuntimeError( - f"{op_overload} is a mutating/aliasing op, we can't preserve it as is" - ) - - if not torch._C._dispatch_has_kernel(op_overload.name()): - raise RuntimeError( - f"{op_overload} is a TorchScript op, we can't preserve it as is" - ) - - return True - - if safe: - # If we didn't error, it means we can go ahead - assert_valid_to_preserve(op_overload) - + for op_overload, decomp_callable in cia_ops_to_callable.items(): saved_tables[op_overload] = op_overload.py_kernels.copy() patched_ops.add(op_overload) @@ -247,11 +205,9 @@ def assert_valid_to_preserve(op_overload): del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] if safe: - - def _(*args, **kwargs): - return NotImplemented - - op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_) + op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( + decomp_callable + ) # For fake tensor prop, we do want to register meta kernel directly if torch._C.DispatchKey.Meta not in op_overload.py_kernels: @@ -259,10 +215,6 @@ def _(*args, **kwargs): functools.partial(_register_cia_to_meta, kernel=op_overload) ) - if op_overload in decomp_table: - removed_decomps[op_overload] = decomp_table[op_overload] - del decomp_table[op_overload] - try: yield finally: @@ -271,8 +223,10 @@ def _(*args, **kwargs): op.py_kernels.update(saved_tables[op]) op._dispatch_cache.clear() - for op, decomp in removed_decomps.items(): - decomp_table[op] = decomp + +def _special_op_to_preserve_cia(*args, **kwargs): + "This is an special marker that tells our infra that we shouldn't decompose this op" + return NotImplemented @contextmanager @@ -281,18 +235,65 @@ def _override_decomp_aten_to_variants(): # and their CompositeImplicitAutograd kernels will not become NotImplemented. # We will later replace them with aten._to_copy when functionalizing. with _override_composite_implicit_decomp( - (torch.ops.aten.to.dtype_layout, torch.ops.aten.to.dtype), - {}, + { + torch.ops.aten.to.dtype_layout: _special_op_to_preserve_cia, + torch.ops.aten.to.dtype: _special_op_to_preserve_cia, + }, safe=False, ): yield +def _split_decomp_table_to_cia_and_python_decomp( + decomp_table: Dict[torch._ops.OperatorBase, Callable] +) -> Tuple[Dict[torch._ops.OperatorBase, Callable], ...]: + from torch._decomp import _collect_all_valid_cia_ops, _get_decomp_for_cia + + all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) + cia_ops_to_callable = {} + + for op in list(decomp_table.keys()): + # TODO we are silently allowing non-safe(non-functional) ops through a crack + # due to core aten decomp table having non-functional entries. Once we have + # a tigher check around core aten decomp, we should warn users about them. + # Tracking issue: (https://github.com/pytorch/pytorch/issues/135759) + + # if it is a valid CIA op we can mess with in export, we check if it is: + # 1. Has been marked as to be decomposed. Example: + # decomp_table = decomp_table_to_core_aten() + # del decomp_table[aten.linear] + # In this case, user says decompose everything except for aten.linear + # 2. Has been marked with custom decomp behavour. Example: + # decomp_table = {aten.linear: some_op} + # For (1), we want to remove all the CIA ops that weren't handled by user as + # it suggests they are safe to decompose, so we should remove from preservable_list. + # for (2), we just plumb the custom decomp to AOTDIspatcher. + # In both cases, we want to remove this CIA op from the decomp_table as it is special + # handled. + if op in all_preservable_cia_ops: + # TODO this is annpying case where aten.item has + # prim decomposition which later calls into aten.item + # and recurses infinitely. (https://github.com/pytorch/pytorch/issues/136050) + if op == torch.ops.aten.item.default: + cia_ops_to_callable[op] = _get_decomp_for_cia(op) + else: + cia_ops_to_callable[op] = decomp_table[op] + all_preservable_cia_ops.remove(op) + del decomp_table[op] + + # If we reached here, it means user intentionally deleted these CIA ops from + # decomp table. + for k in all_preservable_cia_ops: + cia_ops_to_callable[k] = _special_op_to_preserve_cia + + return cia_ops_to_callable, decomp_table + + def _decompose_and_get_gm_with_new_signature_constants( ep, *, - decomp_table: Dict[torch._ops.OperatorBase, Callable], - _preserve_ops: Tuple[torch._ops.OpOverload], + cia_to_decomp: Dict[torch._ops.OperatorBase, Callable], + python_decomp_table: Dict[torch._ops.OperatorBase, Callable], joint_loss_index: Optional[int], ): from torch._functorch.aot_autograd import aot_export_module @@ -356,8 +357,7 @@ def _decompose_and_get_gm_with_new_signature_constants( with _ignore_backend_decomps(), ( fake_mode ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( - _preserve_ops, - decomp_table, + cia_to_decomp, ): aten_export_artifact = _export_to_aten_ir( mod, @@ -369,7 +369,7 @@ def _decompose_and_get_gm_with_new_signature_constants( {}, fake_params_buffers, constant_attrs, - decomp_table=decomp_table, + decomp_table=python_decomp_table, _check_autograd_state=False, ) @@ -404,13 +404,12 @@ def _decompose_and_get_gm_with_new_signature_constants( fake_mode = detect_fake_mode(fake_args) fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( - _preserve_ops, - decomp_table, + cia_to_decomp ): gm, graph_signature = aot_export_module( ep.graph_module, fake_args, - decompositions=decomp_table, + decompositions=python_decomp_table, trace_joint=True if joint_loss_index is not None else False, output_loss_index=joint_loss_index if joint_loss_index is not None @@ -610,14 +609,14 @@ def _common_getitem_elimination_pass( def _decompose_exported_program( ep, *, - decomp_table: Dict[torch._ops.OperatorBase, Callable], - _preserve_ops: Tuple[torch._ops.OpOverload], + cia_to_decomp: Dict[torch._ops.OperatorBase, Callable], + python_decomp_table: Dict[torch._ops.OperatorBase, Callable], joint_loss_index: Optional[int], ): gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants( ep, - decomp_table=decomp_table, - _preserve_ops=_preserve_ops, + cia_to_decomp=cia_to_decomp, + python_decomp_table=python_decomp_table, joint_loss_index=joint_loss_index, ) @@ -994,16 +993,83 @@ def run_decompositions( `Core ATen Operator Set `_. For now, we do not decompose joint graphs. - """ - from torch._decomp import core_aten_decompositions - if decomp_table is None: + Args: + decomp_table: + An optional argument that specifies decomp behaviour for Aten ops + (1) If None, we decompose to core aten decompositions + (2) If empty, we don't decompose any operator + + + Some examples: + + If you don't want to decompose anything + + .. code-block:: python + + ep = torch.export.export(model, ...) + ep = ep.run_decompositions(decomp_table={}) + + If you want to get a core aten operator set except for certain operator, you can do following: + + .. code-block:: python + + ep = torch.export.export(model, ...) + from torch._decomp import core_aten_decompositions decomp_table = core_aten_decompositions() + decomp_table[your_op] = your_custom_decomp + ep = ep.run_decompositions(decomp_table=decomp_table) + """ + from torch._decomp import ( + _decomp_table_to_post_autograd_aten, + core_aten_decompositions, + ) + from torch._inductor import config + + # FIXME delete this option after PTC, Executorch syncing is + # bit annoying so can't get rid of it easily + if _preserve_ops != (): + warnings.warn( + "This API is deprecated and soon will be removed. " + "Please look at the docstring to see how to preserve " + "an operator." + ) + + _decomp_table = ( + core_aten_decompositions() if decomp_table is None else dict(decomp_table) + ) + + if config.is_fbcode(): + # This means the decomp_table would only be containing post-autograd ops + # We should manually add CIA decomps + for k, v in _decomp_table_to_post_autograd_aten().items(): + _decomp_table[k] = v + + for op in _preserve_ops: + if op in _decomp_table: + del _decomp_table[op] + + # Note [Seperating decomp_table into CIA decomps and non-CIA decomps] + # At this point, we have a decomp_table that contains decomp behaviour for + # both CIA and post-autograd ops. + # We need to separate the op into two categories: + # 1. CIA op: These are the ops that we want to override + # CompositeImplicitAutograd decomp for. For them, we need to use _override_composite_implicit_decomp + # context manager to plumb it through AOTDispatcher + # 2. Non-CIA op: These ops are only relevant after AOTDIspatcher runs, so just + # checking if they are statically functional is enough. + # For joint IR case tho, we need to use the old path because we can't register + # custom decomps this way because we can't use context manager as it installs + # autograd_error node. + ( + cia_to_decomp, + python_decomp_table, + ) = _split_decomp_table_to_cia_and_python_decomp(_decomp_table) return _decompose_exported_program( self, - decomp_table=decomp_table, - _preserve_ops=_preserve_ops, # type: ignore[arg-type] + cia_to_decomp=cia_to_decomp, + python_decomp_table=python_decomp_table, joint_loss_index=None, ) diff --git a/torch/onnx/_internal/exporter/_decomp.py b/torch/onnx/_internal/exporter/_decomp.py index 3bbff757e92ef2..87e4925b06b6a0 100644 --- a/torch/onnx/_internal/exporter/_decomp.py +++ b/torch/onnx/_internal/exporter/_decomp.py @@ -1,5 +1,3 @@ -"""Build decomp table from PyTorch.""" - # mypy: allow-untyped-defs from __future__ import annotations @@ -83,7 +81,14 @@ def create_onnx_friendly_decomposition_table( Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding decomposition functions. """ - decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} + # This table contains all CIA decomps, so we should filter out ONNX supported CIAs from it + decomposition_table: dict[torch._ops.OperatorBase, Callable] = ( + torch._decomp._decomp_table_to_post_autograd_aten() + ) + can_preserve = get_preserve_ops().intersection(onnx_registered_ops) + for op in list(decomposition_table.keys()): + if op in can_preserve: + del decomposition_table[op] # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your @@ -95,6 +100,9 @@ def create_onnx_friendly_decomposition_table( # not exportable anyways. if op_overload in onnx_registered_ops: continue + # If it is HOP, we filter those out as well. + if not hasattr(op_overload, "_schema"): + continue decomposition_table[op_overload] = decomp_fn return decomposition_table diff --git a/torch/onnx/_internal/exporter/_fx_passes.py b/torch/onnx/_internal/exporter/_fx_passes.py index 4c19b552d0e4f0..c1083b65f8fcea 100644 --- a/torch/onnx/_internal/exporter/_fx_passes.py +++ b/torch/onnx/_internal/exporter/_fx_passes.py @@ -17,11 +17,7 @@ def decompose_with_registry( """ onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry)) decomp_table = _decomp.create_onnx_friendly_decomposition_table(onnx_registered_ops) - # Try to preserve some known CompositeImplicitAutograd ops - to_preserve = _decomp.get_preserve_ops() - # We can only preserve implemented ops - can_preserve = tuple(to_preserve.intersection(onnx_registered_ops)) - return exported_program.run_decompositions(decomp_table, _preserve_ops=can_preserve) + return exported_program.run_decompositions(decomp_table) def insert_type_promotion_nodes( From af4116213c61ccd0a983028a8019f15aba83a11d Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Sat, 14 Sep 2024 22:28:17 -0700 Subject: [PATCH 0919/1018] Create export_for_inference API and expose core_aten as public facing API (#135912) Differential Revision: [D62606908](https://our.internmc.facebook.com/intern/diff/D62606908) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135912 Approved by: https://github.com/avikchaudhuri ghstack dependencies: #135080 --- docs/source/export.rst | 1 + test/export/test_export.py | 50 +++++++++++++++++ torch/export/__init__.py | 95 +++++++++++++++++++++++++++++++- torch/export/exported_program.py | 15 ++++- 4 files changed, 158 insertions(+), 3 deletions(-) diff --git a/docs/source/export.rst b/docs/source/export.rst index 3908beeb9b345e..a21058d62ad945 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -677,6 +677,7 @@ API Reference .. autofunction:: load .. autofunction:: register_dataclass .. autofunction:: torch.export.dynamic_shapes.Dim +.. autofunction:: torch.export.exported_program.core_aten_decompositions .. autofunction:: dims .. autoclass:: torch.export.dynamic_shapes.ShapesCollection diff --git a/test/export/test_export.py b/test/export/test_export.py index e2711e120ea602..0ab2f49428b29d 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -4762,6 +4762,56 @@ def forward(self, x): self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp))) self.assertEqual(id(state_dict), id(ep.state_dict)) + @unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode") + def test_export_for_inference_e2e(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin(x) + + inp = (torch.randn(5, 10),) + m = M() + + decomp_table = torch.export.core_aten_decompositions() + + def _custom_decomp_for_linear(x, weight, bias): + return x + bias.sum() + + decomp_table[torch.ops.aten.linear.default] = _custom_decomp_for_linear + del decomp_table[torch.ops.aten.sum.default] + ep = torch.export.export_for_inference( + m, inp, decomp_table=decomp_table, dynamic_shapes={"x": {0: Dim("batch")}} + ) + + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, p_lin_weight, p_lin_bias, x): + sum_1 = torch.ops.aten.sum.default(p_lin_bias); p_lin_bias = None + add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None + return (add,)""", + ) + + ep_core = ep.run_decompositions() + + self.assertExpectedInline( + str(ep_core.graph_module.code).strip(), + """\ +def forward(self, p_lin_weight, p_lin_bias, x): + sum_1 = torch.ops.aten.sum.dim_IntList(p_lin_bias, []); p_lin_bias = None + add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None + return (add,)""", + ) + + with self.assertRaisesRegex(RuntimeError, "Expected input"): + ep.module()(torch.randn(4, 12)) + + with self.assertRaisesRegex(RuntimeError, "Expected input"): + ep_core.module()(torch.randn(4, 12)) + @unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode") def test_export_decomp_torture_case_1(self): class M(torch.nn.Module): diff --git a/torch/export/__init__.py b/torch/export/__init__.py index b280fa29ff9789..336b36424f31ce 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: # Import the following modules during type checking to enable code intelligence features, # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch._ops import OpOverload from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint @@ -49,9 +50,11 @@ "ExportedProgram", "ModuleCallEntry", "ModuleCallSignature", + "core_aten_decompositions", "dims", "export", "export_for_training", + "export_for_inference", "load", "register_dataclass", "save", @@ -62,7 +65,12 @@ from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection -from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature +from .exported_program import ( + core_aten_decompositions, + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) from .graph_signature import ExportBackwardSignature, ExportGraphSignature from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule @@ -165,6 +173,91 @@ def export_for_training( ) +def export_for_inference( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), + decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None, +) -> ExportedProgram: + """ + :func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the ATen operator set + (as well as any user-specified custom operators) which is customizable via decomp_table, + (2) has eliminated all Python control flow and data structures (with certain exceptions), + and (3) records the set of shape constraints needed to show that this normalization and control-flow + elimination is sound for future inputs. This API is for convenience use as it combines :func:`export_for_training` and + :func:`run_decompositions`. + + **Soundness Guarantee** + + See :func:`export()` docstring for more details. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + decomp_table: See :func:`run_decompositions` for more details. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + + ep_for_training = export_for_training( + mod, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + ) + + return ep_for_training.run_decompositions(decomp_table=decomp_table) + + def export( mod: torch.nn.Module, args: Tuple[Any, ...], diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index c5ee650a21a96e..2788e804257ac1 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -78,6 +78,7 @@ "ExportedProgram", "ModuleCallEntry", "ModuleCallSignature", + "core_aten_decompositions", ] @@ -289,6 +290,17 @@ def _split_decomp_table_to_cia_and_python_decomp( return cia_ops_to_callable, decomp_table +def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: + """ + This is the default decomposition table which contains decomposition of + all ATEN operators to core aten opset. Use this API together with + :func:`run_decompositions()` + """ + from torch._decomp import core_aten_decompositions + + return core_aten_decompositions() + + def _decompose_and_get_gm_with_new_signature_constants( ep, *, @@ -1015,8 +1027,7 @@ def run_decompositions( .. code-block:: python ep = torch.export.export(model, ...) - from torch._decomp import core_aten_decompositions - decomp_table = core_aten_decompositions() + decomp_table = torch.export.core_aten_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table) """ From c87301c8278537a5d5b21ea99f93c4e5b1f82f84 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Sat, 14 Sep 2024 22:28:17 -0700 Subject: [PATCH 0920/1018] Add some doc for export_for_training (#135918) Differential Revision: [D62610491](https://our.internmc.facebook.com/intern/diff/D62610491) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135918 Approved by: https://github.com/avikchaudhuri ghstack dependencies: #135080, #135912 --- docs/source/export.rst | 172 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/docs/source/export.rst b/docs/source/export.rst index a21058d62ad945..603594847f061c 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -291,6 +291,178 @@ torch.export>`), but seeing as the context manager does not affect the tensor computations in the model, we can go with the non-strict mode's result. +.. _Training Export: + +Export for Training and Inference +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In PyTorch 2.5, we introduced a new API called :func:`export_for_training`. +It's still going through hardening, so if you run into any issues, please file +them to Github with the "oncall: export" tag. + +In this API, we produce the most generic IR that contains all ATen operators +(including both functional and non-functional) which can be used to train in +eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization +and will soon be the default IR of torch.export.export. To read further about +the motivation behind this change, please refer to +https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206 + +When this API is combined with :func:`run_decompositions()`, you should be able to get inference IR with +any desired decomposition behavior. + +To show some examples: + +:: + + class ConvBatchnorm(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) + + mod = ConvBatchnorm() + inp = torch.randn(1, 1, 3, 3) + + ep_for_training = torch.export.export_for_training(mod, (inp,)) + print(ep_for_training) + +.. code-block:: + + ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None + batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + return (batch_norm,) + + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='batch_norm'), target=None) + ] + ) + Range constraints: {} + + +From the above output, you can see that :func:`export_for_training` produces pretty much the same ExportedProgram +as :func:`export` except for the operators in the graph. You can see that we captured batch_norm in the most general +form. This op is non-functional and will be lowered to different ops when running inference. + +You can also go from this IR to an inference IR via :func:`run_decompositions` with arbitrary customizations. + +:: + + # Lower to core aten inference IR, but keep conv2d + decomp_table = torch.export.core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + ep_for_inference = ep_for_training.run_decompositions(decomp_table) + + print(ep_for_inference) + +.. code-block:: + + ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] + getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + return (getitem_3, getitem_4, add, getitem) + + Graph signature: ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'), + OutputSpec(kind=, arg=TensorArgument(name='getitem_4'), target='bn.running_var'), + OutputSpec(kind=, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'), + OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None) + ] + ) + Range constraints: {} + +Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR +containing core aten operators except for `conv2d`. + +You can do even more customization by directly registering your chosen decomposition behaviors. + +You can do even more customizations by directly registering custom decomp behaviour + +:: + + # Lower to core aten inference IR, but customize conv2d + decomp_table = torch.export.core_aten_decompositions() + + def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): + return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) + + decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function + ep_for_inference = ep_for_training.run_decompositions(decomp_table) + + print(ep_for_inference) + +.. code-block:: + + ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None + mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] + getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + return (getitem_3, getitem_4, add, getitem) + + Graph signature: ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'), + OutputSpec(kind=, arg=TensorArgument(name='getitem_4'), target='bn.running_var'), + OutputSpec(kind=, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'), + OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None) + ] + ) + Range constraints: {} + + Expressing Dynamism ^^^^^^^^^^^^^^^^^^^ From 634d85ac85d13195c48863c3ef28ecf7dc96afb1 Mon Sep 17 00:00:00 2001 From: Andrii Grynenko Date: Sun, 15 Sep 2024 18:07:49 +0000 Subject: [PATCH 0921/1018] [pytorch][monitoring] Dynamic backend for WaitCounter (#135967) Summary: This implements a default backend proxy that tries to look up a backend via dlsym. What this enables is dynamically loading a module with a backend implementation without having it statically linked with the application. Differential Revision: D62549295 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135967 Approved by: https://github.com/c-p-i-o --- c10/CMakeLists.txt | 1 + c10/util/WaitCounter.cpp | 66 ++++++++++++++++++++++++++++ c10/util/WaitCounter.h | 3 ++ c10/util/WaitCounterDynamicBackend.h | 21 +++++++++ c10/util/build.bzl | 4 ++ 5 files changed, 95 insertions(+) create mode 100644 c10/util/WaitCounterDynamicBackend.h diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 80e172497d5e66..34577caef2ec6a 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -127,6 +127,7 @@ if(NOT BUILD_LIBTORCHLESS) if(LINUX) target_link_libraries(c10 PRIVATE Threads::Threads) + target_link_libraries(c10 PRIVATE dl) endif() if(ANDROID) diff --git a/c10/util/WaitCounter.cpp b/c10/util/WaitCounter.cpp index 03c2ac72939dc8..3941942dfb3500 100644 --- a/c10/util/WaitCounter.cpp +++ b/c10/util/WaitCounter.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -8,6 +9,10 @@ #include #include +#ifndef _WIN32 +#include +#endif + namespace c10::monitor { namespace detail { @@ -19,6 +24,58 @@ Synchronized& waitCounterBackendFactories() { static auto instance = new Synchronized(); return *instance; } + +class DynamicBackendWrapper : public WaitCounterBackendIf { + public: + explicit DynamicBackendWrapper(WaitCounterDynamicBackend impl) + : impl_{impl} {} + ~DynamicBackendWrapper() override { + impl_.destroy(impl_.self); + } + + intptr_t start(std::chrono::steady_clock::time_point now) noexcept override { + return impl_.start( + impl_.self, + std::chrono::duration_cast( + now.time_since_epoch()) + .count()); + } + + void stop(std::chrono::steady_clock::time_point now, intptr_t ctx) noexcept + override { + return impl_.stop( + impl_.self, + std::chrono::duration_cast( + now.time_since_epoch()) + .count(), + ctx); + } + + private: + WaitCounterDynamicBackend impl_; +}; + +std::unique_ptr getDynamicBackend(std::string_view key) { + static auto dynamicBackendInit = + reinterpret_cast([]() -> void* { +#ifndef _WIN32 + return dlsym( + RTLD_DEFAULT, + std::string(kWaitCounterDynamicBackendInitFn).c_str()); +#else + return nullptr; +#endif + }()); + if (!dynamicBackendInit) { + return nullptr; + } + WaitCounterDynamicBackend backend; + dynamicBackendInit(&backend, &key[0], key.size()); + if (!backend.self) { + return nullptr; + } + return std::make_unique(backend); +} } // namespace class WaitCounterImpl { @@ -70,6 +127,9 @@ class WaitCounterImpl { backends_.push_back(std::move(backend)); } } + if (auto backend = getDynamicBackend(key)) { + backends_.push_back(std::move(backend)); + } } SmallVector> backends_; @@ -80,6 +140,12 @@ void registerWaitCounterBackend( waitCounterBackendFactories().withLock( [&](auto& factories) { factories.push_back(std::move(factory)); }); } + +std::vector> +getRegisteredWaitCounterBackends() { + return waitCounterBackendFactories().withLock( + [](auto& factories) { return factories; }); +} } // namespace detail WaitCounterHandle::WaitCounterHandle(std::string_view key) diff --git a/c10/util/WaitCounter.h b/c10/util/WaitCounter.h index dba8f82f3ca365..504e88720a9c12 100644 --- a/c10/util/WaitCounter.h +++ b/c10/util/WaitCounter.h @@ -36,6 +36,9 @@ class WaitCounterBackendFactoryIf { C10_API void registerWaitCounterBackend( std::unique_ptr); + +C10_API std::vector> +getRegisteredWaitCounterBackends(); } // namespace detail // A handle to a wait counter. diff --git a/c10/util/WaitCounterDynamicBackend.h b/c10/util/WaitCounterDynamicBackend.h new file mode 100644 index 00000000000000..ecbdc7f09f0341 --- /dev/null +++ b/c10/util/WaitCounterDynamicBackend.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace c10::monitor::detail { + +struct WaitCounterDynamicBackend { + void* self{nullptr}; + intptr_t (*start)(void* self, int64_t nowUs){nullptr}; + void (*stop)(void* self, int64_t nowUs, intptr_t ctx){nullptr}; + void (*destroy)(void* self){nullptr}; +}; + +using WaitCounterDynamicBackendInit = + void (*)(WaitCounterDynamicBackend*, const char* key, std::size_t keyLen); + +// This name needs to be updated if anything in the API above is changed. +constexpr std::string_view kWaitCounterDynamicBackendInitFn = + "c10_monitor_wait_counter_dynamic_backend_init_v1"; +} // namespace c10::monitor::detail diff --git a/c10/util/build.bzl b/c10/util/build.bzl index fb06a1517834d1..a6f95ae7516d36 100644 --- a/c10/util/build.bzl +++ b/c10/util/build.bzl @@ -43,6 +43,10 @@ def define_targets(rules): "//c10:using_glog": ["@com_github_glog//:glog"], "//conditions:default": [], }), + linkopts = rules.select({ + "@bazel_tools//src/conditions:windows": [], + "//conditions:default": ["-ldl"], + }), # This library uses flags and registration. Do not let the # linker remove them. alwayslink = True, From dbe83ff5a902a4dfb8914b4a8f497df5e86ed198 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Tue, 10 Sep 2024 13:33:02 +0000 Subject: [PATCH 0922/1018] Add decomposition for permute_copy (#130944) * Extracted from #129476 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130944 Approved by: https://github.com/amjames, https://github.com/eellison --- test/distributed/_tensor/test_dtensor_ops.py | 1 + ...HasDecompTest.test_has_decomposition.expect | 2 -- test/functorch/test_ops.py | 2 ++ test/functorch/test_vmap.py | 1 + test/test_mps.py | 1 + tools/autograd/gen_variable_type.py | 1 + torch/_decomp/__init__.py | 1 + torch/_refs/__init__.py | 2 ++ .../_internal/common_methods_invocations.py | 18 ++++++++++++++++++ 9 files changed, 27 insertions(+), 2 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 532bd3facae554..c6fd0e01e5ffd9 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -370,6 +370,7 @@ def wrapped(fn): xfail("ormqr"), xfail("ones"), xfail("pca_lowrank"), + xfail("permute_copy"), xfail("pinverse"), xfail("polar"), xfail("put"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index a759903d065591..357a5afefd4d3e 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1014,8 +1014,6 @@ aten::ones.names_out aten::ones.out aten::ormqr aten::ormqr.out -aten::permute_copy -aten::permute_copy.out aten::poisson aten::poisson.out aten::polygamma diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 5cd18a072d8daf..9b04fb618ce616 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1415,6 +1415,7 @@ def test_vmapjvpall(self, device, dtype, op): xfail("nn.functional.dropout3d", ""), xfail("as_strided_scatter", ""), xfail("masked.cumprod", ""), + xfail("permute_copy"), xfail("renorm"), # hit vmap fallback, which is disabled xfail("t_copy"), xfail("transpose_copy"), @@ -1479,6 +1480,7 @@ def test(): xfail("masked_select"), xfail("nanquantile"), xfail("ormqr"), + xfail("permute_copy"), xfail("put"), xfail("quantile"), xfail("renorm"), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 1bfc31fe521bb5..329ae0013bf293 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4470,6 +4470,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("histc"), xfail("as_strided"), xfail("as_strided_copy"), + xfail("permute_copy"), xfail("t_copy"), xfail("unsqueeze_copy"), xfail("istft"), diff --git a/test/test_mps.py b/test/test_mps.py index 32adede8dddca0..20dae361ba8fcf 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -318,6 +318,7 @@ def mps_ops_modifier(ops): 'ones', 'outer', 'permute', + 'permute_copy', 'positive', 'randn', 'ravel', diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index fa6c578dea04a2..d1e9245cd72f9c 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -199,6 +199,7 @@ "transpose", "transpose_copy", "permute", + "permute_copy", "squeeze", "unsqueeze", "unsqueeze_copy", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index a22289b75c4012..c21ee837471178 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -569,6 +569,7 @@ def _core_aten_decompositions_post_autograd() -> ( aten.norm, aten.ones, aten.ones_like, + aten.permute_copy, aten.pixel_shuffle, aten.pixel_unshuffle, aten._prelu_kernel, diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 9b4d10cd5a6ad4..41b9031a304227 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -274,6 +274,7 @@ "native_group_norm", "native_layer_norm", "permute", + "permute_copy", "ravel", "repeat", "reshape", @@ -6335,6 +6336,7 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) # TODO: This must return a sparse tensor if the input is sparse, but refs have # no sparse support. See narrow_copy_sparse in core. narrow_copy = _make_copy_from_view(aten.narrow) +permute_copy = _make_copy_from_view(aten.permute) t_copy = _make_copy_from_view(aten.t) transpose_copy = _make_copy_from_view(aten.transpose) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5f4b98aaa5eaac..9c23e438ba1c43 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -16854,6 +16854,19 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_varargs=True, sample_inputs_func=sample_inputs_permute, reference_inputs_func=reference_inputs_permute), + OpInfo('permute_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_varargs=False, # torch.permute is also not varargs + sample_inputs_func=sample_inputs_permute, + reference_inputs_func=reference_inputs_permute, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), BinaryUfuncInfo('pow', dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), @@ -23759,6 +23772,11 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.permute", torch_opinfo_name="permute", ), + PythonRefInfo( + "_refs.permute_copy", + torch_opinfo_name="permute_copy", + supports_out=True, + ), ElementwiseUnaryPythonRefInfo( "_refs.rad2deg", torch_opinfo_name="rad2deg", From 17f48c5ddea15b89871f753488448cf49053cb8d Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Sun, 15 Sep 2024 19:36:28 +0000 Subject: [PATCH 0923/1018] Update link in distributed.tensor.parallel.rst (#136103) dtensor folder was moved Pull Request resolved: https://github.com/pytorch/pytorch/pull/136103 Approved by: https://github.com/kwen2501, https://github.com/fegin --- docs/source/distributed.tensor.parallel.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/distributed.tensor.parallel.rst b/docs/source/distributed.tensor.parallel.rst index de7525b32e2e2d..694212296e35b0 100644 --- a/docs/source/distributed.tensor.parallel.rst +++ b/docs/source/distributed.tensor.parallel.rst @@ -5,7 +5,7 @@ Tensor Parallelism - torch.distributed.tensor.parallel ====================================================== Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor -(`DTensor `__) +(`DTensor `__) and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism. .. warning :: From 270ba69cef0d7ef4f607f4f0aa39535b7959abeb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sun, 15 Sep 2024 11:22:03 -0700 Subject: [PATCH 0924/1018] [Distributed] add pack-check method for float8_e5m2 (#136115) Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check). Made `HasNanFP8x8` a template so that it is extendable based on dtype. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136115 Approved by: https://github.com/Skylion007 ghstack dependencies: #135891, #135961 --- torch/csrc/distributed/c10d/NanCheck.cu | 46 ++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/torch/csrc/distributed/c10d/NanCheck.cu b/torch/csrc/distributed/c10d/NanCheck.cu index 02a4b6e963a4b7..d256413d60a10f 100644 --- a/torch/csrc/distributed/c10d/NanCheck.cu +++ b/torch/csrc/distributed/c10d/NanCheck.cu @@ -79,15 +79,33 @@ struct CheckBytePack { } }; -// (v) Template specialization for Float8_e4m3fn. +// (v) Template specialization for Float8 types. // EltPerPack = 16 / 1 = 16 +// We want to check 8 x FP8 simultaneously, hence this template definition. +template +struct HasNanFP8x8 { + static __device__ __forceinline__ bool check(uint64_t fp8x8) = delete; + /* + { + // `static_assert` in template definition requires c++23 onwards. + // But the error message still applies if you find yourself here. + static_assert( + false, + "You should never call this template definition because it is empty. You " + "can follow the example of Float8_e4m3fn below to implement the check for " + "your new datatype." + ); + } + */ +}; + // isnan condition for Float8_e4m3fn: // (x & 0b01111111) == 0b01111111 // i.e. // (x & 0x7f) == 0x7f -// We want to check 8 x FP8 simultaneously. The algorithm is as follows: +// The algorithm is as follows: // (1) Mask out the most significant bit with mask 0x7f. // (2) If the result is 0x7f (is nan), the following arithmetic would cause the // 8th bit to be 1: x[i] = x[i] + 0x01 @@ -95,16 +113,34 @@ struct CheckBytePack { // (4) If any x[i] is nan, then the whole x != 0. template<> -struct CheckBytePack { - static __device__ __forceinline__ bool hasNanFP8x8(uint64_t fp8x8) { +struct HasNanFP8x8 { + static __device__ __forceinline__ bool check(uint64_t fp8x8) { auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; auto incremented = t + 0x0101010101010101ULL; auto overflow = incremented & 0x8080808080808080ULL; return overflow != 0; } +}; + +// isnan condition for Float8_e5m2: +// (x & 0x7f) > 0x7c +// This case does not overflow: 0x7c + 0x03 == 0x7f but adding 0x03 to anything +// greater than 0x7c will overflow. + +template<> +struct HasNanFP8x8 { + static __device__ __forceinline__ bool check(uint64_t fp8x8) { + auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; + auto incremented = t + 0x0303030303030303ULL; + auto overflow = incremented & 0x8080808080808080ULL; + return overflow != 0; + } +}; +template +struct CheckBytePack { static __device__ __forceinline__ void check(BytePack* tmp) { - if (hasNanFP8x8(tmp->ul[0]) || hasNanFP8x8(tmp->ul[1])) + if (HasNanFP8x8::check(tmp->ul[0]) || HasNanFP8x8::check(tmp->ul[1])) __trap(); } }; From 87f092da8506d1aed1c2289483e775062c8d2fa5 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 14 Sep 2024 18:52:14 +0800 Subject: [PATCH 0925/1018] [dynamo] simplify implementation for `functools.reduce` (#133778) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133778 Approved by: https://github.com/jansel, https://github.com/anijain2305 --- torch/_dynamo/decorators.py | 10 ++++++- torch/_dynamo/polyfills/functools.py | 42 +++++++++++++++++++++++++++- torch/_dynamo/trace_rules.py | 14 +++++++++- torch/_dynamo/variables/builtin.py | 18 +++--------- 4 files changed, 67 insertions(+), 17 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 235db547dc7b1c..67d6c0f27a4c26 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -298,7 +298,10 @@ def sig_ident(sig): ) from torch._dynamo.guards import GuardBuilder - from torch._dynamo.trace_rules import get_torch_obj_rule_map + from torch._dynamo.trace_rules import ( + _polyfilled_function_ids, + get_torch_obj_rule_map, + ) from torch._dynamo.variables import PolyfilledFunctionVariable from torch._dynamo.variables.builder import VariableBuilder @@ -309,6 +312,9 @@ def sig_ident(sig): "already registered in VariableBuilder's id dispatch map" ) + if id(original_fn) in _polyfilled_function_ids: + raise ValueError(f"Duplicate polyfilled object {original_fn}") + rule_map: Dict[Any, Type[VariableTracker]] = get_torch_obj_rule_map() if original_fn in rule_map: raise ValueError( @@ -338,6 +344,8 @@ def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable: ) id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn + _polyfilled_function_ids.add(id(original_fn)) + _polyfilled_function_ids.add(id(wrapped)) rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = wrapped # type: ignore[assignment] diff --git a/torch/_dynamo/polyfills/functools.py b/torch/_dynamo/polyfills/functools.py index 98fd9597773b35..329725b41e354a 100644 --- a/torch/_dynamo/polyfills/functools.py +++ b/torch/_dynamo/polyfills/functools.py @@ -2,5 +2,45 @@ Python polyfills for functools """ +import functools +from typing import Callable, Iterable, TypeVar -__all__ = [] # type: ignore[var-annotated] +from ..decorators import substitute_in_graph + + +__all__ = ["reduce"] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class _INITIAL_MISSING: + pass + + +# Reference: https://docs.python.org/3/library/functools.html#functools.reduce +@substitute_in_graph(functools.reduce) +def reduce( + function: Callable[[_U, _T], _U], + iterable: Iterable[_T], + initial: _U = _INITIAL_MISSING, # type: ignore[assignment] + /, +) -> _U: + it = iter(iterable) + + value: _U + if initial is _INITIAL_MISSING: + try: + value = next(it) # type: ignore[assignment] + except StopIteration: + raise TypeError( + "reduce() of empty iterable with no initial value", + ) from None + else: + value = initial + + for element in it: + value = function(value, element) + + return value diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index ed81b26191b4de..44e1d662efe4af 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2999,13 +2999,18 @@ def _builtin_function_ids() -> Dict[int, str]: rv.update( { id(cast): "typing.cast", - id(functools.reduce): "functools.reduce", id(copy.deepcopy): "copy.deepcopy", } ) return rv +@FunctionIdSet +def _polyfilled_function_ids() -> Set[int]: + # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids + return set() + + @FunctionIdSet def _numpy_function_ids() -> Dict[int, str]: rv = {} @@ -3081,6 +3086,11 @@ def is_builtin_constant(obj) -> bool: return id(obj) in _builtin_constant_ids +def is_polyfilled_callable(obj) -> bool: + # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids + return id(obj) in _polyfilled_function_ids + + def is_numpy(obj) -> bool: if np is None: return False @@ -3533,6 +3543,8 @@ def lookup_callable(obj): return SkipFunctionVariable if is_callable_allowed(obj): return TorchInGraphFunctionVariable + if is_polyfilled_callable(obj): + return PolyfilledFunctionVariable if is_builtin_callable(obj): return BuiltinVariable return None diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index b6ff05e429d12c..73792d5eca7e90 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1605,7 +1605,10 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): if start is self._SENTINEL: start = variables.ConstantVariable.create(0) items = seq.force_unpack_var_sequence(tx) - return BuiltinVariable(functools.reduce).call_function( + + from .builder import SourcelessBuilder + + return SourcelessBuilder.create(tx, functools.reduce).call_function( tx, [ BuiltinVariable(operator.add), @@ -1615,19 +1618,6 @@ def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): {}, ) - def call_reduce( - self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL - ): - if iterable.has_force_unpack_var_sequence(tx): - items = iterable.force_unpack_var_sequence(tx) - if initial is self._SENTINEL: - value, items = items[0], items[1:] - else: - value = initial - for element in items: - value = function.call_function(tx, [value, element], {}) - return value - def call_getattr( self, tx: "InstructionTranslator", From 9717767c0f97ac0391eef5aaa4b931fe613ec674 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 14 Sep 2024 18:52:15 +0800 Subject: [PATCH 0926/1018] [dynamo] simplify implementation for `builtins.sum` (#133779) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133779 Approved by: https://github.com/jansel, https://github.com/anijain2305 ghstack dependencies: #133778 --- torch/_dynamo/polyfills/builtins.py | 8 +++++++ torch/_dynamo/variables/builtin.py | 34 ----------------------------- 2 files changed, 8 insertions(+), 34 deletions(-) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index 737b7c13809183..62305086b804fe 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -5,6 +5,8 @@ from __future__ import annotations import builtins +import functools +import operator from typing import Iterable, TypeVar from ..decorators import substitute_in_graph @@ -14,6 +16,7 @@ "all", "any", "enumerate", + "sum", ] @@ -46,3 +49,8 @@ def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T] for x in iterable: yield start, x start += 1 + + +@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] +def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] + return functools.reduce(operator.add, iterable, start) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 73792d5eca7e90..296eb646187c12 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1584,40 +1584,6 @@ def call_filter(self, tx: "InstructionTranslator", fn, seq): except NotImplementedError: return - def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): - # Special case for sum on tuple of floats and ints - if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( - isinstance(x, variables.ConstantVariable) - and isinstance(x.value, (int, float)) - for x in seq.items - ): - if start is self._SENTINEL: - return variables.ConstantVariable.create( - sum(x.value for x in seq.items), - ) - if isinstance(start, variables.ConstantVariable) and isinstance( - start.value, (int, float) - ): - return variables.ConstantVariable.create( - sum((x.value for x in seq.items), start=start.value), - ) - if seq.has_force_unpack_var_sequence(tx): - if start is self._SENTINEL: - start = variables.ConstantVariable.create(0) - items = seq.force_unpack_var_sequence(tx) - - from .builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, functools.reduce).call_function( - tx, - [ - BuiltinVariable(operator.add), - variables.TupleVariable(items), - start, - ], - {}, - ) - def call_getattr( self, tx: "InstructionTranslator", From 2e23a4b6ccb8ba763033d53502193841be20cff9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 16 Sep 2024 09:11:16 +0000 Subject: [PATCH 0927/1018] Revert "[BE]: Update mypy to 1.11.2 (#133816)" This reverts commit 55299cfc223fa838aadd8d6d6fa3ed541fa5acd1. Reverted https://github.com/pytorch/pytorch/pull/133816 on behalf of https://github.com/jeanschmidt due to seems to have broken https://github.com/pytorch/pytorch/actions/runs/10865710499/job/30155699792 on main ([comment](https://github.com/pytorch/pytorch/pull/133816#issuecomment-2352377684)) --- .ci/docker/requirements-ci.txt | 2 +- .lintrunner.toml | 2 +- torch/_dynamo/repro/after_aot.py | 2 +- torch/_dynamo/repro/after_dynamo.py | 9 +- torch/_dynamo/resume_execution.py | 14 +-- torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/utils.py | 2 - torch/_export/passes/constant_folding.py | 2 +- torch/_higher_order_ops/flex_attention.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 3 +- .../codegen/cuda_combined_scheduling.py | 2 +- torch/_inductor/dependencies.py | 10 +- torch/_inductor/kernel/mm_common.py | 4 +- torch/_inductor/remote_cache.py | 4 +- torch/_inductor/runtime/triton_heuristics.py | 4 +- torch/_inductor/scheduler.py | 4 +- torch/_library/infer_schema.py | 2 +- torch/_subclasses/functional_tensor.py | 4 +- torch/_tensor.py | 2 +- torch/ao/nn/intrinsic/modules/fused.py | 4 +- .../ao/nn/intrinsic/qat/modules/conv_fused.py | 93 ++++++++++--------- .../intrinsic/quantized/modules/conv_add.py | 4 +- torch/ao/nn/qat/modules/conv.py | 18 ++-- torch/ao/nn/quantized/dynamic/modules/conv.py | 25 +++-- torch/ao/nn/quantized/modules/conv.py | 41 ++++---- torch/ao/ns/_numeric_suite.py | 2 +- torch/ao/ns/_numeric_suite_fx.py | 2 +- torch/ao/ns/fx/mappings.py | 6 +- .../data_sparsifier/base_data_sparsifier.py | 6 +- .../data_sparsifier/data_norm_sparsifier.py | 2 +- .../_experimental/pruner/FPGM_pruner.py | 4 +- .../sparsifier/nearly_diagonal_sparsifier.py | 4 +- .../sparsifier/weight_norm_sparsifier.py | 2 +- torch/ao/quantization/__init__.py | 2 +- torch/ao/quantization/_correct_bias.py | 2 +- .../ao/quantization/experimental/observer.py | 2 +- .../quantizer/x86_inductor_quantizer.py | 2 +- torch/distributed/_functional_collectives.py | 4 +- .../distributed/_shard/sharded_tensor/api.py | 2 +- .../algorithms/_comm_hooks/default_hooks.py | 2 - .../optim/post_localSGD_optimizer.py | 2 +- torch/distributed/pipelining/_IR.py | 5 +- torch/distributed/tensor/_api.py | 2 +- torch/distributed/tensor/_shards_wrapper.py | 2 +- .../tensor/parallel/input_reshard.py | 1 - torch/distributions/constraints.py | 2 +- torch/distributions/von_mises.py | 2 +- torch/masked/maskedtensor/core.py | 4 +- torch/multiprocessing/reductions.py | 2 +- torch/nested/_internal/nested_tensor.py | 2 +- torch/nn/attention/bias.py | 2 +- torch/optim/lr_scheduler.py | 4 +- torch/sparse/semi_structured.py | 4 +- torch/utils/_sympy/functions.py | 22 ++--- torch/utils/weak.py | 2 +- 56 files changed, 180 insertions(+), 185 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index bd48d696c9230b..88ecc0d1a7edf3 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -90,7 +90,7 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==1.11.2 +mypy==1.10.0 # Pin MyPy version because new errors are likely to appear with each release #Description: linter #Pinned versions: 1.10.0 diff --git a/.lintrunner.toml b/.lintrunner.toml index 9b43b382809212..4789991736be17 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -139,7 +139,7 @@ init_command = [ 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.2.1', - 'mypy==1.11.2', + 'mypy==1.10.0', 'sympy==1.12.1 ; python_version == "3.8"', 'sympy==1.13.0 ; python_version >= "3.9"', 'types-requests==2.27.25', diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 19aa69d084db36..a21a20cdf2e611 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -521,7 +521,7 @@ def repro_common(options, mod, load_args): def repro_minifier_query(options, mod, load_args): mod, args = repro_common(options, mod, load_args) fail_fn = functools.partial( - ACCURACY_FAILS[options.accuracy], check_str=options.check_str # type: ignore[call-arg] + ACCURACY_FAILS[options.accuracy], check_str=options.check_str ) if fail_fn(mod, args): sys.exit(1) diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 214480b02c9332..b1bb950d11c11b 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -12,7 +12,6 @@ import torch import torch.fx as fx -from torch._dynamo.backends.registry import CompiledFn from torch._dynamo.debug_utils import ( AccuracyError, backend_accuracy_fails, @@ -272,10 +271,8 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -@register_debug_backend # type: ignore[arg-type] -def dynamo_minifier_backend( - gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn -): +@register_debug_backend +def dynamo_minifier_backend(gm, example_inputs, compiler_name): from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) @@ -314,7 +311,7 @@ def dynamo_minifier_backend( return gm -@register_debug_backend # type: ignore[arg-type] +@register_debug_backend def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): from functorch.compile import minifier diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 2013b3992eb54f..6f7db66514c487 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -471,14 +471,14 @@ def generate( code, lineno, offset: int, - setup_fn_target_offsets: Tuple[int, ...], # only used in Python 3.11+ + setup_fn_target_offsets: Tuple[int], # only used in Python 3.11+ nstack: int, - argnames: Tuple[str, ...], - argnames_null: Tuple[str, ...], - setup_fns: Tuple[ReenterWith, ...], - stack_ctx_vars: Tuple[Tuple[int, Tuple[Any]], ...], - argnames_ctx_vars: Tuple[Tuple[str, Tuple[Any]], ...], - null_idxes: Tuple[int, ...], + argnames: Tuple[str], + argnames_null: Tuple[str], + setup_fns: Tuple[ReenterWith], + stack_ctx_vars: Tuple[int, Tuple[Any]], + argnames_ctx_vars: Tuple[str, Tuple[Any]], + null_idxes: Tuple[int], ) -> types.CodeType: assert offset is not None assert not ( diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 9febc69f42cda2..c44f8d9e69abf5 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -272,7 +272,7 @@ def guard_source(self): def name(self): return f"" - def make_guard(self, fn): + def make_guard(self): raise NotImplementedError def is_ephemeral(self): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 4a83b455a1252b..61acaec995757a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3198,7 +3198,7 @@ def fake_mode(self): def run_ctx_mgr(self): return TracingContext.current_frame(self.parent.frame_summary()) - def STORE_DEREF(self, inst): # type: ignore[override] + def STORE_DEREF(self, inst): if inst.argval in self.closure_cells: cell = self.closure_cells[inst.argval] val = self.pop() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index bed3564ec15ce0..2d4a771166cfbe 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2915,8 +2915,6 @@ def is_frozen_dataclass(value): not object_has_getattribute(value) and not class_has_getattribute(value) and is_dataclass(value) - and hasattr(value, "__dataclass_params__") - and hasattr(value.__dataclass_params__, "frozen") and value.__dataclass_params__.frozen ) diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index 083cd69a970df6..b1491ca5d47946 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -195,7 +195,7 @@ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: self.node_replacements[node] = tensor - def run(self): # type: ignore[override] + def run(self): env = {} for n in self.module.graph.find_nodes(op="placeholder"): env[n] = self.unknown_value diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index ed99ae7beca12a..b28f657b9d5d8f 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -77,7 +77,7 @@ class TransformGetItemToIndex(TorchFunctionMode): # scalar and create a view. We do not want that behavior in this case, so we # use this torchfunctionmode to override that behavior for score_mod # wherever we're running it. - def __torch_function__(self, func, types, args=(), kwargs=None): + def __torch_function__(self, func, types, args, kwargs=None): if func == torch.Tensor.__getitem__: index_args = pytree.tree_leaves(args[1]) if all(isinstance(x, torch.Tensor) for x in index_args): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 4e88a6f91f7e22..c009246243f3cf 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1480,8 +1480,7 @@ def generate_profiler_mark_wrapper_call(self, stack): 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' ) - @cache_on_self - def write_triton_header_once(self) -> None: + def write_triton_header_once(self): pass def generate_start_graph(self): diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 9273eeb5f9824f..90971045379656 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -30,7 +30,7 @@ def __init__(self, scheduler: Scheduler) -> None: self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) - def get_backend_features(self, device): # type:ignore[override] + def get_backend_features(self, device): return self._triton_scheduling.get_backend_features(device) def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 643bf686bd8f19..95a48281977c87 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -578,22 +578,22 @@ def extract_read_writes( repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)} name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} for entry in fn.memory_usage[MemoryUsageType.LOAD]: - inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type] + inner.load(entry.buffer_name, name_to_index[entry.index_name]) for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: - inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type] + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) for entry in fn.memory_usage[MemoryUsageType.STORE]: inner.store( - entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type] + entry.buffer_name, name_to_index[entry.index_name], None, entry.mode ) for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: inner.store_reduction( - entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type] + entry.buffer_name, name_to_index[entry.index_name], None ) for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: inner.index_expr(name_to_index[entry.index_name], None) for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: inner.bucketize( - None, entry.buffer_name, name_to_index[entry.index_name], None, None # type: ignore[arg-type] + None, entry.buffer_name, name_to_index[entry.index_name], None, None ) # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped else: diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 21ba6c1e215dbe..0c66da4ca4f314 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -2,7 +2,7 @@ import functools import itertools import logging -from typing import cast, Sequence, Tuple +from typing import cast, List, Tuple import sympy @@ -28,7 +28,7 @@ def filtered_configs( m: int, n: int, k: int, - configs: Sequence[Tuple[int, int, int, int, int]], + configs: List[Tuple[int, int, int, int, int]], has_int8_tensor=False, ): """Heuristic to shrink configs when they are bigger than the input size""" diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 056e24f1b4e26c..14963a1bf5d435 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -101,8 +101,8 @@ def put(self, key: str, value: _T) -> None: self._put(key, value, sample) self._log_sample(sample) - def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] - return self.serde.decode(data) # type: ignore[arg-type] + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: + return self.serde.decode(data) def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U return self.serde.encode(value) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index d011acdba1310d..e76e77b3cccf89 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -830,7 +830,7 @@ def benchmark_one_config(config): ) return config2launcher.get(best_config) - def run(self, *args, grid, stream, **kwargs): # type:ignore[override] + def run(self, *args, grid, stream, **kwargs): if len(self.launchers) != 1: if len(self.launchers) == 0: start_time = time.time_ns() @@ -961,7 +961,7 @@ def __init__(self, *args, regex_filter="", with_profiler=False, **kwargs): super().__init__(*args, **kwargs) self.cached = None - def run(self, *args, grid, stream): # type: ignore[override] + def run(self, *args, grid, stream): possible_names = _find_names(self) kernel_name = f"{max(possible_names, key=len)}" if not re.match(self.regex_filter, kernel_name): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 480ee83cef4536..0c3e69efa5eecf 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2857,7 +2857,9 @@ def has_shared_data_after_reordering_loop( return False # Pick the largest buffer to guide the loop reordering - numel, lhs_dep, rhs_dep = max(candidates, key=lambda x: x[0]) + numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[ + 0 + ] if lhs_dep.num_vars != rhs_dep.num_vars: # this can happen due to we don't merge loops. diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 9845a64874803a..b2eeb24521d382 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -268,4 +268,4 @@ def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.L elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type] return typing.List[type_args[0]] # type: ignore[valid-type] else: - return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value] + return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc] diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index cd5bfa655ea0b5..bec1ebc559c328 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -221,7 +221,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" ) - def __repr__(self) -> str: # type: ignore[override] + def __repr__(self): return f"FunctionalTensor({repr(self.elem)})" @staticmethod @@ -302,7 +302,7 @@ def cuda(self, device=None, *args, **kwargs): long = _conversion_method_template(dtype=torch.int64) # TODO(sparse-team): fixes #133174 but can we do without the relay? - def to_dense(self): # type: ignore[override] + def to_dense(self): return self.elem.to_dense() @property diff --git a/torch/_tensor.py b/torch/_tensor.py index e22c6a92d92d0f..7d8010081abc77 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1353,7 +1353,7 @@ def align_to(self, *names): [name for name in names if not is_ellipsis(name)], ellipsis_idx ) - def unflatten(self, dim, sizes): # type: ignore[override] + def unflatten(self, dim, sizes): r""" unflatten(dim, sizes) -> Tensor diff --git a/torch/ao/nn/intrinsic/modules/fused.py b/torch/ao/nn/intrinsic/modules/fused.py index c7ae2ce3319d3c..010b9a701a3353 100644 --- a/torch/ao/nn/intrinsic/modules/fused.py +++ b/torch/ao/nn/intrinsic/modules/fused.py @@ -228,7 +228,7 @@ def __init__(self, conv, add): super().__init__(conv) self.add = add - def forward(self, x1, x2): # type: ignore[override] + def forward(self, x1, x2): return self.add(self[0](x1), x2) @@ -241,5 +241,5 @@ def __init__(self, conv, add, relu): self.add = add self.relu = relu - def forward(self, x1, x2): # type: ignore[override] + def forward(self, x1, x2): return self.relu(self.add(self[0](x1), x2)) diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 6ceee89757ca75..3c700b30f24720 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import math -from typing import ClassVar, Optional, Type +from typing import TypeVar import torch import torch.ao.nn.intrinsic as nni @@ -33,9 +33,12 @@ } +MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) + + class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): _version = 2 - _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] + _FLOAT_MODULE = MOD def __init__( self, @@ -362,7 +365,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert mod.qconfig, "Input float module must have a valid qconfig" qconfig = mod.qconfig - conv, bn = mod[0], mod[1] # type: ignore[index] + conv, bn = mod[0], mod[1] qat_convbn = cls( conv.in_channels, conv.out_channels, @@ -431,7 +434,7 @@ def to_float(self): return conv -class ConvBn1d(_ConvBnNd, nn.Conv1d): # type: ignore[misc] +class ConvBn1d(_ConvBnNd, nn.Conv1d): r""" A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, attached with FakeQuantize modules for weight, @@ -448,10 +451,10 @@ class ConvBn1d(_ConvBnNd, nn.Conv1d): # type: ignore[misc] weight_fake_quant: fake quant module for weight """ - _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment,misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE = nn.BatchNorm1d + _FLOAT_RELU_MODULE: None = None + _FLOAT_MODULE = nni.ConvBn1d + _FLOAT_CONV_MODULE = nn.Conv1d def __init__( self, @@ -517,12 +520,12 @@ class ConvBnReLU1d(ConvBn1d): """ # base class defines _FLOAT_MODULE as "ConvBn1d" - _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBnReLU1d # type: ignore[assignment,misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d - _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU + _FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment] + _FLOAT_CONV_MODULE = nn.Conv1d + _FLOAT_BN_MODULE = nn.BatchNorm1d + _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d + _FUSED_FLOAT_MODULE = nni.ConvReLU1d def __init__( self, @@ -582,10 +585,10 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment, misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d - _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU + _FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment] + _FLOAT_CONV_MODULE = nn.Conv1d + _FLOAT_BN_MODULE: None = None + _FLOAT_RELU_MODULE = nn.ReLU def __init__( self, @@ -628,7 +631,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) -class ConvBn2d(_ConvBnNd, nn.Conv2d): # type: ignore[misc] +class ConvBn2d(_ConvBnNd, nn.Conv2d): r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, @@ -645,10 +648,10 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d): # type: ignore[misc] weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment,misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d - _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm2d - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE = nni.ConvBn2d + _FLOAT_CONV_MODULE = nn.Conv2d + _FLOAT_BN_MODULE = nn.BatchNorm2d + _FLOAT_RELU_MODULE: None = None def __init__( self, @@ -714,12 +717,12 @@ class ConvBnReLU2d(ConvBn2d): """ # base class defines _FLOAT_MODULE as "ConvBn2d" - _FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment, misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d - _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm2d]] = nn.BatchNorm2d - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU # type: ignore[assignment,misc] + _FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment] + _FLOAT_CONV_MODULE = nn.Conv2d + _FLOAT_BN_MODULE = nn.BatchNorm2d + _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU2d]]] = nni.ConvReLU2d + _FUSED_FLOAT_MODULE = nni.ConvReLU2d def __init__( self, @@ -779,10 +782,10 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment, misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d - _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU + _FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment] + _FLOAT_CONV_MODULE = nn.Conv2d + _FLOAT_BN_MODULE: None = None + _FLOAT_RELU_MODULE = nn.ReLU def __init__( self, @@ -825,7 +828,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) -class ConvBn3d(_ConvBnNd, nn.Conv3d): # type: ignore[misc] +class ConvBn3d(_ConvBnNd, nn.Conv3d): r""" A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, attached with FakeQuantize modules for weight, @@ -842,10 +845,10 @@ class ConvBn3d(_ConvBnNd, nn.Conv3d): # type: ignore[misc] weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment,misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d - _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm3d - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE = nni.ConvBn3d + _FLOAT_CONV_MODULE = nn.Conv3d + _FLOAT_BN_MODULE = nn.BatchNorm3d + _FLOAT_RELU_MODULE: None = None def __init__( self, @@ -910,12 +913,12 @@ class ConvBnReLU3d(ConvBn3d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment, misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d - _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm3d]] = nn.BatchNorm3d - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.ReLU]]] = nn.ReLU # type: ignore[assignment, misc] + _FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment] + _FLOAT_CONV_MODULE = nn.Conv3d + _FLOAT_BN_MODULE = nn.BatchNorm3d + _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU3d]]] = nni.ConvReLU3d + _FUSED_FLOAT_MODULE = nni.ConvReLU3d def __init__( self, @@ -977,10 +980,10 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment,misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d - _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU + _FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment] + _FLOAT_CONV_MODULE = nn.Conv3d + _FLOAT_BN_MODULE: None = None + _FLOAT_RELU_MODULE = nn.ReLU def __init__( self, diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py index 0d1b7e01f4479f..66299394091b7c 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -49,7 +49,7 @@ def __init__( dtype=dtype, ) - def forward(self, input, extra_input): # type: ignore[override] + def forward(self, input, extra_input): # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: @@ -117,7 +117,7 @@ def __init__( dtype=dtype, ) - def forward(self, input, extra_input): # type: ignore[override] + def forward(self, input, extra_input): # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 49b559136bbc6e..3be5a7cec8162e 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import ClassVar, Tuple, Type, Union +from typing import Tuple, TypeVar, Union import torch import torch.nn as nn @@ -10,9 +10,11 @@ __all__ = ["Conv1d", "Conv2d", "Conv3d"] +MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) + class _ConvNd(nn.modules.conv._ConvNd): - _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] + _FLOAT_MODULE = MOD def __init__( self, @@ -134,8 +136,8 @@ class Conv1d(_ConvNd, nn.Conv1d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d # type: ignore[assignment,misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _FLOAT_MODULE = nn.Conv1d + _FLOAT_CONV_MODULE = nn.Conv1d def __init__( self, @@ -195,8 +197,8 @@ class Conv2d(_ConvNd, nn.Conv2d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d # type: ignore[assignment,misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_MODULE = nn.Conv2d + _FLOAT_CONV_MODULE = nn.Conv2d def __init__( self, @@ -259,8 +261,8 @@ class Conv3d(_ConvNd, nn.Conv3d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d # type: ignore[assignment,misc] - _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_MODULE = nn.Conv3d + _FLOAT_CONV_MODULE = nn.Conv3d def __init__( self, diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index f379e487f65de8..fa86dcdd93e17d 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -2,7 +2,6 @@ r"""Dynamically quantized convolution modules.""" import warnings -from typing import ClassVar, Optional, Type import torch import torch.ao.nn.quantized as nnq @@ -48,9 +47,9 @@ class Conv1d(nnq.Conv1d): """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d - _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE = nn.Conv1d + _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] + _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] def __init__( self, @@ -133,9 +132,9 @@ class Conv2d(nnq.Conv2d): >>> output = m(input) """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d - _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE = nn.Conv2d + _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] + _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] def __init__( self, @@ -217,9 +216,9 @@ class Conv3d(nnq.Conv3d): >>> output = m(input) """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d - _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE = nn.Conv3d + _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] + _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] def __init__( self, @@ -309,7 +308,7 @@ class ConvTranspose1d(nnq.ConvTranspose1d): torch.Size([1, 16, 12]) """ - _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d + _FLOAT_MODULE = nn.ConvTranspose1d def __init__( self, @@ -391,7 +390,7 @@ class ConvTranspose2d(nnq.ConvTranspose2d): torch.Size([1, 16, 12, 12]) """ - _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d + _FLOAT_MODULE = nn.ConvTranspose2d def __init__( self, @@ -473,7 +472,7 @@ class ConvTranspose3d(nnq.ConvTranspose3d): torch.Size([1, 16, 12, 12, 12]) """ - _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d + _FLOAT_MODULE = nn.ConvTranspose3d def __init__( self, diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 0ef2f3af1dcd44..48ca18028acf94 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Quantized convolution modules.""" -from typing import ClassVar, List, Optional, Type +from typing import List, Optional, TypeVar import torch import torch.ao.nn.intrinsic as nni @@ -386,11 +386,11 @@ class Conv1d(_ConvNd): """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d - _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn1d - _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d - _NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE = nn.Conv1d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE = nni.ConvReLU1d + _NNI_CONV_ADD_MODULE: None = None + _NNI_CONV_ADD_RELU_MODULE: None = None def __init__( self, @@ -518,11 +518,11 @@ class Conv2d(_ConvNd): >>> output = m(q_input) """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d - _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn2d - _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU2d - _NNI_CONV_ADD_MODULE: ClassVar[Type[nni.ConvAdd2d]] = nni.ConvAdd2d - _NNI_CONV_ADD_RELU_MODULE: ClassVar[Type[nni.ConvAddReLU2d]] = nni.ConvAddReLU2d + _FLOAT_MODULE = nn.Conv2d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE = nni.ConvReLU2d + _NNI_CONV_ADD_MODULE = nni.ConvAdd2d + _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d def __init__( self, @@ -647,11 +647,11 @@ class Conv3d(_ConvNd): >>> output = m(q_input) """ - _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d - _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn3d - _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU3d - _NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None - _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE = nn.Conv3d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d + _NNI_CONV_RELU_MODULE = nni.ConvReLU3d + _NNI_CONV_ADD_MODULE: None = None + _NNI_CONV_ADD_RELU_MODULE: None = None def __init__( self, @@ -740,10 +740,11 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): # === Transposed Convolutions === +MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) class _ConvTransposeNd(_ConvNd): - _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] + _FLOAT_MODULE = MOD def __init__( self, @@ -913,7 +914,7 @@ class ConvTranspose1d(_ConvTransposeNd): torch.Size([1, 16, 12]) """ - _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d + _FLOAT_MODULE = nn.ConvTranspose1d def __init__( self, @@ -1036,7 +1037,7 @@ class ConvTranspose2d(_ConvTransposeNd): torch.Size([1, 16, 12, 12]) """ - _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d + _FLOAT_MODULE = nn.ConvTranspose2d def __init__( self, @@ -1161,7 +1162,7 @@ class ConvTranspose3d(_ConvTransposeNd): torch.Size([1, 16, 12, 12, 12]) """ - _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d + _FLOAT_MODULE = nn.ConvTranspose3d def __init__( self, diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 9e4cf94303c9f1..9a80b4af6f4794 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -198,7 +198,7 @@ def __init__(self): self.stats["float"] = [] self.stats["quantized"] = [] - def forward(self, x, y): # type: ignore[override] + def forward(self, x, y): # fmt: off """ """ # blank docblock to make autodoc happy diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index da79b03ce88089..3cea146b8031fd 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -261,7 +261,7 @@ def __init__(self, *args, **kwargs): self.comparisons = [] # precalculated comparisons function - def forward(self, x, x_ref): # type: ignore[override] + def forward(self, x, x_ref): # fmt: off """ """ # blank docblock to make autodoc happy diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index 56784dd1780308..875af8c875b0a4 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -410,7 +410,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: # Add function swaps from default lowering path # - for source, ( # type:ignore[assignment] + for source, ( target1, target2, ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): @@ -422,7 +422,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: _lower_to_native_backend.QBIN_RELU_OP_MAPPING, quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, ): - for source, target in source_to_target.items(): # type:ignore[assignment] + for source, target in source_to_target.items(): new_connections.append((source, target)) # @@ -432,7 +432,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: for source_to_target in ( quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, ): - for source, target in source_to_target.items(): # type:ignore[assignment] + for source, target in source_to_target.items(): new_connections.append((source, target)) # add the new connections from backend_config diff --git a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py index 75d86737d832d5..b81ee658e55318 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -71,7 +71,7 @@ def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults # add data with default config here [self.add_data(name, data, **self.defaults) for name, data in data_list] - def prepare(self, model, config): + def prepare(self): raise NotImplementedError("this function is undefined for this class") def _extract_weight(self, data): @@ -266,7 +266,7 @@ def __getstate__(self): "_container": self._container.state_dict(), } - def __repr__(self): # type:ignore[override] + def __repr__(self): format_string = self.__class__.__name__ + " (" for name, sparse_args in self.data_groups.items(): format_string += "\n" @@ -299,7 +299,7 @@ def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): self._container, name, leave_parametrized=leave_parametrized ) - def step(self): # type:ignore[override] + def step(self): if not self.enable_mask_update: return with torch.no_grad(): diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index ec867205f64ab1..34444fa79df0fe 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -156,7 +156,7 @@ def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): ] # squeeze only the first 2 dimension return mask - def update_mask( # type: ignore[override] + def update_mask( self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs ): values_per_block = reduce(operator.mul, sparse_block_shape) diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index 3da27ba38df55b..54177272ce4f2a 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -75,9 +75,7 @@ def _compute_distance(self, t): return distance - def update_mask( # type: ignore[override] - self, module, tensor_name, sparsity_level, **kwargs - ): + def update_mask(self, module, tensor_name, sparsity_level, **kwargs): tensor_weight = getattr(module, tensor_name) mask = getattr(module.parametrizations, tensor_name)[0].mask diff --git a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py index a4d42ea803289c..47c567ec6d93eb 100644 --- a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py +++ b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -33,9 +33,7 @@ def __init__(self, nearliness: int = 1): defaults = {"nearliness": nearliness} super().__init__(defaults=defaults) - def update_mask( # type:ignore[override] - self, module, tensor_name, nearliness, **kwargs - ): + def update_mask(self, module, tensor_name, nearliness, **kwargs): mask = getattr(module.parametrizations, tensor_name)[0].mask mask.data = torch.zeros_like(mask) if nearliness <= 0: diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index a25b7ffbca61cb..3c2eef0dd15c79 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -208,7 +208,7 @@ def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None) mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() return mask - def update_mask( # type: ignore[call-override, override] + def update_mask( self, module, tensor_name, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 503755d383bb93..840d2715c7717f 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -216,5 +216,5 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return x - def calculate_qparams(self): # type:ignore[override] + def calculate_qparams(self): return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/torch/ao/quantization/_correct_bias.py b/torch/ao/quantization/_correct_bias.py index 03d259a0ad26ca..4be37467de1818 100644 --- a/torch/ao/quantization/_correct_bias.py +++ b/torch/ao/quantization/_correct_bias.py @@ -61,7 +61,7 @@ def __init__(self): self.float_sum = None self.quant_sum = None - def forward(self, x, y): # type: ignore[override] + def forward(self, x, y): """Compute the average of quantized and floating-point data from modules. The inputs x,y are output data from the quantized and floating-point modules. diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 1cf27c97054c1e..631ec11b1669d1 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -34,7 +34,7 @@ def __init__(self, b, k, dtype=torch.quint8) -> None: # min_val and max_val are optional args to override # the min_val and max_val observed by forward - def calculate_qparams(self, signed): # type:ignore[override] + def calculate_qparams(self, signed): return self._calculate_qparams(signed, self.min_val, self.max_val) r""" Calculates nonuniform quantization parameters according to APoT paper: diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 6042fd2ee5adb9..574af30a7159b9 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -225,7 +225,7 @@ def _map_module_function_to_aten_operator_type(): ), ) for map_item in map_list: - module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] + module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload] return module_function_to_aten_operator diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 4127885dccc1f6..77b962bf3df47c 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -600,7 +600,7 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): elem = inner_tensors["elem"] return AsyncCollectiveTensor(elem) - def __repr__(self) -> str: # type: ignore[override] + def __repr__(self): return f"AsyncCollectiveTensor({self.trigger_wait()})" def trigger_wait(self): @@ -653,7 +653,7 @@ def wrap(e: torch.Tensor): return out - def numpy(self): # type: ignore[override] + def numpy(self): return self.wait().numpy() diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index d50160ca8ecc31..68df582cd5145e 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1200,7 +1200,7 @@ def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: def __hash__(self): return id(self) - def __repr__(self) -> str: # type: ignore[override] + def __repr__(self): return f"ShardedTensor({self._metadata})" @dataclass diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index 032b1979e576ce..0acafd6868d3bb 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -168,7 +168,6 @@ def fp16_compress_hook( output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. """ fp16_hook = functools.partial(_low_precision_hook, torch.float16) - assert output is not None return fp16_hook(state, grad, output) @@ -190,5 +189,4 @@ def bf16_compress_hook( output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. """ bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16) - assert output is not None return bf16_hook(state, grad, output) diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index 3c0027d1124073..db65856e32adad 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -96,7 +96,7 @@ def load_state_dict(self, state_dict): ) self.averager.step = 0 - def step(self): # type: ignore[override] + def step(self): r""" Performs a single optimization step (parameter update). """ diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 3010bccd377c94..036ad6cf8621eb 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -364,7 +364,7 @@ def __init__(self, module, garbage_collect_values=True): super().__init__(module, garbage_collect_values) self.value_remap = {} - def run(self, *args, initial_env=None): # type: ignore[override] + def run(self, *args, initial_env=None): self.value_remap = {} return super().run(*args, initial_env=initial_env) @@ -932,7 +932,8 @@ def move_param_to_callee( if node.op == "get_attr": # get_attr might get access deeper level attribute fqn = scope + "." + node.target if scope else node.target - unused_attributes.discard(fqn) + if fqn in unused_attributes: # used, remove it + unused_attributes.remove(fqn) for _name, _submod in _mod.named_children(): stack.append((scope + "." + _name if scope else _name, _submod)) # delete unused attributes diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 8684d7b0cafa0f..583aa20c1dc188 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -283,7 +283,7 @@ def __new__( # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. # pyre-fixme[3]: Return type must be annotated. - def __repr__(self): # type: ignore[override] + def __repr__(self): # TODO: consider all_gather the local tensors for better debugging return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" diff --git a/torch/distributed/tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py index df8c7d09e38a4a..de396473b77c91 100644 --- a/torch/distributed/tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -309,7 +309,7 @@ def __hash__(self): # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. # pyre-fixme[3]: Return type must be annotated. - def __repr__(self) -> str: # type: ignore[override] + def __repr__(self): return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" def __str__(self) -> str: diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 401ca3d3fdaf69..630c287cae88f5 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -42,7 +42,6 @@ def input_reshard( cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None: - assert input_reshard_dim is not None saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks( partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim), partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim), diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index e6f730b123afd1..3c510bd32abc62 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -204,7 +204,7 @@ def __init__( self._is_discrete = is_discrete self._event_dim = event_dim - def __call__(self, fn): # type: ignore[override] + def __call__(self, fn): """ Support for syntax to customize static attributes:: diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index a4d403383d9cb0..bd8fa87f2619a4 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -177,7 +177,7 @@ def sample(self, sample_shape=torch.Size()): self._loc, self._concentration, self._proposal_r, x ).to(self.loc.dtype) - def expand(self, batch_shape, _instance=None): + def expand(self, batch_shape): try: return super().expand(batch_shape) except NotImplementedError: diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 366cf45eb2d50f..69850df81c65a1 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -258,7 +258,7 @@ def _set_data_mask(self, data, mask): self._masked_mask = mask self._validate_members() - def __repr__(self): # type: ignore[override] + def __repr__(self): formatter = "{0:8.4f}" if self.dim() == 0: scalar_data = self.get_data().item() @@ -350,7 +350,7 @@ def get_mask(self): def is_sparse_coo(self): return self.layout == torch.sparse_coo - def is_sparse_csr(self): # type: ignore[override] + def is_sparse_csr(self): return self.layout == torch.sparse_csr # Update later to support more sparse layouts diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index c02ad3dc2362ba..fa0818571a93c0 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -74,7 +74,7 @@ def __init__(self) -> None: def _after_fork(self): self.lock = threading.Lock() - def get(self, key): # type: ignore[override] + def get(self, key): with self.lock: return dict.get(self, key) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index d39eb12d919c87..76f8d5a29e2d8f 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -209,7 +209,7 @@ def _max_seqlen(self): def _min_seqlen(self): return self._get_min_seqlen() - def __repr__(self): # type: ignore[override] + def __repr__(self): # We should implement this in torch/_tensor_str.py instead grad_fn_str = ( f", requires_grad={self.requires_grad}" if self.requires_grad else "" diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index da7acb957d96d7..2fed60030a6898 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -289,7 +289,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): ) return cls._dispatch(*args, **kwargs) - def __repr__(self): # type:ignore[override] + def __repr__(self): return self._materialize().__repr__() diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 820a57f3e0b756..57dcbd85a83164 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -910,7 +910,7 @@ def __init__( self._last_lr = schedulers[0].get_last_lr() - def step(self): # type: ignore[override] + def step(self): """Perform a step.""" self.last_epoch += 1 idx = bisect_right(self._milestones, self.last_epoch) @@ -1179,7 +1179,7 @@ def __init__( group["lr"] for group in self._schedulers[-1].optimizer.param_groups ] - def step(self): # type: ignore[override] + def step(self): """Perform a step.""" for scheduler in self._schedulers: scheduler.step() diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 5e8a632633e09a..0017a10e6771bb 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -295,7 +295,7 @@ def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: else: return dense_input - def to_dense(self): # type:ignore[override] + def to_dense(self): col = self.shape[-1] return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) @@ -420,7 +420,7 @@ def from_dense( requires_grad=original_tensor.requires_grad, ) - def to_dense(self): # type: ignore[override] + def to_dense(self): assert self.meta is not None and self.packed is not None return ( sparse_semi_structured_to_dense_cutlass( diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index ed597402d8b637..4924fff6dea6be 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -310,11 +310,11 @@ def eval(cls, base, divisor, modulus): if isinstance(base, FloorDiv): return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) - def _eval_is_nonnegative(self): # type:ignore[override] + def _eval_is_nonnegative(self): p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] - def _eval_is_positive(self): # type:ignore[override] + def _eval_is_positive(self): p, q = self.args[:2] return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined] @@ -329,14 +329,14 @@ class Where(sympy.Function): def _eval_is_integer(self): return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] - def _eval_is_nonnegative(self): # type:ignore[override] + def _eval_is_nonnegative(self): return ( True if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] else None ) - def _eval_is_positive(self): # type:ignore[override] + def _eval_is_positive(self): return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] @classmethod @@ -397,7 +397,7 @@ def eval(cls, p, q): return S.Zero # NB: args[1] for PythonMod - def _eval_is_nonnegative(self): # type:ignore[override] + def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] def _eval_is_nonpositive(self): @@ -823,13 +823,13 @@ class Max(MinMaxBase, Application): # type: ignore[misc] zero = S.Infinity identity = S.NegativeInfinity - def _eval_is_positive(self): # type:ignore[override] + def _eval_is_positive(self): return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined] - def _eval_is_nonnegative(self): # type:ignore[override] + def _eval_is_nonnegative(self): return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] - def _eval_is_negative(self): # type:ignore[override] + def _eval_is_negative(self): return fuzzy_and(a.is_negative for a in self.args) @@ -841,13 +841,13 @@ class Min(MinMaxBase, Application): # type: ignore[misc] zero = S.NegativeInfinity identity = S.Infinity - def _eval_is_positive(self): # type:ignore[override] + def _eval_is_positive(self): return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined] - def _eval_is_nonnegative(self): # type:ignore[override] + def _eval_is_nonnegative(self): return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] - def _eval_is_negative(self): # type:ignore[override] + def _eval_is_negative(self): return fuzzy_or(a.is_negative for a in self.args) diff --git a/torch/utils/weak.py b/torch/utils/weak.py index f729ff06489ffe..cc272a7f26375a 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -263,7 +263,7 @@ def pop(self, key, *args): def setdefault(self, key, default=None): return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED - def update(self, dict=None, **kwargs): # type: ignore[override] + def update(self, dict=None, **kwargs): d = self.data if dict is not None: if not hasattr(dict, "items"): From bade051cb2cf911beb7e2bcb550477e6da17e93b Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 16 Sep 2024 12:49:36 +0000 Subject: [PATCH 0928/1018] Add python 3.13.0t build to Docker images (#136001) Adds 3.13t python to Docker images Pull Request resolved: https://github.com/pytorch/pytorch/pull/136001 Approved by: https://github.com/albanD --- .ci/docker/common/install_cpython.sh | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/.ci/docker/common/install_cpython.sh b/.ci/docker/common/install_cpython.sh index 5e0c0f2bb2815d..caf0467c523f97 100755 --- a/.ci/docker/common/install_cpython.sh +++ b/.ci/docker/common/install_cpython.sh @@ -7,7 +7,7 @@ PYTHON_DOWNLOAD_GITHUB_BRANCH=https://github.com/python/cpython/archive/refs/hea GET_PIP_URL=https://bootstrap.pypa.io/get-pip.py # Python versions to be installed in /opt/$VERSION_NO -CPYTHON_VERSIONS=${CPYTHON_VERSIONS:-"3.8.1 3.9.0 3.10.1 3.11.0 3.12.0 3.13.0"} +CPYTHON_VERSIONS=${CPYTHON_VERSIONS:-"3.8.1 3.9.0 3.10.1 3.11.0 3.12.0 3.13.0 3.13.0t"} function check_var { if [ -z "$1" ]; then @@ -22,6 +22,13 @@ function do_cpython_build { check_var $py_ver check_var $py_folder tar -xzf Python-$py_ver.tgz + + local additional_flags="" + if [ "$py_ver" == "3.13.0t" ]; then + additional_flags=" --disable-gil" + mv cpython-3.13/ cpython-3.13t/ + fi + pushd $py_folder local prefix="/opt/_internal/cpython-${py_ver}" @@ -37,8 +44,10 @@ function do_cpython_build { local openssl_flags="--with-openssl=${WITH_OPENSSL} --with-openssl-rpath=auto" fi + + # -Wformat added for https://bugs.python.org/issue17547 on Python 2.6 - CFLAGS="-Wformat" ./configure --prefix=${prefix} ${openssl_flags} ${shared_flags} > /dev/null + CFLAGS="-Wformat" ./configure --prefix=${prefix} ${openssl_flags} ${shared_flags} ${additional_flags} > /dev/null make -j40 > /dev/null make install > /dev/null @@ -69,7 +78,14 @@ function build_cpython { check_var $py_ver check_var $PYTHON_DOWNLOAD_URL local py_ver_folder=$py_ver - if [ "$py_ver" = "3.13.0" ]; then + + if [ "$py_ver" = "3.13.0t" ]; then + PY_VER_SHORT="3.13" + PYT_VER_SHORT="3.13t" + check_var $PYTHON_DOWNLOAD_GITHUB_BRANCH + wget $PYTHON_DOWNLOAD_GITHUB_BRANCH/$PY_VER_SHORT.tar.gz -O Python-$py_ver.tgz + do_cpython_build $py_ver cpython-$PYT_VER_SHORT + elif [ "$py_ver" = "3.13.0" ]; then PY_VER_SHORT="3.13" check_var $PYTHON_DOWNLOAD_GITHUB_BRANCH wget $PYTHON_DOWNLOAD_GITHUB_BRANCH/$PY_VER_SHORT.tar.gz -O Python-$py_ver.tgz From 4ca094b7d88f9561d9d0a21a001134094e1bda4f Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Sat, 14 Sep 2024 10:43:55 -0700 Subject: [PATCH 0929/1018] [AOTI] Refactor how cpp_wrapper specific options are set (#136035) Summary: 1) When cpp-wrapper is turned on, certain triton specific options need to be set, both for forward and backward. This PR considate the settings in one place. 2) Change config.triton.autotune_at_compile_time to default to None. If the flag is not explicitly set by user, default it to True for cpp-wrapper. Differential Revision: [D62689940](https://our.internmc.facebook.com/intern/diff/D62689940) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136035 Approved by: https://github.com/chenyang78 --- torch/_inductor/compile_fx.py | 49 +++++++++++++++++++---------------- torch/_inductor/config.py | 3 ++- torch/_inductor/graph.py | 6 +---- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index f79f247db35fc9..763f7cab3e2cfa 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1246,6 +1246,18 @@ def wrapper(args): return wrapper +def get_cpp_wrapper_config(): + return { + # Set autotune_at_compile_time to True as default if the option is not explicitly set + "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time + if config.triton.autotune_at_compile_time is not None + else True, + "triton.autotune_cublasLt": False, + "triton.cudagraphs": False, # TODO: to be removed + "triton.store_cubin": True, + } + + def compile_fx( model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor], @@ -1268,18 +1280,8 @@ def compile_fx( if config.cpp_wrapper: with config.patch( { - "cpp_wrapper": False, - # For triton.autotune_at_compile_time, disable by default for - # FBCode, but enabled by default for OSS. - "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time - if config.is_fbcode() - else os.environ.get( - "TORCHINDUCTOR_TRITON_AUTOTUNE_AT_COMPILE_TIME", "1" - ) - == "1", - "triton.autotune_cublasLt": False, - "triton.cudagraphs": False, - "triton.store_cubin": True, + "cpp_wrapper": False, # reset to break recursive call to compile_fx + **get_cpp_wrapper_config(), } ), V.set_real_inputs(example_inputs_): inputs_ = example_inputs_ @@ -1470,16 +1472,19 @@ def bw_compiler( n.name for n in model_outputs if isinstance(n, torch.fx.Node) ) fixed = count_tangents(model) - return inner_compile( - model, - example_inputs, - static_input_idxs=list(range(fixed)), - cudagraphs=cudagraphs, - is_backward=True, - graph_id=graph_id, - boxed_forward_device_index=forward_device, - user_visible_outputs=user_visible_outputs, - ) + with config.patch( + get_cpp_wrapper_config() + ) if config.cpp_wrapper else contextlib.nullcontext(): + return inner_compile( + model, + example_inputs, + static_input_idxs=list(range(fixed)), + cudagraphs=cudagraphs, + is_backward=True, + graph_id=graph_id, + boxed_forward_device_index=forward_device, + user_visible_outputs=user_visible_outputs, + ) # TODO: can add logging before/after the call to create_aot_dispatcher_function # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 15c3446099229f..a6c9f0bfece500 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -905,7 +905,8 @@ class triton: autotune_cublasLt = True # Tune the generated Triton kernels at compile time instead of first time they run - autotune_at_compile_time = False + # Setting to None means uninitialized + autotune_at_compile_time: Optional[bool] = None # should we stop a fusion to allow better tiling? tiling_prevents_pointwise_fusion = True diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 989977897f9d92..6b44723a342def 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1719,11 +1719,7 @@ def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]: if any(device in self.device_types for device in ["cuda", "xpu"]): # first pass self.cpp_wrapper = False - # Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick - # that up. In theory it should work by only setting triton.store_cubin to True here, - # but that will cause a problem when use_runtime_constant_folding is OrderedSet. - with config.patch({"triton.store_cubin": True}): - compiled = self.compile_to_module().call + compiled = self.compile_to_module().call if not config.triton.autotune_at_compile_time: From ac9b029c2f46a5f58d2d22068d695022398f0245 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 16 Sep 2024 14:35:19 +0000 Subject: [PATCH 0930/1018] [reland][Inductor] Rename `cpp_wrapper_cuda.py` as `cpp_wrapper_gpu.py` (#136046) Summary: Reland https://github.com/pytorch/pytorch/pull/135313 after fixing internal build issues Test Plan: CI Differential Revision: D62658837 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136046 Approved by: https://github.com/chenyang78, https://github.com/etaf, https://github.com/jansel --- tools/amd_build/build_amd.py | 2 +- torch/_inductor/codegen/common.py | 2 +- .../codegen/{cpp_wrapper_cuda.py => cpp_wrapper_gpu.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename torch/_inductor/codegen/{cpp_wrapper_cuda.py => cpp_wrapper_gpu.py} (100%) diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 967f63ae2c8f15..60a1be73fbb2d1 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -207,7 +207,7 @@ def remove_hcc(line: str) -> str: ignores=ignores, extra_files=[ "torch/_inductor/codegen/cpp_wrapper_cpu.py", - "torch/_inductor/codegen/cpp_wrapper_cuda.py", + "torch/_inductor/codegen/cpp_wrapper_gpu.py", "torch/_inductor/codegen/wrapper.py", ], out_of_place_only=args.out_of_place_only, diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 3902d41b97ce3d..e653d75715248a 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -232,7 +232,7 @@ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): def init_backend_registration(): from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu - from .cpp_wrapper_cuda import CppWrapperGpu + from .cpp_wrapper_gpu import CppWrapperGpu from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .triton import TritonScheduling diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py similarity index 100% rename from torch/_inductor/codegen/cpp_wrapper_cuda.py rename to torch/_inductor/codegen/cpp_wrapper_gpu.py From c916df01ea15fdbb42e0e85c141f9530bcaf1abc Mon Sep 17 00:00:00 2001 From: Jon Janzen Date: Mon, 16 Sep 2024 15:32:17 +0000 Subject: [PATCH 0931/1018] Delete stable prototype (#135911) This project ended up going in an entirely different direction, so we can close out all this Pull Request resolved: https://github.com/pytorch/pytorch/pull/135911 Approved by: https://github.com/izaitsevfb, https://github.com/malfet --- .../sync_distributed_folder_prototype.sh | 35 ------------------- .../sync_distributed_folder_prototype.yml | 30 ---------------- 2 files changed, 65 deletions(-) delete mode 100755 .github/scripts/sync_distributed_folder_prototype.sh delete mode 100644 .github/workflows/sync_distributed_folder_prototype.yml diff --git a/.github/scripts/sync_distributed_folder_prototype.sh b/.github/scripts/sync_distributed_folder_prototype.sh deleted file mode 100755 index d31fef5c79c428..00000000000000 --- a/.github/scripts/sync_distributed_folder_prototype.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -set -eoux pipefail - -SYNC_BRANCH=pytorch-stable-prototype - -git config user.email "fake@example.com" -git config user.name "PyTorch Stable Bot" - -git fetch origin main -git fetch origin "$SYNC_BRANCH" -git checkout "$SYNC_BRANCH" - -# Using a hardcoded SHA here is a massive speedup as we can skip the entire history of the pytorch GitHub repo. -# This specific SHA was chosen as it was before the "branch point" of the stable branch -for SHA in $(git log ba3b05fdf37ddbc3c301294d6a560a816335e717..origin/main --pretty="%h" -- torch/distributed torch/csrc/distributed test/distributed test/cpp/c10d benchmarks/distributed) -do - # `git merge-base --is-ancestor` exits with code 0 if the given SHA is an ancestor, and non-0 otherwise - if git merge-base --is-ancestor $SHA HEAD || [[ $(git log --grep="(cherry picked from commit $SHA") ]] - then - echo "Skipping $SHA" - continue - fi - echo "Copying $SHA" - git cherry-pick -x "$SHA" -X theirs - git reset --soft HEAD~1 - git add torch/distributed torch/csrc/distributed test/distributed test/cpp/c10d benchmarks/distributed - git checkout . - git commit --reuse-message=HEAD@{1} - git clean -f -done - -if [[ "${WITH_PUSH}" == true ]]; then - git push -fi diff --git a/.github/workflows/sync_distributed_folder_prototype.yml b/.github/workflows/sync_distributed_folder_prototype.yml deleted file mode 100644 index 5b7d8e6b1f17f4..00000000000000 --- a/.github/workflows/sync_distributed_folder_prototype.yml +++ /dev/null @@ -1,30 +0,0 @@ -name: Sync Distributed Folder - -on: - #push: - # branches: - # - 'main' - # paths: - # - 'torch/distributed/**' - workflow_dispatch: - pull_request: - paths: - - '.github/scripts/sync_distributed_folder_prototype.sh' - - '.github/workflows/sync_distributed_folder_prototype.yml' - -env: - WITH_PUSH: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} - -permissions: - contents: write - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true - -jobs: - sync: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - run: .github/scripts/sync_distributed_folder_prototype.sh From 0f22d785d16068c246a8474cf3c9664a3e984eb0 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Mon, 16 Sep 2024 12:09:58 +0000 Subject: [PATCH 0932/1018] Add decomposition for squeeze_copy (#130941) * Extracted from #128416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130941 Approved by: https://github.com/amjames, https://github.com/eellison --- test/distributed/_tensor/test_dtensor_ops.py | 1 + ...asDecompTest.test_has_decomposition.expect | 6 ----- test/functorch/test_ops.py | 3 +++ test/functorch/test_vmap.py | 1 + test/test_mps.py | 1 + tools/autograd/gen_variable_type.py | 1 + torch/_decomp/__init__.py | 1 + torch/_refs/__init__.py | 2 ++ .../_internal/common_methods_invocations.py | 25 +++++++++++++++++++ 9 files changed, 35 insertions(+), 6 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index c6fd0e01e5ffd9..d76fccee216d1f 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -427,6 +427,7 @@ def wrapped(fn): xfail("special.xlog1py"), xfail("special.zeta"), xfail("squeeze", "multiple"), + xfail("squeeze_copy"), xfail("signal.windows.bartlett"), xfail("signal.windows.blackman"), xfail("signal.windows.cosine"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 357a5afefd4d3e..2ecc1aef4c3954 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1285,12 +1285,6 @@ aten::split_copy.Tensor_out aten::squeeze_ aten::squeeze_.dim aten::squeeze_.dims -aten::squeeze_copy -aten::squeeze_copy.dim -aten::squeeze_copy.dim_out -aten::squeeze_copy.dims -aten::squeeze_copy.dims_out -aten::squeeze_copy.out aten::sspaddmm.out aten::t_ aten::to_mkldnn diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 9b04fb618ce616..1c94fae00c98ce 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1417,6 +1417,7 @@ def test_vmapjvpall(self, device, dtype, op): xfail("masked.cumprod", ""), xfail("permute_copy"), xfail("renorm"), # hit vmap fallback, which is disabled + xfail("squeeze_copy"), xfail("t_copy"), xfail("transpose_copy"), xfail("unsqueeze_copy"), @@ -1484,6 +1485,7 @@ def test(): xfail("put"), xfail("quantile"), xfail("renorm"), + xfail("squeeze_copy"), xfail("take"), xfail("tensor_split"), xfail("to_sparse"), @@ -1544,6 +1546,7 @@ def test(): xfail( "index_fill" ), # aten::_unique hit the vmap fallback which is currently disabled + xfail("squeeze_copy"), xfail("t_copy"), xfail("transpose_copy"), xfail("unsqueeze_copy"), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 329ae0013bf293..3f369226183a9f 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4439,6 +4439,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("put"), xfail("quantile"), xfail("renorm"), + xfail("squeeze_copy"), xfail("resize_as_"), xfail("take"), xfail("tensor_split"), diff --git a/test/test_mps.py b/test/test_mps.py index 20dae361ba8fcf..d6e19e9eeb92e2 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -337,6 +337,7 @@ def mps_ops_modifier(ops): 'split_with_sizes_copy', 'splitlist_args', 'squeeze', + 'squeeze_copy', 'squeezemultiple', 'sub', 'svd', diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index d1e9245cd72f9c..c456b127168baa 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -201,6 +201,7 @@ "permute", "permute_copy", "squeeze", + "squeeze_copy", "unsqueeze", "unsqueeze_copy", "resize", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index c21ee837471178..34e8935be1f22c 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -621,6 +621,7 @@ def _core_aten_decompositions_post_autograd() -> ( aten.special_xlog1py, aten.split.Tensor, aten.split_with_sizes_copy, + aten.squeeze_copy, aten.squeeze.default, aten.squeeze.dim, aten.std, diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 41b9031a304227..d6ebf7e55c2394 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -285,6 +285,7 @@ "stack", "swap_axes", # alias for transpose "squeeze", + "squeeze_copy", "t", "t_copy", "T", @@ -6337,6 +6338,7 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) # no sparse support. See narrow_copy_sparse in core. narrow_copy = _make_copy_from_view(aten.narrow) permute_copy = _make_copy_from_view(aten.permute) +squeeze_copy = _make_copy_from_view(aten.squeeze) t_copy = _make_copy_from_view(aten.t) transpose_copy = _make_copy_from_view(aten.transpose) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 9c23e438ba1c43..b5907f4025d967 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -19433,6 +19433,26 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # https://github.com/pytorch/pytorch/issues/66357 check_batched_forward_grad=False, sample_inputs_func=sample_inputs_squeeze_multiple), + OpInfo('squeeze_copy', + ref=_squeeze_ref, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze, + skips=( + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,), + ), + )), UnaryUfuncInfo( 'fill', ref=_fill_np, @@ -23827,6 +23847,11 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.squeeze", torch_opinfo_name="squeeze", ), + PythonRefInfo( + "_refs.squeeze_copy", + torch_opinfo_name="squeeze_copy", + supports_out=True, + ), PythonRefInfo( "_refs.squeeze", torch_opinfo_name="squeeze", From 15c04bd125c54bc33869ec06a56d60f890a8aea1 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Fri, 13 Sep 2024 12:25:41 -0700 Subject: [PATCH 0933/1018] [effects] Turn off dtype promotion for with_effects lowering (#136039) By default inductor promotes arguments to the common highest dtype. Having empty token with dtype=torch.float32 results in dtype promotion for effectful ops during lowering of with_effects. Disabling dtype promotion for this lowering. Removing previous workaround making token dtype torch.bool. Testing: ``` python test/distributed/test_c10d_functional_native.py -k test_inductor_dtypeview_memory_lea ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136039 Approved by: https://github.com/bdhirsh, https://github.com/eellison, https://github.com/zou3519 --- torch/_higher_order_ops/effects.py | 3 +-- torch/_inductor/lowering.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index bb3d93bc69eb6b..651b3d6e2caa85 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -127,8 +127,7 @@ def get_effect_key(op, args, kwargs) -> Optional[_EffectType]: def new_token_tensor() -> torch.Tensor: - # Use dtype bool to not affect Inductor dtype promotions - return torch.tensor([], dtype=torch.bool) + return torch.tensor([]) @with_effects.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 2bdeac572ddd9c..adc199a20c1bcf 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -6257,7 +6257,7 @@ def _sink_tokens(tokens): return None -@register_lowering(torch.ops.higher_order.with_effects) +@register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None) def with_effects(token, op, *args, **kwargs): result = ir.EffectfulKernel.create(op, *args, **kwargs) From ac3c0a0f8524fad124614502f9cec81025eda1ac Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 16 Sep 2024 16:46:13 +0000 Subject: [PATCH 0934/1018] [ONNX] Run type promotion test in CI and update the table (#135915) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135915 Approved by: https://github.com/gramalingam, https://github.com/xadupre --- test/onnx/test_fx_type_promotion.py | 12 +++--------- torch/onnx/_internal/fx/passes/type_promotion.py | 14 +++++++------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/test/onnx/test_fx_type_promotion.py b/test/onnx/test_fx_type_promotion.py index 1e3860ad2a8efa..fc7dc21fba0069 100644 --- a/test/onnx/test_fx_type_promotion.py +++ b/test/onnx/test_fx_type_promotion.py @@ -1,22 +1,16 @@ # Owner(s): ["module: onnx"] -import pytorch_test_common - from torch.onnx._internal.fx.passes import type_promotion from torch.testing._internal import common_utils class TestGeneratedTypePromotionRuleSet(common_utils.TestCase): - @pytorch_test_common.skip_in_ci( - "Reduce noise in CI. " - "The test serves as a tool to validate if the generated rule set is current. " - ) def test_generated_rule_set_is_up_to_date(self): generated_set = type_promotion._GENERATED_ATEN_TYPE_PROMOTION_RULE_SET - latest_set = ( - type_promotion.TypePromotionRuleSetGenerator.generate_from_torch_refs() - ) + latest_set = type_promotion.ElementwiseTypePromotionRuleSetGenerator.generate_from_torch_refs() + # Please update the list in torch/onnx/_internal/fx/passes/type_promotion.py following the instruction + # if this test fails self.assertEqual(generated_set, latest_set) def test_initialize_type_promotion_table_succeeds(self): diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index 6397beb5f089a4..81cb6ccb7439d9 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -554,6 +554,9 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ), + ElementwiseTypePromotionRule( + "aten", "dot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), ElementwiseTypePromotionRule( "aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), @@ -870,10 +873,7 @@ def preview_type_promotion( "aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), ElementwiseTypePromotionRule( - "aten", "normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ), - ElementwiseTypePromotionRule( - "aten", "normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + "aten", "normal", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), ElementwiseTypePromotionRule( "aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT @@ -924,9 +924,6 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ), - ElementwiseTypePromotionRule( - "aten", "rsub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ), ElementwiseTypePromotionRule( "aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), @@ -1030,6 +1027,9 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), + ElementwiseTypePromotionRule( + "aten", "vdot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), ElementwiseTypePromotionRule( "aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH ), From 992cad855302be302f11fbae6a5e6c890ba755be Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 16 Sep 2024 17:49:12 +0000 Subject: [PATCH 0935/1018] [BE][Ez]: Add full half/bfloat16 dtype for `unique` and `isin` (#136114) Fixes #136090 * Add support for isin to tensor half dtypes for CPU (just add a few extra dispatches). * Seems like the CUDA implementation for bfloat16 was mostly compiled and available all along (it just calls sort internally AND unique). To enable it, we just need to remove an assert to access it (since sort's functionality was updated since the assert was added) and add missing dtype support to unique. * This unlocks more GPU functionality with minimal code bloat. I also added CPU kernels for the dtypes for parity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136114 Approved by: https://github.com/malfet --- aten/src/ATen/cuda/cub-RadixSortKeys.cu | 2 +- aten/src/ATen/native/TensorCompare.cpp | 1 - aten/src/ATen/native/cpu/TensorCompareKernel.cpp | 2 +- aten/src/ATen/native/cuda/Unique.cu | 10 +++++----- aten/src/ATen/native/cuda/UniqueCub.cu | 1 + torch/_meta_registrations.py | 2 +- torch/testing/_internal/common_methods_invocations.py | 7 ++----- 7 files changed, 11 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/cuda/cub-RadixSortKeys.cu b/aten/src/ATen/cuda/cub-RadixSortKeys.cu index 74e82ae55cdee0..56571295155268 100644 --- a/aten/src/ATen/cuda/cub-RadixSortKeys.cu +++ b/aten/src/ATen/cuda/cub-RadixSortKeys.cu @@ -50,7 +50,7 @@ void radix_sort_keys( int64_t begin_bit, \ int64_t end_bit); -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES) +AT_FORALL_SCALAR_TYPES_AND3(Bool, BFloat16, Half, AT_INSTATIATE_CUB_TEMPLATES) AT_INSTATIATE_CUB_TEMPLATES(uint16_t, UInt16) AT_INSTATIATE_CUB_TEMPLATES(uint32_t, UInt32) AT_INSTATIATE_CUB_TEMPLATES(uint64_t, UInt64) diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 6d6db1477f1f8f..c82e429621812f 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -82,7 +82,6 @@ namespace at::meta { static inline void check_for_unsupported_isin_dtype(const ScalarType type) { // Bail out for dtypes unsupported by the sorting algorithm to keep the interface consistent. TORCH_CHECK(type != ScalarType::Bool && - type != ScalarType::BFloat16 && type != ScalarType::ComplexFloat && type != ScalarType::ComplexDouble, "Unsupported input type encountered for isin(): ", type); diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 44722034ab3baa..b374935036dad9 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -325,7 +325,7 @@ static void isin_default_kernel_cpu( .check_all_same_dtype(false) .build(); // Dispatch based on promoted type. - AT_DISPATCH_ALL_TYPES(iter.dtype(1), "isin_default_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(1), "isin_default_cpu", [&]() { cpu_kernel(iter, [&](scalar_t element_val) -> bool { const auto* test_element_data = test_elements_flat.const_data_ptr(); for (const auto j : c10::irange(test_elements_flat.numel())) { diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index 39e80e0a68c3c8..67eeacee29a850 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -191,7 +191,7 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { // lack of hashtable implementation in thrust auto [output, inverse, _] = internal::unique_cuda_template(self, false, return_inverse, false); return std::make_tuple(output, inverse); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } std::tuple @@ -200,21 +200,21 @@ _unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse, // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust return internal::unique_cuda_template(self, false, return_inverse, return_counts); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } std::tuple unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] { return unique_dim_cuda_template(self, dim, false, return_inverse, return_counts); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } std::tuple unique_dim_consecutive_cuda(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] { return unique_dim_cuda_template(self, dim, true, return_inverse, return_counts); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } std::tuple @@ -224,7 +224,7 @@ unique_consecutive_cuda(const Tensor& self, const bool return_inverse, const boo // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust return internal::unique_cuda_template(self, true, return_inverse, return_counts); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } return unique_dim_consecutive_cuda(self, dim.value(), return_inverse, return_counts); } diff --git a/aten/src/ATen/native/cuda/UniqueCub.cu b/aten/src/ATen/native/cuda/UniqueCub.cu index bbd8673bcf5a65..1bda65815d6d49 100644 --- a/aten/src/ATen/native/cuda/UniqueCub.cu +++ b/aten/src/ATen/native/cuda/UniqueCub.cu @@ -339,6 +339,7 @@ INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint32_t); INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint64_t); INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint16_t); INSTANTIATE_UNIQUE_CUDA_TEMPLATE(bool); +INSTANTIATE_UNIQUE_CUDA_TEMPLATE(BFloat16); INSTANTIATE_UNIQUE_CUDA_TEMPLATE(at::Half); #undef INSTANTIATE diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index ce37b0a96fbfb5..820c21218cb1c3 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6343,7 +6343,7 @@ def meta_searchsorted( def _check_for_unsupported_isin_dtype(dtype): torch._check( - dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64], + dtype not in (torch.bool, torch.complex128, torch.complex64), lambda: f"Unsupported input type encountered for isin(): {dtype}", ) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b5907f4025d967..06467f6459d840 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -20,7 +20,7 @@ from torch.testing._internal.common_dtype import ( _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, - all_types, empty_types, complex_types_and, integral_types, custom_types, + empty_types, complex_types_and, integral_types, custom_types, ) from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -13579,8 +13579,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_gradient, error_inputs_func=error_inputs_gradient), OpInfo('isin', - dtypes=all_types(), - dtypesIfCUDA=all_types_and(torch.half), + dtypes=all_types_and(torch.bfloat16, torch.half), supports_autograd=False, sample_inputs_func=sample_inputs_isin), OpInfo('kthvalue', @@ -18335,7 +18334,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('unique', dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), - dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.uint16, torch.uint32, torch.uint64), sample_inputs_func=sample_inputs_unique, supports_out=False, supports_autograd=False, @@ -18347,7 +18345,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('unique_consecutive', dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.bool, torch.float16), sample_inputs_func=sample_inputs_unique_consecutive, supports_out=False, supports_autograd=False, From c1fd5c02818c1a83da3ab10620096f92830f8998 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 16 Sep 2024 17:58:02 +0000 Subject: [PATCH 0936/1018] Revert "[inductor] More fixes on the keys of `constants` and `signature` dictionaries (#135406)" This reverts commit e54b559e8860e343692bb5534777b2384a57a613. Reverted https://github.com/pytorch/pytorch/pull/135406 on behalf of https://github.com/jeanschmidt due to Reverting as it is breaking triton_mtia internal signals @jansel could you have a look and help get those changes merged? ([comment](https://github.com/pytorch/pytorch/pull/135406#issuecomment-2353557481)) --- docs/source/torch.compiler_get_started.rst | 2 +- test/inductor/test_cuda_repro.py | 6 +----- test/inductor/test_metrics.py | 2 +- test/inductor/test_triton_heuristics.py | 2 +- torch/_inductor/codegen/triton.py | 6 +++--- torch/_inductor/codegen/triton_combo_kernel.py | 10 ++++------ torch/_inductor/codegen/triton_utils.py | 5 ++--- torch/_inductor/codegen/wrapper.py | 11 +++++------ torch/_inductor/runtime/triton_heuristics.py | 2 +- torch/_inductor/select_algorithm.py | 6 ++---- 10 files changed, 21 insertions(+), 31 deletions(-) diff --git a/docs/source/torch.compiler_get_started.rst b/docs/source/torch.compiler_get_started.rst index 2b5bec254958f5..8e18ad411544df 100644 --- a/docs/source/torch.compiler_get_started.rst +++ b/docs/source/torch.compiler_get_started.rst @@ -57,7 +57,7 @@ the following: .. code-block:: python - @pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) + @pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 10000 diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index b493ce92539a2f..98489a383435e4 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -433,11 +433,7 @@ def decorator(fn): triton.Config({"XBLOCK": 2}), ], meta={ - "signature": { - "in_out_ptr0": "*fp32", - "in_ptr0": "*fp32", - "xnumel": "i32", - }, + "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, "device": DeviceProperties.create(torch.device("cuda")), "configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())], "constants": {}, diff --git a/test/inductor/test_metrics.py b/test/inductor/test_metrics.py index bec9943eafd7eb..f8c4815cc8946d 100644 --- a/test/inductor/test_metrics.py +++ b/test/inductor/test_metrics.py @@ -14,7 +14,7 @@ reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={ - 'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'xnumel': 'i32', 'rnumel': 'i32'}, + 'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'GPU_TYPE', 'constants': {}, diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index 65daa638520021..24f322dfebb845 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -102,7 +102,7 @@ def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): tl.store(out_ptr0 + (x0), tmp1, xmask) triton_meta = { - "signature": {"in_ptr0": "*fp32", "out_ptr0": "*fp32", "xnumel": "i32"}, + "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, "device": DeviceProperties.create(torch.device("cuda")), "constants": {}, "configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index a85fa4172829fb..b52ca309dfbb28 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2652,7 +2652,7 @@ def codegen_kernel(self, name=None): mutated_args = sorted(mutated_args) triton_meta_signature = signature_to_meta( - signature, size_dtype=self.index_dtype, argdefs=argdefs + signature, size_dtype=self.index_dtype ) triton_meta = { "signature": triton_meta_signature, @@ -2680,7 +2680,7 @@ def codegen_kernel(self, name=None): for tree in self.active_range_trees(): sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) signature.append(sizearg) - triton_meta_signature[sizearg.name] = signature_of( + triton_meta_signature[len(argdefs)] = signature_of( sizearg, size_dtype=self.index_dtype ) argdefs.append(f"{tree.prefix}numel") @@ -2698,7 +2698,7 @@ def codegen_kernel(self, name=None): # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] - triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] + triton_meta["constants"][arg_num] = 1 # type: ignore[index] self.triton_meta = triton_meta diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index c5ab3a922d5706..134c226d38effe 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -660,19 +660,18 @@ def jit_line( heuristics: str, size_hints: List[int], selected_kernel: TritonKernel, - signature: List[Any], - argdefs: List[str], pointwise_with_reduce: bool = False, + signature: Optional[List[Any]] = None, ) -> str: can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) size_dtype = "tl.int32" if can_use_32bit else "tl.int64" + if signature is None: + _, _, signature, _ = self.args.python_argdefs() for i, sub in enumerate(self.sub_kernels): self.min_x_blocks_sub_kernel(sub, i) self.select_dispatch_strategy() triton_meta = { - "signature": signature_to_meta( - signature, size_dtype=size_dtype, argdefs=argdefs - ), + "signature": signature_to_meta(signature, size_dtype=size_dtype), "device": DeviceProperties.create( V.graph.scheduler.get_current_device_or_throw() ), @@ -851,7 +850,6 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: selected_kernel, pointwise_with_reduce=pointwise_with_reduction, signature=signature, - argdefs=argdefs, ) ) code.writeline( diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 77ef3abe025d8a..c8dbd516f13595 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -68,13 +68,12 @@ def signature_to_meta( signature: List[KernelArgType], *, size_dtype: str, - argdefs: List[str], indices: Optional[List[int]] = None, -) -> Dict[str, str]: +) -> Dict[int, str]: if indices is None: indices = list(range(len(signature))) return { - argdefs[i]: signature_of(arg, size_dtype=size_dtype) + i: signature_of(arg, size_dtype=size_dtype) for i, arg in zip(indices, signature) } diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 75cbec0c8610a1..ce7bf65bd5f665 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1278,15 +1278,15 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): from .common import KernelArgType, SizeArg, TensorArg signature: List[KernelArgType] = [] - constants: Dict[str, Any] = {} + constants: Dict[int, Any] = {} non_constant_indices = [] - equal_to_1_args: List[str] = [] + equal_to_1_arg_idx: List[int] = [] for idx, key in enumerate(kernel.arg_names): if key not in kwargs: continue arg = kwargs[key] if idx in kernel.constexprs: - constants[key] = arg + constants[idx] = arg else: non_constant_indices.append(idx) if isinstance(arg, ir.Buffer): @@ -1316,14 +1316,13 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): ) and V.graph.sizevars.statically_known_equals( arg, 1 # type: ignore[arg-type] ): - equal_to_1_args.append(key) + equal_to_1_arg_idx.append(idx) index_dtype = "tl.int32" triton_meta = { "signature": signature_to_meta( signature, size_dtype=index_dtype, indices=non_constant_indices, - argdefs=kernel.arg_names, ), "device": DeviceProperties.create( V.graph.scheduler.get_current_device_or_throw() @@ -1337,7 +1336,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 "constants": { **constants, - **dict.fromkeys(equal_to_1_args, 1), + **dict.fromkeys(equal_to_1_arg_idx, 1), }, "configs": [ config_of( diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index e76e77b3cccf89..037f67fb429ee0 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -360,7 +360,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): if k == "waves_per_eu": compile_meta["waves_per_eu"] = v continue - compile_meta["constants"][k] = v + compile_meta["constants"][self.fn.arg_names.index(k)] = v compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages compile_meta["debug"] = self.inductor_meta.get( diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index a4e46acceef727..2685f730d659be 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -219,15 +219,13 @@ def jit_lines(self): argdefs, _, signature, _ = self.args.python_argdefs() triton_meta = { - "signature": signature_to_meta( - signature, size_dtype=self.index_dtype, argdefs=argdefs - ), + "signature": signature_to_meta(signature, size_dtype=self.index_dtype), "device": DeviceProperties.create(self.output_node.get_device()), "constants": {}, } triton_meta["configs"] = [config_of(signature)] for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] - triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] + triton_meta["constants"][arg_num] = 1 # type: ignore[index] matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0) if matrix_instr_nonkdim != 0: triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim From 6ca6c636fb7eab9fc30f1bb190d3aff01cbe84d5 Mon Sep 17 00:00:00 2001 From: Suresh Babu Kolla Date: Mon, 16 Sep 2024 18:10:33 +0000 Subject: [PATCH 0937/1018] [Pytorch] Cleanup Strobelight URL and shorten for readability (#136102) Summary: - Converted strobelight URL prefix to more readable and editable json - Dump shortened URLs when possible for easier readability Test Plan: ``` python ./torch/_strobelight/examples/compile_time_profile_example.py python torch/_strobelight/examples/cli_function_profiler_example.py ``` Differential Revision: D62690292 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136102 Approved by: https://github.com/laithsakka --- torch/_strobelight/compile_time_profiler.py | 79 +++++++++++++++------ 1 file changed, 58 insertions(+), 21 deletions(-) diff --git a/torch/_strobelight/compile_time_profiler.py b/torch/_strobelight/compile_time_profiler.py index 197fcad97165e1..13132188a1930e 100644 --- a/torch/_strobelight/compile_time_profiler.py +++ b/torch/_strobelight/compile_time_profiler.py @@ -1,7 +1,9 @@ # mypy: disallow-untyped-defs +import json import logging import os +import subprocess from datetime import datetime from socket import gethostname from typing import Any, Optional @@ -22,6 +24,60 @@ logger.propagate = False +def get_fburl(url: str) -> str: + short_url = url + # Attempt to shorten the URL + try: + result = subprocess.run( + ["fburl", url], capture_output=True, stdin=subprocess.DEVNULL + ) + if result.returncode == 0: + short_url = result.stdout.decode("utf-8") + except Exception as e: + logger.warning("URL shortening failed: %s, using long URL", repr(e)) + return short_url + + +def get_strobelight_url(identifier: str) -> str: + scuba_json = { + "aggregateList": [], + "aggregation_field": "async_stack_complete", + "b_constraints": [[]], + "c_constraints": [[]], + "cols": ["namespace_id", "namespace_process_id"], + "compare": "none", + "constraints": [ + [{"column": "sample_tags", "op": "all", "value": [f'["{identifier}"]']}] + ], + "derivedCols": [], + "end": "now", + "enumCols": [], + "filterMode": "DEFAULT", + "hideEmptyColumns": "false", + "ignoreGroupByInComparison": "false", + "is_timeseries": "false", + "mappedCols": [], + "metric": "count", + "modifiers": [], + "order": "weight", + "order_desc": "true", + "param_dimensions": [ + {"dim": "py_async_stack", "op": "edge", "param": "0", "anchor": "0"} + ], + "purposes": [], + "return_remainder": "false", + "samplingRatio": "1", + "should_pivot": "false", + "start": "-30 days", + "timezone": "America/Los_Angeles", + "top": 10000, + } + scuba_url_prefix = "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experimental/on_demand&drillstate=" + scuba_url_suff = "&view=GraphProfilerView&&normalized=1726332703&pool=uber" + long_url = scuba_url_prefix + json.dumps(scuba_json) + scuba_url_suff + return get_fburl(long_url) + + class StrobelightCompileTimeProfiler: success_profile_count: int = 0 failed_profile_count: int = 0 @@ -89,27 +145,8 @@ def _cls_init(cls) -> None: logger.info("Unique sample tag for this run is: %s", cls.identifier) logger.info( - "You can use the following link to access the strobelight profile at the end of the run: %s", - ( - "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experime" - "ntal%2Fon_demand&drillstate=%7B%22purposes%22%3A[]%2C%22end%22%3A%22no" - "w%22%2C%22start%22%3A%22-30%20days%22%2C%22filterMode%22%3A%22DEFAULT%" - "22%2C%22modifiers%22%3A[]%2C%22sampleCols%22%3A[]%2C%22cols%22%3A[%22n" - "amespace_id%22%2C%22namespace_process_id%22]%2C%22derivedCols%22%3A[]%" - "2C%22mappedCols%22%3A[]%2C%22enumCols%22%3A[]%2C%22return_remainder%22" - "%3Afalse%2C%22should_pivot%22%3Afalse%2C%22is_timeseries%22%3Afalse%2C" - "%22hideEmptyColumns%22%3Afalse%2C%22timezone%22%3A%22America%2FLos_Ang" - "eles%22%2C%22compare%22%3A%22none%22%2C%22samplingRatio%22%3A%221%22%2" - "C%22metric%22%3A%22count%22%2C%22aggregation_field%22%3A%22async_stack" - "_complete%22%2C%22top%22%3A10000%2C%22aggregateList%22%3A[]%2C%22param" - "_dimensions%22%3A[%7B%22dim%22%3A%22py_async_stack%22%2C%22op%22%3A%22" - "edge%22%2C%22param%22%3A%220%22%2C%22anchor%22%3A%220%22%7D]%2C%22orde" - "r%22%3A%22weight%22%2C%22order_desc%22%3Atrue%2C%22constraints%22%3A[[" - "%7B%22column%22%3A%22sample_tags%22%2C%22op%22%3A%22all%22%2C%22value%" - f"22%3A[%22[%5C%22{cls.identifier}%5C%22]%22]%7D]]%2C%22c_constraints%22%3A[[]]%2C%22b" - "_constraints%22%3A[[]]%2C%22ignoreGroupByInComparison%22%3Afalse%7D&vi" - "ew=GraphProfilerView&&normalized=1712358002&pool=uber" - ), + "URL to access the strobelight profile at the end of the run: %s", + get_strobelight_url(cls.identifier), ) @classmethod From 1712cd6a7609aa944167001f44a946d75dba7bc6 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 16 Sep 2024 18:22:14 +0000 Subject: [PATCH 0938/1018] [BE][Ez]: Fix missing float16 coverage for adaptive_pool3d_cpu (#136091) Testing if op info coverage has issues Pull Request resolved: https://github.com/pytorch/pytorch/pull/136091 Approved by: https://github.com/ezyang --- aten/src/ATen/native/AdaptiveMaxPooling3d.cpp | 8 ++++---- torch/testing/_internal/common_methods_invocations.py | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp index 1c037c31c6f64d..c0f2399138cea1 100644 --- a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp @@ -297,7 +297,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu) int64_t osizeW = output_size[2]; if (input.ndimension() == 4) { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "adaptive_max_pool3d_cpu", [&] { auto input_data = input.const_data_ptr(); auto output_data = output.data_ptr(); @@ -320,7 +320,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu) istrideW); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "adaptive_max_pool3d_cpu", [&] { auto input_data = input.const_data_ptr(); auto output_data = output.data_ptr(); @@ -390,7 +390,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu) /* backprop */ if (input.ndimension() == 4) { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "adaptive_max_pool3d_backward", [&] { /* get raw pointers */ scalar_t* gradInput_data = gradInput.data_ptr(); @@ -410,7 +410,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu) osizeW); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "adaptive_max_pool3d_backward", [&] { /* get raw pointers */ scalar_t* gradInput_data = gradInput.data_ptr(); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 06467f6459d840..cc61f8bbc69863 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14709,7 +14709,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_adaptive_avg_pool2d), OpInfo('nn.functional.adaptive_avg_pool3d', dtypes=floating_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), decorators=( # RuntimeError: # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor): @@ -14763,8 +14762,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): error_inputs_func=error_inputs_adaptive_max_pool2d, sample_inputs_func=sample_inputs_adaptive_max_pool2d), OpInfo('nn.functional.adaptive_max_pool3d', - dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types_and(torch.bfloat16, torch.half), decorators=( # RuntimeError: # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor): From 7136fe1d0d25e6967624c0816fe3f049b800b5ba Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 16 Sep 2024 18:33:33 +0000 Subject: [PATCH 0939/1018] Revert "Add CI for Triton CPU backend (#135342)" This reverts commit 426580a67db15ec17b2b861a09667bf59927e033. Reverted https://github.com/pytorch/pytorch/pull/135342 on behalf of https://github.com/jeanschmidt due to Broke internal signals, see D62737208 for more details ([comment](https://github.com/pytorch/pytorch/pull/133408#issuecomment-2353623816)) --- .ci/docker/build.sh | 8 - .ci/docker/ci_commit_pins/triton-cpu.txt | 1 - .ci/docker/common/install_triton.sh | 8 +- .ci/docker/ubuntu/Dockerfile | 7 - .ci/pytorch/test.sh | 7 - .github/workflows/inductor.yml | 22 --- test/inductor/test_torchinductor.py | 182 ++++++---------------- test/inductor/test_triton_cpu_backend.py | 12 +- torch/testing/_internal/inductor_utils.py | 1 - 9 files changed, 51 insertions(+), 197 deletions(-) delete mode 100644 .ci/docker/ci_commit_pins/triton-cpu.txt diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 9e0abc37ab8578..8aafbb89d811ec 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -381,13 +381,6 @@ case "$image" in HALIDE=yes TRITON=yes ;; - pytorch-linux-jammy-py3.12-triton-cpu) - CUDA_VERSION=12.4 - ANACONDA_PYTHON_VERSION=3.12 - GCC_VERSION=11 - CONDA_CMAKE=yes - TRITON_CPU=yes - ;; pytorch-linux-focal-linter) # TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627. # We will need to update mypy version eventually, but that's for another day. The task @@ -517,7 +510,6 @@ docker build \ --build-arg "UCC_COMMIT=${UCC_COMMIT}" \ --build-arg "CONDA_CMAKE=${CONDA_CMAKE}" \ --build-arg "TRITON=${TRITON}" \ - --build-arg "TRITON_CPU=${TRITON_CPU}" \ --build-arg "ONNX=${ONNX}" \ --build-arg "DOCS=${DOCS}" \ --build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \ diff --git a/.ci/docker/ci_commit_pins/triton-cpu.txt b/.ci/docker/ci_commit_pins/triton-cpu.txt deleted file mode 100644 index 36e9eea06792a7..00000000000000 --- a/.ci/docker/ci_commit_pins/triton-cpu.txt +++ /dev/null @@ -1 +0,0 @@ -6a333f1b05671f6fada4ba7bbfae4a02a9d96f4f diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index cb2d1edc71c995..6e5fb8839c686e 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -15,11 +15,8 @@ conda_reinstall() { if [ -n "${XPU_VERSION}" ]; then TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton" TRITON_TEXT_FILE="triton-xpu" -elif [ -n "${TRITON_CPU}" ]; then - TRITON_REPO="https://github.com/triton-lang/triton-cpu" - TRITON_TEXT_FILE="triton-cpu" else - TRITON_REPO="https://github.com/triton-lang/triton" + TRITON_REPO="https://github.com/openai/triton" TRITON_TEXT_FILE="triton" fi @@ -47,10 +44,9 @@ chown -R jenkins /var/lib/jenkins/triton chgrp -R jenkins /var/lib/jenkins/triton pushd /var/lib/jenkins/ -as_jenkins git clone --recursive ${TRITON_REPO} triton +as_jenkins git clone ${TRITON_REPO} triton cd triton as_jenkins git checkout ${TRITON_PINNED_COMMIT} -as_jenkins git submodule update --init --recursive cd python # TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 1be2fccca41922..80410be3cbc69f 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -147,13 +147,6 @@ COPY ci_commit_pins/triton.txt triton.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton.txt -ARG TRITON_CPU -COPY ./common/install_triton.sh install_triton.sh -COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton-cpu.txt triton-cpu.txt -RUN if [ -n "${TRITON_CPU}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton-cpu.txt - ARG EXECUTORCH # Build and install executorch COPY ./common/install_executorch.sh install_executorch.sh diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 594c5e4ef69252..9d2f2246ad7509 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -607,11 +607,6 @@ test_inductor_halide() { assert_git_not_dirty } -test_inductor_triton_cpu() { - python test/run_test.py --include inductor/test_triton_cpu_backend.py --verbose - assert_git_not_dirty -} - test_dynamo_benchmark() { # Usage: test_dynamo_benchmark huggingface 0 TEST_REPORTS_DIR=$(pwd)/test/test-reports @@ -1438,8 +1433,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then test_inductor_distributed elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then test_inductor_halide -elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then - test_inductor_triton_cpu elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then test_inductor_micro_benchmark elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 0c630878fea538..88ffe090fd8897 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -118,28 +118,6 @@ jobs: docker-image: ${{ needs.linux-jammy-cpu-py3_12-inductor-halide-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cpu-py3_12-inductor-halide-build.outputs.test-matrix }} - linux-jammy-cpu-py3_12-inductor-triton-cpu-build: - name: linux-jammy-cpu-py3.12-gcc11-inductor-triton-cpu - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - build-environment: linux-jammy-py3.12-gcc11 - docker-image-name: pytorch-linux-jammy-py3.12-triton-cpu - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - test-matrix: | - { include: [ - { config: "inductor-triton-cpu", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - ]} - - linux-jammy-cpu-py3_12-inductor-triton-cpu-test: - name: linux-jammy-cpu-py3.12-gcc11-inductor-triton-cpu - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cpu-py3_12-inductor-triton-cpu-build - with: - build-environment: linux-jammy-py3.12-gcc11 - docker-image: ${{ needs.linux-jammy-cpu-py3_12-inductor-triton-cpu-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cpu-py3_12-inductor-triton-cpu-build.outputs.test-matrix }} - linux-focal-cuda12_4-py3_10-gcc9-inductor-build: # Should be synced with the one in inductor-periodic.yml but this only runs inductor_timm name: cuda12.4-py3.10-gcc9-sm86 diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index f2c333bac132be..f2240ff64922d4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -738,6 +738,12 @@ def is_cpp_backend(device): return getattr(device, "type", device) == "cpu" and config.cpu_backend == "cpp" +def is_halide_backend(device): + if getattr(device, "type", device) == "cpu": + return config.cpu_backend == "halide" + return config.cuda_backend == "halide" + + def skip_if_halide(fn): @functools.wraps(fn) def wrapper(self): @@ -748,42 +754,6 @@ def wrapper(self): return wrapper -def is_halide_backend(device): - if getattr(device, "type", device) == "cpu": - return config.cpu_backend == "halide" - return config.cuda_backend == "halide" - - -def is_triton_cpu_backend(device): - return getattr(device, "type", device) == "cpu" and config.cpu_backend == "triton" - - -def skip_if_triton_cpu(fn): - import types - - reason = "Triton CPU not supported" - - def decorator(fn): - @functools.wraps(fn) - def wrapper(self): - if is_triton_cpu_backend(self.device): - raise unittest.SkipTest(reason) - return fn(self) - - return wrapper - - if isinstance(fn, types.FunctionType): - return decorator(fn) - else: - reason = fn - return decorator - - -def xfail_if_triton_cpu(fn): - fn._expected_failure_triton_cpu = True - return fn - - def skip_if_gpu_halide(fn): @functools.wraps(fn) def wrapper(self): @@ -823,7 +793,6 @@ def fn(a, b): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti - @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_dtype_device_layout(self): ns = "aten" @@ -866,7 +835,6 @@ def test_aoti_eager_dtype_device_layout(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti - @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_support_out(self): ns = "aten" @@ -921,7 +889,6 @@ def test_aoti_eager_support_out(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti - @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_support_str(self): ns = "aten" @@ -961,7 +928,6 @@ def test_aoti_eager_support_str(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti - @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_cache_hit(self): ns = "aten" @@ -1005,7 +971,6 @@ def test_aoti_eager_cache_hit(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti - @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_with_persistent_cache(self): def fn(a): @@ -1052,7 +1017,6 @@ def fn(a): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti - @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_with_scalar(self): namespace_name = "aten" @@ -1125,7 +1089,6 @@ def test_aoti_eager_with_scalar(self): @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_halide # aoti - @skip_if_triton_cpu # aoti @skipIfWindows(msg="aoti not support on Windows") def test_aoti_eager_override_registration(self): namespace_name = "aten" @@ -1284,7 +1247,6 @@ def fn(a): self.common(fn, (torch.randn(17),)) - @xfail_if_triton_cpu def test_angle(self): def fn(a, b, c): return torch.angle(a), torch.angle(b), torch.angle(c) @@ -1474,7 +1436,6 @@ def nested(x, repeats): actual = nested_opt(*example_inputs) self.assertEqual(expect, actual) - @xfail_if_triton_cpu def test_index_propagation_flip(self): def flip(x): i = torch.arange(x.size(0) - 1, -1, -1, device=x.device) @@ -1560,7 +1521,7 @@ def test( fn_opt = torch.compile(fn) if is_halide_backend(self.device): pass # no device asserts in halide - elif self.device == "cpu" and not is_triton_cpu_backend(self.device): + elif self.device == "cpu": _, code = run_and_get_cpp_code(fn_opt, *inps) self.assertTrue((") ? (" in code or "blendv" in code) is has_wrapping) self.assertTrue(("TORCH_CHECK" in code) is has_assert) @@ -1656,7 +1617,6 @@ def constant_propagation_neg(a): vectorize=False, # There's no loop to vectorize! ) - @xfail_if_triton_cpu def test_computed_buffer_inlining(self): def flip(x): idx = torch.arange(x.size(0) - 1, -1, -1, device=x.device) @@ -2264,7 +2224,6 @@ def make_tensor(shape): ) self.assertEqual(cfn(inp), fn(inp)) - @xfail_if_triton_cpu def test_logcumsumexp(self): def fn(x): return x.logcumsumexp(0), x.logcumsumexp(1) @@ -2312,7 +2271,6 @@ def fn(a): self.common(fn, (torch.randint(4, (4,)),)) @skip_if_gpu_halide - @xfail_if_triton_cpu def test_dist(self): def fn(a, b): return ( @@ -2562,7 +2520,6 @@ def fn(a, b): self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - @xfail_if_triton_cpu def test_round(self): def fn(a, b): return torch.round(a), torch.round(b + 1), torch.round(a, decimals=2) @@ -2574,7 +2531,6 @@ def fn(a, b): # with *100 we are always getting a number exactly at .5 which we don't do right in half self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 10)) - @xfail_if_triton_cpu def test_round_correctness(self): if self.device == "cuda": raise unittest.SkipTest("need to debug tl.libdevice on A100/V100") @@ -2588,7 +2544,6 @@ def fn(a): check_lowp=False, ) - @xfail_if_triton_cpu def test_builtins_round(self): def fn(x, i): return x[: round(i / 2 + 1)] + round(i / 2) @@ -2600,7 +2555,6 @@ def fn(x, i): for i in range(1, 6): self.assertEqual(cfn(x, i), fn(x, i)) - @xfail_if_triton_cpu def test_builtins_round_float_ndigits_pos(self): def fn(x, i): return x + round(i / 2 * 123.4567, 1) @@ -2613,7 +2567,6 @@ def fn(x, i): with torch.no_grad(): self.assertEqual(cfn(x, i), fn(x, i)) - @xfail_if_triton_cpu def test_builtins_round_float_ndigits_zero(self): def fn(x, i): return x + round(i / 2 * 123.4567, 0) @@ -2626,7 +2579,6 @@ def fn(x, i): with torch.no_grad(): self.assertEqual(cfn(x, i), fn(x, i)) - @xfail_if_triton_cpu def test_builtins_round_float_ndigits_neg(self): def fn(x, i): return x + round(i / 2 * 123.4567, -1) @@ -2778,7 +2730,6 @@ def fn(a, b): (torch.ones([8, 8], dtype=torch.bool), torch.randint(-100, -1, [8, 8])), ) - @skip_if_triton_cpu # divide by zero; cannot xfail because it crashes process def test_div7(self): def fn(a, b): return ( @@ -2817,7 +2768,6 @@ def fn(x): self.common(fn, (torch.randn(8),)) - @skip_if_triton_cpu # divide by zero; cannot xfail because it crashes process def test_div_zero_dim(self): def fn(a, b): return ( @@ -2844,7 +2794,6 @@ def fn(a, b): ), ) - @skip_if_triton_cpu # divide by zero; cannot xfail because it crashes process def test_div_prim(self): def fn(a, b): return (torch.ops.prims.div(a, b),) @@ -3998,7 +3947,6 @@ def fn(a): ) @requires_gpu() - @xfail_if_triton_cpu def test_multi_device(self): def fn(x): x = x + 1 @@ -4396,7 +4344,6 @@ def run_weights_sharing_model(m, inp): thread.join() @unittest.skipIf(config.is_fbcode(), "fbcode triton error, needs debugging") - @skip_if_triton_cpu("Flaky on Triton CPU") @skip_if_gpu_halide # https://github.com/halide/Halide/issues/8311 def test_adaptive_avg_pool2d_low_prec(self): class Model(torch.nn.Module): @@ -4826,7 +4773,6 @@ def fn(x): ) @skip_if_halide # lgamma not implemented - @xfail_if_triton_cpu def test_lgamma(self): def fn(x): return aten.lgamma(x) + 2, aten.cos(x + 1) @@ -5366,7 +5312,6 @@ def fn(x): (torch.randn([16, 16]),), ) - @xfail_if_triton_cpu def test_pow2(self): def fn(x): return aten.pow(1000, x), aten.pow(x, 1000) @@ -5416,7 +5361,6 @@ def fn(x, y): ), ) - @xfail_if_triton_cpu def test_pow_symfloat(self): def fn(x): r = math.sqrt(x.size(0)) @@ -5925,7 +5869,6 @@ def fn(x): self.common(fn, (torch.randn([1, 2, 6, 6]),)) - @xfail_if_triton_cpu def test_fmod(self): def fn(a, b): return torch.fmod(a, b), torch.fmod(3.0 * a, b) - 2.0 @@ -5933,7 +5876,6 @@ def fn(a, b): shape = [1, 2, 6, 6] self.common(fn, (torch.randn(shape), torch.randn(shape))) - @xfail_if_triton_cpu def test_fmod_zero_dim(self): def fn(a, b): return (torch.fmod(a, b),) @@ -7354,7 +7296,6 @@ def fn(x): actual_out = compiled_fn(view) self.assertEqual(reference_out.stride(), actual_out.stride()) - @xfail_if_triton_cpu def test_like_channels_last(self): def foo(): randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32) @@ -7464,7 +7405,6 @@ def fn(a, b): self.common(fn, [a, b]) @with_tf32_off - @xfail_if_triton_cpu def test_slice_scatter_reinplace(self): class M(nn.Module): def __init__(self, device): @@ -8242,7 +8182,6 @@ def fn(x): self.assertFalse(torch.allclose(a0, a1)) @requires_gpu() - @skip_if_triton_cpu("Flaky on Triton CPU") def test_like_rands3(self): # rand_like with `device` which is different from `x.device` def test_like_rands_on_different_device(device1, device2): @@ -8604,7 +8543,6 @@ def fn(a, b): ) assertGeneratedKernelCountEqual(self, 0) - @xfail_if_triton_cpu @config.patch(search_autotune_cache=False) def test_mm_views(self): def fn(a, b): @@ -8692,7 +8630,6 @@ def check(r, g): self.assertTrue(same(r2, r3)) self.assertTrue(same(g2, g3)) - @xfail_if_triton_cpu @config.patch(search_autotune_cache=False) def test_dropout3(self): m = torch.nn.Sequential( @@ -9480,7 +9417,6 @@ def fn(a, b): ], ) - @xfail_if_triton_cpu def test_index_dynamic_shapes(self): # Repro from vision_maskrcnn def fn(arg0_1): @@ -9653,7 +9589,6 @@ def forward(self, arg0_1, arg1_1): self.assertEqual(inductor_out, eager_out) @skipIfRocm - @xfail_if_triton_cpu def test_require_stride_expanded(self): def forward(arg6, arg7, arg16): convolution = torch.ops.aten.convolution( @@ -9923,7 +9858,6 @@ def fn(a, b): self.common(fn, (torch.rand(1), torch.rand(2))) - @xfail_if_triton_cpu def test_view_on_aliased(self): # https://github.com/pytorch/pytorch/issues/96728 def fn1(a, b): @@ -9995,7 +9929,6 @@ def fn(x): self.common(fn, (torch.ones(1, 1, 13, dtype=dtype),)) @unittest.skipIf(not HAS_CPU, "requires C++ compiler") - @xfail_if_triton_cpu # bf16 @skip_if_halide # bf16 def test_data_type_propogation(self): from torch._dynamo.utils import detect_fake_mode @@ -10296,7 +10229,6 @@ def fn(x, y): opt_fn = torch._dynamo.optimize("inductor")(fn) same(fn(x, y), opt_fn(x_clone, y)) - @xfail_if_triton_cpu def test_erfc(self): def fn(x): return torch.erfc(x) @@ -10304,7 +10236,6 @@ def fn(x): self.common(fn, (torch.randn(8, 8),)) @skip_if_halide # erfinv not implemented - @xfail_if_triton_cpu def test_erfinv(self): def fn(x): return torch.erfinv(x) @@ -10688,7 +10619,6 @@ def fn(x): self.common(fn, (inp,), check_lowp=False) @requires_gpu() - @xfail_if_triton_cpu @config.patch(implicit_fallbacks=True) def test_mutable_custom_op_fixed_layout2(self): with torch.library._scoped_library("mylib", "DEF") as lib: @@ -11010,7 +10940,6 @@ def fn(x): @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") # We only support dtypeview for abi_conpatible aoti - @skip_if_triton_cpu("Compile time crash in Triton CPU CI") @torch._inductor.config.patch(abi_compatible=True) def test_dtypeview(self): if TEST_WITH_ASAN: @@ -11087,7 +11016,6 @@ def fn(x, x2): _, code = run_and_get_code(fn, x, x2) FileCheck().check("aten.view.dtype(reinterpret_tensor").run(code[0]) - @xfail_if_triton_cpu @requires_gpu() def test_scalar_cpu_tensor_arg(self): def fn(x, y): @@ -11122,7 +11050,6 @@ def fn(x): @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") @skip_if_gpu_halide # https://github.com/halide/Halide/issues/8311 - @xfail_if_triton_cpu def test_bfloat16_to_int16(self): def fn(a, b): x = a + b @@ -11225,60 +11152,47 @@ def test_pointwise(self, name, op): # _cuda not implemented for Half check_lowp = False - if ( - is_halide_backend(self.device) - or is_triton_cpu_backend(self.device) - and name - in ( - "erfinv", - "airy_ai", - "bessel_j0", - "bessel_j1", - "bessel_y0", - "bessel_y1", - "chebyshev_polynomial_t", - "chebyshev_polynomial_u", - "chebyshev_polynomial_v", - "chebyshev_polynomial_w", - "digamma", - "gammainc", - "gammaincc", - "gammaln", - "hermite_polynomial_h", - "hermite_polynomial_he", - "i0", - "i0e", - "i1", - "i1e", - "laguerre_polynomial_l", - "legendre_polynomial_p", - "modified_bessel_i0", - "modified_bessel_i1", - "modified_bessel_k0", - "modified_bessel_k1", - "multigammaln", - "ndtri", - "polygamma", - "psi", - "scaled_modified_bessel_k0", - "scaled_modified_bessel_k1", - "shifted_chebyshev_polynomial_t", - "shifted_chebyshev_polynomial_u", - "shifted_chebyshev_polynomial_v", - "shifted_chebyshev_polynomial_w", - "spherical_bessel_j0", - "zeta", - ) + if is_halide_backend(self.device) and name in ( + "erfinv", + "airy_ai", + "bessel_j0", + "bessel_j1", + "bessel_y0", + "bessel_y1", + "chebyshev_polynomial_t", + "chebyshev_polynomial_u", + "chebyshev_polynomial_v", + "chebyshev_polynomial_w", + "digamma", + "gammainc", + "gammaincc", + "gammaln", + "hermite_polynomial_h", + "hermite_polynomial_he", + "i0", + "i0e", + "i1", + "i1e", + "laguerre_polynomial_l", + "legendre_polynomial_p", + "modified_bessel_i0", + "modified_bessel_i1", + "modified_bessel_k0", + "modified_bessel_k1", + "multigammaln", + "ndtri", + "polygamma", + "psi", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", + "shifted_chebyshev_polynomial_t", + "shifted_chebyshev_polynomial_u", + "shifted_chebyshev_polynomial_v", + "shifted_chebyshev_polynomial_w", + "spherical_bessel_j0", + "zeta", ): - raise unittest.SkipTest(f"Halide & Triton CPU do not support {name}") - - if is_triton_cpu_backend(self.device) and name in [ - "erfc", - "erfcx", - "round", - "log_ndtr", - ]: - raise unittest.SkipTest(f"Triton CPU does not support {name}") + raise unittest.SkipTest(f"halide does not support {name}") if name in {"gammainc", "gammaincc"}: args = ( @@ -11400,7 +11314,6 @@ def test_generate_rand_fp8(self): t = rand_strided((2, 3), (3, 1), device=self.device, dtype=torch.float8_e4m3fn) self.assertTrue(t.dtype is torch.float8_e4m3fn) - @skip_if_triton_cpu("Triton CPU: Cannot xfail because it crashes process") def test_large_grid(self): # https://github.com/pytorch/pytorch/issues/123210 def fn(primals_5): @@ -11455,7 +11368,6 @@ def forward(): self.common(forward, ()) - @xfail_if_triton_cpu def test_flip_cat(self): def forward(unsqueeze, unsqueeze_1): cat_1 = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1) diff --git a/test/inductor/test_triton_cpu_backend.py b/test/inductor/test_triton_cpu_backend.py index cb738163b3d120..8be4d85caea6a6 100644 --- a/test/inductor/test_triton_cpu_backend.py +++ b/test/inductor/test_triton_cpu_backend.py @@ -25,16 +25,8 @@ class SweepInputsCpuTritonTest(test_torchinductor.SweepInputsCpuTest): pass @config.patch(cpu_backend="triton") - class CpuTritonTests(test_torchinductor.TestCase): - common = test_torchinductor.check_model - device = "cpu" - - test_torchinductor.copy_tests( - test_torchinductor.CommonTemplate, - CpuTritonTests, - "cpu", - xfail_prop="_expected_failure_triton_cpu", - ) + class CpuTritonTests(test_torchinductor.CpuTests): + pass if __name__ == "__main__": diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index d8e86651fa9a56..00f9ba0dd2ee29 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -74,7 +74,6 @@ def _check_has_dynamic_shape( def skipDeviceIf(cond, msg, *, device): if cond: def decorate_fn(fn): - @functools.wraps(fn) def inner(self, *args, **kwargs): if not hasattr(self, "device"): warn_msg = "Expect the test class to have attribute device but not found. " From 07f1b262381d0da6d57c915be9a9d486f97894ab Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 16 Sep 2024 18:33:33 +0000 Subject: [PATCH 0940/1018] Revert "Add Triton CPU as an Inductor backend (#133408)" This reverts commit e498b02b472e45cfd6b7a08db0d6c1babec655c5. Reverted https://github.com/pytorch/pytorch/pull/133408 on behalf of https://github.com/jeanschmidt due to Broke internal signals, see D62737208 for more details ([comment](https://github.com/pytorch/pytorch/pull/133408#issuecomment-2353623816)) --- .ci/docker/build.sh | 1 - .../fsdp/test_fully_shard_compile.py | 22 +-- .../fully_shard/test_fully_shard_compile.py | 4 +- .../test_replicate_with_compiler.py | 14 +- .../_tensor/test_dtensor_compile.py | 10 +- .../fsdp/test_fsdp_use_orig_params.py | 4 +- .../tensor/parallel/test_micro_pipeline_tp.py | 18 +- .../test_c10d_functional_native.py | 28 +-- .../test_compute_comm_reordering.py | 14 +- test/distributed/test_dynamo_distributed.py | 50 +++--- test/distributed/test_functional_api.py | 8 +- test/distributed/test_inductor_collectives.py | 38 ++-- test/inductor/test_cuda_repro.py | 4 +- test/inductor/test_memory_planning.py | 3 +- test/inductor/test_triton_cpu_backend.py | 34 ---- test/test_sparse_semi_structured.py | 8 +- torch/_dynamo/device_interface.py | 49 ----- torch/_inductor/autotune_process.py | 72 ++++---- torch/_inductor/codegen/common.py | 7 +- torch/_inductor/codegen/cpp.py | 2 +- .../_inductor/codegen/cpp_template_kernel.py | 4 +- .../codegen/cpu_device_op_overrides.py | 26 --- torch/_inductor/codegen/halide.py | 1 - .../codegen/rocm/rocm_benchmark_request.py | 8 +- torch/_inductor/codegen/triton.py | 7 +- torch/_inductor/codegen/wrapper.py | 168 +++++++++--------- torch/_inductor/config.py | 2 +- torch/_inductor/runtime/benchmarking.py | 2 +- torch/_inductor/runtime/triton_helpers.py | 15 -- torch/_inductor/runtime/triton_heuristics.py | 3 - torch/_inductor/scheduler.py | 3 +- torch/_inductor/select_algorithm.py | 17 +- torch/_inductor/utils.py | 8 +- torch/utils/_triton.py | 16 +- 34 files changed, 248 insertions(+), 422 deletions(-) delete mode 100644 test/inductor/test_triton_cpu_backend.py delete mode 100644 torch/_inductor/codegen/cpu_device_op_overrides.py diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 8aafbb89d811ec..61f7acfe7633fe 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -379,7 +379,6 @@ case "$image" in GCC_VERSION=11 CONDA_CMAKE=yes HALIDE=yes - TRITON=yes ;; pytorch-linux-focal-linter) # TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627. diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 1bf26d1b67b3e6..4a4aac1f63fde7 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -30,7 +30,7 @@ ModelArgs, Transformer, ) -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton log = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def _is_fallback_op_in_snodes(snodes, op): class TestFullyShardCompileCompute(FSDPTest): - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_disable_compiling_hooks(self): self.run_subtests( @@ -518,14 +518,14 @@ def input_creation_fn(): return model_init_fn, input_creation_fn @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_simple_mlp_fullgraph_backend_aot_eager(self): self._test_traceable_fsdp( *self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True ) @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self): self._test_traceable_fsdp( *self._create_simple_mlp_factory_fns(), @@ -534,7 +534,7 @@ def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self): ) @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_simple_mlp_fullgraph_backend_inductor(self): self._test_traceable_fsdp( *self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True @@ -602,7 +602,7 @@ def input_creation_fn(): return model_init_fn, input_creation_fn @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_aot_eager(self): for fullgraph in [True, False]: self._test_traceable_fsdp( @@ -612,7 +612,7 @@ def test_nested_fully_shard_backend_aot_eager(self): ) @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): for fullgraph in [True, False]: self._test_traceable_fsdp( @@ -622,7 +622,7 @@ def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): ) @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_inductor(self): for fullgraph in [True, False]: with self._reinplace_all_gather_with_optional_checks( @@ -785,7 +785,7 @@ def _sdpa_with_graph_break(orig_fn, fullgraph, *args, **kwargs): ) @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_transformer_backend_aot_eager(self): for fullgraph, all_requires_grad in itertools.product( [True, False], [True, False] @@ -802,7 +802,7 @@ def test_transformer_backend_aot_eager(self): ) @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout has worse accuracy after decomp, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_aot_eager_decomp_partition(self): @@ -819,7 +819,7 @@ def test_transformer_backend_aot_eager_decomp_partition(self): ) @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_inductor(self): diff --git a/test/distributed/_composable/fully_shard/test_fully_shard_compile.py b/test/distributed/_composable/fully_shard/test_fully_shard_compile.py index 1ad5b3be8462c2..05af94ff11e06f 100644 --- a/test/distributed/_composable/fully_shard/test_fully_shard_compile.py +++ b/test/distributed/_composable/fully_shard/test_fully_shard_compile.py @@ -18,7 +18,7 @@ TransformerWithSharedParams, ) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton if not dist.is_available(): @@ -38,7 +38,7 @@ class TestCompile(FSDPTest): def world_size(self) -> int: return torch.cuda.device_count() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_compile(self): self.run_subtests( diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index d3a519e8e6f93f..14d9cc47271f59 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -33,7 +33,7 @@ ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint @@ -216,21 +216,21 @@ def test_compile_cpu_no_sync(self): ] self._test_compile(use_gpu=False, no_sync=True) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) @torch._inductor.config.patch(reorder_for_locality=False) def test_compile_gpu(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=False) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) @torch._inductor.config.patch(reorder_for_locality=False) def test_compile_gpu_ac(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=True) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) def test_compile_bf16(self): @@ -244,7 +244,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: self._test_compile(use_gpu=True, no_sync=False, setup_func=setup) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) def test_compile_fp16(self): @@ -261,7 +261,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm @skip_if_lt_x_gpu(2) def test_compile_backward_only(self): @@ -385,7 +385,7 @@ def setUp(self): def tearDown(self): dist.destroy_process_group() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm def test_ddp_tp(self): ref_model = Net() diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index 60a12d1935f103..eaef0716ba30dc 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -46,7 +46,7 @@ with_comms, ) from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint @@ -420,7 +420,7 @@ def fn(x): tmp_dt._local_tensor.stride(), tmp_dt_fake._local_tensor.stride() ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent(self): # Partial -> Shard on an unbalanced tensor results in: # - A contiguous DTensor @@ -496,7 +496,7 @@ def fw_hook(module, inp, out): out_test = opt_mod(dt) self.assertEqual(out_ref, out_test) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_dtensor_different_gradient_placement(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -628,7 +628,7 @@ def forward(self, primals_1): return (sin_1, primals_1, wait_tensor)""", ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_dtensor_partial_placement_graph_output(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -646,7 +646,7 @@ def fn(x): out_dt = torch.matmul(tmp_dt, y_dt) out_dt.sum().backward() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(1) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index e19a522e157002..33fae8dea8a561 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -43,7 +43,7 @@ TEST_WITH_DEV_DBG_ASAN, TestCase, ) -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton if not dist.is_available(): @@ -218,7 +218,7 @@ def _get_sharding_strategy_from_str( raise ValueError(f"Invalid string: {sharding_strategy_str}") return sharding_strategy - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_fsdp_compile(self): self.run_subtests( diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index 3a4b2eb946fc51..951e77188364c3 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -37,7 +37,7 @@ ) from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton def _make_post_grad_fx(f, *inps): @@ -78,7 +78,7 @@ def setUp(self): def tearDown(self): dist.destroy_process_group() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_find_all_gather_patterns(self): group = dist.group.WORLD @@ -129,7 +129,7 @@ def func(inp: torch.Tensor) -> torch.Tensor: torch.ops.aten.view.dtype, ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_find_reduce_scatter_patterns(self): group = dist.group.WORLD @@ -168,7 +168,7 @@ def func(inp: torch.Tensor) -> torch.Tensor: self.assertEqual(reduce_scatters[1].reduce_op, "avg") self.assertEqual(reduce_scatters[1].scatter_dim, 1) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_get_unexposed_collectives(self): group = dist.group.WORLD @@ -193,7 +193,7 @@ def func(inp: torch.Tensor) -> torch.Tensor: ["all_gather_into_tensor", "reduce_scatter_tensor"], ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("gather_dim", [0, 1, 2]) @fresh_inductor_cache() @@ -231,7 +231,7 @@ def func(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor: self.assertNotIn("all_gather_into_tensor", code) @runOnRocmArch(MI300_ARCH) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("gather_dim", [0, 1, 2]) @fresh_inductor_cache() @@ -299,7 +299,7 @@ def func( self.assertIn("fused_all_gather_scaled_matmul", code) self.assertNotIn("all_gather_into_tensor", code) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("scatter_dim", [0, 1, 2]) @fresh_inductor_cache() @@ -328,7 +328,7 @@ def func(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: self.assertNotIn("reduce_scatter_tensor", code) @runOnRocmArch(MI300_ARCH) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("scatter_dim", [0, 1, 2]) @fresh_inductor_cache() @@ -381,7 +381,7 @@ def func( self.assertIn("fused_scaled_matmul_reduce_scatter", code) self.assertNotIn("reduce_scatter_tensor", code) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("shard_dim", [0, 1]) @fresh_inductor_cache() def test_dtensor_seq_par(self, shard_dim: int): diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 0bea36967b880e..619c4b7887f7f1 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -28,7 +28,7 @@ TestCase, ) from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton def load_test_module(name): @@ -218,7 +218,7 @@ def test_all_gather_into_tensor_single(self) -> None: assert output.eq(expect).all() assert output.completed - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) # https://github.com/pytorch/pytorch/issues/126338 def test_inductor_dtypeview_memory_leak(self): @@ -434,7 +434,7 @@ def wait(self, _): torch.ops._c10d_functional.wait_tensor(tensor) self.assertTrue(wait_called) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @fresh_inductor_cache() def test_threading(self): @@ -498,7 +498,7 @@ def setUp(self): def tearDown(self): dist.destroy_process_group() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_reduce_single(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -535,7 +535,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_reduce_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: @@ -581,7 +581,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_inplace_op_on_view(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -608,7 +608,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: .run(code) ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reuse_buffer_after_inplace_collective(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -643,7 +643,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_gather_into_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -670,7 +670,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_gather_into_tensor_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: @@ -704,7 +704,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reduce_scatter_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -730,7 +730,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reduce_scatter_tensor_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: @@ -766,7 +766,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_to_all_single(self): def _tolist_with_constrain_as_size(tensor): @@ -814,7 +814,7 @@ def func( .run(code) ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_broadcast(self): def func(arg: torch.Tensor) -> torch.Tensor: @@ -850,7 +850,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_ranks_and_tag(self): def func(arg: torch.Tensor) -> torch.Tensor: diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index bb0e54484b94e5..db0d6c8becbff0 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -28,7 +28,7 @@ DynamoDistributedMultiProcTestCase, requires_nccl, ) -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton def get_snode_runtime_for_reorder_compute_test(snode): @@ -92,7 +92,7 @@ def world_size(self) -> int: # works around issue with skipif<2 and workers with unpredictable #s gpu return 2 - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -131,7 +131,7 @@ def func(a): correct = func(inputs) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -178,7 +178,7 @@ def func(a): correct = func(inputs) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -231,7 +231,7 @@ def func(a, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -283,7 +283,7 @@ def func(a, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -340,7 +340,7 @@ def func(a, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object( diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 1cf98a4d8b4047..d6357678c94f97 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -46,7 +46,7 @@ skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import requires_cuda -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton def reset_rng_state(): @@ -325,7 +325,7 @@ def run_hf_bert_ddp(self, model, inputs, backend): class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase): - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): @@ -528,7 +528,7 @@ def _test_hf_bert_ddp_inductor(self, static_graph): @skip_if_lt_x_gpu(2) @import_transformers_or_skip() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=True, enable_compiler_collectives=True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): @@ -536,7 +536,7 @@ def test_hf_bert_ddp_inductor(self): @skip_if_lt_x_gpu(2) @import_transformers_or_skip() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=True, enable_compiler_collectives=True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor_static_graph(self): @@ -561,7 +561,7 @@ def test_hf_bert_ddp_aot_eager_static_graph(self): self._test_hf_bert_aot_eager(static_graph=True) @skip_if_lt_x_gpu(2) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=False, enable_compiler_collectives=True) def test_ddp_activation_checkpointing(self): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -676,7 +676,7 @@ def test_fsdp_unspecialized_forced_getattr_inline(self): @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_fsdp_inductor(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) @@ -701,7 +701,7 @@ def test_fsdp_inductor(self): @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_fsdp_activation_checkpointing(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): model, inputs = get_toy_model_for_activation_checkpointing( @@ -722,7 +722,7 @@ def test_fsdp_activation_checkpointing(self): ) @import_transformers_or_skip() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) @@ -767,7 +767,7 @@ def apply_fsdp(model, wrap_policy): self.assertTrue(same(correct_results, opt_results)) @import_transformers_or_skip() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) @@ -815,7 +815,7 @@ def test_hf_bert_fsdp_activation_checkpointing(self): ) self.assertTrue(same(correct_results, opt_results)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_tensor(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -860,7 +860,7 @@ def B(s): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_scalar(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -888,7 +888,7 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_speculation_divergence(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -921,7 +921,7 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_graph_break_empty_graph_still_collective(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -955,7 +955,7 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_dim_mismatch(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -984,7 +984,7 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_missing_source(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -1006,7 +1006,7 @@ def f(rank, xs): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_scalar_missing_source(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -1028,7 +1028,7 @@ def f(rank, xs): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_type_mismatch(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -1062,7 +1062,7 @@ def f(x): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", False) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) def test_asymmetric_compilation(self): @@ -1113,7 +1113,7 @@ def f(x): for r in res[1:]: self.assertEqual(res[0], r) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", True) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) @patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10) @@ -1203,7 +1203,7 @@ def test_ddp_baseline_aot_eager(self): outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_inductor(self): from torch.nn.parallel import DistributedDataParallel as DDP @@ -1299,7 +1299,7 @@ def opt_fn(inputs): self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons)) @patch.object(config, "optimize_ddp", True) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor(self): assert config.optimize_ddp """ @@ -1368,18 +1368,18 @@ def opt_fn(inputs): opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_layout_optimizations_training(self): self._test_graph_split_inductor_layout_optimizations_impl( contextlib.nullcontext ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_layout_optimizations_inference(self): self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad) @patch.object(config, "optimize_ddp", True) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_transpose(self): assert config.optimize_ddp @@ -1470,7 +1470,7 @@ def opt_fn(inputs): self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 3) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_empty_graph_inductor(self): def fn(): get_world_size = torch.distributed.distributed_c10d.get_world_size() diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index 2624ce87089f37..c4972a46405862 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -14,7 +14,7 @@ from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton if not dist.is_available(): @@ -564,7 +564,7 @@ def test_all_to_all_single_split_sizes_none(self): expected = torch.cat(expected) self.assertEqual(y, expected) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @requires_nccl() @with_comms() def test_tracing(self): @@ -574,7 +574,7 @@ def allreduce(t, pg): compiled_allreduce = torch.compile(allreduce, fullgraph=True) compiled_allreduce(torch.randn(8, device=self.device), self.process_group) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_tracing_with_fakepg(self): exit_if_lt_x_gpu(self.world_size) @@ -590,7 +590,7 @@ def allreduce(t, pg): ) allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @requires_nccl() @with_comms() def test_tracing_with_dce_code(self): diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 216643e59cdec5..a19183425d3907 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -29,7 +29,7 @@ parametrize, requires_cuda, ) -from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._triton import has_triton def _tolist_with_constrain_as_size(tensor): @@ -58,7 +58,7 @@ def world_size(self) -> int: # works around issue with skipif<2 and workers with unpredictable #s gpu return 2 - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_broadcast_inductor(self): """ @@ -90,7 +90,7 @@ def compile(func, example_inputs): compiled_out = compiled_func(*inputs) self.assertTrue(same(eager_out, compiled_out)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_allreduce_inductor(self): """ @@ -123,7 +123,7 @@ def compile(func, example_inputs): inductor_out = compiled_matmul_cat_col(*inputs) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_allreduce_inductor_cudagraph_trees(self): """ @@ -169,7 +169,7 @@ def test_c10d_functional_tagged_pt2_compliant(self): op = torch.ops.c10d_functional.all_reduce.default self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_eager_allreduce_inductor_wait(self): def eager_func(a, b, c, d, *, tag, ranks, group_size): @@ -208,7 +208,7 @@ def compile(func, example_inputs): print(f"inductor_out, {inductor_out}") self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_inductor_allreduce_eager_wait(self): def inductor_func(a, b, c, d, *, tag, ranks, group_size): @@ -243,7 +243,7 @@ def compile(func, example_inputs): ) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) def test_allreduce_input_buffer_reuse(self): @@ -261,7 +261,7 @@ def func(a, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_permute_tensor(self): def func(tensor, src_dst_pairs, *, tag, ranks, group_size): @@ -287,7 +287,7 @@ def func(tensor, src_dst_pairs, *, tag, ranks, group_size): self.assertEqual(out, expected) self.assertEqual(correct, expected) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) def test_allgather_output_buffer_reuse(self): @@ -311,7 +311,7 @@ def forward(self, x, world_size, tag, ranks, group_size): correct = model(inp, self.world_size, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_allgather_contiguous_input(self): class Model(torch.nn.Module): @@ -335,7 +335,7 @@ def forward(self, x, world_size, tag, ranks, group_size): correct = model(inp, self.world_size, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_allgather_into_tensor_inductor(self): """ @@ -366,7 +366,7 @@ def compile(func, example_inputs): inductor_out = compiled_matmul_cat_col(*inputs) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_reduce_scatter_tensor_inductor(self): def example(a, b, *, tag, ranks, group_size): @@ -393,7 +393,7 @@ def compile(func, example_inputs): inductor_out = compiled_fn(*inputs) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) def test_all_to_all_single_inductor(self): @@ -462,7 +462,7 @@ def example( inductor_out = compiled_fn(*inputs, **trs) self.assertTrue(same(eager_out, inductor_out, tol=0.001)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) def test_all_to_all_single_inductor_split_sizes_none(self): def example(inp, *, tag, ranks, group_size): @@ -518,7 +518,7 @@ def get_world_trs(self, world_size=1): "group_size": world_size, } - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch(debug=True) def test_inductor_single_op(self): def func(inp, *, tag, ranks, group_size): @@ -547,7 +547,7 @@ def func(inp, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch(debug=True) def test_inductor_steal_buffer(self): """ @@ -582,7 +582,7 @@ def func(inp, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) def test_inductor_doesnt_mutate_shared(self): """ @@ -1030,7 +1030,7 @@ def test_meta(self): out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs()) self.assertEqual(x.size(), out.size()) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) def test_inductor_all_gather_coalesced(self): """ @@ -1076,7 +1076,7 @@ def func(inp, *, tag, ranks, group_size): correct = func(inputs, **self.get_world_trs()) assert same(out, correct), f"{out} va {correct}" - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False}) def test_inductor_reduce_scatter_coalesced(self): """ diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 98489a383435e4..852b56e6326a75 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1315,9 +1315,7 @@ def fn(x, y): self.assertEqual(expect, actual) # Expect the code iterates in contiguous order, and is not tiled - lines = code[0].split("\n") - start = lines.index("@triton.jit") - kernel_code = "\n".join(lines[start : start + 14]) + kernel_code = "\n".join(code[0].split("\n")[60:74]) self.assertExpectedInline( kernel_code, """\ diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index fd8bfd5c82bcc1..d3e07670492123 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -22,9 +22,10 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_cpp_code from torch.export import Dim +from torch.utils._triton import has_triton -@unittest.skipIf(not HAS_CUDA, "Inductor+gpu needs triton and CUDA") +@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @config.patch(memory_planning=True) class TestMemoryPlanning(TestCase): def _generate(self, *, device): diff --git a/test/inductor/test_triton_cpu_backend.py b/test/inductor/test_triton_cpu_backend.py deleted file mode 100644 index 8be4d85caea6a6..00000000000000 --- a/test/inductor/test_triton_cpu_backend.py +++ /dev/null @@ -1,34 +0,0 @@ -# Owner(s): ["module: inductor"] -from torch._inductor import config -from torch._inductor.test_case import run_tests -from torch.testing._internal.inductor_utils import HAS_CPU -from torch.utils._triton import has_triton - - -try: - from . import test_torchinductor -except ImportError: - import test_torchinductor - -if has_triton(): - import triton - - TRITON_HAS_CPU = "cpu" in triton.backends.backends -else: - TRITON_HAS_CPU = False - - -if HAS_CPU and TRITON_HAS_CPU: - - @config.patch(cpu_backend="triton") - class SweepInputsCpuTritonTest(test_torchinductor.SweepInputsCpuTest): - pass - - @config.patch(cpu_backend="triton") - class CpuTritonTests(test_torchinductor.CpuTests): - pass - - -if __name__ == "__main__": - if HAS_CPU and TRITON_HAS_CPU: - run_tests(needs="filelock") diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 6bac38723d3013..0f7c96065bcdd6 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -38,10 +38,10 @@ IS_WINDOWS, ) -from torch.testing._internal.inductor_utils import HAS_GPU - import pytest +from torch.utils._triton import has_triton + SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict() _IS_SM8X = False @@ -981,7 +981,7 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): torch.backends.cuda.matmul.allow_tf32 = orig - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") @inference_dtypes def test_conversions(self, device, dtype): @@ -1009,7 +1009,7 @@ def run_test(r, c, device, dtype): for r, c in shapes: run_test(r, c, device, dtype) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") @inference_dtypes def test_conversions_all_patterns(self, device, dtype): r, c = 32, 128 diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index ad58a11bbd5642..00975cda550818 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,7 +1,5 @@ # mypy: allow-untyped-defs import inspect -import time -from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union import torch @@ -297,51 +295,6 @@ def is_bf16_supported(including_emulation: bool = False) -> bool: return torch.xpu.is_bf16_supported() -@dataclass -class CpuDeviceProperties: - multi_processor_count: int - - -class CpuInterface(DeviceInterface): - class Event(_EventBase): - def __init__(self, enable_timing=True): - self.time = 0.0 - - def elapsed_time(self, end_event) -> float: - return (end_event.time - self.time) * 1000 - - def record(self): - self.time = time.perf_counter() - - @staticmethod - def is_available() -> bool: - return True - - @staticmethod - def get_compute_capability(device: _device_t = None) -> str: - return "" - - @staticmethod - def get_raw_stream(device_idx) -> int: - return 0 - - @staticmethod - def current_device(): - return 0 - - @staticmethod - def synchronize(device: _device_t = None): - pass - - class Worker: - @staticmethod - def get_device_properties(device: _device_t = None): - import multiprocessing - - cpu_count = multiprocessing.cpu_count() - return CpuDeviceProperties(cpu_count) - - device_interfaces: Dict[str, Type[DeviceInterface]] = {} _device_initialized = False @@ -380,6 +333,4 @@ def init_device_reg(): for i in range(torch.xpu.device_count()): register_interface_for_device(f"xpu:{i}", XpuInterface) - register_interface_for_device("cpu", CpuInterface) - _device_initialized = True diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 8dcda8fa3d7dca..94efbaf4e32fdd 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -574,7 +574,7 @@ def benchmark( return self.value -class GPUDeviceBenchmarkMixin: +class GPUDeviceBenchmarkRequest(BenchmarkRequest): def do_bench( self, fn, @@ -601,17 +601,7 @@ def do_bench( return out -class CPUDeviceBenchmarkMixin: - def do_bench( - self, - fn, - *input_tensors: torch.Tensor, - output_tensor: Optional[torch.Tensor] = None, - ) -> float: - return benchmarker.benchmark_cpu(fn) - - -class TritonBenchmarkRequest(BenchmarkRequest): +class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! def __init__( @@ -656,22 +646,28 @@ def make_run_fn( if "warmup" in inspect.signature(run_method).parameters: warmup_arg["warmup"] = False - if output_tensor.device.type == "cpu": - stream = 0 + from torch._C import _cuda_getCurrentRawStream as get_raw_stream + + if torch.version.hip and self.matrix_instr_nonkdim != 0: + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + stream=get_raw_stream(self.output_tensor_meta.device.index), + ) else: - from torch._C import _cuda_getCurrentRawStream as get_raw_stream - - stream = get_raw_stream(self.output_tensor_meta.device.index) - - return functools.partial( - run_method, - *input_tensors, - output_tensor, - *self.extra_args, - grid=self.grid, - **warmup_arg, - stream=stream, - ) + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + stream=get_raw_stream(self.output_tensor_meta.device.index), + ) def precompile(self): mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) @@ -681,15 +677,7 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" -class TritonGPUBenchmarkRequest(GPUDeviceBenchmarkMixin, TritonBenchmarkRequest): - pass - - -class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest): - pass - - -class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): +class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! @@ -806,7 +794,17 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" -class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest): +class CPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + return benchmarker.benchmark_cpu(fn) + + +class CppBenchmarkRequest(CPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put Tensors in here! diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e653d75715248a..95183c1484fe00 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -239,11 +239,7 @@ def init_backend_registration(): from .wrapper import WrapperCodeGen if get_scheduling_for_device("cpu") is None: - cpu_backends = { - "cpp": CppScheduling, - "halide": HalideScheduling, - "triton": TritonScheduling, - } + cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling} register_backend_for_device( "cpu", lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs), @@ -305,7 +301,6 @@ def get_device_op_overrides(device: str): assert isinstance(device, str) if not device_op_overrides_dict.keys(): - from . import cpu_device_op_overrides # noqa: F401 from .cuda import device_op_overrides # noqa: F401 from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e4011322fbe88b..8bbe8023a312af 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4643,7 +4643,7 @@ def codegen_group(self, name=None) -> str: def call_kernel(self, wrapper, kernel_name): _, call_args, arg_types = self.args.cpp_argdefs() wrapper.generate_kernel_call( - kernel_name, call_args, gpu=False, triton=False, arg_types=arg_types + kernel_name, call_args, gpu=False, arg_types=arg_types ) diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 7ac125c76cc6fb..e72a895dc44d59 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -107,9 +107,7 @@ def hook(): def call_kernel(self, name: str, node: ir.CppTemplateBuffer): wrapper = V.graph.wrapper_code _, call_args, arg_types = self.args.cpp_argdefs() - wrapper.generate_kernel_call( - name, call_args, triton=False, gpu=False, arg_types=arg_types - ) + wrapper.generate_kernel_call(name, call_args, gpu=False, arg_types=arg_types) def dtype(self, node: ir.Buffer) -> str: return DTYPE_TO_CPP[node.get_dtype()] diff --git a/torch/_inductor/codegen/cpu_device_op_overrides.py b/torch/_inductor/codegen/cpu_device_op_overrides.py deleted file mode 100644 index 1944c0e6beb81e..00000000000000 --- a/torch/_inductor/codegen/cpu_device_op_overrides.py +++ /dev/null @@ -1,26 +0,0 @@ -# mypy: allow-untyped-defs -from textwrap import dedent - -from .common import DeviceOpOverrides, register_device_op_overrides - - -class CpuDeviceOpOverrides(DeviceOpOverrides): - def import_get_raw_stream_as(self, name): - return dedent( - """ - def get_raw_stream(_): - return 0 - """ - ) - - def set_device(self, device_idx): - return "pass" - - def synchronize(self): - return "pass" - - def device_guard(self, device_idx): - return "pass" - - -register_device_op_overrides("cpu", CpuDeviceOpOverrides()) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index bbf6866fdc3818..337aa544b0d10a 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -1639,7 +1639,6 @@ def call_kernel(self, name: str, node=None): name, call_args, gpu=False, # grid/stream is handled internally in halide - triton=False, ) def generate_assert(self, check): diff --git a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py index bdd68d9a2e6ba6..a70f45b7033d6d 100644 --- a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py +++ b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -7,18 +7,14 @@ from typing import Any, Callable, Iterable, List, Optional, Union import torch -from torch._inductor.autotune_process import ( - BenchmarkRequest, - GPUDeviceBenchmarkMixin, - TensorMeta, -) +from torch._inductor.autotune_process import GPUDeviceBenchmarkRequest, TensorMeta from torch._inductor.codecache import DLLWrapper, ROCmCodeCache log = logging.getLogger(__name__) -class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): +class ROCmBenchmarkRequest(GPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index b52ca309dfbb28..c30f2d0bddc2f1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2602,11 +2602,6 @@ def codegen_kernel(self, name=None): if name is None: code.splice(gen_common_triton_imports()) - device_type = V.graph.scheduler.get_current_device_or_throw().type - if device_type == "cpu": - code.splice("triton_helpers.set_driver_to_cpu()") - else: - code.splice("triton_helpers.set_driver_to_gpu()") if config.benchmark_kernel: code.splice(self.imports_for_benchmark_kernel()) @@ -2839,7 +2834,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): call_args, grid, current_device.index, - gpu=current_device.type != "cpu", + gpu=True, triton=True, arg_types=arg_types, grid_fn=self._get_grid_fn(), diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index ce7bf65bd5f665..fe95b92c6877e8 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1668,97 +1668,99 @@ def generate_kernel_call( gpu: Defines whether the backend is GPU. Otherwise the backend is CPU. - triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True, - and C++ when gpu=False. + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when gpu == True. """ - if not (triton or gpu): - self.writeline(self.wrap_kernel_call(kernel_name, call_args)) - return - - device_index, call_args_str = self.prepare_triton_kernel_call( - device_index, call_args - ) - call_args_str = ", ".join(call_args_str) - stream_name = self.write_get_raw_stream(device_index, V.graph) - - if not triton: - stream_ptr = f"c_void_p({stream_name})" - self.writeline( - f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" + if gpu: + device_index, call_args_str = self.prepare_triton_kernel_call( + device_index, call_args ) - return + call_args_str = ", ".join(call_args_str) + stream_name = self.write_get_raw_stream(device_index, V.graph) + if triton: + self.write_triton_header_once() + if grid is None: + grid_str = grid_fn + else: + grid_str = ", ".join(self._grid_dim_str(item) for item in grid) + if grid_extra_kwargs: + grid_str = f"{grid_str}, {grid_extra_kwargs}" + grid_str = f"{grid_fn}({grid_str})" + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_types, None + ) + with debug_printer_manager: + self.writeline( + f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + ) + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Create example args for autotune in a separate epilogue + assert arg_types is not None and len(call_args) == len( + arg_types + ), "call_args and arg_types do not match" + + tensor_args = {} + all_args = [] + if raw_args is None: + # create a dummy raw_args for uniform behavior in the following loop + raw_args = [None] * len(call_args) + else: + assert len(raw_args) == len( + call_args + ), "call_args and raw_args do not match" - self.write_triton_header_once() - if grid is None: - grid_str = grid_fn - else: - grid_str = ", ".join(self._grid_dim_str(item) for item in grid) - if grid_extra_kwargs: - grid_str = f"{grid_str}, {grid_extra_kwargs}" - grid_str = f"{grid_fn}({grid_str})" - # add debug printer code for triton kernel calls at (jit) inductor level - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) - with debug_printer_manager: - self.writeline( - f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" - ) - if ( - config.triton.autotune_at_compile_time - and kernel_name not in self.kernel_autotune_names - ): - # Create example args for autotune in a separate epilogue - assert arg_types is not None and len(call_args) == len( - arg_types - ), "call_args and arg_types do not match" - - tensor_args = {} - all_args = [] - if raw_args is None: - # create a dummy raw_args for uniform behavior in the following loop - raw_args = [None] * len(call_args) - else: - assert len(raw_args) == len( - call_args - ), "call_args and raw_args do not match" + for i, (arg, arg_type, raw_arg) in enumerate( + zip(call_args, arg_types, raw_args) + ): + key = None + if isinstance(arg, str) and "=" in str(arg): + # arg may be passed in a kwarg style, and then we need to extract its value + key, arg = arg.split("=") + + if isinstance(arg_type, torch_dtype): + if arg not in tensor_args: + arg_str = self.generate_example_arg_value( + arg, arg_type, raw_arg, i + ) + tensor_args[arg] = arg_str + else: + arg_str = tensor_args[arg] + else: + arg_str = self.generate_example_arg_value( + arg, arg_type, raw_arg, i + ) + all_args.append(arg_str if key is None else f"{key}={arg_str}") - for i, (arg, arg_type, raw_arg) in enumerate( - zip(call_args, arg_types, raw_args) - ): - key = None - if isinstance(arg, str) and "=" in str(arg): - # arg may be passed in a kwarg style, and then we need to extract its value - key, arg = arg.split("=") - - if isinstance(arg_type, torch_dtype): - if arg not in tensor_args: - arg_str = self.generate_example_arg_value( - arg, arg_type, raw_arg, i - ) - tensor_args[arg] = arg_str + if grid is None: + grid_str = grid_fn else: - arg_str = tensor_args[arg] - else: - arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i) - all_args.append(arg_str if key is None else f"{key}={arg_str}") + grid_str = ", ".join( + self.generate_example_arg_value(g, type(g)) for g in grid + ) + if grid_extra_kwargs: + grid_str = f"{grid_str}, {grid_extra_kwargs}" + grid_str = f"{grid_fn}({grid_str})" - if grid is None: - grid_str = grid_fn + self.kernel_autotune_calls.writeline( + f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})" + ) + self.kernel_autotune_calls.writeline( + f"del {', '.join(arg for arg in tensor_args.values())}\n", + ) + self.kernel_autotune_names.add(kernel_name) else: - grid_str = ", ".join( - self.generate_example_arg_value(g, type(g)) for g in grid + stream_ptr = f"c_void_p({stream_name})" + self.writeline( + f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" ) - if grid_extra_kwargs: - grid_str = f"{grid_str}, {grid_extra_kwargs}" - grid_str = f"{grid_fn}({grid_str})" - - self.kernel_autotune_calls.writeline( - f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})" - ) - self.kernel_autotune_calls.writeline( - f"del {', '.join(arg for arg in tensor_args.values())}\n", - ) - self.kernel_autotune_names.add(kernel_name) + else: + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) def writeline(self, line): self.lines.append(line) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index a6c9f0bfece500..99a25b13d2fd2a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1137,7 +1137,7 @@ class rocm: use_preselected_instances: bool = False -# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) +# Backend to use for CPU codegen either "cpp" or "halide" (experimental) cpu_backend = "cpp" # Backend to use for CUDA codegen either "triton" or "halide" (experimental) diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 8d521c0d714f61..44c222c3f936ef 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -77,7 +77,7 @@ def __init__(self: Self) -> None: def benchmark( self: Self, fn: Callable[..., Any], - fn_args: Tuple[Any, ...], + fn_args: Tuple[Any], fn_kwargs: Dict[str, Any], **kwargs: Any, ) -> float: diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 37f17924331912..1a8ff2ba2408d7 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -31,21 +31,6 @@ def _log2(x): raise NotImplementedError -def set_driver_to_cpu(): - if backend := triton.backends.backends.get("cpu", None): - triton.runtime.driver.set_active(backend.driver()) - return - raise RuntimeError("Could not find an active CPU backend") - - -def set_driver_to_gpu(): - for name, backend in triton.backends.backends.items(): - if backend.driver.is_active() and name != "cpu": - triton.runtime.driver.set_active(backend.driver()) - return - raise RuntimeError("Could not find an active GPU backend") - - @triton.jit def promote_to_tensor(x): # Addition promotes to tensor for us diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 037f67fb429ee0..56169905dbe3fb 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -673,9 +673,6 @@ def kernel_call(): return do_bench_using_profiling(kernel_call, warmup=10, rep=40) - if self.device_props.type == "cpu": - return benchmarker.benchmark_cpu(kernel_call) - return benchmarker.benchmark_gpu(kernel_call, rep=40) def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 0c3e69efa5eecf..5f8d35cbcab5eb 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3449,11 +3449,12 @@ def _codegen(self) -> None: self.current_device.type ): V.graph.wrapper_code.codegen_device_guard_exit() - self.current_device = device if device_need_guard(device.type): assert device.index is not None, "device should have an index" V.graph.wrapper_code.codegen_device_guard_enter(device.index) + self.current_device = device + self.buffer_names_to_free.update(node.last_usage) if node.is_template(): diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 2685f730d659be..7b90690cf50b19 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -15,7 +15,7 @@ from collections import namedtuple from concurrent.futures import as_completed, ThreadPoolExecutor from io import StringIO -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import sympy @@ -27,12 +27,7 @@ from torch._dynamo.utils import counters, identity, preserve_rng_state from . import config, ir -from .autotune_process import ( - TensorMeta, - TritonBenchmarkRequest, - TritonCPUBenchmarkRequest, - TritonGPUBenchmarkRequest, -) +from .autotune_process import TensorMeta, TritonBenchmarkRequest from .codecache import code_hash, PersistentCache, PyCodeCache from .codegen.common import IndentedBuffer, KernelTemplate from .codegen.triton import ( @@ -578,7 +573,6 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}", arg_types=arg_types, triton_meta=self.triton_meta, - gpu="cpu" not in V.graph.device_types, ) @@ -748,12 +742,7 @@ def make_kernel_render(out_node): ), kwargs, ) - bmreq_cls: Type[TritonBenchmarkRequest] - if layout.device.type == "cpu": - bmreq_cls = TritonCPUBenchmarkRequest - else: - bmreq_cls = TritonGPUBenchmarkRequest - bmreq = bmreq_cls( + bmreq = TritonBenchmarkRequest( module_path=mod.__file__, module_cache_key=mod.key, kernel_name=kernel_name, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 82aa6a2ac001e7..1004817d6e7dbc 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1099,13 +1099,7 @@ def use_triton_template(layout, *, enable_int32=False, enable_float8=False): if enable_float8: layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2]) return ( - ( - ( - layout.device.type == "cuda" - and _use_template_for_cuda(layout, layout_dtypes) - ) - or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) - ) + _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend("TRITON") and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) ) diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 3cc977c8483823..92b42320c3c668 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -17,27 +17,15 @@ def has_triton_package() -> bool: @functools.lru_cache(None) def has_triton() -> bool: - if not has_triton_package(): - return False - from torch._dynamo.device_interface import get_interface_for_device def cuda_extra_check(device_interface): return device_interface.Worker.get_device_properties().major >= 7 - def cpu_extra_check(device_interface): - import triton.backends - - return "cpu" in triton.backends.backends - def _return_true(device_interface): return True - triton_supported_devices = { - "cuda": cuda_extra_check, - "xpu": _return_true, - "cpu": cpu_extra_check, - } + triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true} def is_device_compatible_with_triton(): for device, extra_check in triton_supported_devices.items(): @@ -46,7 +34,7 @@ def is_device_compatible_with_triton(): return True return False - return is_device_compatible_with_triton() + return is_device_compatible_with_triton() and has_triton_package() @functools.lru_cache(None) From 1978af3594fcb8e2ff02168e3d76d2a42d1cb6f7 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 16 Sep 2024 18:34:18 +0000 Subject: [PATCH 0941/1018] Remove accidentally committed code (#136154) Accidentally left out during rebase Pull Request resolved: https://github.com/pytorch/pytorch/pull/136154 Approved by: https://github.com/kit1980, https://github.com/albanD --- aten/src/ATen/native/mps/operations/Im2Col.mm | 7 ------- 1 file changed, 7 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Im2Col.mm b/aten/src/ATen/native/mps/operations/Im2Col.mm index 8cba89490f5683..5fd3f0d9ac36d5 100644 --- a/aten/src/ATen/native/mps/operations/Im2Col.mm +++ b/aten/src/ATen/native/mps/operations/Im2Col.mm @@ -149,10 +149,6 @@ static void im2col_out_mps_template(Tensor& output, std::array input_strides = {input.stride(3), input.stride(2), input.stride(1), input.stride(0)}; std::array output_strides = {output.stride(2), output.stride(1), output.stride(0), output_width}; getMPSProfiler().beginProfileKernel(im2colPSO, "im2col", {input, output}); - - if (getMPSProfiler().isCaptureEnabled()) { - getMPSProfiler().startCapture("im2col", stream); - } auto computeEncoder = stream->commandEncoder(); [computeEncoder setComputePipelineState:im2colPSO]; mtl_setBuffer(computeEncoder, input, 0); @@ -164,9 +160,6 @@ static void im2col_out_mps_template(Tensor& output, mtl_setBytes(computeEncoder, input_sizes, 6); [computeEncoder dispatchThreads:MTLSizeMake(output_length, n_input_plane, batch_size) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; - if (getMPSProfiler().isCapturing()) { - getMPSProfiler().stopCapture(stream); - } getMPSProfiler().endProfileKernel(im2colPSO); } }); From 2554acd41930884a3c47a084ae5c82a06f323be1 Mon Sep 17 00:00:00 2001 From: Alexander Kurakin Date: Mon, 16 Sep 2024 18:52:16 +0000 Subject: [PATCH 0942/1018] `torch.nn.MultiheadAttention`: docs: improvement (#136111) `torch.nn.MultiheadAttention`: docs: improvement Pull Request resolved: https://github.com/pytorch/pytorch/pull/136111 Approved by: https://github.com/janeyx99 --- torch/nn/modules/activation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 4889a4485c4901..02369cb4974b41 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -979,11 +979,11 @@ class MultiheadAttention(Module): Multi-Head Attention is defined as: .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O - where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. - ``nn.MultiHeadAttention`` will use the optimized implementations of + ``nn.MultiheadAttention`` will use the optimized implementations of ``scaled_dot_product_attention()`` when possible. In addition to support for the new ``scaled_dot_product_attention()`` From e7527adfbb30c8170589bfcd2374efd339145408 Mon Sep 17 00:00:00 2001 From: eugenekoran Date: Mon, 16 Sep 2024 19:02:19 +0000 Subject: [PATCH 0943/1018] Drop outdated section 'Running clang-tidy' in CONTRIBUTING.md (#136146) Fixes #125920 [Running clang-tidy](https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md#running-clang-tidy) section is misleading and outdated. C++ lint is done with lintrunner and covered in [local-linting](https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md#local-linting) section. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136146 Approved by: https://github.com/janeyx99 --- CONTRIBUTING.md | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 12f8222ff1bd48..99e47ef502998b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -50,7 +50,6 @@ aspects of contributing to PyTorch. - [Windows development tips](#windows-development-tips) - [Known MSVC (and MSVC with NVCC) bugs](#known-msvc-and-msvc-with-nvcc-bugs) - [Building on legacy code and CUDA](#building-on-legacy-code-and-cuda) -- [Running clang-tidy](#running-clang-tidy) - [Pre-commit tidy/linting hook](#pre-commit-tidylinting-hook) - [Building PyTorch with ASAN](#building-pytorch-with-asan) - [Getting `ccache` to work](#getting-ccache-to-work) @@ -1132,38 +1131,6 @@ CUDA, MSVC, and PyTorch versions are interdependent; please install matching ver Note: There's a [compilation issue](https://github.com/oneapi-src/oneDNN/issues/812) in several Visual Studio 2019 versions since 16.7.1, so please make sure your Visual Studio 2019 version is not in 16.7.1 ~ 16.7.5 -## Running clang-tidy - -[Clang-Tidy](https://clang.llvm.org/extra/clang-tidy/index.html) is a C++ -linter and static analysis tool based on the clang compiler. We run clang-tidy -in our CI to make sure that new C++ code is safe, sane and efficient. See the -[`clang-tidy` job in our GitHub Workflow's -lint.yml file](https://github.com/pytorch/pytorch/blob/main/.github/workflows/lint.yml) -for the simple commands we use for this. - -To run clang-tidy locally, follow these steps: - -1. Install clang-tidy. -We provide custom built binaries which have additional checks enabled. You can install it by running: -```bash -python3 -m tools.linter.clang_tidy.generate_build_files -``` -We currently only support Linux and MacOS (x86). - -2. Install clang-tidy driver script dependencies -```bash -pip3 install -r tools/linter/clang_tidy/requirements.txt -``` - -3. Run clang-tidy -```bash -# Run clang-tidy on the entire codebase -make clang-tidy -# Run clang-tidy only on your changes -make clang-tidy CHANGED_ONLY=--changed-only -``` -This internally invokes our driver script and closely mimics how clang-tidy is run on CI. - ## Pre-commit tidy/linting hook We use clang-tidy to perform additional From 18c39b9453ad3aebc08a2b5ddabeac2cc201e98a Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sun, 15 Sep 2024 23:50:45 -0700 Subject: [PATCH 0944/1018] [Distributed] fix FileSystemWriter __init__ (#136135) Fixes #135608. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136135 Approved by: https://github.com/Skylion007 --- torch/distributed/checkpoint/filesystem.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 3faee96b238dc7..079608c7230629 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -747,15 +747,19 @@ def __init__( N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. """ - super().__init__( + _FileSystemWriter.__init__( + self, path=path, single_file_per_rank=single_file_per_rank, sync_files=sync_files, thread_count=thread_count, per_thread_copy_ahead=per_thread_copy_ahead, - cache_staged_state_dict=cache_staged_state_dict, overwrite=overwrite, ) + BlockingAsyncStager.__init__( + self, + cache_staged_state_dict=cache_staged_state_dict, + ) def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: """Override of AsyncStager.stage""" From 5a20276fad63c3851a0b6b5e099c91d05cced49c Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 16 Sep 2024 19:30:30 +0000 Subject: [PATCH 0945/1018] [EZ] Fix spelling typo (#136157) s/toosl/tools/ (spotted by @louie-tsai) Also, capitalize CUDA Pull Request resolved: https://github.com/pytorch/pytorch/pull/136157 Approved by: https://github.com/kit1980 --- docs/source/torch.compiler_profiling_torch_compile.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/torch.compiler_profiling_torch_compile.rst b/docs/source/torch.compiler_profiling_torch_compile.rst index 33a6bca92faf98..0a465f1ca9b764 100644 --- a/docs/source/torch.compiler_profiling_torch_compile.rst +++ b/docs/source/torch.compiler_profiling_torch_compile.rst @@ -6,7 +6,7 @@ What to use torch.profiler for: torch.profiler is helpful for understanding the performance of your program at a kernel-level granularity - for example, it can show graph breaks and GPU utilization at the level of the program. The data provided by the profiler can often help users understand where to investigate further to understand model performance. -To understand kernel-level performance, other toosl exist. NVIDIA's ncu tool can be used, or :ref:`inductor's profiling tools `. +To understand kernel-level performance, other tools exist. NVIDIA's ncu tool can be used, or :ref:`inductor's profiling tools `. See also the `general pytorch profiler guide `_. @@ -66,7 +66,7 @@ Alternatively, turn on *all* flows with the “Flow events” dropdown at the to Working around CUDA Graph profiling issues ------------------------------------------ -When CUDA graphs are enabled, some cuda configurations (driver version under 525.85.12 or CUDA < 12) can encounter issues between the profiling tools and CUDA graphs. To fix these issues, add an empty profiling context at the top of your program: +When CUDA graphs are enabled, some CUDA configurations (driver version under 525.85.12 or CUDA < 12) can encounter issues between the profiling tools and CUDA graphs. To fix these issues, add an empty profiling context at the top of your program: .. code-block:: python From 1b7cd2ded9af860b79f0ef616ff7f2e05754a343 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Mon, 16 Sep 2024 19:44:08 +0000 Subject: [PATCH 0946/1018] [BE]: Update mypy to 1.11.2 (#133816) Updates mypy to 1.11.1 to improve type inference Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816 Approved by: https://github.com/ezyang --- .ci/docker/requirements-ci.txt | 2 +- .lintrunner.toml | 2 +- torch/_dynamo/repro/after_aot.py | 2 +- torch/_dynamo/repro/after_dynamo.py | 9 +- torch/_dynamo/resume_execution.py | 14 +-- torch/_dynamo/source.py | 2 +- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/utils.py | 2 + torch/_export/passes/constant_folding.py | 2 +- torch/_higher_order_ops/flex_attention.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 3 +- .../codegen/cuda_combined_scheduling.py | 2 +- torch/_inductor/dependencies.py | 10 +- torch/_inductor/kernel/mm_common.py | 4 +- torch/_inductor/remote_cache.py | 4 +- torch/_inductor/runtime/triton_heuristics.py | 4 +- torch/_inductor/scheduler.py | 4 +- torch/_library/infer_schema.py | 2 +- torch/_subclasses/functional_tensor.py | 4 +- torch/_tensor.py | 2 +- torch/ao/nn/intrinsic/modules/fused.py | 4 +- .../ao/nn/intrinsic/qat/modules/conv_fused.py | 93 +++++++++---------- .../intrinsic/quantized/modules/conv_add.py | 4 +- torch/ao/nn/qat/modules/conv.py | 18 ++-- torch/ao/nn/quantized/dynamic/modules/conv.py | 25 ++--- torch/ao/nn/quantized/modules/conv.py | 41 ++++---- torch/ao/ns/_numeric_suite.py | 2 +- torch/ao/ns/_numeric_suite_fx.py | 2 +- torch/ao/ns/fx/mappings.py | 6 +- .../data_sparsifier/base_data_sparsifier.py | 6 +- .../data_sparsifier/data_norm_sparsifier.py | 2 +- .../_experimental/pruner/FPGM_pruner.py | 4 +- .../sparsifier/nearly_diagonal_sparsifier.py | 4 +- .../sparsifier/weight_norm_sparsifier.py | 2 +- torch/ao/quantization/__init__.py | 2 +- torch/ao/quantization/_correct_bias.py | 2 +- .../ao/quantization/experimental/observer.py | 2 +- .../quantizer/x86_inductor_quantizer.py | 2 +- torch/distributed/_functional_collectives.py | 4 +- .../distributed/_shard/sharded_tensor/api.py | 2 +- .../algorithms/_comm_hooks/default_hooks.py | 2 +- .../optim/post_localSGD_optimizer.py | 2 +- torch/distributed/pipelining/_IR.py | 5 +- torch/distributed/tensor/_api.py | 2 +- torch/distributed/tensor/_shards_wrapper.py | 2 +- .../tensor/parallel/input_reshard.py | 5 +- torch/distributions/constraints.py | 2 +- torch/distributions/von_mises.py | 2 +- torch/masked/maskedtensor/core.py | 4 +- torch/multiprocessing/reductions.py | 2 +- torch/nested/_internal/nested_tensor.py | 2 +- torch/nn/attention/bias.py | 2 +- torch/optim/lr_scheduler.py | 4 +- torch/sparse/semi_structured.py | 4 +- torch/utils/_sympy/functions.py | 22 ++--- torch/utils/weak.py | 2 +- 56 files changed, 186 insertions(+), 183 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 88ecc0d1a7edf3..bd48d696c9230b 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -90,7 +90,7 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==1.10.0 +mypy==1.11.2 # Pin MyPy version because new errors are likely to appear with each release #Description: linter #Pinned versions: 1.10.0 diff --git a/.lintrunner.toml b/.lintrunner.toml index 4789991736be17..9b43b382809212 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -139,7 +139,7 @@ init_command = [ 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.2.1', - 'mypy==1.10.0', + 'mypy==1.11.2', 'sympy==1.12.1 ; python_version == "3.8"', 'sympy==1.13.0 ; python_version >= "3.9"', 'types-requests==2.27.25', diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index a21a20cdf2e611..19aa69d084db36 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -521,7 +521,7 @@ def repro_common(options, mod, load_args): def repro_minifier_query(options, mod, load_args): mod, args = repro_common(options, mod, load_args) fail_fn = functools.partial( - ACCURACY_FAILS[options.accuracy], check_str=options.check_str + ACCURACY_FAILS[options.accuracy], check_str=options.check_str # type: ignore[call-arg] ) if fail_fn(mod, args): sys.exit(1) diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index b1bb950d11c11b..214480b02c9332 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -12,6 +12,7 @@ import torch import torch.fx as fx +from torch._dynamo.backends.registry import CompiledFn from torch._dynamo.debug_utils import ( AccuracyError, backend_accuracy_fails, @@ -271,8 +272,10 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -@register_debug_backend -def dynamo_minifier_backend(gm, example_inputs, compiler_name): +@register_debug_backend # type: ignore[arg-type] +def dynamo_minifier_backend( + gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn +): from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) @@ -311,7 +314,7 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name): return gm -@register_debug_backend +@register_debug_backend # type: ignore[arg-type] def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): from functorch.compile import minifier diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 6f7db66514c487..2013b3992eb54f 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -471,14 +471,14 @@ def generate( code, lineno, offset: int, - setup_fn_target_offsets: Tuple[int], # only used in Python 3.11+ + setup_fn_target_offsets: Tuple[int, ...], # only used in Python 3.11+ nstack: int, - argnames: Tuple[str], - argnames_null: Tuple[str], - setup_fns: Tuple[ReenterWith], - stack_ctx_vars: Tuple[int, Tuple[Any]], - argnames_ctx_vars: Tuple[str, Tuple[Any]], - null_idxes: Tuple[int], + argnames: Tuple[str, ...], + argnames_null: Tuple[str, ...], + setup_fns: Tuple[ReenterWith, ...], + stack_ctx_vars: Tuple[Tuple[int, Tuple[Any]], ...], + argnames_ctx_vars: Tuple[Tuple[str, Tuple[Any]], ...], + null_idxes: Tuple[int, ...], ) -> types.CodeType: assert offset is not None assert not ( diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index c44f8d9e69abf5..9febc69f42cda2 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -272,7 +272,7 @@ def guard_source(self): def name(self): return f"" - def make_guard(self): + def make_guard(self, fn): raise NotImplementedError def is_ephemeral(self): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 61acaec995757a..4a83b455a1252b 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3198,7 +3198,7 @@ def fake_mode(self): def run_ctx_mgr(self): return TracingContext.current_frame(self.parent.frame_summary()) - def STORE_DEREF(self, inst): + def STORE_DEREF(self, inst): # type: ignore[override] if inst.argval in self.closure_cells: cell = self.closure_cells[inst.argval] val = self.pop() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 2d4a771166cfbe..bed3564ec15ce0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2915,6 +2915,8 @@ def is_frozen_dataclass(value): not object_has_getattribute(value) and not class_has_getattribute(value) and is_dataclass(value) + and hasattr(value, "__dataclass_params__") + and hasattr(value.__dataclass_params__, "frozen") and value.__dataclass_params__.frozen ) diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index b1491ca5d47946..083cd69a970df6 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -195,7 +195,7 @@ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: self.node_replacements[node] = tensor - def run(self): + def run(self): # type: ignore[override] env = {} for n in self.module.graph.find_nodes(op="placeholder"): env[n] = self.unknown_value diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index b28f657b9d5d8f..ed99ae7beca12a 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -77,7 +77,7 @@ class TransformGetItemToIndex(TorchFunctionMode): # scalar and create a view. We do not want that behavior in this case, so we # use this torchfunctionmode to override that behavior for score_mod # wherever we're running it. - def __torch_function__(self, func, types, args, kwargs=None): + def __torch_function__(self, func, types, args=(), kwargs=None): if func == torch.Tensor.__getitem__: index_args = pytree.tree_leaves(args[1]) if all(isinstance(x, torch.Tensor) for x in index_args): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index c009246243f3cf..4e88a6f91f7e22 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1480,7 +1480,8 @@ def generate_profiler_mark_wrapper_call(self, stack): 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' ) - def write_triton_header_once(self): + @cache_on_self + def write_triton_header_once(self) -> None: pass def generate_start_graph(self): diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 90971045379656..9273eeb5f9824f 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -30,7 +30,7 @@ def __init__(self, scheduler: Scheduler) -> None: self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) - def get_backend_features(self, device): + def get_backend_features(self, device): # type:ignore[override] return self._triton_scheduling.get_backend_features(device) def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 95a48281977c87..643bf686bd8f19 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -578,22 +578,22 @@ def extract_read_writes( repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)} name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} for entry in fn.memory_usage[MemoryUsageType.LOAD]: - inner.load(entry.buffer_name, name_to_index[entry.index_name]) + inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type] for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: - inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type] for entry in fn.memory_usage[MemoryUsageType.STORE]: inner.store( - entry.buffer_name, name_to_index[entry.index_name], None, entry.mode + entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type] ) for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: inner.store_reduction( - entry.buffer_name, name_to_index[entry.index_name], None + entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type] ) for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: inner.index_expr(name_to_index[entry.index_name], None) for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: inner.bucketize( - None, entry.buffer_name, name_to_index[entry.index_name], None, None + None, entry.buffer_name, name_to_index[entry.index_name], None, None # type: ignore[arg-type] ) # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped else: diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 0c66da4ca4f314..21ba6c1e215dbe 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -2,7 +2,7 @@ import functools import itertools import logging -from typing import cast, List, Tuple +from typing import cast, Sequence, Tuple import sympy @@ -28,7 +28,7 @@ def filtered_configs( m: int, n: int, k: int, - configs: List[Tuple[int, int, int, int, int]], + configs: Sequence[Tuple[int, int, int, int, int]], has_int8_tensor=False, ): """Heuristic to shrink configs when they are bigger than the input size""" diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 14963a1bf5d435..056e24f1b4e26c 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -101,8 +101,8 @@ def put(self, key: str, value: _T) -> None: self._put(key, value, sample) self._log_sample(sample) - def _decode(self, data: _U, sample: Optional[Sample]) -> _T: - return self.serde.decode(data) + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] + return self.serde.decode(data) # type: ignore[arg-type] def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U return self.serde.encode(value) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 56169905dbe3fb..abb932266c5a89 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -827,7 +827,7 @@ def benchmark_one_config(config): ) return config2launcher.get(best_config) - def run(self, *args, grid, stream, **kwargs): + def run(self, *args, grid, stream, **kwargs): # type:ignore[override] if len(self.launchers) != 1: if len(self.launchers) == 0: start_time = time.time_ns() @@ -958,7 +958,7 @@ def __init__(self, *args, regex_filter="", with_profiler=False, **kwargs): super().__init__(*args, **kwargs) self.cached = None - def run(self, *args, grid, stream): + def run(self, *args, grid, stream): # type: ignore[override] possible_names = _find_names(self) kernel_name = f"{max(possible_names, key=len)}" if not re.match(self.regex_filter, kernel_name): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5f8d35cbcab5eb..3d2676e0b2e000 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2857,9 +2857,7 @@ def has_shared_data_after_reordering_loop( return False # Pick the largest buffer to guide the loop reordering - numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[ - 0 - ] + numel, lhs_dep, rhs_dep = max(candidates, key=lambda x: x[0]) if lhs_dep.num_vars != rhs_dep.num_vars: # this can happen due to we don't merge loops. diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index b2eeb24521d382..9845a64874803a 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -268,4 +268,4 @@ def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.L elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type] return typing.List[type_args[0]] # type: ignore[valid-type] else: - return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc] + return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value] diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index bec1ebc559c328..cd5bfa655ea0b5 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -221,7 +221,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" ) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"FunctionalTensor({repr(self.elem)})" @staticmethod @@ -302,7 +302,7 @@ def cuda(self, device=None, *args, **kwargs): long = _conversion_method_template(dtype=torch.int64) # TODO(sparse-team): fixes #133174 but can we do without the relay? - def to_dense(self): + def to_dense(self): # type: ignore[override] return self.elem.to_dense() @property diff --git a/torch/_tensor.py b/torch/_tensor.py index 7d8010081abc77..e22c6a92d92d0f 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1353,7 +1353,7 @@ def align_to(self, *names): [name for name in names if not is_ellipsis(name)], ellipsis_idx ) - def unflatten(self, dim, sizes): + def unflatten(self, dim, sizes): # type: ignore[override] r""" unflatten(dim, sizes) -> Tensor diff --git a/torch/ao/nn/intrinsic/modules/fused.py b/torch/ao/nn/intrinsic/modules/fused.py index 010b9a701a3353..c7ae2ce3319d3c 100644 --- a/torch/ao/nn/intrinsic/modules/fused.py +++ b/torch/ao/nn/intrinsic/modules/fused.py @@ -228,7 +228,7 @@ def __init__(self, conv, add): super().__init__(conv) self.add = add - def forward(self, x1, x2): + def forward(self, x1, x2): # type: ignore[override] return self.add(self[0](x1), x2) @@ -241,5 +241,5 @@ def __init__(self, conv, add, relu): self.add = add self.relu = relu - def forward(self, x1, x2): + def forward(self, x1, x2): # type: ignore[override] return self.relu(self.add(self[0](x1), x2)) diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 3c700b30f24720..6ceee89757ca75 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import math -from typing import TypeVar +from typing import ClassVar, Optional, Type import torch import torch.ao.nn.intrinsic as nni @@ -33,12 +33,9 @@ } -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) - - class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): _version = 2 - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -365,7 +362,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert mod.qconfig, "Input float module must have a valid qconfig" qconfig = mod.qconfig - conv, bn = mod[0], mod[1] + conv, bn = mod[0], mod[1] # type: ignore[index] qat_convbn = cls( conv.in_channels, conv.out_channels, @@ -434,7 +431,7 @@ def to_float(self): return conv -class ConvBn1d(_ConvBnNd, nn.Conv1d): +class ConvBn1d(_ConvBnNd, nn.Conv1d): # type: ignore[misc] r""" A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, attached with FakeQuantize modules for weight, @@ -451,10 +448,10 @@ class ConvBn1d(_ConvBnNd, nn.Conv1d): weight_fake_quant: fake quant module for weight """ - _FLOAT_BN_MODULE = nn.BatchNorm1d - _FLOAT_RELU_MODULE: None = None - _FLOAT_MODULE = nni.ConvBn1d - _FLOAT_CONV_MODULE = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d def __init__( self, @@ -520,12 +517,12 @@ class ConvBnReLU1d(ConvBn1d): """ # base class defines _FLOAT_MODULE as "ConvBn1d" - _FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv1d - _FLOAT_BN_MODULE = nn.BatchNorm1d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBnReLU1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU1d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d def __init__( self, @@ -585,10 +582,10 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv1d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, @@ -631,7 +628,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) -class ConvBn2d(_ConvBnNd, nn.Conv2d): +class ConvBn2d(_ConvBnNd, nn.Conv2d): # type: ignore[misc] r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, @@ -648,10 +645,10 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBn2d - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE = nn.BatchNorm2d - _FLOAT_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -717,12 +714,12 @@ class ConvBnReLU2d(ConvBn2d): """ # base class defines _FLOAT_MODULE as "ConvBn2d" - _FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE = nn.BatchNorm2d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm2d]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU # type: ignore[assignment,misc] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU2d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU2d]]] = nni.ConvReLU2d def __init__( self, @@ -782,10 +779,10 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, @@ -828,7 +825,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) -class ConvBn3d(_ConvBnNd, nn.Conv3d): +class ConvBn3d(_ConvBnNd, nn.Conv3d): # type: ignore[misc] r""" A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, attached with FakeQuantize modules for weight, @@ -845,10 +842,10 @@ class ConvBn3d(_ConvBnNd, nn.Conv3d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBn3d - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE = nn.BatchNorm3d - _FLOAT_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -913,12 +910,12 @@ class ConvBnReLU3d(ConvBn3d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE = nn.BatchNorm3d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm3d]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.ReLU]]] = nn.ReLU # type: ignore[assignment, misc] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU3d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU3d]]] = nni.ConvReLU3d def __init__( self, @@ -980,10 +977,10 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py index 66299394091b7c..0d1b7e01f4479f 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -49,7 +49,7 @@ def __init__( dtype=dtype, ) - def forward(self, input, extra_input): + def forward(self, input, extra_input): # type: ignore[override] # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: @@ -117,7 +117,7 @@ def __init__( dtype=dtype, ) - def forward(self, input, extra_input): + def forward(self, input, extra_input): # type: ignore[override] # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 3be5a7cec8162e..49b559136bbc6e 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Tuple, TypeVar, Union +from typing import ClassVar, Tuple, Type, Union import torch import torch.nn as nn @@ -10,11 +10,9 @@ __all__ = ["Conv1d", "Conv2d", "Conv3d"] -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) - class _ConvNd(nn.modules.conv._ConvNd): - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -136,8 +134,8 @@ class Conv1d(_ConvNd, nn.Conv1d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv1d - _FLOAT_CONV_MODULE = nn.Conv1d + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d def __init__( self, @@ -197,8 +195,8 @@ class Conv2d(_ConvNd, nn.Conv2d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv2d - _FLOAT_CONV_MODULE = nn.Conv2d + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d def __init__( self, @@ -261,8 +259,8 @@ class Conv3d(_ConvNd, nn.Conv3d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv3d - _FLOAT_CONV_MODULE = nn.Conv3d + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d def __init__( self, diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index fa86dcdd93e17d..f379e487f65de8 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -2,6 +2,7 @@ r"""Dynamically quantized convolution modules.""" import warnings +from typing import ClassVar, Optional, Type import torch import torch.ao.nn.quantized as nnq @@ -47,9 +48,9 @@ class Conv1d(nnq.Conv1d): """ - _FLOAT_MODULE = nn.Conv1d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -132,9 +133,9 @@ class Conv2d(nnq.Conv2d): >>> output = m(input) """ - _FLOAT_MODULE = nn.Conv2d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -216,9 +217,9 @@ class Conv3d(nnq.Conv3d): >>> output = m(input) """ - _FLOAT_MODULE = nn.Conv3d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -308,7 +309,7 @@ class ConvTranspose1d(nnq.ConvTranspose1d): torch.Size([1, 16, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose1d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d def __init__( self, @@ -390,7 +391,7 @@ class ConvTranspose2d(nnq.ConvTranspose2d): torch.Size([1, 16, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose2d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d def __init__( self, @@ -472,7 +473,7 @@ class ConvTranspose3d(nnq.ConvTranspose3d): torch.Size([1, 16, 12, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose3d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d def __init__( self, diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 48ca18028acf94..0ef2f3af1dcd44 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Quantized convolution modules.""" -from typing import List, Optional, TypeVar +from typing import ClassVar, List, Optional, Type import torch import torch.ao.nn.intrinsic as nni @@ -386,11 +386,11 @@ class Conv1d(_ConvNd): """ - _FLOAT_MODULE = nn.Conv1d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d - _NNI_CONV_RELU_MODULE = nni.ConvReLU1d - _NNI_CONV_ADD_MODULE: None = None - _NNI_CONV_ADD_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -518,11 +518,11 @@ class Conv2d(_ConvNd): >>> output = m(q_input) """ - _FLOAT_MODULE = nn.Conv2d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d - _NNI_CONV_RELU_MODULE = nni.ConvReLU2d - _NNI_CONV_ADD_MODULE = nni.ConvAdd2d - _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU2d + _NNI_CONV_ADD_MODULE: ClassVar[Type[nni.ConvAdd2d]] = nni.ConvAdd2d + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Type[nni.ConvAddReLU2d]] = nni.ConvAddReLU2d def __init__( self, @@ -647,11 +647,11 @@ class Conv3d(_ConvNd): >>> output = m(q_input) """ - _FLOAT_MODULE = nn.Conv3d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d - _NNI_CONV_RELU_MODULE = nni.ConvReLU3d - _NNI_CONV_ADD_MODULE: None = None - _NNI_CONV_ADD_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn3d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU3d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -740,11 +740,10 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): # === Transposed Convolutions === -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) class _ConvTransposeNd(_ConvNd): - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -914,7 +913,7 @@ class ConvTranspose1d(_ConvTransposeNd): torch.Size([1, 16, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose1d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d def __init__( self, @@ -1037,7 +1036,7 @@ class ConvTranspose2d(_ConvTransposeNd): torch.Size([1, 16, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose2d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d def __init__( self, @@ -1162,7 +1161,7 @@ class ConvTranspose3d(_ConvTransposeNd): torch.Size([1, 16, 12, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose3d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d def __init__( self, diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 9a80b4af6f4794..9e4cf94303c9f1 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -198,7 +198,7 @@ def __init__(self): self.stats["float"] = [] self.stats["quantized"] = [] - def forward(self, x, y): + def forward(self, x, y): # type: ignore[override] # fmt: off """ """ # blank docblock to make autodoc happy diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index 3cea146b8031fd..da79b03ce88089 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -261,7 +261,7 @@ def __init__(self, *args, **kwargs): self.comparisons = [] # precalculated comparisons function - def forward(self, x, x_ref): + def forward(self, x, x_ref): # type: ignore[override] # fmt: off """ """ # blank docblock to make autodoc happy diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index 875af8c875b0a4..56784dd1780308 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -410,7 +410,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: # Add function swaps from default lowering path # - for source, ( + for source, ( # type:ignore[assignment] target1, target2, ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): @@ -422,7 +422,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: _lower_to_native_backend.QBIN_RELU_OP_MAPPING, quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, ): - for source, target in source_to_target.items(): + for source, target in source_to_target.items(): # type:ignore[assignment] new_connections.append((source, target)) # @@ -432,7 +432,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: for source_to_target in ( quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, ): - for source, target in source_to_target.items(): + for source, target in source_to_target.items(): # type:ignore[assignment] new_connections.append((source, target)) # add the new connections from backend_config diff --git a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py index b81ee658e55318..75d86737d832d5 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -71,7 +71,7 @@ def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults # add data with default config here [self.add_data(name, data, **self.defaults) for name, data in data_list] - def prepare(self): + def prepare(self, model, config): raise NotImplementedError("this function is undefined for this class") def _extract_weight(self, data): @@ -266,7 +266,7 @@ def __getstate__(self): "_container": self._container.state_dict(), } - def __repr__(self): + def __repr__(self): # type:ignore[override] format_string = self.__class__.__name__ + " (" for name, sparse_args in self.data_groups.items(): format_string += "\n" @@ -299,7 +299,7 @@ def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): self._container, name, leave_parametrized=leave_parametrized ) - def step(self): + def step(self): # type:ignore[override] if not self.enable_mask_update: return with torch.no_grad(): diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index 34444fa79df0fe..ec867205f64ab1 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -156,7 +156,7 @@ def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): ] # squeeze only the first 2 dimension return mask - def update_mask( + def update_mask( # type: ignore[override] self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs ): values_per_block = reduce(operator.mul, sparse_block_shape) diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index 54177272ce4f2a..3da27ba38df55b 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -75,7 +75,9 @@ def _compute_distance(self, t): return distance - def update_mask(self, module, tensor_name, sparsity_level, **kwargs): + def update_mask( # type: ignore[override] + self, module, tensor_name, sparsity_level, **kwargs + ): tensor_weight = getattr(module, tensor_name) mask = getattr(module.parametrizations, tensor_name)[0].mask diff --git a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py index 47c567ec6d93eb..a4d42ea803289c 100644 --- a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py +++ b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -33,7 +33,9 @@ def __init__(self, nearliness: int = 1): defaults = {"nearliness": nearliness} super().__init__(defaults=defaults) - def update_mask(self, module, tensor_name, nearliness, **kwargs): + def update_mask( # type:ignore[override] + self, module, tensor_name, nearliness, **kwargs + ): mask = getattr(module.parametrizations, tensor_name)[0].mask mask.data = torch.zeros_like(mask) if nearliness <= 0: diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index 3c2eef0dd15c79..a25b7ffbca61cb 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -208,7 +208,7 @@ def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None) mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() return mask - def update_mask( + def update_mask( # type: ignore[call-override, override] self, module, tensor_name, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 840d2715c7717f..503755d383bb93 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -216,5 +216,5 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return x - def calculate_qparams(self): + def calculate_qparams(self): # type:ignore[override] return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/torch/ao/quantization/_correct_bias.py b/torch/ao/quantization/_correct_bias.py index 4be37467de1818..03d259a0ad26ca 100644 --- a/torch/ao/quantization/_correct_bias.py +++ b/torch/ao/quantization/_correct_bias.py @@ -61,7 +61,7 @@ def __init__(self): self.float_sum = None self.quant_sum = None - def forward(self, x, y): + def forward(self, x, y): # type: ignore[override] """Compute the average of quantized and floating-point data from modules. The inputs x,y are output data from the quantized and floating-point modules. diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 631ec11b1669d1..1cf27c97054c1e 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -34,7 +34,7 @@ def __init__(self, b, k, dtype=torch.quint8) -> None: # min_val and max_val are optional args to override # the min_val and max_val observed by forward - def calculate_qparams(self, signed): + def calculate_qparams(self, signed): # type:ignore[override] return self._calculate_qparams(signed, self.min_val, self.max_val) r""" Calculates nonuniform quantization parameters according to APoT paper: diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 574af30a7159b9..6042fd2ee5adb9 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -225,7 +225,7 @@ def _map_module_function_to_aten_operator_type(): ), ) for map_item in map_list: - module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload] + module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] return module_function_to_aten_operator diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 77b962bf3df47c..4127885dccc1f6 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -600,7 +600,7 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): elem = inner_tensors["elem"] return AsyncCollectiveTensor(elem) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"AsyncCollectiveTensor({self.trigger_wait()})" def trigger_wait(self): @@ -653,7 +653,7 @@ def wrap(e: torch.Tensor): return out - def numpy(self): + def numpy(self): # type: ignore[override] return self.wait().numpy() diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 68df582cd5145e..d50160ca8ecc31 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1200,7 +1200,7 @@ def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: def __hash__(self): return id(self) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"ShardedTensor({self._metadata})" @dataclass diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index 0acafd6868d3bb..872ad0e2a76731 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -136,7 +136,7 @@ def _low_precision_hook( prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, - output: torch.Tensor, + output: Optional[torch.Tensor], ): if grad.dtype != prec: grad.data = grad.data.to(prec) diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index db65856e32adad..3c0027d1124073 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -96,7 +96,7 @@ def load_state_dict(self, state_dict): ) self.averager.step = 0 - def step(self): + def step(self): # type: ignore[override] r""" Performs a single optimization step (parameter update). """ diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 036ad6cf8621eb..3010bccd377c94 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -364,7 +364,7 @@ def __init__(self, module, garbage_collect_values=True): super().__init__(module, garbage_collect_values) self.value_remap = {} - def run(self, *args, initial_env=None): + def run(self, *args, initial_env=None): # type: ignore[override] self.value_remap = {} return super().run(*args, initial_env=initial_env) @@ -932,8 +932,7 @@ def move_param_to_callee( if node.op == "get_attr": # get_attr might get access deeper level attribute fqn = scope + "." + node.target if scope else node.target - if fqn in unused_attributes: # used, remove it - unused_attributes.remove(fqn) + unused_attributes.discard(fqn) for _name, _submod in _mod.named_children(): stack.append((scope + "." + _name if scope else _name, _submod)) # delete unused attributes diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 583aa20c1dc188..8684d7b0cafa0f 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -283,7 +283,7 @@ def __new__( # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. # pyre-fixme[3]: Return type must be annotated. - def __repr__(self): + def __repr__(self): # type: ignore[override] # TODO: consider all_gather the local tensors for better debugging return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" diff --git a/torch/distributed/tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py index de396473b77c91..df8c7d09e38a4a 100644 --- a/torch/distributed/tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -309,7 +309,7 @@ def __hash__(self): # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. # pyre-fixme[3]: Return type must be annotated. - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" def __str__(self) -> str: diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 630c287cae88f5..ab113246b7071c 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -39,6 +39,9 @@ def input_reshard( Return: A :class:`nn.Module` object registered with TP input resharding. """ + if input_reshard_dim is None: + return module + cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None: @@ -56,8 +59,6 @@ def input_reshard_backward_hook( nonlocal cx cx.__exit__() # type: ignore[name-defined, union-attr] - if input_reshard_dim is None: - return module module.register_forward_pre_hook(input_reshard_forward_pre_hook) module.register_forward_hook(input_reshard_backward_hook) return module diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 3c510bd32abc62..e6f730b123afd1 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -204,7 +204,7 @@ def __init__( self._is_discrete = is_discrete self._event_dim = event_dim - def __call__(self, fn): + def __call__(self, fn): # type: ignore[override] """ Support for syntax to customize static attributes:: diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index bd8fa87f2619a4..a4d403383d9cb0 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -177,7 +177,7 @@ def sample(self, sample_shape=torch.Size()): self._loc, self._concentration, self._proposal_r, x ).to(self.loc.dtype) - def expand(self, batch_shape): + def expand(self, batch_shape, _instance=None): try: return super().expand(batch_shape) except NotImplementedError: diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 69850df81c65a1..366cf45eb2d50f 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -258,7 +258,7 @@ def _set_data_mask(self, data, mask): self._masked_mask = mask self._validate_members() - def __repr__(self): + def __repr__(self): # type: ignore[override] formatter = "{0:8.4f}" if self.dim() == 0: scalar_data = self.get_data().item() @@ -350,7 +350,7 @@ def get_mask(self): def is_sparse_coo(self): return self.layout == torch.sparse_coo - def is_sparse_csr(self): + def is_sparse_csr(self): # type: ignore[override] return self.layout == torch.sparse_csr # Update later to support more sparse layouts diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index fa0818571a93c0..c02ad3dc2362ba 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -74,7 +74,7 @@ def __init__(self) -> None: def _after_fork(self): self.lock = threading.Lock() - def get(self, key): + def get(self, key): # type: ignore[override] with self.lock: return dict.get(self, key) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 76f8d5a29e2d8f..d39eb12d919c87 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -209,7 +209,7 @@ def _max_seqlen(self): def _min_seqlen(self): return self._get_min_seqlen() - def __repr__(self): + def __repr__(self): # type: ignore[override] # We should implement this in torch/_tensor_str.py instead grad_fn_str = ( f", requires_grad={self.requires_grad}" if self.requires_grad else "" diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 2fed60030a6898..da7acb957d96d7 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -289,7 +289,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): ) return cls._dispatch(*args, **kwargs) - def __repr__(self): + def __repr__(self): # type:ignore[override] return self._materialize().__repr__() diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 57dcbd85a83164..820a57f3e0b756 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -910,7 +910,7 @@ def __init__( self._last_lr = schedulers[0].get_last_lr() - def step(self): + def step(self): # type: ignore[override] """Perform a step.""" self.last_epoch += 1 idx = bisect_right(self._milestones, self.last_epoch) @@ -1179,7 +1179,7 @@ def __init__( group["lr"] for group in self._schedulers[-1].optimizer.param_groups ] - def step(self): + def step(self): # type: ignore[override] """Perform a step.""" for scheduler in self._schedulers: scheduler.step() diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 0017a10e6771bb..5e8a632633e09a 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -295,7 +295,7 @@ def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: else: return dense_input - def to_dense(self): + def to_dense(self): # type:ignore[override] col = self.shape[-1] return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) @@ -420,7 +420,7 @@ def from_dense( requires_grad=original_tensor.requires_grad, ) - def to_dense(self): + def to_dense(self): # type: ignore[override] assert self.meta is not None and self.packed is not None return ( sparse_semi_structured_to_dense_cutlass( diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 4924fff6dea6be..ed597402d8b637 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -310,11 +310,11 @@ def eval(cls, base, divisor, modulus): if isinstance(base, FloorDiv): return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] p, q = self.args[:2] return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined] @@ -329,14 +329,14 @@ class Where(sympy.Function): def _eval_is_integer(self): return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return ( True if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] else None ) - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] @classmethod @@ -397,7 +397,7 @@ def eval(cls, p, q): return S.Zero # NB: args[1] for PythonMod - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return True if self.args[1].is_positive else None # type: ignore[attr-defined] def _eval_is_nonpositive(self): @@ -823,13 +823,13 @@ class Max(MinMaxBase, Application): # type: ignore[misc] zero = S.Infinity identity = S.NegativeInfinity - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] - def _eval_is_negative(self): + def _eval_is_negative(self): # type:ignore[override] return fuzzy_and(a.is_negative for a in self.args) @@ -841,13 +841,13 @@ class Min(MinMaxBase, Application): # type: ignore[misc] zero = S.NegativeInfinity identity = S.Infinity - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] - def _eval_is_negative(self): + def _eval_is_negative(self): # type:ignore[override] return fuzzy_or(a.is_negative for a in self.args) diff --git a/torch/utils/weak.py b/torch/utils/weak.py index cc272a7f26375a..f729ff06489ffe 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -263,7 +263,7 @@ def pop(self, key, *args): def setdefault(self, key, default=None): return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED - def update(self, dict=None, **kwargs): + def update(self, dict=None, **kwargs): # type: ignore[override] d = self.data if dict is not None: if not hasattr(dict, "items"): From 1c3418aa23589abb9157e58a2cfb3b20dc616841 Mon Sep 17 00:00:00 2001 From: James Wu Date: Mon, 16 Sep 2024 19:48:08 +0000 Subject: [PATCH 0947/1018] Refactor FxGraphCache.load into separate functions, so that AOTAutogradCache may access it correctly later (#135491) Summary: We refactor FxGraphCache.load into three phases: - prepare_key, which checks that an inductor input is cacheable and bypasses otherwise - load_with_key, which tries to lookup the key in the cache - post compile, where we do some logging and run post compile steps Splitting it along these lines will allow AOTAutogradCache to use load_with_key and still get access to all of the observability + remote cache logic when accessing FxGraphCache, without needing to pass key components, etc. Differential Revision: D62314862 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135491 Approved by: https://github.com/oulgen --- torch/_inductor/codecache.py | 235 ++++++++++++++++++++++++----------- 1 file changed, 159 insertions(+), 76 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 3cd20e9f86df6e..135ebf96a86b69 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1281,6 +1281,111 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: ): raise BypassFxGraphCache("Can't cache torchbind objects") + @staticmethod + def prepare_key( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], + remote: bool, + ) -> Tuple[Optional[Tuple[str, List[str]]], Dict[str, Any]]: + """ + Checks that the inductor input is cacheable, then computes + and returns the cache key for the input. + Returns (key_info, cache_info) where: + - key_info is (hash_key, debug_lines), and + - cache_info will contain debug info in the event of BypassFxGraphCache. + + NB: It is possible to have this function return a union instead. But + I personally believe it is more annoying/difficult to read in that format. + """ + try: + FxGraphCache._check_can_cache(gm) + key, debug_lines = compiled_fx_graph_hash( + gm, example_inputs, fx_kwargs, inputs_to_check + ) + except BypassFxGraphCache as e: + counters["inductor"]["fxgraph_cache_bypass"] += 1 + log.info("Bypassing FX Graph Cache because '%s'", e) + if remote: + log_cache_bypass("bypass_fx_graph", str(e)) + cache_info = { + "cache_state": "bypass", + "cache_bypass_reason": str(e), + "cache_event_time": time_ns(), + } + return None, cache_info + # If key exists, then cache_info will come from load_with_key + return (key, debug_lines), {} + + @staticmethod + def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + """ + Attempts to load the remote cache, returns None on error. + """ + remote_cache = None + cache_id = "fx-graph-v1" + try: + if config.is_fbcode(): + from torch._inductor.fb.remote_cache import FbRemoteFxGraphCache + + remote_cache = FbRemoteFxGraphCache(cache_id) + else: + from torch._inductor.remote_cache import RemoteFxGraphCache + + remote_cache = RemoteFxGraphCache(cache_id) + except ModuleNotFoundError as e: + # No need for a stack trace on this error + remote_cache = None + log.warning("Unable to create a remote cache: %s", e) + except Exception: + remote_cache = None + log.warning("Unable to create a remote cache", exc_info=True) + return remote_cache + + @staticmethod + def load_with_key( + key: str, + debug_lines: List[str], + example_inputs: List[torch.Tensor], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + is_backward: bool, + ) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]: + """ + Lookup the graph with the given key, and return results and metadata. + Doesn't do any logging on its own, because AOTAutograd handles a cache miss + differently from FXGraphCache. + """ + compiled_graph = FxGraphCache._lookup_graph( + key, example_inputs, local, remote_cache + ) + cache_info = { + "key": key, + "components": debug_lines, + "cache_event_time": time_ns(), + } + if compiled_graph is not None: + log.debug("fx graph cache miss for key %s", key) + counters["inductor"]["fxgraph_cache_hit"] += 1 + cache_info["cache_state"] = "hit" + + if (time_saved_ns := compiled_graph._time_taken_ns) is not None: + cache_info["time_saved_ns"] = time_saved_ns + add_remote_cache_time_saved(time_saved_ns, is_backward) + if ( + ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( + time_saved_ns + ) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + else: + log.debug("fx graph cache hit for key %s", key) + counters["inductor"]["fxgraph_cache_miss"] += 1 + cache_info["cache_state"] = "miss" + + return compiled_graph, cache_info + @staticmethod def load( # type: ignore[no-untyped-def] compile_fx_fn: Callable[..., Any], @@ -1297,92 +1402,70 @@ def load( # type: ignore[no-untyped-def] """ assert local or remote, "at least one of them needs to be enabled" compiled_graph = None - cache_state = None - cache_event_time = None - cache_info: Dict[str, Any] = {} - try: - FxGraphCache._check_can_cache(gm) - key, debug_lines = compiled_fx_graph_hash( - gm, example_inputs, fx_kwargs, inputs_to_check - ) - cache_info["key"] = key - cache_info["components"] = debug_lines - - remote_cache: Optional[RemoteCache[JsonDataTy]] = None + remote_cache = None + (key_info, cache_info) = FxGraphCache.prepare_key( + gm, example_inputs, fx_kwargs, inputs_to_check, remote + ) + if key_info is not None: + key, debug_lines = key_info if remote: - cache_id = "fx-graph-v1" - try: - if config.is_fbcode(): - from torch._inductor.fb.remote_cache import FbRemoteFxGraphCache - - remote_cache = FbRemoteFxGraphCache(cache_id) - else: - from torch._inductor.remote_cache import RemoteFxGraphCache - - remote_cache = RemoteFxGraphCache(cache_id) - except ModuleNotFoundError as e: - # No need for a stack trace on this error - remote_cache = None - log.warning("Unable to create a remote cache: %s", e) - except Exception: - remote_cache = None - log.warning("Unable to create a remote cache", exc_info=True) - - compiled_graph = FxGraphCache._lookup_graph( - key, example_inputs, local, remote_cache + remote_cache = FxGraphCache.get_remote_cache() + compiled_graph, cache_info = FxGraphCache.load_with_key( + key, + debug_lines, + example_inputs, + local, + remote_cache, + is_backward=fx_kwargs.get("is_backward", False), ) - if compiled_graph is None: - log.debug("fx graph cache miss for key %s", key) - counters["inductor"]["fxgraph_cache_miss"] += 1 - cache_state = "miss" - start_time = time_ns() - cache_event_time = start_time - compiled_graph = compile_fx_fn( - gm, example_inputs, inputs_to_check, fx_kwargs - ) - compiled_graph._time_taken_ns = time_ns() - start_time - cache_info["time_taken_ns"] = compiled_graph._time_taken_ns - FxGraphCache._save_graph( - key, - compiled_graph, - example_inputs, - local, - remote_cache, - ) - else: - log.debug("fx graph cache hit for key %s", key) - counters["inductor"]["fxgraph_cache_hit"] += 1 - cache_state = "hit" - cache_event_time = time_ns() - if (time_saved_ns := compiled_graph._time_taken_ns) is not None: - cache_info["time_saved_ns"] = time_saved_ns - add_remote_cache_time_saved(time_saved_ns, fx_kwargs["is_backward"]) - if ( - ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( - time_saved_ns - ) - ) != 0: - cache_info["ephemeral_timeout_increase"] = ephemeral_increase - compiled_graph._fx_graph_cache_key = key - except BypassFxGraphCache as e: - counters["inductor"]["fxgraph_cache_bypass"] += 1 - cache_state = "bypass" - log.info("Bypassing FX Graph Cache because '%s'", e) - cache_info["cache_bypass_reason"] = str(e) - if remote: - log_cache_bypass("bypass_fx_graph", str(e)) - cache_event_time = time_ns() + # CACHE BYPASS: Compile the graph, don't save it to the cache + if cache_info["cache_state"] == "bypass": + assert compiled_graph is None + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) - if not compiled_graph: + # CACHE MISS: Compile the graph and save to cache + elif cache_info["cache_state"] == "miss": + assert compiled_graph is None + assert key_info is not None + start_time = cache_info["cache_event_time"] compiled_graph = compile_fx_fn( gm, example_inputs, inputs_to_check, fx_kwargs ) + compiled_graph._time_taken_ns = time_ns() - start_time + cache_key = key_info[0] + compiled_graph._fx_graph_cache_key = cache_key + cache_info["time_taken_ns"] = compiled_graph._time_taken_ns + FxGraphCache._save_graph( + cache_key, + compiled_graph, + example_inputs, + local, + remote_cache, + ) + # CACHE HIT: not much to really do, just make sure the cache key + # is recorded on the graph + else: + assert cache_info["cache_state"] == "hit" + assert compiled_graph is not None + assert key_info is not None + cache_key = key_info[0] + compiled_graph._fx_graph_cache_key = cache_key + assert compiled_graph is not None - cache_info["cache_state"] = cache_state + + # Logging and observability: we log a single chromium event + # and a tlparse log for every cache action. + # In the event of a bypass, we also logged to the remote table earlier + # with log_cache_bypass. chromium_log = get_chromium_event_logger() + cache_state = cache_info["cache_state"] chromium_log.log_instant_event( - f"fx_graph_cache_{cache_state}", cache_event_time, metadata=cache_info + f"fx_graph_cache_{cache_state}", + cache_info["cache_event_time"], + metadata=cache_info, ) torch._logging.trace_structured( "artifact", From 225ea8fab732394b2b8b8e523a6b51a5a2865f8c Mon Sep 17 00:00:00 2001 From: fduwjj Date: Wed, 11 Sep 2024 14:24:25 -0700 Subject: [PATCH 0948/1018] [c10d][Reland] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931) (#135653) We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG. Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options" We need to make changes to the test to make it aligned with the change. This is try to reland D62008954 by fixing internal errors. Differential Revision: [D62483294](https://our.internmc.facebook.com/intern/diff/D62483294/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135653 Approved by: https://github.com/wz337, https://github.com/H-Huang --- test/distributed/test_c10d_common.py | 8 +- test/distributed/test_device_mesh.py | 3 +- torch/_C/_distributed_c10d.pyi | 15 +- torch/csrc/distributed/c10d/ProcessGroup.cpp | 4785 ++++++++++++++- torch/csrc/distributed/c10d/ProcessGroup.hpp | 5635 ++++-------------- torch/csrc/distributed/c10d/init.cpp | 60 +- torch/distributed/device_mesh.py | 5 +- torch/distributed/distributed_c10d.py | 65 +- 8 files changed, 5829 insertions(+), 4747 deletions(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 3e5538d57e38ae..d96abb1ca82675 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1820,6 +1820,7 @@ def test_init_process_group_optional_backend(self): def test_init_process_group_for_all_backends(self): for backend in dist.Backend.backend_list: + excepted_backend = backend # skip if the backend is not available on the system if backend == dist.Backend.UNDEFINED: continue @@ -1835,6 +1836,11 @@ def test_init_process_group_for_all_backends(self): elif backend == dist.Backend.UCC: if not dist.is_ucc_available(): continue + # Multi-threaded PG is defined as a pure python class. + # Its pg.name() does not going through Pybind, so its backend name + # is still "threaded" instead of "custom". + elif backend != "threaded": + excepted_backend = "custom" with tempfile.NamedTemporaryFile(delete=False) as f: store = dist.FileStore(f.name, self.world_size) @@ -1847,7 +1853,7 @@ def test_init_process_group_for_all_backends(self): pg = c10d._get_default_group() self.assertEqual(pg.rank(), self.rank) self.assertEqual(pg.size(), self.world_size) - self.assertEqual(pg.name(), str(backend)) + self.assertEqual(pg.name(), str(excepted_backend)) dist.destroy_process_group() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 8431bac6f6f913..7cbbe1a9e71450 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -236,7 +236,8 @@ def test_set_mesh_dim_group_options(self): mesh_tensor = torch.arange(4).reshape(2, 2) mesh = DeviceMesh(device_type, mesh_tensor) - self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake") + # Fake pg only have BackendType as BackendType::CUSTOM. + self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom") class DeviceMeshTestNDim(DTensorTestBase): diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 6033d969925972..f89f0b50c8582e 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -296,15 +296,6 @@ class Backend: def _set_default_timeout(self, timeout: timedelta) -> None: ... class ProcessGroup: - class Options: - def __init__(self, backend: str, timeout: timedelta = ...) -> None: ... - @property - def backend(self) -> str: ... - @property - def _timeout(self) -> timedelta: ... - @_timeout.setter - def _timeout(self, val: timedelta) -> None: ... - class BackendType(Enum): UNDEFINED = ... GLOO = ... @@ -319,7 +310,6 @@ class ProcessGroup: store: Store, rank: int, size: int, - options: Options, ) -> None: ... def rank(self) -> int: ... def size(self) -> int: ... @@ -509,6 +499,7 @@ class ProcessGroup: @property def _device_types(self) -> list[torch.device]: ... def _get_backend(self, device: torch.device) -> Backend: ... + def _set_default_backend(self, backend_type: BackendType) -> None: ... def _register_backend( self, device: torch.device, @@ -533,7 +524,7 @@ class ProcessGroup: class ProcessGroupGloo(Backend): class Device: ... - class Options(ProcessGroup.Options): + class Options(Backend.Options): devices: list[ProcessGroupGloo.Device] threads: int @@ -563,7 +554,7 @@ class ProcessGroupNCCL(Backend): min_ctas: int max_ctas: int - class Options(ProcessGroup.Options): + class Options(Backend.Options): config: ProcessGroupNCCL.NCCLConfig is_high_priority_stream: bool split_from: ProcessGroupNCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 75635bc68aed4f..d8c4409c28d4b6 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -1,178 +1,4701 @@ -#include -#include +#ifdef USE_C10D_NCCL -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include -#include #include -#include -#include +#include +#include +#include +#include +#include namespace c10d { -static ProcessGroup::BackendType strToBackendType(std::string_view backend) { - if (backend == "undefined") { - return ProcessGroup::BackendType::UNDEFINED; - } else if (backend == "gloo") { - return ProcessGroup::BackendType::GLOO; - } else if (backend == "nccl") { - return ProcessGroup::BackendType::NCCL; - } else if (backend == "ucc") { - return ProcessGroup::BackendType::UCC; - } else if (backend == "mpi") { - return ProcessGroup::BackendType::MPI; - } else { - return ProcessGroup::BackendType::CUSTOM; - } -} - -std::string opTypeToString(OpType opType) { - switch (opType) { - case OpType::BROADCAST: - return "BROADCAST"; - case OpType::ALLREDUCE: - return "ALLREDUCE"; - case OpType::ALLREDUCE_COALESCED: - return "ALLREDUCE_COALESCED"; - case OpType::REDUCE: - return "REDUCE"; - case OpType::ALLGATHER: - return "ALLGATHER"; - case OpType::_ALLGATHER_BASE: - return "_ALLGATHER_BASE"; - case OpType::ALLGATHER_COALESCED: - return "ALLGATHER_COALESCED"; - case OpType::GATHER: - return "GATHER"; - case OpType::SCATTER: - return "SCATTER"; - case OpType::REDUCE_SCATTER: - return "REDUCE_SCATTER"; - case OpType::ALLTOALL_BASE: - return "ALLTOALL_BASE"; - case OpType::ALLTOALL: - return "ALLTOALL"; - case OpType::SEND: - return "SEND"; - case OpType::RECV: - return "RECV"; - case OpType::RECVANYSOURCE: - return "RECVANYSOURCE"; - case OpType::BARRIER: - return "BARRIER"; - case OpType::UNKNOWN: - return "UNKNOWN"; - case OpType::_REDUCE_SCATTER_BASE: - return "_REDUCE_SCATTER_BASE"; - case OpType::COALESCED: - return "COALESCED"; - case OpType::_ALLREDUCE_SPARSE: - return "_ALLREDUCE_SPARSE"; +constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM"; + +namespace { + +#if defined(NCCL_MAJOR) && \ + ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) +#define NCCL_HAS_AVG 1 +#endif + +// NCCL op mapping +const std::map ncclOp = { + {ReduceOp::MIN, ncclMin}, + {ReduceOp::MAX, ncclMax}, + {ReduceOp::SUM, ncclSum}, + {ReduceOp::PRODUCT, ncclProd}, +#ifdef NCCL_HAS_AVG + {ReduceOp::AVG, ncclAvg}, +#endif +}; + +// NCCL type typing +std::map ncclDataType = { + {at::kChar, ncclInt8}, + {at::kByte, ncclUint8}, + {at::kFloat, ncclFloat}, + {at::kDouble, ncclDouble}, + {at::kInt, ncclInt32}, + {at::kLong, ncclInt64}, + {at::kHalf, ncclHalf}, + {at::kBool, ncclUint8}, + {at::kFloat8_e5m2, ncclUint8}, + {at::kFloat8_e4m3fn, ncclUint8}, + {at::kFloat8_e4m3fnuz, ncclUint8}, + {at::kFloat8_e5m2fnuz, ncclUint8}, +#if HAS_NCCL_BF16_DATATYPE + {at::kBFloat16, ncclBfloat16}, +#endif +}; + +// Helper function that gets the data type and issues error if not supported +ncclDataType_t getNcclDataType(at::ScalarType type) { + auto it = ncclDataType.find(type); + TORCH_CHECK_WITH( + TypeError, + it != ncclDataType.end(), + "Input tensor data type is not supported for NCCL process group: ", + type); + return it->second; +} + +bool complexViewAsRealAllowed(const ReduceOp reduceOp) { + switch (reduceOp) { + case ReduceOp::SUM: + return true; + case ReduceOp::AVG: + return true; + case ReduceOp::PREMUL_SUM: + return true; + case ReduceOp::UNUSED: + return true; default: - TORCH_INTERNAL_ASSERT(false, "Unknown op type!"); + return false; } - return "UNKNOWN"; + return false; } -bool isP2POp(OpType opType, bool batchP2P /*= false*/) { - if (batchP2P) - return false; - return opType == OpType::SEND || opType == OpType::RECV || - opType == OpType::RECVANYSOURCE; +#ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT +template +ncclRedOpRAII unpackPreMulSum( + const ReduceOp& reduceOp, + const ncclComm_t& comm) { + const auto* preMulSupplement = + reinterpret_cast(reduceOp.supplement_.get()); + ncclRedOp_t preMulSum; + bool has_tensor = preMulSupplement->tensor_factor.defined(); + auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; + const T* ptr_factor = has_tensor + ? preMulSupplement->tensor_factor.const_data_ptr() + : nullptr; + T scalar_factor = T(preMulSupplement->double_factor); + ncclRedOpCreatePreMulSum( + &preMulSum, + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/ops.html#ncclredopcreatepremulsum + // tells us that the scalar input is strictly a multiplier. + /*scalar=*/has_tensor ? const_cast(ptr_factor) : &scalar_factor, + dataType, + residence, + comm); + return ncclRedOpRAII(preMulSum, comm); } +#endif -c10::intrusive_ptr ProcessGroup::getBackend( - c10::DeviceType deviceType) { - // If there is a backend associated with this device type then return it - if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) { - return deviceTypeToBackend_.at(deviceType); +ncclRedOpRAII getNcclReduceOp( + const ReduceOp& reduceOp, + at::Tensor& input, + const ncclDataType_t& dataType, + const ncclComm_t& comm) { + try { + if (input.scalar_type() == at::kBool) { + if (reduceOp == ReduceOp::SUM) { + // For bool tensors, map sum to max, which both represent a bitwise or. + // This is to prevent overflow issues with sum, since we use uint8 to + // represent a bool (see ncclDataType mapping). + return ncclMax; + } +#ifdef NCCL_HAS_AVG + if (reduceOp == ReduceOp::AVG) { + C10_THROW_ERROR( + TypeError, "Cannot use ReduceOp.AVG with boolean inputs"); + } +#endif + } + if (reduceOp == ReduceOp::PREMUL_SUM) { +#ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT + switch (dataType) { + case ncclHalf: + return unpackPreMulSum(reduceOp, comm); + case ncclFloat: + return unpackPreMulSum(reduceOp, comm); + case ncclDouble: + return unpackPreMulSum(reduceOp, comm); + default: + C10_THROW_ERROR( + TypeError, "PreMulSum Data type must be half, float, or double"); + ncclRedOp_t unused; + return unused; + } +#else + C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1"); +#endif + } + return ncclOp.at(reduceOp); + } catch (const std::out_of_range&) { + switch (reduceOp) { + case ReduceOp::AVG: + C10_THROW_ERROR( + ValueError, + c10::str( + "AVG requires NCCL 2.10+. The current version is ", + NCCL_MAJOR, + ".", + NCCL_MINOR)); + break; + case ReduceOp::BAND: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL"); + break; + case ReduceOp::BOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL"); + break; + case ReduceOp::BXOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); + break; + default: + C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); + break; + } } +} - // Get the backend type associated with the device - ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED}; +// Get a key string from device +inline std::string getKeyFromDevice(at::Device& device) { + return std::to_string(device.index()); +} + +inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) { + // initialize the device index to -1, which is an invalid value. + int index = -1; try { - backendType = deviceTypeToBackendType_.at(deviceType); + index = std::stoi(deviceKey); + } catch (const std::invalid_argument& e) { + LOG(ERROR) << c10::str( + "Invalid deviceKey: ", deviceKey, ",", e.what(), "."); } catch (const std::out_of_range& e) { - TORCH_CHECK( - false, "No backend type associated with device type ", deviceType); + LOG(ERROR) << "Out of range: " << e.what(); + } + return static_cast(index); +} + +std::string getKeySendRecv(int myRank, int peer) { + int lowRank = myRank < peer ? myRank : peer; + int highRank = myRank < peer ? peer : myRank; + std::string sendRecvPair = + std::to_string(lowRank) + ":" + std::to_string(highRank); + return sendRecvPair; +} + +// Get device from tensor +inline at::Device getDevice(at::Tensor& tensor) { + return tensor.device(); +} + +// [Sync Streams] Helper that lets the input ncclStreams to wait for the current +// stream. NCCL communications run on ncclStreams, but input tensors are +// allocated on different streams (i.e., current streams). Communications on +// ncclStreams cannot start before pending input tensor ops on current streams +// finish. Otherwise, ops on two streams might read/write same tensors +// concurrently. +// +// The synchronization above alone is not enough. We also need to make sure +// input tensors are not freed before their usages on ncclStreams finish. This +// can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream, +// which remembers the usage stream (ncclStream), creates an event on the usage +// stream when GC attempts to free the input tensor, and delays GC until that +// event is done. +void syncStream( + at::Device& device, + at::cuda::CUDAEvent& ncclEvent, + at::cuda::CUDAStream& ncclStream) { + ncclEvent.record(at::cuda::getCurrentCUDAStream(device.index())); + ncclEvent.block(ncclStream); +} + +// Given a ncclUniqueId, convert it to a string representation that can be put +// in the store. +std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { + const uint8_t* bytes = reinterpret_cast(&ncclID); + std::ostringstream oss; + for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { + oss << std::hex << static_cast(bytes[i]); } + return oss.str(); +} + +std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) { + return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; +} - // Check if the backend has already been initialized - if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) { - auto backend = backendTypeToBackend_.at(backendType); - deviceTypeToBackend_[deviceType] = backend; - return backend; +// Returns exception's what() given an exception_ptr instance. +std::string getExceptionMsgFromExceptionPtr( + const std::exception_ptr& exceptionPtr) { + TORCH_CHECK(exceptionPtr != nullptr); + try { + std::rethrow_exception(exceptionPtr); + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "Unknown exception type"; } +} - TORCH_CHECK( - false, - "Could not retrieve or create the backend ", - backendType, - " for device type ", - deviceType); +inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { + // parentheses avoid some compiler warnings + static const uint64_t min_version = + (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); + static const uint64_t cur_version = torch::cuda::nccl::version(); + if (cur_version < min_version) { + TORCH_CHECK_WITH( + NotImplementedError, + status == c10::cuda::CaptureStatus::None, + "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); + } +} + +} // namespace + +// Map from each communicator to its device index. +// This map is used when register/deregister cache segments from cache +// allocator. See design notes below: +// - Each segment should be registered only to the communicator on the +// same device. +// - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be +// ranks rather than device in point-to-point case. +// - This map has also to be maintained as global variable since the register +// hooks are called outside the scope of any PG, thus we need traverse +// communicators in all PGs. +static std::unordered_map, int> ncclCommDevIdxMap; +static std::mutex ncclCommDevIdxMapMutex; +static bool allocatorHooksAttached = false; + +std::atomic ProcessGroupNCCL::shouldDump_(false); + +void cacheAllocatorRegisterHook( + const c10::cuda::CUDACachingAllocator::TraceEntry& te) { + // Register after SEGMENT_ALLOC + if (te.action_ != + c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) { + return; + } + + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& it : ncclCommDevIdxMap) { + auto& ncclComm = it.first; + auto& devIdx = it.second; + if (te.device_ == devIdx) { + ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); + } + } +} + +void cacheAllocatorDeregisterHook( + const c10::cuda::CUDACachingAllocator::TraceEntry& te) { + // deregister before SEGMENT_FREE + if (te.action_ != + c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) { + return; + } + + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& it : ncclCommDevIdxMap) { + auto& ncclComm = it.first; + auto& devIdx = it.second; + if (te.device_ == devIdx) { + ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); + } + } +} + +std::unordered_map> +getNCCLCommDumpMap() { +#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) + std::unordered_map< + std::string /* ncclUniqueID */, + std::unordered_map /* dump from this comm */> + ncclDumpMap; + // dump_nccl_trace is only called from the default PG (local_id_=0), but we + // want to dump from all comms so we need to iterate over ncclCommDevIdxMap, + // which is static + std::vector> allNCCLComms; + // within the critical section, we don't want to dump while holding the lock + // as dump might hang + ncclCommDevIdxMapMutex.lock(); + for (auto& [ncclComm, _] : ncclCommDevIdxMap) { + allNCCLComms.push_back(ncclComm); + } + ncclCommDevIdxMapMutex.unlock(); + for (auto& ncclComm : allNCCLComms) { + std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); + ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); + } + return ncclDumpMap; +#else + return std::unordered_map< + std::string, + std::unordered_map>(); +#endif +} + +std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto ncclDumpMap = getNCCLCommDumpMap(); + return NCCLTraceBuffer::get()->dump( + ncclDumpMap, includeCollectives, includeStackTraces, onlyActive); +} + +std::string dump_nccl_trace_json(bool includeCollectives, bool onlyActive) { + auto ncclDumpMap = getNCCLCommDumpMap(); + return NCCLTraceBuffer::get()->dump_json( + ncclDumpMap, includeCollectives, onlyActive); +} + +std::optional)>>& +get_cpp_trace_dumper() { + static std::optional< + std::function)>> + dumper(std::nullopt); + return dumper; +} + +gil_checker_t& get_gil_checker() { + static gil_checker_t gil_checker = nullptr; + return gil_checker; +} + +std::future launchAsyncGilCheck() { + std::promise resultPromise; + std::future resultFuture = resultPromise.get_future(); + TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker"); + std::thread workerThread([promise = std::move(resultPromise)]() mutable { + c10::setThreadName("pt_nccl_gil_chk"); + + try { + auto& gil_checker = get_gil_checker(); + promise.set_value((*gil_checker)()); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + + // Detach the thread to allow it to run independently + workerThread.detach(); + + return resultFuture; +} + +const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100; +constexpr int64_t kSynchronizeBusyWaitMillis = 10; +thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; + +std::ostream& operator<<( + std::ostream& output, + const ProcessGroupNCCL::WorkNCCL& workNCCL) { + std::string workInfo; + workInfo = c10::str( + "WorkNCCL(", + "SeqNum=", + workNCCL.seq_, + ", OpType=", + opTypeToString(workNCCL.opType_), + ", NumelIn=", + workNCCL.numelIn_, + ", NumelOut=", + workNCCL.numelOut_, + ", Timeout(ms)=", + workNCCL.opTimeout_.count(), + ")"); + return output << workInfo; +} + +ProcessGroupNCCL::WorkNCCL::WorkNCCL( + const std::string& pgUID, + const std::string& pgDesc, + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle, + const std::optional>& inputs, + bool desyncDebug, + bool enableTiming, + bool cudaEventCacheEnabled, + DebugLevel distDebugLevel) + : Work(rank, opType, profilingTitle, inputs), + pgUID_(pgUID), + pgDesc_(pgDesc), + device_(device), + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq), + timingEnabled_(enableTiming), + distDebugLevel_(distDebugLevel) { + // Creates the CUDA event wrappers + // Note: The actual events are lazily created when first recorded to with + // DEFAULT_FLAGS = cudaEventDisableTiming. + if (cudaEventCacheEnabled) { + ncclStartEvent_ = enableTiming + ? ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming) + : nullptr; + ncclEndEvent_ = + ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming); + } else { + ncclStartEvent_ = enableTiming + ? std::make_shared(cudaEventDefault) + : nullptr; + ncclEndEvent_ = std::make_shared( + enableTiming ? cudaEventDefault : cudaEventDisableTiming); + } +} + +ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) + : Work(w.rank_, w.opType_), + std::enable_shared_from_this(w), + pgUID_(w.pgUID_), + pgDesc_(w.pgDesc_), + device_(w.device_), + ncclStartEvent_(w.ncclStartEvent_), + ncclEndEvent_(w.ncclEndEvent_), + ncclComm_(w.ncclComm_), + blockingWait_(w.blockingWait_), + opTimeout_(w.opTimeout_), + ownedEphermeralTimeout_(w.ownedEphermeralTimeout_), + workStartTime_(w.workStartTime_), + seq_(w.seq_), + startTraceUpdated_(w.startTraceUpdated_), + numelIn_(w.numelIn_), + numelOut_(w.numelOut_), + store_(w.store_), + timingEnabled_(w.timingEnabled_), + trace_id_(w.trace_id_), + distDebugLevel_(w.distDebugLevel_) { + exception_ = w.exception_; +} + +ProcessGroupNCCL::WorkNCCL::~WorkNCCL() = default; + +bool ProcessGroupNCCL::WorkNCCL::isCompleted() { + if (!ncclComm_->isAborted()) { + checkAndSetException(); + } + return exception() || finishedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::isStarted() { + if (!ncclComm_->isAborted()) { + checkAndSetException(); + } + return exception() || startedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::isSuccess() const { + C10_THROW_ERROR(NotImplementedError, "WorkNCCL::isSuccess() is deprecated"); +} + +void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { + if (exception()) { + // We already have an exception. + return; + } + + auto exception_ptr = checkForNCCLErrors(); + std::unique_lock lock(mutex_); + exception_ = exception_ptr; + if (exception_) { + LOG(ERROR) << logPrefix() << "Collective " << *this + << " raised the following async exception: " + << getExceptionMsgFromExceptionPtr(exception_); + } +} + +const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { + static std::string prefix = c10::str("[Rank ", rank_, "] "); + return prefix; +} + +void ProcessGroupNCCL::WorkNCCL::setException( + std::exception_ptr exception_ptr) { + std::unique_lock lock(mutex_); + exception_ = exception_ptr; +} + +// Helper that checks if the NCCL kernels are completed on the GPUs +bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() { + checkAndSetException(); + return finishedGPUExecutionInternal(); +} + +bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const { + // if timing is disabled we won't have allocated start events + if (!timingEnabled_) { + return false; + } + // Checking the work's corresponding CUDA event's status + if (!ncclStartEvent_->query()) { + return false; + } + return true; +} + +bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { + // Checking the work's corresponding CUDA event's status + // It calls `cudaEventQuery` eventually. Although this seems to be a + // non-blocking call, but we did notice hangs in the past. It can + // hang if another thread is holding the CUDA global context lock. For + // example, when doing a `cudaDeviceSynchronize` or even + // `cudaStreamSynchronize`. + if (!ncclEndEvent_->query()) { + return false; + } + return true; +} + +bool ProcessGroupNCCL::WorkNCCL::checkTimeout( + std::optional timeout) { + STATIC_SCOPED_WAIT_COUNTER( + pytorch.wait_counter.ProcessGroupNCCL__checkTimeout); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + auto workTimeout = timeout ? *timeout : opTimeout_; + + if (timeElapsed < workTimeout) + return false; + + // Timed out + + // There is already an error, we don't override it + if (exception()) + return true; + + std::string exceptionMsg = c10::str( + logPrefix(), + "Watchdog caught collective operation timeout: ", + *this, + " ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + + LOG(ERROR) << exceptionMsg; + std::exception_ptr exception_ptr = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exceptionMsg)); + setException(exception_ptr); + return true; +} + +void ProcessGroupNCCL::WorkNCCL::handleException( + ErrorHandlingMode errorHandling) { + if (exception_) { + auto exceptionMsg = c10::str( + "Some NCCL operations have failed or timed out. Due to the ", + "asynchronous nature of CUDA kernels, subsequent GPU operations ", + "might run on corrupted/incomplete data."); + LOG(ERROR) << logPrefix() << exceptionMsg; + C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException"); + + if (SHOULD_TEAR_DOWN(errorHandling)) { + auto tearDownMsg = c10::str( + "To avoid data inconsistency, we are taking the entire process down."); + LOG(ERROR) << logPrefix() << tearDownMsg; + std::rethrow_exception(exception_); + } + } +} + +void ProcessGroupNCCL::WorkNCCL::synchronize() { + // Call Synchronize without a timeout. We use this method to avoid adding a + // timeout argument to the public synchronize API. + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { + auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); + // Block the current stream on the NCCL stream + ncclEndEvent_->block(currentStream); + + if (avoidRecordStreams_) { + stashed_for_allocator_safety_->clear(); + } +} + +// Waiting on the work's corresponding CUDA events +void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { + synchronizeStream(); + + // In case of blocking, wait for the operation to complete. + if (blockingWait_) { + while (!isCompleted()) { + bool timedOut = checkTimeout( + timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); + // Explicitly abort ncclComms here before throwing this timed out + // exception to users. + // If throwing timed out excepiton without aborting nccl communicators + // here, it was observed that CUDA GPU will have 100% utilization and + // can not run new events successfully. + if (timedOut) { + std::string exceptionMsg = c10::str( + logPrefix(), + "Work ", + (*this), + " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1)."); + LOG(ERROR) << exceptionMsg; + break; + } + // Yield + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + // exception() includes timeout and error during blocking wait + if (exception()) { + // Abort NCCL communicators + abort(); + // Throw exception (from main thread here) + handleException(TearDown); + } + } + + // Device synchronize only after we've completed timeout checks. + if (barrierTensor_.defined()) { + // If we use the work to do barrier, we should block here + // `dist.barrier()` only requires all CPU processes to enter this + // function, hence we only need to make sure the dummy all-reduce has + // completed. So we would only need to sync the **current stream** back to + // host, and do not need to synchronize the entire device (which may have + // kernels running on other streams). + // Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can: + // - lower chance of hang; + // - CurrentCUDAStream is usually the context of the next operation in + // Python, thus blocking current stream would already block the next + // compute kernel; + // - achieve better barrier performance. + auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); + // CUDAStream wrapper will correctly use a DeviceGuard here + currentStream.synchronize(); + } +} + +// Same as calling synchronize(). +bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { + RECORD_PARAM_COMMS( + static_cast(this->seq_), // seq + std::make_tuple(pgUID_, pgDesc_), // PG name tuple + rank_, // rank + "wait", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, + -1, + static_cast(1)); // number of device? + synchronizeInternal(timeout); + // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL + // upgrade. Once a NCCL version is qualified, this code should not be needed + // at runtime. +#ifdef PGNCCL_ENABLE_HASH + if (distDebugLevel_ >= DebugLevel::Detail) { + auto numel = getTensorsNumel(*outputs_); + auto hashValue = hashTensors(*outputs_); + PRINT_COLLECTIVE_HASH_SIGNATURE( + "output", opTypeToString(opType_), numel, hashValue); + } +#endif + // Always return true, because abort API is not implemented. + return true; } -ProcessGroup::ProcessGroup( - const c10::intrusive_ptr<::c10d::Store>& store, +void ProcessGroupNCCL::WorkNCCL::abort() { + // Abort all communicators of this work + ncclComm_->ncclCommAbort(); + + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.erase(ncclComm_); + ncclCommDevIdxMapMutex.unlock(); +} + +ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} + +// CUDA event is used to record the start/end of one Work. +// Instead of let the CUDA event gets destroyed, we now reuse it after the Work +// has been erased from workMetaList_. +// This is to avoid the potential deadlock caused by CudaEventDestroy. +std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( + bool timing) { + auto deleter = [this, timing](at::cuda::CUDAEvent* event) { + std::lock_guard lock(this->cacheMutex_); + this->eventsArray_[timing ? 1 : 0].push_back(event); + }; + at::cuda::CUDAEvent* event = nullptr; + { + std::lock_guard lock(cacheMutex_); + auto events = eventsArray_[timing ? 1 : 0]; + if (!events.empty()) { + event = events.back(); + events.pop_back(); + } + } + if (!event) { + event = new at::cuda::CUDAEvent( + timing ? cudaEventDefault : cudaEventDisableTiming); + } + return std::shared_ptr(event, std::move(deleter)); +} + +ProcessGroupNCCL::CUDAEventCache& ProcessGroupNCCL::CUDAEventCache::get() { + static ProcessGroupNCCL::CUDAEventCache cache; + return cache; +} + +static std::atomic process_group_id = 0; + +constexpr const char* MULTI_DEVICE_ERROR_MSG = + "Expecting one tensor only but got multiple. You are probably using multiple " + "devices under one thread. The support for such usage has been deprecated. " + "For details, please refer to " + "https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. " + "ProcessGroupNCCL continues supporting multi-process and multi-thread modes."; + +ProcessGroupNCCL::ProcessGroupNCCL( + const c10::intrusive_ptr& store, int rank, int size, c10::intrusive_ptr options) - : store_(store), - rank_(rank), - size_(size), - options_(std::move(options)), - backendType_(strToBackendType(options_->backend)), - dist_debug_level_(debug_level()) { - C10_LOG_API_USAGE_ONCE("c10d.process_group"); + : Backend(rank, size), + store_(store), + options_(options), + ncclCommCounter_(0), + traceKeyStart_(getTraceStartKey("NCCL", rank)), + traceKeyEnd_(getTraceEndKey("NCCL", rank)), + terminateProcessGroup_(false), + terminateHeartbeatMonitorThread_(false), + collectiveDebugInfoMode_(false), + local_id_(process_group_id++), + intraNodeComm_(initIntraNodeComm()) { + TORCH_CHECK_WITH( + ValueError, + at::cuda::getNumGPUs() != 0, + "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); + + // getNcclVersion needs to get called before launching threads which can + // potentially call getenv. getNcclVersion internally calls setenv to set some + // environment variables from config file, which can race with getenv from + // other threads and cause segfaults. + const auto ncclVersion = getNcclVersion(); + this->setGroupUid(options_->group_name); + this->localDeviceCount_ = at::cuda::getNumGPUs(); + logPrefix_ = createLogPrefix(); + blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false); + asyncErrorHandling_ = static_cast( + getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); + desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || + (dist_debug_level_ >= DebugLevel::Detail); + rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true); + // TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT + // or change its name to reflect that dump happens on exception including + // both timeout and other errors. + dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) || + (dist_debug_level_ >= DebugLevel::Detail); + sleepAfterException_ = getCvarBool(TORCH_NCCL_SLEEP_AFTER_EXCEPTION, false); + // logging C++ stack isn't safe. Introduce a variable to control it. + logCppStackOnUncleanShutdown_ = + getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true); + enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false); + heartbeat_ = 1ULL; + monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true)); + cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, false)); + heartbeatTimeoutInSec_ = + getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/); + waitTimeoutDumpInMilSec_ = + getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/); + coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000); + ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0); + enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); + // store_ usually is wrapped with PrefixStore and the prefix is different + // across different ProcessGroupNCCL(PG) instances. We need to get the + // underlying non-PrefixStore for sharing global information shared across + // different PGs. + PrefixStore* prefixStore = dynamic_cast(store_.get()); + globalStore_ = + prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; +#ifdef ENABLE_NCCL_ERROR_CHECKING + enableTiming_.store( + getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); +#endif + avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false); +#ifdef NCCL_HAS_COMM_REGISTER + useTensorRegisterAllocatorHook_ = + getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); + if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + expandable_segments()) { + useTensorRegisterAllocatorHook_ = false; + LOG(INFO) + << logPrefix() + << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; + } +#endif + + if (blockingWait_) { + if (asyncErrorHandling_ != NoHandling || desyncDebug_) { + LOG(INFO) + << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and " + << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG" + << "should not both be enabled. " + << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process."; + asyncErrorHandling_ = NoHandling; + desyncDebug_ = false; + } + } else { + if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { + LOG(INFO) + << logPrefix() + << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " + << "must both be enabled. " + << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING."; + asyncErrorHandling_ = SkipCleanUp; + } + } + +#ifdef ENABLE_NCCL_ERROR_CHECKING + ncclCommWatchdogThread_ = + std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); +#endif + + init(); + const std::string OFF = "OFF"; + std::string torch_distributed_debug = + getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: " + << "size: " << size << ", global rank: " << globalRank() + << ", TIMEOUT(ms): " << options_->timeout.count() + << ", USE_HIGH_PRIORITY_STREAM: " + << options_->is_high_priority_stream + << ", SPLIT_FROM: " << options_->split_from + << ", SPLIT_COLOR: " << options_->split_color + << ", PG Name: " << options_->group_name; + + LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: " + << "NCCL version: " << ncclVersion + << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ + << ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeoutOrEx_ + << ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: " + << waitTimeoutDumpInMilSec_ + << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_ + << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load() + << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_ + << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug +#ifdef NCCL_HAS_COMM_REGISTER + << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " + << useTensorRegisterAllocatorHook_ +#endif + << ", TORCH_NCCL_ENABLE_MONITORING: " + << monitorThreadEnabled_.load() + << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_ + << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_ + << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_ + << ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_ + << ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_ + << ", TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: " + << logCppStackOnUncleanShutdown_; + + if (options_->global_ranks_in_group.empty()) { + this->globalRankStart = 0; + } else { + this->globalRankStart = options_->global_ranks_in_group[0]; + } + + if (options_->global_ranks_in_group.empty()) { + this->globalRankStride = 1; + } else if (options_->global_ranks_in_group.size() == 1) { + this->globalRankStride = 0; + } else { + bool ranksAreStrided = true; + int startRank = options_->global_ranks_in_group[0]; + int stride = + options_->global_ranks_in_group[1] - options_->global_ranks_in_group[0]; + for (std::vector::size_type i = 0; + i < options_->global_ranks_in_group.size(); + i++) { + if (options_->global_ranks_in_group[i] != startRank + i * stride) { + ranksAreStrided = false; + break; + } + } + + if (ranksAreStrided) { + this->globalRankStride = options_->global_ranks_in_group[1] - + options_->global_ranks_in_group[0]; + } else { + this->globalRankStride = -1; + } + } + + // Attach hooks to cache allocator to trigger the hooks whenever a traced + // action is called. In the following hooks, we register a newly allocated + // segment when SEGMENT_ALLOC action occurs, and deregister a segment when + // SEGMENT_FREE action occurs. + // We attach hooks only once at the first PG creation. + // Attaching hooks fails if CUDACachingAllocator is not initialized, so + // lazyInitCUDA is called (and is a no-op if CUDA is already initialized). + if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { + at::globalContext().lazyInitCUDA(); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorRegisterHook); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorDeregisterHook); + allocatorHooksAttached = true; + } +} + +void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { + const auto key = getKeyFromDevice(device); + LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device " + << device; + getNCCLComm(key, device, OpType::ALLREDUCE); +} + +void ProcessGroupNCCL::performNocolorSplit(at::Device device) { + // If our backend doesn't support splitting, this is a no-op for + // ranks not in the new subgroup (and ranks that would be in it will + // just use a new communicator rather than split). +#ifdef NCCL_HAS_COMM_SPLIT + const auto key = getKeyFromDevice(device); + LOG(INFO) << logPrefix() << "Performing nocolor split on backend device " + << device << ", key " << key << ", i am " << this; + auto comm = getNCCLComm(key, device, OpType::ALLREDUCE); + NCCLComm::split( + comm.get(), + NCCL_SPLIT_NOCOLOR, + rank_, + options_->config, + options_->global_ranks_in_group); +#endif } -ProcessGroup::ProcessGroup(int rank, int size) - : rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {} +c10::intrusive_ptr ProcessGroupNCCL:: + initIntraNodeComm() { + using IntraNodeComm = intra_node_comm::IntraNodeComm; + if (!IntraNodeComm::isEnabled()) { + return nullptr; + } + auto prefixStore = c10::make_intrusive("IntraNodeComm", store_); + auto comm = c10::make_intrusive(prefixStore, rank_, size_); + if (comm->rendezvous()) { + return comm; + } else { + return nullptr; + } +} -ProcessGroup::~ProcessGroup() = default; +void ProcessGroupNCCL::setSequenceNumberForGroup() { +} // NCCL just starts sequence numbers at 0. -void ProcessGroup::init() { - C10_LOG_API_USAGE_ONCE( - fmt::format("c10d.process_group_{}", getBackendName())); +uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { + return seqCollective_; } -const std::string& ProcessGroup::getGroupName() const { - TORCH_CHECK(!deviceTypeToBackend_.empty(), "ProcessGroup name not set"); - return deviceTypeToBackend_.begin()->second->getGroupUid(); +void ProcessGroupNCCL::registerOnCompletionHook( + std::function)>&& hook) { + TORCH_CHECK_WITH( + DistBackendError, + onCompletionHook_ == nullptr, + "ProcessGroupNCCL OnCompletion hook already registered"); + + TORCH_CHECK_WITH( + ValueError, + enableTiming_.load(), + "ProcessGroupNCCL OnCompletion hook requires recording start and end " + "events which require setting TORCH_NCCL_ENABLE_TIMING environment variable. " + "This is only available for NCCL version >= 2.4."); + onCompletionHook_ = std::move(hook); + onCompletionHookThread_ = std::thread(&ProcessGroupNCCL::runHookLoop, this); } -void ProcessGroup::setGroupName(const std::string& name) { - for (auto& kv : deviceTypeToBackend_) { - kv.second->setGroupUid(name); +// must release GIL when calling this method +void ProcessGroupNCCL::waitForPendingWorks() { + // Reasoning about hook completion: + // 1. waitForPendingWorks should be called after user code has finished + // calling + // all collectives. This means, when we got here, all of the collectives + // are either in workMetaList_ or has been erased from workMetaList_. + // 2. The watchdog thread grabs both locks to move Work object from the + // workMetaList_ to the completedWorkList_, and the hook thread only erases + // a Work object after the hook is returned. Therefore, after user code + // calls a collective, its Work object is either in workMetaList_ or in + // completedWorkList_ before it finishes. + // 3. We have three threads and two locks. + // a. main thread (this function) grabs two locks atomically + // b. watchdog thread (watchdogHandler function) always grabs + // workMetaListMutex_ + // first and then grabs completedWorkListMutex_. + // c. hook thread (runHookLoop function) only grabs + // completedWorkListMutex_. Therefore, locks are always acquired in the + // same order and hence no deadlocks. + while (true) { + { + std::lock(workMetaListMutex_, completedWorkListMutex_); + std::lock_guard lockWork(workMetaListMutex_, std::adopt_lock); + std::lock_guard lockHook( + completedWorkListMutex_, std::adopt_lock); + + if (workMetaList_.empty() && completedWorkList_.empty()) { + return; + } + } + + std::this_thread::sleep_for( + std::chrono::milliseconds(kWatchdogThreadSleepMillis)); } } -const std::string& ProcessGroup::getGroupDesc() const { - return pg_desc_; +void ProcessGroupNCCL::enableCollectivesTiming() { + enableTiming_.store(true); +} + +void ProcessGroupNCCL::waitForFutureOrTimeout( + std::future& fut, + const std::chrono::milliseconds& timeOutMilSec, + const std::string& futDescription, + bool throwException, + bool log) { + std::string errorMsg; + + ::c10d::C10dLoggingData data; + if (log) { + data.integers["pg_id"] = local_id_; + data.integers["rank"] = rank_; + data.integers["global_rank"] = globalRank(); + data.strings["flight_recorder_version"] = c10d::version_val_str; + } + + TORCH_CHECK(fut.valid(), "Expected a valid future"); + std::future_status status = fut.wait_for(timeOutMilSec); + if (status == std::future_status::ready) { + // Calling .get() will re-raise any exception from the future, and we don't + // care about the retval + try { + bool result = fut.get(); + if (result) { + LOG(INFO) << logPrefix() + << "future is successfully executed for: " << futDescription; + if (log) { + data.strings["status"] = "SUCCESS"; + } + } + } catch (const std::exception& e) { + errorMsg = c10::str( + logPrefix(), + "Exception thrown when waiting for future ", + futDescription, + ": ", + e.what()); + if (log) { + data.strings["status"] = "EXCEPTION"; + data.strings["exception"] = e.what(); + } + LOG(ERROR) << errorMsg; + } catch (...) { + errorMsg = c10::str( + logPrefix(), + "Unknown exception thrown when waiting for future ", + futDescription); + if (log) { + data.strings["status"] = "EXCEPTION"; + data.strings["exception"] = "Unknown exception"; + } + LOG(ERROR) << errorMsg; + } + } else { + errorMsg = c10::str( + logPrefix(), + "Future for ", + futDescription, + " timed out after ", + timeOutMilSec.count(), + " ms"); + data.strings["status"] = "TIMEOUT"; + LOG(ERROR) << errorMsg; + } + if (log) { + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(data); + } + } + if (throwException && !errorMsg.empty()) { + C10_THROW_ERROR(DistBackendError, errorMsg); + } } -void ProcessGroup::setGroupDesc(const std::string& name) { - pg_desc_ = name; - // Also set the group desc for all backends - for (auto& kv : deviceTypeToBackend_) { - kv.second->setGroupDesc(name); +void ProcessGroupNCCL::abortCommsFromMap( + std::unordered_map>& ncclCommsMap, + std::optional abortReason) { + // The process may control multiple devices, loop through the communicators on + // each device + for (auto& it : ncclCommsMap) { + auto& devName = it.first; + auto& ncclComm = it.second; + at::cuda::OptionalCUDAGuard gpuGuard; + at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); + if (deviceIndex >= 0) { + // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in + // the map could be non deviceIndex, but rank to rank numbers. So we + // indeed need to check if deviceIndex >= 0 + // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` + gpuGuard.set_index(deviceIndex); + } + LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " + << ncclComm->ncclComm_ << " on CUDA device: " << devName; + ncclComm->ncclCommAbort(abortReason); + // Note that we don't remove the aborted communicators from the + // cache. The reason is that if we do remove the communicator + // from the cache, it is possible that a new collective operation + // calls `ncclCommInitRank` to create a new communicator whereas + // other ranks might have failed/timed out and didn't enter + // `ncclCommInitRank`. As a result, when there is a failure on + // a communicator the application receives an exception and its + // their responsibility to destroy the process group and recreate + // it to recover from errors. + + LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed " + << " communicator on CUDA device: " << devName; } } -void ProcessGroup::enableCollectivesTiming() { - for (auto& kv : deviceTypeToBackend_) { - kv.second->enableCollectivesTiming(); +// Abort all communicators on this rank +bool ProcessGroupNCCL::abort(std::optional abortReason) { + // This will log counter for how long the abort actually takes. + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); + // Remove record from global ncclCommDevIdxMapMutex before aboarting, + // so that a new cache segment would not register to already aborded + // communicators. Note that ncclCommDevIdxMap is a global container which may + // contain other PG's communicators, thus we need to only erase communicators + // for the current PG. + ncclCommDevIdxMapMutex.lock(); + for (auto& it : devNCCLCommMap_) { + auto& ncclComm = it.second; + ncclCommDevIdxMap.erase(ncclComm); } + ncclCommDevIdxMapMutex.unlock(); + + std::lock_guard lock(mutex_); + abortCommsFromMap(devNCCLCommMap_, abortReason); + abortCommsFromMap(inInitializationCommMap_, abortReason); + return true; } -void ProcessGroup::release_resources() { - store_.reset(); - deviceTypeToBackend_.clear(); - backendTypeToBackend_.clear(); +void ProcessGroupNCCL::shutdown(std::optional reason) { + // Don't join threads here since the purpose of this method is to abort all + // communicators and signal the threads to exit. Joining on the threads could + // potentially block and hence avoid it in this method. + terminateProcessGroup_.store(true); + workMetaListCV_.notify_one(); + + // lauch abort asynchrounously and wait for it to complete or timeout + LOG(INFO) << logPrefix() + << "Launching ProcessGroupNCCL abort asynchrounously."; + std::future fut = std::async( + std::launch::async, [this, &reason]() { return this->abort(reason); }); + + waitForFutureOrTimeout( + fut, options_->timeout, "ProcessGroup abort", true, false); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully."; + + // We need to wait for abort to finish before we can safely shut down + // heartbeat monitoring thread. + terminateHeartbeatMonitorThread_.store(true); + monitorWakeUpCV_.notify_one(); } -} // namespace c10d +ProcessGroupNCCL::~ProcessGroupNCCL() { + LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered."; + + if (!terminateProcessGroup_.load()) { + if (rank_ % localDeviceCount_ == 0) { + TORCH_WARN_ONCE( + "WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. ", + "On normal program exit, the application should call destroy_process_group to ", + "ensure that any pending NCCL operations have finished in this process. " + "In rare cases this process can exit before this point and block the progress of " + "another member of the process group. This constraint has always been present, " + " but this warning has only been added since PyTorch 2.4"); + } + // If user haven't explicitly destroy/shutdown process group, destructor + // needs to do so + shutdown(); + } + + // Wait for all threads to finish before returning +#ifdef ENABLE_NCCL_ERROR_CHECKING + if (ncclCommWatchdogThread_.joinable()) { + ncclCommWatchdogThread_.join(); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; + } + if (ncclHeartbeatMonitorThread_.joinable()) { + ncclHeartbeatMonitorThread_.join(); + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL heart beat monitor thread joined."; + } +#endif + if (onCompletionHookThread_.joinable()) { + onCompletionHookThread_.join(); + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL onCompletionHookThread thread joined."; + } +} + +bool ProcessGroupNCCL::dumpDebuggingInfo() { + // Serialize all calls to this function to avoid corrupting data, but allow + // multiple calls in one runtime. User is responsible for preserving the + // output file from an earlier call before a later call overwrites it. + static std::mutex writeDebugInfoMutex; + std::lock_guard lock(writeDebugInfoMutex); + LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; + if (ncclTraceBufferSize_ > 0) { + // We dump nccl trace into local disk by default and users can register + // their customized writer by inheriting `DebugInfoWriter` via + // `registerDebugInfoWriter`. + auto ncclTrace = dump_nccl_trace(true, true, false); + DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to " + << writer.getWriterTarget(); + writer.write(ncclTrace); + return true; + } + return false; +} + +void ProcessGroupNCCL::terminateProcess(std::string errMsg) { + // Logging with `FATAL`, after errMsg printed, it calls `std::abort()` + // to terminate the program execution. + LOG(FATAL) << logPrefix() << errMsg; +} + +int computeDeltaMS( + std::chrono::time_point start, + std::chrono::time_point end) { + return std::chrono::duration_cast(end - start) + .count(); +} + +std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg( + const std::string& extraMsg) { + return c10::str( + logPrefix(), + "Received a dump signal due to a collective timeout from ", + extraMsg, + " and we will try our best to dump the debug info. ", + "Last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + ".", + "This is most likely caused by incorrect usages of collectives, e.g., wrong ", + "sizes used across ranks, the order of collectives is not same for all ranks ", + "or the scheduled collective, for some reason, didn't run. Additionally, ", + "this can be caused by GIL deadlock or other reasons such as network errors or ", + "bugs in the communications library (e.g. NCCL), etc. "); +} + +std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg( + const std::string& exitReason) { + return c10::str( + logPrefix(), + "Terminating the process after attempting to dump debug info, due to ", + exitReason, + "."); +} + +void ProcessGroupNCCL::heartbeatMonitor() { + c10::setThreadName("pt_nccl_heartbt"); + + uint64_t heartBeatCounter = 0ULL; + std::string errorMsg; + std::string exitReason; + bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0); + int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_ + : heartbeatTimeoutInSec_ * 1000; + auto lastTimePollStore = std::chrono::steady_clock::now(); + auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now(); + std::optional dumpPipe = std::nullopt; + if (local_id_ == 0) { + // DumpPipe is one per-trainer process, and its convenient to name them + // after 'global' ranks in the system, So we assume processgroup (uid)==0 is + // the global PG and has globally unique rank ids across trainers. + dumpPipe.emplace(rank_); + } + while (true) { + // This won't have any lock since this lock is only used here. + // Please be aware that mutex `monitorMutex_` should not be used + // somewhere else to avoid the deadlock. + std::unique_lock lock(monitorMutex_); + if (monitorWakeUpCV_.wait_for( + lock, std::chrono::milliseconds(monitorPollInterval), [&] { + return terminateHeartbeatMonitorThread_.load(); + })) { + // For the normal complete or user interception, monitorWakeUpCV_ + // will get notified, we early return and exit heartbeatMonitor. + return; + } + auto currentTime = std::chrono::steady_clock::now(); + + // We put extra functionality in the thread for the default PG (aka, + // local_id_=0) because the signal is same across different PGs. We only + // need to run once per process to avoid duplicate things performed in too + // many separate threads. For example, we check a global flag on the + // TCPStore periodically to see if any PG on any rank observed a timeout and + // signaled peers to dump debugging info, and we avoid hammering the + // TCPStore from all PGs on the same rank. + if (checkDumpSignal) { + // There are two scenarios where monitor thread will dump on timeout: + // 1. The current rank is the first to observe a timeout in watchdog. + // (shouldDump_ was set to true by the watchdog thread). + // 2. Other ranks detected the timeout and signal the current rank to + // dump. In addtion, monitor threads will dump if watchdog threads has no + // heartbeat or dumpPipe is not empty. + if (shouldDump_.load()) { + errorMsg = getNCCLWatchdogTimeoutErrorMsg("this local rank"); + exitReason = "collective timeout or exception"; + break; + } + // We poll store to see if some ranks have flagged a timeout when + // we haven't polled for `heartbeat_timeout` seconds and there haven't + // any work added or removed for `watchdog_timeout` seconds. + if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >= + kWatchdogThreadSleepMillis && + computeDeltaMS(lastTimePollStore, currentTime) >= + coordCheckIntervalMilSec_) { + lastTimePollStore = currentTime; + // Wrap globalStore_->check() in a try-catch block to avoid crashing if + // the store is not available. + bool checkExceptionDump = false; + try { + checkExceptionDump = + globalStore_->check({std::string(EXCEPTION_DUMP)}); + } catch (const std::exception& e) { + LOG(WARNING) + << logPrefix() + << "Failed to check the \"should dump\" flag on TCPStore, " + << "(maybe TCPStore server has shut down too early), with error: " + << e.what(); + // We give up for now assuming TCPStore has been torn down. + return; + } + + if (checkExceptionDump) { + int timeOutRank = -1; + if (!shouldDump_.load()) { + LOG(ERROR) + << logPrefix() + << "Observed flight recorder dump signal from another rank via TCPStore."; + } + shouldDump_.store(true); + try { + auto vec = globalStore_->get(std::string(EXCEPTION_DUMP)); + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == sizeof(int), + "Invalid size for the timeout rank ID"); + std::memcpy(&timeOutRank, vec.data(), vec.size()); + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() + << "Failed to get timeout rank ID from TCPStore." + << e.what(); + } + errorMsg = + getNCCLWatchdogTimeoutErrorMsg(c10::str(" rank ", timeOutRank)); + exitReason = "collective timeout or exception"; + break; + } + } + } + + if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >= + heartbeatTimeoutInSec_ * 1000) { + // Check the heart beat of watchdog thread. + lastTimeHeartBeatCheck = currentTime; + auto heartbeat = heartbeat_.load(); + if (heartbeat != heartBeatCounter) { + heartBeatCounter = heartbeat; + } else { + shouldDump_.store(true); + // Watchdog heartbeat timeout. + errorMsg = c10::str( + logPrefix(), + "ProcessGroupNCCL's watchdog got stuck for ", + heartbeatTimeoutInSec_, + " seconds without making progress in monitoring enqueued collectives. ", + "This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, ", + "and could be triggered by another thread holding the GIL inside a ", + "CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.", + "If you suspect the watchdog is not actually stuck and a longer timeout would help, ", + "you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value " + "or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)." + "If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout " + "or false positive abort; otherwise, please attempt to debug the hang. "); + exitReason = "ProcessGroupNCCL watchdog hang"; + break; + } + } + // process a request to dump the trace. only PG uid 0 will respond to dump + // requests, but this is fine since all PG's feed into the same flight + // recorder and dump. After dump, the training should continue. + if (dumpPipe.has_value() && dumpPipe->shouldDump()) { + // best effort dump, not waiting for the dump here + std::future fut = std::async( + std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); + } + } + LOG(ERROR) << errorMsg; + + auto& cpp_dumper = get_cpp_trace_dumper(); + if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) { + LOG(INFO) << "Dumping c++ stacktraces:"; + cpp_dumper.value()([](const std::string& line) { LOG(INFO) << line; }); + } + + if (checkDumpSignal && shouldDump_.load()) { + // Store debug info to storage if no other thread does it. (By default to + // local disk) + std::future asyncDebugDump = std::async( + std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); + + // wait for the dump until timeout - log data + waitForFutureOrTimeout( + asyncDebugDump, + std::chrono::milliseconds(waitTimeoutDumpInMilSec_), + "Flight recorder dump in heartbeatMonitor", + false, + true); + } + + if (get_gil_checker() != nullptr) { + auto fut = launchAsyncGilCheck(); + auto kGilCheckTimeout = std::chrono::milliseconds(300); + auto futStatus = fut.wait_for(kGilCheckTimeout); + if (futStatus != std::future_status::ready) { + TORCH_CHECK( + futStatus != std::future_status::deferred, + "Expected the future to have been launched eagerly."); + LOG(ERROR) + << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang"; + } + } else { + LOG(INFO) + << "GIL checker was not registered, perhaps this is a no-python build?"; + } + + // There are two possible cases for the watchdog thread exit: + // Case one: desync report runs quickly, and it follows the step: + // collective timeout -> desync -> exception handling -> destructors + // -> set terminateHeartbeatMonitorThread_ -> notify monitorWakeUpCV_. + // So the code either early returns above or will skip the sleep below. + // Case two: desync might be slow or get stuck. Or we get stuck in + // destructors, we will sleep for some time before calling std::abort() to + // kill the whole process. + if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() || + shouldDump_.load()) && + !terminateHeartbeatMonitorThread_.load()) { + // Leave another two mins for desync report generation or process group + // destroy. + std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_)); + LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ + << " waiting for desync report or process group destroy."; + } + + // At this point, we either already sleep for another `heartbeatTimeoutInSec_` + // or the thread has finished. Because we don't want to block the monitor + // thread, so We mark the thread detach and the dump of debug info becomes + // "best effort". If the process exit normally, marking it detach also makes + // sense because we don't really care about dumping the debug info. + + // We already log completion inside the thread, so it may not be necessary to + // check the return value here. We mainly use a future so we can exit early + // if done. + + if (!terminateHeartbeatMonitorThread_.load()) { + // Create a error message reported from MonitorThread, so + // we throw exception and make the whole process to be killed. + // TODO(fduwjj): After having a hang debug wiki, we need to update the wiki + // url here. + if (monitorThreadEnabled_.load()) { + terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason)); + } else { + // Ideally we want to merge this one with the above one, but we are going + // to remove the kill switch for monitor thread soon, so we keep this one + // for now. + LOG(ERROR) + << logPrefix() + << "ProcessGroupNCCL monitor thread is disabled, but would have terminated the process" + << "after attempting to dump debug info, due to " << exitReason + << "."; + } + } +} + +void ProcessGroupNCCL::ncclCommWatchdog() { + c10::setThreadName("pt_nccl_watchdg"); + + try { + VLOG(2) << logPrefix() << "Process group watchdog thread started!"; + ncclHeartbeatMonitorThread_ = + std::thread(&ProcessGroupNCCL::heartbeatMonitor, this); + watchdogHandler(); + VLOG(2) << logPrefix() + << "Process group watchdog thread terminated normally"; + } catch (std::exception& e) { + if (std::string(e.what()).find("driver shutting down") != + std::string::npos) { + LOG(INFO) + << logPrefix() + << "main process destroyed cuda before watchdog loop exited, terminating watchdog." + << " (Watchdog caught exception: " << e.what(); + + } else { + // Append error message reported from watchdogHandler + const auto exitMsg = c10::str( + logPrefix(), + "Process group watchdog thread terminated with exception: ", + e.what()); + LOG(ERROR) << exitMsg; + if (C10_LIKELY(rethrowCUDAErrors_) || + !(std::string(e.what()).find("CUDA Error"))) { + // TODO(whc) clean up the rethrow - why is it stored in a class var and + // rethrown? + watchDogException_ = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); + std::rethrow_exception(watchDogException_); + } + } + } catch (...) { + const auto exitMsg = c10::str( + logPrefix(), + "Process group watchdog thread terminated with exception: unknown"); + LOG(ERROR) << exitMsg; + watchDogException_ = + std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); + std::rethrow_exception(watchDogException_); + } +} + +void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) { + if (work.startTraceUpdated_) + return; + + if (terminateProcessGroup_.load() || storeError_) + return; + + work.startTraceUpdated_ = true; + storeError_ = !c10d::traceUpdate( + store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_)); +} + +void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) { + if (terminateProcessGroup_.load() || storeError_) + return; + + // In case the start of the work hasn't been logged + if (!work.startTraceUpdated_) { + logWorkStart(work); + } + + storeError_ = !c10d::traceUpdate( + store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_)); +} + +std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() { + return retrieveDesyncReport(store_, "NCCL", rank_, size_); +} + +// We want to have both PG ID and global unique ID (guid) for the logging +// prefix. PG ID records how many ProcessGroupNCCL objects were created on a +// specific rank and is a stable index across ranks, which lets users reason +// about, for example, the second PG we initialized on this rank is for FSDP, +// and corresponds with PG ID = 1 on other ranks as well. Unlike PG ID, guid (or +// group name) is a global unique ID across ranks. The guid is either a hash of +// all the ranks in the group or a counter of how many times +// `_process_group_name` is called, essentially it means how many times we +// have PGs users have created. Before using split_group, even if +// we are creating a new sub-PG, all ranks have to call the API at the same +// time, and this makes `group_name` a unique identifier for a group (PG). +std::string ProcessGroupNCCL::createLogPrefix() const { + if (!pg_desc_.empty() && pg_desc_ != "undefined") { + return c10::str( + "[PG ID ", + local_id_, + " PG GUID ", + pg_uid_, + "(", + pg_desc_, + ") Rank ", + rank_, + "] "); + } + return c10::str( + "[PG ID ", local_id_, " PG GUID ", pg_uid_, " Rank ", rank_, "] "); +} + +const std::string& ProcessGroupNCCL::logPrefix() const { + return logPrefix_; +} + +const int& ProcessGroupNCCL::globalRank() const { + static int globalRank = rank_; + return globalRank; +} + +const std::vector& ProcessGroupNCCL::groupRanks() const { + if (options_->global_ranks_in_group.empty() && local_id_ == 0) { + static std::vector globalRanks(size_); + std::iota(globalRanks.begin(), globalRanks.end(), 0); + return globalRanks; + } + return options_->global_ranks_in_group; +} + +void ProcessGroupNCCL::addEphemeralTimeout( + const std::chrono::milliseconds& timeout) { + std::lock_guard timeoutLock(mtxTimeoutExtension_); + ephemeralTimeoutActive_ += timeout; +} + +bool ProcessGroupNCCL::verifyWorkTimeoutForTest( + const c10::intrusive_ptr work, + const std::chrono::milliseconds& timeout) { + // Since collective returns a c10d::Work, we need to cast it to WorkNCCL. + if (auto workNCCL = c10::dynamic_intrusive_pointer_cast(work)) { + // workNCCL is now a c10::intrusive_ptr + return workNCCL->opTimeout_ == timeout; + } + C10_THROW_ERROR( + DistBackendError, "Non c10d::WorkNCCL object returned from collective"); +} + +void ProcessGroupNCCL::watchdogHandler() { + bool done = false; + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); + auto lastStatusUpdateTime = std::chrono::steady_clock::now(); + std::list completedWorkList; + + while (!done || !terminateProcessGroup_.load()) { + std::unique_lock lock(workMetaListMutex_); + // We busy-poll the work vector every kWatchdogThreadSleepMillis + // milliseconds as long as the atomic is True. + workMetaListCV_.wait_for( + lock, + std::chrono::milliseconds(kWatchdogThreadSleepMillis), + [&]() -> bool { return terminateProcessGroup_.load(); }); + // Bump up heart beat by one. + heartbeat_++; + +// Some versions of GLOG support less-spammy version of LOG_EVERY_MS +// in which case we don't want to spam the logs. +#ifdef LOG_EVERY_MS + // Log the progress of this PG periodically + C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str( + logPrefix(), + "NCCL Work update periodically: ", + "last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + "."); +#endif + auto logger = ::c10d::C10dLogger::getLogger(); + if (logger && + computeDeltaMS( + lastStatusUpdateTime, std::chrono::steady_clock::now()) >= + kWorkStatusUpdatePeriodMs) { + ::c10d::C10dLoggingData data; + // logging integers + data.integers["pg_id"] = local_id_; + data.integers["rank"] = rank_; + data.integers["global_rank"] = globalRank(); + data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq; + data.integers["last_started_work"] = pgStatus_->lastStartedSeq; + data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq; + data.integers["last_enqueued_numel_in"] = pgStatus_->lastEnqueuedNumelIn; + data.integers["last_enqueued_numel_out"] = + pgStatus_->lastEnqueuedNumelOut; + data.integers["last_completed_numel_in"] = + pgStatus_->lastCompletedNumelIn; + data.integers["last_completed_numel_out"] = + pgStatus_->lastCompletedNumelOut; + // logging strings + data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName; + data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName; + data.strings["last_completed_work_name"] = + pgStatus_->lastCompletedWorkName; + data.strings["pg_name"] = pg_uid_; + data.strings["pg_desc"] = pg_desc_; + logger->log(data); + lastStatusUpdateTime = std::chrono::steady_clock::now(); + } + + for (auto it = workMetaList_.begin(); it != workMetaList_.end(); + /* no increment */) { + auto& work = *it; + // When terminateProcessGroup_ is true, communicators have already been + // aborted, So cannot check exception based on them. But watchdog needs to + // finish the check for the works that have already been enqueued to + // workMetaList_ + if (!terminateProcessGroup_.load()) { + work.checkAndSetException(); + } + bool timedOut = work.checkTimeout(); + + // If work hits an exception (either an error or timeout) + if (work.exception()) { + // log as soon as exception is detected + LOG(ERROR) << c10::str( + logPrefix(), + "Exception (either an error or timeout) detected by watchdog at work: ", + work.seq_, + ", last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + "."); + // try to notify other ranks via global TCPStore to dump the flight + // recorder when a collective timeout or exception happens. Flight + // recorder behavior is independent of desync Debug. + if (dumpOnTimeoutOrEx_) { + try { + auto rank = globalRank(); + auto vec = std::vector( + reinterpret_cast(&rank), + reinterpret_cast(&rank) + sizeof(rank)); + globalStore_->set(std::string(EXCEPTION_DUMP), vec); + if (!shouldDump_.load()) { + LOG(ERROR) + << logPrefix() + << "Broadcasting flight-recorder dump signal to other processes via TCPStore."; + } + // signal the monitor thread on PG0 to start dumping + shouldDump_.store(true); + if (sleepAfterException_) { + // This sleep is used to give time for dumping before throwing + // exception + std::this_thread::sleep_for( + std::chrono::seconds(heartbeatTimeoutInSec_)); + LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ + << " giving time for flight recorder dumps to finish."; + } + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() + << "Failed to set dump signal in tcpstore. " + << "Error: " << e.what(); + } + } + + if (SHOULD_CLEAN_UP(asyncErrorHandling_)) { + // Abort work and corresponding communicators + work.abort(); + // PG level abort, which would abort all other communicators on this + // rank + abort(); + } + + // Report desync state in case of timeout + if (timedOut) { + LOG(ERROR) << c10::str( + logPrefix(), + "Timeout at NCCL work: ", + work.seq_, + ", last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + "."); + if (desyncDebug_) { + try { + collectiveDebugInfoMode_.store(true); + auto desyncMsg = getNCCLWatchdogDebugInfo(); + LOG(ERROR) << logPrefix() << desyncMsg; + } catch (const std::exception& e) { + LOG(ERROR) + << logPrefix() + << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " + << " Please file an issue. Error: " << e.what(); + } catch (...) { + LOG(ERROR) + << logPrefix() + << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." + << " Please file an issue."; + } + } + } + // Throw exception + work.handleException(asyncErrorHandling_); + } + + // Work status logging for desync debug + if (desyncDebug_) { + if (work.isStarted()) { + logWorkStart(work); + } + if (work.isCompleted()) { + logWorkEnd(work); + } + } + + // a work could be started but not completed, so we should not update + // lastStartedSeq and lastStartedOpName if the work state is checked + // multiple times after the start + if (pgStatus_->lastStartedSeq < static_cast(work.seq_) && + work.isStarted()) { + pgStatus_->lastStartedSeq = work.seq_; + pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); + } + + // Clean up completed work + if (work.isCompleted()) { + { + // Reset the timeout and first work if the work is completed. + std::lock_guard timeoutLock(mtxTimeoutExtension_); + if (work.ownedEphermeralTimeout_.count() > 0) { + ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; + ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; + } + } + pgStatus_->lastCompletedSeq = work.seq_; + pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); + pgStatus_->lastCompletedNumelIn = work.numelIn_; + pgStatus_->lastCompletedNumelOut = work.numelOut_; + NCCLTraceBuffer::get()->retire_id(work.trace_id_, true); + if (onCompletionHook_) { + // Move Work object to completedWorkList_ to be consumed by the hook + // thread + { + const std::lock_guard lock(completedWorkListMutex_); + completedWorkList_.splice( + completedWorkList_.end(), workMetaList_, it++); + } + completedWorkListCV_.notify_one(); + } else { + it = workMetaList_.erase(it); + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); + } + at::cuda::CUDAGraph::dec_pending_event_queries(); + } else { + // Increment the iterator if the current WorkNCCL object is not + // completed. + ++it; + } + // Increment heartbeat after each work processed, + // in case processing is slowed down (but not hung) by cuda api contention + heartbeat_++; + } + done = workMetaList_.empty(); + } +} + +void ProcessGroupNCCL::runHookLoop() { + c10::setThreadName("pt_nccl_runhook"); + + bool done = false; + while (!done || !terminateProcessGroup_.load()) { + std::unique_lock lock(completedWorkListMutex_); + // We busy-poll the work vector every kWatchdogThreadSleepMillis + // milliseconds as long as the atomic is True. + completedWorkListCV_.wait_for( + lock, + std::chrono::milliseconds(kWatchdogThreadSleepMillis), + [&]() -> bool { + return !completedWorkList_.empty() || terminateProcessGroup_.load(); + }); + + try { + for (auto it = completedWorkList_.begin(); it != completedWorkList_.end(); + /* no increment */) { + const WorkNCCL& work = *it; + // Hook might grab GIL, unlock first to prevent deadlock + lock.unlock(); + + auto timeStarted = + std::chrono::system_clock::now() + + std::chrono::duration_cast( + work.workStartTime_ - std::chrono::steady_clock::now()); + onCompletionHook_(std::make_shared( + work.retrieveOpType(), // OpType + work.getSequencenumber(), // seq + timeStarted, // timeStarted + std::chrono::system_clock::now(), // timeFinished + std::chrono::duration( + work.getDuration()) // activeDuration + )); + + lock.lock(); + it = completedWorkList_.erase(it); + } + } catch (std::exception& e) { + if (std::string(e.what()).find("driver shutting down") != + std::string::npos) { + LOG(INFO) + << logPrefix() + << "main process destroyed cuda before runHookLoop exited, terminating runHookLoop." + << " (runHookLoop caught exception: " << e.what(); + + } else { + // PythonOnCompletionHook has already extracted Python exception message + // and wrapped it with a cpp one. So we no longer need to acquire GIL + // here. + const auto errorStr = c10::str( + "Caught exception on rank ", + rank_, + " while running onCompletion hook for ProcessGroupNCCL: ", + e.what(), + ". Aborting all communicators."); + + // No need to call abort() on WorkNCCL here as that collective has + // already finished successfully at this point. We just need to abort + // the process Abort all NCCL Communicators on this ProcessGroupNCCL + // instance. + abort(errorStr); + } + } + + // Lock is still acquired at this point + done = completedWorkList_.empty(); + } +} + +std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors() { + return checkForNCCLErrorsInternal(ncclComm_); +} + +std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors( + std::shared_ptr& ncclComm) { + return checkForNCCLErrorsInternal(ncclComm); +} + +std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( + std::shared_ptr& ncclComm) { + // Prioritize commFailureReason over checkForNcclError() result if + // commFailureReason is set. + auto commFailureReason = ncclComm->getNcclCommFailureReason(); + if (commFailureReason != std::nullopt) { + return std::make_exception_ptr(C10_BUILD_ERROR( + DistBackendError, + c10::str( + "NCCL communicator encountered error set by ProcessGroupNCCL: ", + *commFailureReason))); + } + ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError(); + // When nonblocking mode is enabled by TORCH_NCCL_USE_COMM_NONBLOCKING, + // ncclInProgress could be returned when there are pending NCCL calls. + // In this case, no exception should be thrown +#ifdef NCCL_HAS_COMM_NONBLOCKING + // ncclInProgress is defined only if NCCL_HAS_COMM_NONBLOCKING is defined + if (ncclAsyncErr != ncclSuccess && ncclAsyncErr != ncclInProgress) { +#else + if (ncclAsyncErr != ncclSuccess) { +#endif + return std::make_exception_ptr(C10_BUILD_ERROR( + DistBackendError, + "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" + + getNcclErrorDetailStr(ncclAsyncErr))); + } + + return nullptr; +} + +void ProcessGroupNCCL::broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + bool isSingleP2POp, + const std::string& p2pKey, + int p2pRank) { + // For collective operations: + // For every NCCL communicator that we create we need to broadcast + // a unique ID from rank 0 to all other ranks. This broadcast is + // done by rank 0 setting a key in the store and all other ranks + // retrieving the contents of that key. A single process group + // may create multiple NCCL communicators, so we use a sequence + // number to differentiate between them. + // For single point-to-point operations: + // The sequence number will only be increased on 2 out of all the + // processes in a Process Group. So all following collective + // operations will see different sequence numbers which will cause + // runtime errors. To avoid that, use the src:target pair instead + // of sequence number for p2p communications. + + std::string storeKey; + if (!isSingleP2POp) { + storeKey = std::to_string(ncclCommCounter_++); + } else { + storeKey = p2pKey; + } + if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) { + auto vec = std::vector( + reinterpret_cast(ncclID), + reinterpret_cast(ncclID) + NCCL_UNIQUE_ID_BYTES); + store_->set(storeKey, vec); + } else { + try { + auto vec = store_->get(storeKey); + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == NCCL_UNIQUE_ID_BYTES, + "Invalid size for ncclUniqueId"); + std::memcpy(ncclID, vec.data(), vec.size()); + } catch (const std::exception& e) { + std::string exceptionMsg = c10::str( + "[", + rank_, + "] is setting up NCCL communicator and " + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", + storeKey, + "', but store->get('", + storeKey, + "') got error: "); + C10_THROW_ERROR( + DistBackendError, + exceptionMsg + e.what() + + ". This may indicate a possible application crash on rank 0 or a network set up issue."); + } catch (...) { + C10_THROW_ERROR( + DistBackendError, + c10::str( + "Unknown exception while [", + rank_, + "] is setting up NCCL communicator and " + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", + storeKey, + "'", + ". This may indicate a possible application crash on rank 0 or a network set up issue.")); + } + } +} + +void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { + std::lock_guard lock(mutex_); + if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { + TORCH_INTERNAL_ASSERT( + false, + "Expected to find key ", + devNCCLCommMapKey, + " in NCCL communicator map."); + } + std::shared_ptr& ncclComm = devNCCLCommMap_[devNCCLCommMapKey]; + // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being + // destroyed, so using ncclCommAbort here. + ncclComm->ncclCommAbort(); + // Remove communicators from the cache. + devNCCLCommMap_.erase(devNCCLCommMapKey); + // Clear used device indices. + usedDeviceIdxs_.clear(); + + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.erase(ncclComm); + ncclCommDevIdxMapMutex.unlock(); +} + +std::shared_ptr ProcessGroupNCCL::getNCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank, + bool isSendRecvSelf) { + // Sanity check + if (deviceKey.empty()) { + C10_THROW_ERROR( + DistBackendError, + "Not able to create/get the NCCL Communicator since " + "the GPU devices are not known"); + } + if (bound_device_id_) { + if (*bound_device_id_ != device) { + LOG(ERROR) << logPrefix() << "Tensor found on device " << device + << " but backend constrained to " << *bound_device_id_; + C10_THROW_ERROR( + DistBackendError, + "Attempt to perform collective on tensor not on device passed to init_process_group"); + } + } + + usedDeviceIdxs_.insert(device.index()); + + { + std::lock_guard lock(mutex_); + if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + return devNCCLCommMap_[deviceKey]; + } + } + + // NCCL communicator not cached, create a new entry + std::shared_ptr ncclComm; + + // Create the unique NCCL ID and broadcast it + ncclUniqueId ncclID; + + // reset log prefix to include group_desc + logPrefix_ = createLogPrefix(); + +#ifdef NCCL_COMM_DESCRIPTION + // Pass process group name and description to NCCL communicator + std::string commDesc = pg_desc_ + ':' + pg_uid_; + options_->config.commDesc = strdup(commDesc.c_str()); +#endif + + // For batch_isend_irecv, ncclGroupStart() would be called upfront + bool batchP2P = ncclActiveGroupCounter_ > 0; + bool singleP2POp = isP2POp(opType, batchP2P); + + // Get the device index + auto deviceIndex = device.index(); + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // [Group Start/End Note] This is used to ensure that nccl communicator will + // be created before communication primitives are called. Let's look at this + // example: Using the batch_isend_irecv to send a tensor to a target process. + // On the sender side, the corresponding underlying NCCL calls will look like + // ncclGroupStart() // This is in batch_isend_irecv + // ncclCommInitRank() // Inside NCCLComm::create + // ncclSend() + // ncclGroupEnd() // This is in batch_isend_irecv + // With this pattern, the nccl communicator will be created in the last + // ncclGroupEnd which means when ncclSend is processed, the passed + // communicator argument is NULL which will lead to runtime error. So we need + // to "close" all active nccl groups to ensure nccl communicator is actually + // created before encountering any communication calls. This is why we need + // the following for loop. + for (const auto i : c10::irange(ncclActiveGroupCounter_)) { + (void)i; + // comms have not been initiated yet, so can only check in blocking-way + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + } + + // GPU world size and GPU rank + int numRanks, rank; + + if (!singleP2POp) { + // Collective, all-to-all, or batch P2P + numRanks = getSize(); + rank = getRank(); + } else if (isSendRecvSelf) { + // Same process send and recv. + numRanks = 1; + rank = 0; + } else { + // For single point-to-point operation, there are only 2 processes + // involved so the GPU rank is either 0 or 1. + numRanks = 2; + rank = p2pRank; + } + +#ifdef NCCL_HAS_COMM_SPLIT + if (options_->split_from) { + TORCH_CHECK( + options_->split_color != 0, + "Must specify a non-zero color when splitting"); + // Find a valid, healthy communicator to split from if possible. + std::lock_guard lock(options_->split_from->mutex_); + auto& other_comms = options_->split_from->devNCCLCommMap_; + auto dit = other_comms.find(getKeyFromDevice(device)); + if (dit != other_comms.end()) { + auto& parentComm = dit->second; + if (parentComm != nullptr && !parentComm->isAborted()) { + ncclComm = NCCLComm::split( + parentComm.get(), + options_->split_color, + rank, + options_->config, + options_->global_ranks_in_group); + } + } + } +#endif + + // To simplify conditional nesting, just create the ncclComms[i] + // entry if it hasn't been yet rather than untangling the + // conditions that might have resulted in a split above. + if (!ncclComm) { + if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) { + // For point-to-point communication, lower rank of the two will get unique + // id. + if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { + C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt); + } + + // Broadcast so that each process can have a unique NCCL ID + auto timeStarted = std::chrono::steady_clock::now(); + broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank); + auto timerDeltaMs = + std::chrono::duration_cast>( + std::chrono::steady_clock::now() - timeStarted) + .count() * + 1000; + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL broadcast unique ID through store took " + << timerDeltaMs << " ms"; + } + +#ifdef NCCL_HAS_COMM_NONBLOCKING + ncclComm = NCCLComm::create(numRanks, rank, ncclID, options_->config); +#else + ncclComm = NCCLComm::create(numRanks, rank, ncclID); +#endif + } + + // Creates the NCCL streams + bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); + auto streamVal = at::cuda::getStreamFromPool( + options_->is_high_priority_stream || force_high); + + { + std::lock_guard lock(mutex_); + inInitializationCommMap_.emplace(deviceKey, ncclComm); + } + + NCCLTraceBuffer::get()->record_pg_ranks( + std::make_tuple(pg_uid_, pg_desc_), groupRanks()); + + RECORD_PARAM_COMMS( + 0, // seq + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank, // rank + "init", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + size_); // worldSize + + LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ " + << ncclComm->ncclComm_ << " on CUDA device: " << deviceIndex; + + // At this point NCCL should have been initialized, hence we can accurately + // get the env value even if NCCL sets it by reading from nccl.conf file + LOG(INFO) << logPrefix() + << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A"); + + // See [Group Start/End Note] + for (const auto i : c10::irange(ncclActiveGroupCounter_)) { + (void)i; + C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); + } + + ncclStreams_.emplace(deviceKey, std::move(streamVal)); + + // Note: these events are created with the (default) cudaEventDisableTiming + // flag This flag provides the best performance when used with + // cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't measure the + // performance using cudaEvent, this should be set. + // TODO(kwen2501): is ncclEvents_ used anywhere else? + ncclEvents_.emplace(deviceKey, at::cuda::CUDAEvent(cudaEventDisableTiming)); + + // Move the NCCL resource to cache + auto it = inInitializationCommMap_.find(deviceKey); + // A previous thread could've already removed devicesKey from + // inInitializationCommMap_ and added it to devNCCLCommMap_ + if (it != inInitializationCommMap_.end()) { + devNCCLCommMap_.emplace(deviceKey, std::move(it->second)); + inInitializationCommMap_.erase(deviceKey); + + // Now ncclComms are fully initialized. + // Register all active CUDA memory segments in cache allocator to + // the new NCCL communicators + if (useTensorRegisterAllocatorHook_) { + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + // Register the segment to a new NCCL communicator if on the same device + for (const auto& segmentInfo : snapshot.segments) { + TORCH_INTERNAL_ASSERT( + segmentInfo.device == device.index(), + "Mismatch between CUDA memory segment device and current device"); + ncclComm->registerSegment( + reinterpret_cast(segmentInfo.address), + segmentInfo.total_size); + } + } + // Record the mapping between ncclComm and device index so that later + // register hook can register a newly allocated segment to communicators + // on the same device. + // NOTE: we need remove the communicator from this map when it is + // destroyed, otherwise may register onto an invalid communicator. + ncclCommDevIdxMapMutex.lock(); + ncclCommDevIdxMap.emplace(ncclComm, device.index()); + ncclCommDevIdxMapMutex.unlock(); + } + + it = devNCCLCommMap_.find(deviceKey); + TORCH_INTERNAL_ASSERT( + it != devNCCLCommMap_.end(), "Communicators not populated in cache!"); + + return it->second; +} + +uint64_t ProcessGroupNCCL::getCommSplitCounter() const { + uint64_t ret = 0; + for (const auto& i : devNCCLCommMap_) { + auto& ncclComm = i.second; + ret += ncclComm->getCommSplitCounter(); + } + return ret; +} + +namespace { + +// Check validity of tensor +void check_gpu_single_tensor( + const at::Tensor& tensor, + const bool p2p = false // whether operation is a P2P operation +) { + if (!tensor.is_cuda() || tensor.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + } + // Skip the following requirements for P2P operations + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + if (p2p) { + TORCH_WARN_ONCE( + "Detected non-contiguous tensor in P2P operations. It is user " + "responsibility to guarantee that source and destination tensors have " + "the same contiguity format."); + } else { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } + } +} + +// Checks that all `tensors' have the same type and shape and reside on the same +// GPU. +// TODO: test_c10d_nccl.py should consider adding tests for the error conditions +// here, ie, that deliberately pass invalid tensors and check the right +// exception is thrown. The "Expected list of tensors on the same device" +// condition may be a challenge because the test would need to pass tensors on +// different devices in the same process. +int64_t check_gpu_tensors_same_device(const std::vector& tensors) { + if (tensors.size() == 0) { + C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); + } + + const auto& first = tensors.front(); + + int64_t total_numel = 0; + for (const auto& t : tensors) { + if (!t.is_cuda() || t.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); + } + if (t.scalar_type() != first.scalar_type()) { + C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + } + if (!t.is_non_overlapping_and_dense()) { + C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); + } + // If we're in this function, the user called a _coalesced collective + // on a set of tensors with potentially different sizes and strides. + // Therefore, we don't check for matching sizes and strides, + // but we do double-check tensors are on the same device. + TORCH_CHECK_WITH( + ValueError, + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); + total_numel += t.numel(); + } + + return total_numel; +} + +bool check_same_size(const std::vector& input_tensors) { + for (const auto& input_tensor : input_tensors) { + if (!input_tensors[0].is_same_size(input_tensor)) { + return false; + } + } + return true; +} + +} // namespace + +c10::intrusive_ptr ProcessGroupNCCL::initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle, + const std::vector& inputs, + const std::vector& outputs, // TODO(kwen2501): necessary? + bool record) { + auto r = c10::make_intrusive( + pg_uid_, + pg_desc_, + device, + rank, + opType, + seqCollective_, + profilingTitle, + profilingTitle != nullptr ? std::optional>(inputs) + : std::nullopt, + desyncDebug_, + enableTiming_.load(), + cudaEventCacheEnabled_.load(), + dist_debug_level_); + if (record) { + bool isP2P = isP2POp(opType); + // Ideally record every work that we enqueue, rather than every work we + // create. + // - at the time of this PR we do not currently enqueue every created work + // - but it is unsafe to steal refs to start/end cuda events from Works that + // may go out of scope before flight recorder has retired them, + // so we must ensure that any work that is initialized via initWork will + // be enqueued + // - initially, moved record() into workEnqueue(), but found that makes it + // hard to get access to profilingTitle, + // inputs, and outputs for metadata recording, and we don't want to attach + // these objects to the Work becuase it has implications for keeping those + // tensors alive longer and adds overhead when copying Work objects + // between threads + r->trace_id_ = NCCLTraceBuffer::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), + seqCollective_, + seqP2P_, + op_id_, + profilingTitle ? profilingTitle : "", + inputs, + outputs, + r->ncclStartEvent_.get(), + r->ncclEndEvent_.get(), + options_->timeout, + pgStatus_, + isP2P); + } + return r; +} + +// TODO(kwen2501): deprecate +std::vector ProcessGroupNCCL::WorkNCCL::result() { + return *outputs_; +} + +c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: + getFuture() { + return future_; +} + +float ProcessGroupNCCL::WorkNCCL::getDuration() const { + TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled"); + TORCH_CHECK( + ncclStartEvent_, + "getDuration only works if ncclStartEvents_ is populated, true if timing enabled"); + TORCH_CHECK( + ncclEndEvent_, + "getDuration only works if ncclEndEvents_ is populated, which should always be true"); + return ncclStartEvent_->elapsed_time(*ncclEndEvent_); +} + +uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const { + return seq_; +} + +void ProcessGroupNCCL::assignTimeoutToWork( + const c10::intrusive_ptr& work, + const c10::intrusive_ptr& option) { + std::chrono::milliseconds timeout = option->timeout; + std::lock_guard timeoutLock(mtxTimeoutExtension_); + if (ephemeralTimeoutActive_.count() > 0) { + timeout += ephemeralTimeoutActive_; + } + work->opTimeout_ = timeout; + work->ownedEphermeralTimeout_ = + ephemeralTimeoutActive_ - ephemeralTimeoutInflight_; + ephemeralTimeoutInflight_ = ephemeralTimeoutActive_; +} + +void ProcessGroupNCCL::workEnqueue( + c10::intrusive_ptr work) { + if (!terminateProcessGroup_.load()) { + std::lock_guard lock(workMetaListMutex_); + // Avoid view tensors to be processed in cleanup thread. + // View tensors' destruction invokes autograd_meta, which + // needs to be destructed in user thread. Otherwise will + // get deadlock. Here we enqueue work without outputs_. + workMetaList_.emplace_back(*work); + // update the PG status related to the last enqueued work + pgStatus_->lastEnqueuedSeq = work->seq_; + pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_); + pgStatus_->lastEnqueuedNumelIn = work->numelIn_; + pgStatus_->lastEnqueuedNumelOut = work->numelOut_; + lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); + } +} + +ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) + : Backend::Options(NCCL_BACKEND_NAME, kProcessGroupNCCLDefaultTimeout), + is_high_priority_stream(is_high_priority_stream) {} + +static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; + +void ProcessGroupNCCL::startCoalescing() { + // Other collective ops bump seq_ before creating a work. Thus, if coalesced + // ops bump seq_ only after initing a work they will collide with (reuse) the + // seq_ of the last non-coalesced collective. Previously, seq_ was bumped + // inside endCoalescing, but before initWork. Since we now record individual + // ops from a coalesce group into the flight recorder, we want to have the + // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during + // start, which has one minor downside- we burn a seq_ if someone ever does a + // 'start' and 'end' coalescing region without doing an operation inbetween. + + // Don't bump op_id_ here, because startCoalescing isn't a logical operation. + // Bump it for each logical op inside the coalescing group. + if (coalescing_state_ & CoalP2P) { + seqP2P_++; + } else { + seqCollective_++; + } + + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); +} + +// `optype` is for specifying a composite optype, such as ALLGATHER and +// REDUCE_SCATTER +c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { + if (coalescedComm_ == nullptr) { + // There is no actual work being coalesced, return here + groupEnd(); + coalescing_state_ = 0; + return nullptr; + } + TORCH_CHECK( + coalescedDevice_.index() >= 0, + "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + + // `coalescedComm_` should have same set of comms across collectives + auto comm = coalescedComm_; + // `coalescedDevice_` should have same set of devices across collectives + auto device = coalescedDevice_; + + // `getKeyFromDevice` is how we get keys for both collectives and batch P2P + const auto key = getKeyFromDevice(device); + auto ncclStream = ncclStreams_.at(key); + + // Create Work object + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + bool enqueue = + (coalescing_state_) && capture_status == c10::cuda::CaptureStatus::None; + auto work = + initWork(device, rank_, optype, "nccl:coalesced", {}, {}, enqueue); + work->ncclComm_ = comm; + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams_; + work->store_ = store_; + assignTimeoutToWork(work, options_); + + // Record start before ncclGroupEnd + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + if (nccl_use_nonblocking()) { + groupEndNonblocking(comm); + } else { + groupEnd(); + } + + // Record end after ncclGroupEnd + // TODO(eqy): is this still necessary if avoidRecordStreams_ is set? + work->ncclEndEvent_->record(ncclStream); + + if (avoidRecordStreams_) { + // other functions expect an initialized ptr if avoidRecordStreams_ is set + work->stashed_for_allocator_safety_ = + std::make_shared>(); + } + + // Notify graphs before we check the capture status preemptively + at::cuda::CUDAGraph::inc_pending_event_queries(); + + if (enqueue) { + workEnqueue(work); + } else { + at::cuda::CUDAGraph::dec_pending_event_queries(); + } + + coalescing_state_ = 0; + coalescedComm_ = nullptr; + return work; +} + +c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { + // Default OpType to COALESCED if not specified + return endCoalescing(OpType::COALESCED); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; + nanCheck &= enableNanCheck_; + + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + errorIfCapturingNonCapturableNCCL(capture_status); + + // Bump collective counter + seqCollective_++; + op_id_++; + + auto device = getDevice(inputs[0]); + const auto key = getKeyFromDevice(device); + auto ncclComm = getNCCLComm(key, device, opType); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; + } else { + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + bool enqueue = + !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; + auto work = + initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue); + + // Store references to outputs to be used by WorkNCCL::result and operator<<. + work->outputs_ = std::make_shared>(outputs); + + if (avoidRecordStreams) { + work->stashed_for_allocator_safety_ = + std::make_shared>(inputs); + } + + at::cuda::OptionalCUDAGuard gpuGuard(device); + + if (nanCheck) { + for (const auto& input : inputs) { + checkForNan(input, ncclStream); + } + } + + // Start event should only be recorded before the ncclGroupStart() + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + pre(ncclStream, work); + + ncclComm_t comm = ncclComm->getNcclComm(); + + // Both `inputs' and `outputs' are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // We only record `inputs' here, and leave recording `outputs' to `fn' for + // operations where `inputs' and `outputs' are not the same. + // + // See [Sync Streams]. + if (!avoidRecordStreams) { + for (const auto& input : inputs) { + if (!input.is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + input.values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + input.indices().storage().data_ptr(), ncclStream); + } + } + } + +// Not all collectives have the same signature, e.g, all-reduce take in a Tensor +// as the input and output while all-to-all take in a vector of Tensors as input +// and output. Because we define the signature of the fn to take only single +// tensor as input and output, we need to do a hack to get the first element in +// the vector and pass it to fn. +// TODO: we should clean up this in future (by either entirely removing lambda's +// or removing input and output from lambda's signature). +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(inputs[0], outputs[0], comm, ncclStream), + ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(inputs[0], outputs[0], comm, ncclStream), + comm, + ncclComm->getNcclCommFailureReason()); +#endif + + post(ncclStream, work); + + // End event should only be recorded after the ncclGroupEnd() + if (!coalescing_state_) { + work->ncclEndEvent_->record(ncclStream); + } + work->ncclComm_ = ncclComm; + + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); + } + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device as + // multi-device per process is deprecated + work->numelIn_ = 0; + work->numelOut_ = 0; + for (const auto& input : inputs) { + work->numelIn_ += input.numel(); + } + for (const auto& output : outputs) { + work->numelOut_ += output.numel(); + } + + // Notify graphs before we check the capture status preemptively + at::cuda::CUDAGraph::inc_pending_event_queries(); + if (enqueue) { + workEnqueue(work); + } else { + at::cuda::CUDAGraph::dec_pending_event_queries(); + } + + return work; +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams) { + // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + errorIfCapturingNonCapturableNCCL(capture_status); + + // Bump collective counter + seqCollective_++; + + // For coalescingManager collectives, there is no individual c++ call per + // collective so there is no flight record and we increment seq*_ and op_id_ + // together. Compare this to startCoalesing/endCoalescing flow where we + // increment seq_ once per group and increment op_id_ once per indvidual + // operation within the group + op_id_++; + + // Currently, the API permits one scenario where inputs.size() and + // outputs.size() are > 0. + // 1. If the call was a _coalesced call, all inputs must be on the same + // device. + // The group of nccl calls applies the collective separately to each input, + // but the group as a whole should be efficient, and might even execute as + // a single fused kernel. + auto device = getDevice(inputs[0]); + const auto key = getKeyFromDevice(device); + auto ncclComm = getNCCLComm(key, device, opType); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; + } else { + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + auto work = initWork( + device, rank_, opType, profilingTitle, inputs, outputs, /*record=*/true); + + // Store references to outputs to be used by WorkNCCL::result and operator<<. + work->outputs_ = std::make_shared>(outputs); + + if (avoidRecordStreams) { + work->stashed_for_allocator_safety_ = + std::make_shared>(inputs); + } + + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Start event should only be recorded before the ncclGroupStart() (which + // happens inside AutoNcclGroup guard below) + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + ncclComm_t comm = ncclComm->getNcclComm(); + +// TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL +// upgrade. Once a NCCL version is qualified, this code should not be needed at +// runtime. +#ifdef PGNCCL_ENABLE_HASH + if (enableCollecticeHashDebug_.load()) { + auto numel = getTensorsNumel(inputs); + auto hashValue = hashTensors(inputs); + PRINT_COLLECTIVE_HASH_SIGNATURE( + "input", opTypeToString(opType), numel, hashValue); + } +#endif + + { + torch::cuda::nccl::AutoNcclGroup nccl_group_guard( + comm, nccl_use_nonblocking()); + for (const auto i : c10::irange(inputs.size())) { + // Both `inputs' and `outputs' are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // We only record `inputs' here, and leave recording `outputs' to `fn' for + // operations where `inputs' and `outputs' are not the same. + // + // See [Sync Streams]. + if (!avoidRecordStreams) { + if (!inputs[i].is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].indices().storage().data_ptr(), ncclStream); + } + } +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(inputs[i], outputs[i], comm, ncclStream), + ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(inputs[i], outputs[i], comm, ncclStream), + comm, + ncclComm->getNcclCommFailureReason()); +#endif + } + } + + work->ncclEndEvent_->record(ncclStream); + work->ncclComm_ = ncclComm; + + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); + } + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device as + // multi-device per process is deprecated + work->numelIn_ = inputs[0].numel(); + work->numelOut_ = outputs[0].numel(); + + /* Note [cuda graph capture and workEnqueue] + + Normal behavior of the C10D watchdog is to query cuda events on work objects + periodically, but when cuda graph recording is active these event queries + would crash or mess up the recording. + + To ensure we do not enqueue a work object to the watchdog when cuda graph + capture is active, we use a one-way sync. We increment a flag pre-emptively, + indicating our intent to enqueue a work object. Then we check capture_status + to see if (a) capturing is already in progress (we cannot enqueue in this + case), (b) capturing hasn't started yet, so we can trust that no capture will + start (since a pre-condition of starting a capture is to check the event query + count is 0). + + If we are not able to enqueue the work due to capture-in-progress, we finally + decrement the counter. + + For this reason we cannot easily move the increment inside workEnqueue unless + we also change the semantic of workEnqueue to 'maybeWorkEnqueue'. + + TODO: + - Is our design for flight recorder safe in this context? are we recording + any FR events during cudagraph capture? if so, they won't be safe to poll for + completion status. + */ + at::cuda::CUDAGraph::inc_pending_event_queries(); + if (capture_status == c10::cuda::CaptureStatus::None) { + workEnqueue(work); + } else { + at::cuda::CUDAGraph::dec_pending_event_queries(); + } + // TODO(whc) if the work isn't enqueued, I don't feel great about returning + // it, since interactions with it by usercode won't behave normally - they + // won't observe work completion, for instance. Will this lead to silent + // problems during capture? + return work; +} + +template +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + PreProcess pre, + PostProcess post, + const char* profilingTitle) { + // avoidRecordStreams_ note: + // send, recv, and irecv should be ok with avoidRecordStreams, + // However, for isend, I don't think the API requires the user + // to wait() on the returned handle, so ProcessGroupNCCL can't know + // when it's safe to release the input back to the allocator, + // and the present call has no way to know it's not an isend. + // Therefore, we warn and fall back to the typical recordStream logic: + if (avoidRecordStreams_) { + TORCH_WARN_ONCE( + "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " + "collectives."); + } + + auto device = getDevice(tensor); + std::string key; + int p2pRank = 0, p2pTargetRank = 0; + bool isSendRecvSelf = false; + // For batch_isend_irecv, ncclGroupStart() would be called upfront + bool batchP2P = ncclActiveGroupCounter_ > 0; + if (batchP2P) { + // For batch P2P, we need to treat it like a collective when selecting + // communicator, because other ranks can call into this batch other than my + // rank and my peer + key = getKeyFromDevice(device); + p2pRank = rank_; + p2pTargetRank = peer; + } else { + // For single P2P, preserve the old two-rank behavior (to avoid perf diff) + key = getKeySendRecv(rank_, peer); + p2pRank = rank_ <= peer ? 0 : 1; + isSendRecvSelf = rank_ == peer; + p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; + + if (!coalescing_state_) { + // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be + // bumped in `startCoalescing`. + seqP2P_++; + } + } + + // Bump the logical operation counter regardless of whether this op is + // coalesced or individual + op_id_++; + + auto ncclComm = getNCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalP2P; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = ncclComm; + } else { + TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); + } + } + + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + + // Work itself will create the CUDA events on all GPUs of tensors + c10::intrusive_ptr work; + if (coalescing_state_) { + // When coalescing, we record events per op that lack timing/state + // information becuase there is no 'work' associated with them, and then + // later in endCoalescing we record a 'coalesced' Work which has + // timing/state updates via watchdog thread, but lacks op metadata such as + // input/output sizes and profilingTitle per-op in the group. + auto trace_id = NCCLTraceBuffer::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), + seqCollective_, + seqP2P_, + op_id_, + profilingTitle, + {tensor}, + {tensor}, + nullptr, + nullptr, + options_->timeout, + pgStatus_, + /*isP2P=*/true); + // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get + // their timings/states updated by proxy when the Work obj representing the + // coalesce group gets its update, we could accumulate these trace_ids + // together and ask FlightRecorder to take the update from one Work and + // apply it to multiple entries + (void)trace_id; + } else { + // Store references to outputs to be used by WorkNCCL::result and + // operator<<. Note that these outputs are only valid for recv(), as send() + // does not modify the inputs but we still create these outputs for use + // cases such as profiling. + + work = initWork( + device, rank_, opType, profilingTitle, {tensor}, {}, /*record=*/false); + // This bypasses something in Work() that crashes if {tensor} is given as + // output, not sure what + work->outputs_ = std::make_shared>(); + work->outputs_->push_back(tensor); + // TODO(whc) because we don't pass output {tensor} to initWork, we tell + // initWork to not record, and then we manually call record passing all the + // information it wants. + work->trace_id_ = NCCLTraceBuffer::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), + seqCollective_, + seqP2P_, + op_id_, + profilingTitle, + {tensor}, + {tensor}, + work->ncclStartEvent_.get(), + work->ncclEndEvent_.get(), + options_->timeout, + pgStatus_, + /*isP2P=*/true); + } + + // is gpuGuard needed for the if block below, or can i swap them + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Only check for NaN for send ops, for recv ops `tensor` can be a random + // placeholder + if (enableNanCheck_ && opType == OpType::SEND) { + checkForNan(tensor, ncclStream); + } + + if (!coalescing_state_) { + // Start event should only be recorded before the ncclGroupStart() + if (work->timingEnabled_) { + work->ncclStartEvent_->record(ncclStream); + } + + pre(ncclStream, work); + } + + // Both send tensor and recv tensor are created on a worker stream and used + // in different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + tensor.storage().data_ptr(), ncclStream); + + // This part seems common to both p2p and coalesced-p2p usage? + ncclComm_t comm_ = ncclComm->getNcclComm(); + +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK( + fn(tensor, comm_, ncclStream, p2pTargetRank), + ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(tensor, comm_, ncclStream, p2pTargetRank), + ncclComm->getNcclComm(), + ncclComm->getNcclCommFailureReason()); +#endif + + if (!coalescing_state_) { + post(ncclStream); + + // End event should only be recorded after the ncclGroupEnd() + work->ncclEndEvent_->record(ncclStream); + work->ncclComm_ = ncclComm; + work->blockingWait_ = blockingWait_; + work->store_ = store_; + assignTimeoutToWork(work, options_); + // Record size info for debug. We only record the size on the first device + // as multi-device per process is deprecated + work->numelIn_ = work->numelOut_ = tensor.numel(); + + // Future only needs to be created and marked completed with outputs for + // recv(), but still create future for use cases such as profiling even for + // send(). + { + c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA + // future blocks the stream this callback runs on the corresponding + // ncclEndEvents_ ensuring appropriate synchronization. + if (work->recordFunctionEndCallback_) { + work->future_->addCallback( + [work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }, + // uses_future = false allows us to skip synchronization in + // ivalue::Future, but is only valid as long as the lambda doesn't use + // the "Future" argument. + /*uses_future=*/false); + } + } + + // Enqueue P2P op so that it can be cancelled by NCCL watchdog + c10::cuda::CaptureStatus capture_status = + c10::cuda::currentStreamCaptureStatusMayInitCtx(); + + // Notify graphs before we check the capture status preemptively + at::cuda::CUDAGraph::inc_pending_event_queries(); + + if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) { + workEnqueue(work); + return work; + } else { + at::cuda::CUDAGraph::dec_pending_event_queries(); + return nullptr; + } +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + pre, + post, + opType, + profilingTitle, + avoidRecordStreams, + nanCheck); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + opType, + profilingTitle, + avoidRecordStreams, + nanCheck); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle) { + return pointToPoint( + tensor, + fn, + peer, + opType, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&) {}, + profilingTitle); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); +#ifdef IS_NCCLX + tensor = tensor.coalesce(); + at::Tensor outputTensor = + torch::zeros(tensor.sizes(), tensor.options().layout(torch::kStrided)); + auto work = collective( + tensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + + size_t num_elements = output.numel(); + auto indices = input.indices(); + auto sizes = input.sizes(); + int colSize = sizes[1]; + auto rows = indices[0]; + size_t blockCount = rows.sizes()[0]; + auto recvIndices = indices[0] * colSize; + + // prevent output and recvIndices from being freed + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + c10::cuda::CUDACachingAllocator::recordStream( + recvIndices.storage().data_ptr(), stream); + auto result = ncclAllReduceSparseBlock( + input._values().data_ptr(), // sendbuff + recvIndices.data_ptr(), // recv_indices + blockCount, // block_count + colSize, // block_length + output.data_ptr(), // recvbuff + output.numel(), // recv_count + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + return result; + }, + [](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) {}, + [&](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + // Convert output tensors to sparse and back into tensors. + at::cuda::CUDAStreamGuard guard(ncclStream); + if (opts.sparseIndices.has_value()) { + tensor = at::sparse_coo_tensor( + opts.sparseIndices.value(), outputTensor, tensor.sizes()); + } else { + tensor = outputTensor.to_sparse(); + } + }, + OpType::_ALLREDUCE_SPARSE, + "nccl:all_reduce_sparse"); + return work; +#else + // If the nccl branch is not "exp" then we just error + C10_THROW_ERROR( + Error, + "NCCL does not support all_reduce with sparse tensors. Please use dense tensors instead."); +#endif +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts) { + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclAllReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::ALLREDUCE, + "nccl:all_reduce"); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + TORCH_CHECK( + complexViewAsRealAllowed(opts.reduceOp), + "all_reduce does not support", + opts.reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } + check_gpu_single_tensor(tensor); + + if (intraNodeComm_ != nullptr && opts.reduceOp == ReduceOp::SUM) { + using namespace intra_node_comm; + auto algo = intraNodeComm_->selectAllReduceAlgo(tensor); + if (algo != intra_node_comm::AllReduceAlgo::NONE) { + intraNodeComm_->allReduce(tensor, algo); + return c10::make_intrusive(); + } + } + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return allreduce_impl(tensor, opts); +} + +c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + auto total_numel = check_gpu_tensors_same_device(tensors); + TORCH_CHECK( + !isFloat8Type(tensors.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce_coalesced", // collective name + total_numel, // inNelems + total_numel, // outNelems + tensors[0].scalar_type(), // dType + // I'm not sure what in,outSplitSizes mean here. + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclAllReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:allreduce_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + tensor = at::view_as_real(tensor); + } + check_gpu_single_tensor(tensor); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "broadcast", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclBcast( + input.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + root, + comm, + stream.stream()); + }, + OpType::BROADCAST, + "nccl:broadcast", + avoidRecordStreams, + nanCheck); +} + +// _broadcast_oop adds an out-of-place broadcast in PGNCCL +// Custom collectives may be implemented by coalescing broadcast operations +// One use-case is implementing a vector all_gather (all_gather_v) +// where unevenly sized inputs are gathered among participating ranks +// Since all_gather provides an out-of-place API, an all_gather_v +// semantic implemented inside pg_nccl.all_gather also needs to support +// out-of-place, for which an out-of-place broadcast is required to be added +c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _broadcast_oop must have the same number of elements "); + } + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclBroadcast( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + root, + comm, + stream.stream()); + }, + OpType::BROADCAST, + "nccl:_broadcast_oop", + /*avoidRecordStreams=*/false, + nanCheck); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce( + std::vector& tensors, + const ReduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + if (tensor.is_complex()) { + TORCH_CHECK( + complexViewAsRealAllowed(opts.reduceOp), + "reduce does not support", + opts.reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } + check_gpu_single_tensor(tensor); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "reduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash tensors. + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + root, + comm, + stream.stream()); + }, + OpType::REDUCE, + "nccl:reduce"); +} + +// _reduce_oop exposes an out-of-place reduce from PGNCCL +// Custom collectives may be implemented by coalescing reduce operations +// One use-case is implementing a vector reduce_scatter (reduce_scatter_v) +// where inputs are reduced and scattered unevenly among participating ranks +// Since reduce_scatter provides an out-of-place API, a reduce_scatter_v +// semantic implemented inside pg_nccl.reduce_scatter also needs to support +// out-of-place, for which an out-of-place reduce is required to be added +c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _reduce_oop must have the same number of elements "); + } + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank + opts.rootTensor; + const auto ncclDataType = getNcclDataType(input.scalar_type()); + const auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduce( + input.data_ptr(), + output.data_ptr(), + input.numel(), + ncclDataType, + ncclReduceOp, + (int)root, + comm, + stream.stream()); + }, + OpType::REDUCE, + "nccl:_reduce_oop"); +} + +c10::intrusive_ptr ProcessGroupNCCL::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + check_gpu_single_tensor(inputTensor); + // @lint-ignore CLANGTIDY + auto outputTensors_ = outputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * // outNelems + this->getSize(), + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(outputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor outputFlattened = newLikeFlat(outputTensors_); + + return collective( + inputTensor, + outputFlattened, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + [](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + // avoidRecordStreams_ note: We actually don't need to stash anything + // here. + // - inputTensors is stashed onto work->stashed_for_allocator_safety_ + // in collective(). + // - outputFlattened is stashed onto work->outputs_ in collective(). + // - User-facing outputTensors should be held by the user until after + // waiting on work_, or the call makes no sense. + // So all participating tensors are accounted for, and won't be + // released back to their allocation streams until after work_ is + // waited on. + }, + [&](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + // Copy the flattened output tensors to the outputs. + at::cuda::CUDAStreamGuard guard(ncclStream); + for (const auto j : c10::irange(outputTensors_.size())) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), ncclStream); + } + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + OpType::ALLGATHER, + "nccl:all_gather"); + } else { + const auto num_reduces = outputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& output = outputTensors_[i]; + auto& input = (i == rank_) ? inputTensor : output; + auto broadcastOpts = BroadcastOptions{ + static_cast(i), static_cast(0), opts.timeout}; + _broadcast_oop(output, input, broadcastOpts); + } + auto work = endCoalescing(OpType::ALLGATHER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( + std::vector>& /* unused */, + std::vector& /* unused */, + const AllgatherOptions& /* unused */) { + C10_THROW_ERROR( + NotImplementedError, + "ProcessGroupNCCL does not support allgather_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:all_gather_into_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto outputTensor = outputTensors.back(); + check_gpu_single_tensor(outputTensor); + // @lint-ignore CLANGTIDY + auto inputTensors_ = inputTensors.back(); + TORCH_CHECK( + !isFloat8Type(outputTensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "reduce_scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + const auto ncclDataType = getNcclDataType(input.scalar_type()); + const auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + [&](at::cuda::CUDAStream& ncclStream, + c10::intrusive_ptr& work) { + if (avoidRecordStreams_) { + // We only need to stash inputTensors. + // - inputFlattened is stashed onto + // work->stashed_for_allocator_safety_ + // in collective(). + // - User-facing outputTensors is stashed onto work->outputs_ in + // collective(), + // and should also be held by the user until after waiting on + // work_. + auto& v = work->stashed_for_allocator_safety_; + v->insert(v->end(), inputTensors_.begin(), inputTensors_.end()); + } + + // Copy the input tensors to the flattened inputs. + at::cuda::CUDAStreamGuard guard(ncclStream); + for (const auto j : c10::irange(inputTensors_.size())) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), ncclStream); + } + inputFlattened[j].copy_(inputTensors_[j], true); + } + }, + [&](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::REDUCE_SCATTER, + "nccl:reduce_scatter"); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); + } + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + if (inputTensor.dtype() != outputTensor.dtype()) { + C10_THROW_ERROR( + TypeError, "input tensor must be the same type as the output tensor."); + } + + if (inputTensor.numel() != outputTensor.numel() * size_) { + C10_THROW_ERROR( + ValueError, + "input tensor must be the same size as output size times world size"); + } + + // @lint-ignore CLANGTIDY + const auto& tensor = outputTensor; + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "_reduce_scatter_base", // collective name + inputTensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dtype + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputs and outputs. + // Note 2: for asyncOp = false, we don't want to record streams because we + // know that the NCCL stream will join back to the "current" stream right + // after this op. So we might just as well keep the stream ownership of the + // input/output tensors unchanged. The benefit would be that the + // allocation/free of the tensors would look deterministic to the "current" + // stream so that the caching allocator can reuse memory pool for this stream + // in a clever way. This setting is added for libraries like FSDP which uses + // `reduce_scatter_tensor`. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::_REDUCE_SCATTER_BASE, + "nccl:_reduce_scatter_base", + avoidRecordStreams); +} + +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts) { + TORCH_CHECK( + !isFloat8Type(inputs.back().scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + auto ncclDataType = getNcclDataType(input.scalar_type()); + auto ncclReduceOp = + getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); + return ncclReduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + ncclDataType, + ncclReduceOp, + comm, + stream.stream()); + }, + OpType::COALESCED, + "nccl:reduce_scatter_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { + RECORD_PARAM_COMMS( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank_, // rank + "barrier", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // Device to use for barrier + int barDevIdx = -1; + + // Select device to use for barrier + // 1st choice: Use user defined GPU device ids if provided + if (!opts.device_ids.empty()) { + // Use the first device id because PG NCCL is single-device now + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + // 2nd choice: Use the bound GPU device id if available. + // Bounded device id can be passed to `init_process_group`. + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + // 3rd choice: infer the device id from the used device ids. + barDevIdx = *usedDeviceIdxs_.begin(); + } else { + // This means there is not yet a NCCL collective being called + // Here we have to use the best guesses and will use a single GPU to call + // allreduce to achieve barrier. + // In case the multiple processes fall into the same node, we use rank to + // ensure that each process is on a different GPU + // Note: it is better to use global rank because the group-local rank can be + // offset wrt the device id if intra-node GPUs are sharded into multiple + // dimensions. + barDevIdx = static_cast(globalRank() % localDeviceCount_); + LOG(WARNING) + << logPrefix() + << c10::str( + " using GPU ", + barDevIdx, + " to perform barrier as devices used by this process are currently unknown. ", + "This can potentially cause a hang if this rank to GPU mapping is incorrect.", + "Specify device_ids in barrier() to force use of a particular device,", + "or call init_process_group() with a device_id."); + } + + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx); + + // Create a dummy tensor on the device + // Note: we use zeros() instead of empty() to prevent barrier from triggering + // alarm when NaN checker is enabled. + at::Tensor barrierTensor = + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + + // All reduce to achieve the barrier + auto work = allreduce_impl(barrierTensor); + + // Work will take over barrierTensors + auto ncclWork = dynamic_cast(work.get()); + TORCH_CHECK(ncclWork); + ncclWork->barrierTensor_ = std::move(barrierTensor); + return work; +} + +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + check_gpu_single_tensor(outputTensor, true); + check_gpu_single_tensor(inputTensor, true); + if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_all", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputTensors. + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + torch::cuda::nccl::all2all_single_equal_split( + input, output, this->getSize(), comm, stream); + return ncclSuccess; + }, + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); + } else { + c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); + c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_allv", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + inputSplitSizes, // inSplitSizes + outputSplitSizes, // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputTensors. + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + std::vector send_lengths(size_); + std::vector recv_lengths(size_); + std::vector send_offsets(size_); + std::vector recv_offsets(size_); + c10d::computeLengthsAndOffsets( + inputSplitSizes, input, &send_lengths, &send_offsets); + c10d::computeLengthsAndOffsets( + outputSplitSizes, output, &recv_lengths, &recv_offsets); + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + torch::cuda::nccl::all2all_single_unequal_split( + input.data_ptr(), + send_lengths.data(), + send_offsets.data(), + output.data_ptr(), + recv_lengths.data(), + recv_offsets.data(), + input.element_size(), + input.scalar_type(), + comm, + stream); + return ncclSuccess; + }, + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); + } +} + +c10::intrusive_ptr ProcessGroupNCCL::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& /* unused */) { + std::vector inSplitSizes; + std::vector outSplitSizes; + int64_t total_numel = 0; + + auto device = outputTensors[0].device(); + for (const auto r : c10::irange(outputTensors.size())) { + check_gpu_single_tensor(outputTensors[r], true); + check_gpu_single_tensor(inputTensors[r], true); + TORCH_CHECK( + device == outputTensors[r].device() && + device == inputTensors[r].device(), + "Tensors must be on the same device") + inSplitSizes.push_back(inputTensors[r].numel()); + outSplitSizes.push_back(outputTensors[r].numel()); + total_numel += inputTensors[r].numel(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_to_all", // collective name + total_numel, // inNelems + total_numel, // outNelems + inputTensors.front().scalar_type(), // dType + inSplitSizes, // inSplitSizes + outSplitSizes, // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensors, + outputTensors, + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream); + return ncclSuccess; + }, + [&](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) { + if (avoidRecordStreams_) { + // inputTensor0 and outputTensor0 are stashed redundantly by + // collective(), but that's ok. + auto& v = work->stashed_for_allocator_safety_; + v->insert(v->end(), inputTensors.begin(), inputTensors.end()); + v->insert(v->end(), outputTensors.begin(), outputTensors.end()); + } + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::ALLTOALL, + "nccl:all_to_all"); +} + +c10::intrusive_ptr ProcessGroupNCCL::send( + std::vector& tensors, + int dstRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + check_gpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + dstRank, // dst rank + "send", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream& stream, + int dst) { + torch::cuda::nccl::send(input, comm, stream, dst); + return ncclSuccess; + }, + dstRank, + OpType::SEND, + c10::str("nccl:send ", rank_, "->", dstRank).c_str()); + return ret; +} + +c10::intrusive_ptr ProcessGroupNCCL::recv( + std::vector& tensors, + int srcRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + check_gpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + srcRank, // src rank + "recv", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream, + int src) { + torch::cuda::nccl::recv(output, comm, stream, src); + return ncclSuccess; + }, + srcRank, + OpType::RECV, + c10::str("nccl:recv ", rank_, "<-", srcRank).c_str()); + return ret; +} + +void ProcessGroupNCCL::groupStart() { + C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); + ++ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEnd() { + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + --ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEndNonblocking(std::shared_ptr comm) { +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); +#else + if (!nccl_use_nonblocking()) { + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + } else { + C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt); + } +#endif + --ncclActiveGroupCounter_; +} + +c10::intrusive_ptr ProcessGroupNCCL::gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::gather: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + + std::vector outputs; + + if (getRank() == opts.rootRank) { + if (outputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element output list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (outputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect output list size " << outputTensors[0].size() + << ". Output list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = inputTensor.options(); + const auto& sizes = inputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); + outputs = outputTensors[0]; + } else { + // if not in the root rank, initialize outputs as empty list + if (outputTensors.size() != 0) { + invalidArgument("requires empty output on non-root"); + } + outputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + outputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * this->getSize(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputTensors and + // outputs, which == outputTensors[0] on the root rank where it matters. + + auto inputs = std::vector{inputTensor}; + return collective( + inputs, + outputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + const auto root = opts.rootRank; + if (getRank() == root) { + if (!avoidRecordStreams_) { + for (auto output : outputs) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + } + } + torch::cuda::nccl::gather(inputTensor, outputs, comm, stream, root); + return ncclSuccess; + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::GATHER, + "nccl:gather"); +} + +c10::intrusive_ptr ProcessGroupNCCL::scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::scatter: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto outputTensor = outputTensors.back(); + + std::vector inputs; + + if (getRank() == opts.rootRank) { + if (inputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element input list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (inputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect input list size " << inputTensors[0].size() + << ". Input list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = outputTensor.options(); + const auto& sizes = outputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); + inputs = inputTensors[0]; + } else { + // if not in the root rank, initialize inputTensors as empty place holder + // with an empty list + if (inputTensors.size() != 0) { + invalidArgument("requires empty input on non-root"); + } + inputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + inputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash outputTensors and + // inputs, which == inputTensors[0] on the root rank where it matters. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + const auto root = opts.rootRank; + bool nanCheck = (rank_ == root); + + auto outputs = std::vector{outputTensor}; + return collective( + outputs, + inputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (getRank() == root) { + if (!avoidRecordStreams) { + for (auto input : inputs) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), stream); + } + } + } + torch::cuda::nccl::scatter(inputs, outputTensor, comm, stream, root); + return ncclSuccess; + }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + OpType::SCATTER, + "nccl:scatter", + avoidRecordStreams, + nanCheck); +} + +c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( + std::vector& /* unused */, + int /* unused */) { + C10_THROW_ERROR( + NotImplementedError, "ProcessGroupNCCL does not support recvAnysource"); +} + +c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + check_gpu_single_tensor(input_tensor); + check_gpu_single_tensor(output_tensor); + + if (input_tensor.dtype() != output_tensor.dtype()) { + C10_THROW_ERROR( + TypeError, "output tensor must have the same type as input tensor"); + } + + if (input_tensor.numel() * size_ != output_tensor.numel()) { + C10_THROW_ERROR( + ValueError, + "output tensor size must be equal to world_size times input tensor size"); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + input_tensor, // inputTensors + output_tensor, // outputTensors + rank_, // rank + "_allgather_base", // collective name + input_tensor.numel(), // inNelems + output_tensor.numel(), // outNelems + output_tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + + // avoidRecordStreams_ note: collective() will stash inputs and outputs. + // Note 2: for asyncOp = false, we don't want to record streams because we + // know that the NCCL stream will join back to the "current" stream right + // after this op. So we might just as well keep the stream ownership of the + // input/output tensors unchanged. The benefit would be that the + // allocation/free of the tensors would look deterministic to the "current" + // stream so that the caching allocator can reuse memory pool for this stream + // in a clever way. This setting is added for libraries like FSDP which uses + // `all_gather_into_tensor`. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + + return collective( + input_tensor, + output_tensor, + [&](at::Tensor& input, + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + return ncclAllGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + }, + OpType::_ALLGATHER_BASE, + "nccl:_all_gather_base", + avoidRecordStreams); +} + +} // namespace c10d + +#endif // USE_C10D_NCCL \ No newline at end of file diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index d8c4409c28d4b6..ed1a63052d4f35 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -1,4701 +1,1242 @@ +#pragma once + #ifdef USE_C10D_NCCL -#include -#include -#include +#if defined(__linux__) +#include +#include +#include +#include +#endif + +#include +#include +#include +#include +#include #include -#include -#include -#include -#include -#include +#include +#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include -#include #include -#include -#include -#include -#include -#include -#include +#include +#include -namespace c10d { - -constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM"; +#include +#include +#include +#include +#include +#include +#include +#include -namespace { +#include -#if defined(NCCL_MAJOR) && \ - ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) -#define NCCL_HAS_AVG 1 -#endif - -// NCCL op mapping -const std::map ncclOp = { - {ReduceOp::MIN, ncclMin}, - {ReduceOp::MAX, ncclMax}, - {ReduceOp::SUM, ncclSum}, - {ReduceOp::PRODUCT, ncclProd}, -#ifdef NCCL_HAS_AVG - {ReduceOp::AVG, ncclAvg}, -#endif -}; +namespace c10d { -// NCCL type typing -std::map ncclDataType = { - {at::kChar, ncclInt8}, - {at::kByte, ncclUint8}, - {at::kFloat, ncclFloat}, - {at::kDouble, ncclDouble}, - {at::kInt, ncclInt32}, - {at::kLong, ncclInt64}, - {at::kHalf, ncclHalf}, - {at::kBool, ncclUint8}, - {at::kFloat8_e5m2, ncclUint8}, - {at::kFloat8_e4m3fn, ncclUint8}, - {at::kFloat8_e4m3fnuz, ncclUint8}, - {at::kFloat8_e5m2fnuz, ncclUint8}, -#if HAS_NCCL_BF16_DATATYPE - {at::kBFloat16, ncclBfloat16}, -#endif +// Control broadcasting of NCCL uniqueId +static std::vector TORCH_NCCL_BCAST_UNIQUEID = { + "TORCH_NCCL_BCAST_UNIQUEID"}; + +// Control whether to always use high priority streams +static std::vector TORCH_NCCL_HIGH_PRIORITY = { + "TORCH_NCCL_HIGH_PRIORITY"}; + +// Control whether or not wait() is blocking or non-blocking. +static std::vector TORCH_NCCL_BLOCKING_WAIT = { + "TORCH_NCCL_BLOCKING_WAIT", + "NCCL_BLOCKING_WAIT"}; + +// TODO: We want to eventually remove this variable and make users to use +// the default value (3 - SkipCleanUp). +// Control whether or not we perform Async Error Handling with NCCL. +static std::vector TORCH_NCCL_ASYNC_ERROR_HANDLING = { + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + "NCCL_ASYNC_ERROR_HANDLING"}; + +// Control whether dumping debug info on watchdog +// timeout is enabled. This variable must be set together with +// TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0. +static std::vector TORCH_NCCL_DUMP_ON_TIMEOUT = { + "TORCH_NCCL_DUMP_ON_TIMEOUT"}; + +// TODO: remove this change after a safe rollout. +// Control whether we sleep after an exception is thrown. +// This change is temporary and is used to safely remove the current sleep that +// exists after an exception is thrown. +static std::vector TORCH_NCCL_SLEEP_AFTER_EXCEPTION = { + "TORCH_NCCL_SLEEP_AFTER_EXCEPTION"}; + +// Control whether Desync Debug is enabled. This variable must be set +// together with TORCH_NCCL_ASYNC_ERROR_HANDLING. +static std::vector TORCH_NCCL_DESYNC_DEBUG = { + "TORCH_NCCL_DESYNC_DEBUG", + "NCCL_DESYNC_DEBUG"}; + +// Enable recording start-events for all ProcessGroupNCCL collectives, and +// compute accurate collective timing per-collective. (Note: end-events are +// recorded by default. Turn on this flag can increase chances of a watchdog +// hang due to performing a CUDA event query which eventually calls +// cudaEventElapsedTime() API. +static std::vector TORCH_NCCL_ENABLE_TIMING = { + "TORCH_NCCL_ENABLE_TIMING", + "NCCL_ENABLE_TIMING"}; + +// Enable monitoring thread which aborts the process when the ProcessGroupNCCL +// Watchdog thread gets stuck and no heartbeat is detected after +// TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL +// APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged +// time than necessary tying up cluster resources. +static std::vector TORCH_NCCL_ENABLE_MONITORING = { + "TORCH_NCCL_ENABLE_MONITORING"}; + +// Control the watchdog heartbeat timeout period after which the monitoring +// thread will abort the process. +static std::vector TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = { + "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"}; + +// Whether to rethrow CUDA Errors in the watchdog (default true) +static std::vector TORCH_NCCL_RETHROW_CUDA_ERRORS = { + "TORCH_NCCL_RETHROW_CUDA_ERRORS"}; + +// The maximum number of events we store in the flight recorder's ring buffer. +// (One event could be the start or end of a collective, for example). +static std::vector TORCH_NCCL_TRACE_BUFFER_SIZE = { + "TORCH_NCCL_TRACE_BUFFER_SIZE"}; + +// Control how much extra time we will wait for dumping the debugging info +// before we exit and throws timeout exception. +static std::vector TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = { + "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"}; + +// Control the interval inside the monitoring thread to check the coordinated +// signal from other ranks, e.g. to dump the debugging information. +static std::vector TORCH_NCCL_COORD_CHECK_MILSEC = { + "TORCH_NCCL_COORD_CHECK_MILSEC"}; + +// Whether to log C++ stack traces on unclean shutdown (default true) +static std::vector TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN = { + "TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN"}; + +// Control whether to use CudaEventCache for the collective in watchdog thread. +// We noticed in the past when cuda global lock is held, destroying CudaEvent +// can cause a hang. +static std::vector TORCH_NCCL_CUDA_EVENT_CACHE = { + "TORCH_NCCL_CUDA_EVENT_CACHE"}; + +static std::vector TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"}; + +constexpr const char* NCCL_BACKEND_NAME = "nccl"; + +constexpr const char* EXCEPTION_DUMP = "exception_dump"; + +constexpr const int kWorkStatusUpdatePeriodMs = 30 * 1000; // 30 seconds + +constexpr auto kProcessGroupNCCLDefaultTimeout = + std::chrono::milliseconds(10 * 60 * 1000); + +// NoHandling: do not handle asynchronous NCCL errors +// TearDown: tear down process upon error, see `WorkNCCL::handleException` +// CleanUpOnly: just clean up collectives and abort communicators without +// tearing down process SkipCleanUp: (this is a temporary option and can be +// removed in future) tear down process without cleaning up NCCL communicators. +// This should be used as a last resort in case `ncclCommAbort` itself is +// hanging +enum ErrorHandlingMode { + NoHandling = 0, + TearDown = 1, + CleanUpOnly = 2, + SkipCleanUp = 3 }; -// Helper function that gets the data type and issues error if not supported -ncclDataType_t getNcclDataType(at::ScalarType type) { - auto it = ncclDataType.find(type); - TORCH_CHECK_WITH( - TypeError, - it != ncclDataType.end(), - "Input tensor data type is not supported for NCCL process group: ", - type); - return it->second; -} - -bool complexViewAsRealAllowed(const ReduceOp reduceOp) { - switch (reduceOp) { - case ReduceOp::SUM: - return true; - case ReduceOp::AVG: - return true; - case ReduceOp::PREMUL_SUM: - return true; - case ReduceOp::UNUSED: - return true; - default: - return false; - } - return false; -} - -#ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT -template -ncclRedOpRAII unpackPreMulSum( - const ReduceOp& reduceOp, - const ncclComm_t& comm) { - const auto* preMulSupplement = - reinterpret_cast(reduceOp.supplement_.get()); - ncclRedOp_t preMulSum; - bool has_tensor = preMulSupplement->tensor_factor.defined(); - auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; - const T* ptr_factor = has_tensor - ? preMulSupplement->tensor_factor.const_data_ptr() - : nullptr; - T scalar_factor = T(preMulSupplement->double_factor); - ncclRedOpCreatePreMulSum( - &preMulSum, - // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/ops.html#ncclredopcreatepremulsum - // tells us that the scalar input is strictly a multiplier. - /*scalar=*/has_tensor ? const_cast(ptr_factor) : &scalar_factor, - dataType, - residence, - comm); - return ncclRedOpRAII(preMulSum, comm); -} -#endif - -ncclRedOpRAII getNcclReduceOp( - const ReduceOp& reduceOp, - at::Tensor& input, - const ncclDataType_t& dataType, - const ncclComm_t& comm) { - try { - if (input.scalar_type() == at::kBool) { - if (reduceOp == ReduceOp::SUM) { - // For bool tensors, map sum to max, which both represent a bitwise or. - // This is to prevent overflow issues with sum, since we use uint8 to - // represent a bool (see ncclDataType mapping). - return ncclMax; - } -#ifdef NCCL_HAS_AVG - if (reduceOp == ReduceOp::AVG) { - C10_THROW_ERROR( - TypeError, "Cannot use ReduceOp.AVG with boolean inputs"); - } -#endif +#define SHOULD_CLEAN_UP(a) (a != NoHandling && a != SkipCleanUp) + +#define SHOULD_TEAR_DOWN(a) (a != NoHandling && a != CleanUpOnly) + +#define PRINT_COLLECTIVE_HASH_SIGNATURE(phase, opType, numel, hashValue) \ + LOG(WARNING) << logPrefix() << "Hash of " << phase << " to NCCL " << opType \ + << " with size " << numel << " is " << hashValue; + +// If set, ProcessGroupNCCL doesn't use recordStream calls to ensure +// caching allocator safety for tensors used on both user-facing and +// internal comm streams. +// Instead, it stashes live references to those tensors until after +// user-facing streams are synced with comm streams. +// See stashed_for_allocator_safety_ below. +static std::vector TORCH_NCCL_AVOID_RECORD_STREAMS = { + "TORCH_NCCL_AVOID_RECORD_STREAMS"}; + +// If set, ProcessGroupNCCL registers postAlloc and preFree hooks to cuda cache +// allocator so that whenever a tensor is allocated or freed, ProcessGroupNCCL +// can register/deregister the tensor on all available NCCL communicators. +static std::vector TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK = + {"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK", + "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"}; + +#if defined(__linux__) +struct DumpPipe { + DumpPipe(int rank) { + std::string fileStem = + getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, ""); + if (fileStem.empty() || + getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) { + return; } - if (reduceOp == ReduceOp::PREMUL_SUM) { -#ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT - switch (dataType) { - case ncclHalf: - return unpackPreMulSum(reduceOp, comm); - case ncclFloat: - return unpackPreMulSum(reduceOp, comm); - case ncclDouble: - return unpackPreMulSum(reduceOp, comm); - default: - C10_THROW_ERROR( - TypeError, "PreMulSum Data type must be half, float, or double"); - ncclRedOp_t unused; - return unused; - } -#else - C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1"); -#endif + TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_PIPE_FILE is empty"); + std::string filename = c10::str(fileStem, rank, ".pipe"); + TORCH_CHECK( + unlink(filename.c_str()) != -1 || errno == ENOENT, + "Error removing existing named pipe ", + filename); + TORCH_CHECK( + mkfifo(filename.c_str(), 0666) != -1, + "Error creating named pipe ", + filename); + fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK); + LOG(INFO) << "Pipe file " << filename + << " has been opened, write to it to trigger NCCL Debug Dump."; + TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename); + } + bool shouldDump() { + if (fd_ == -1) { + return false; } - return ncclOp.at(reduceOp); - } catch (const std::out_of_range&) { - switch (reduceOp) { - case ReduceOp::AVG: - C10_THROW_ERROR( - ValueError, - c10::str( - "AVG requires NCCL 2.10+. The current version is ", - NCCL_MAJOR, - ".", - NCCL_MINOR)); - break; - case ReduceOp::BAND: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL"); - break; - case ReduceOp::BOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL"); - break; - case ReduceOp::BXOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); - break; - default: - C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); - break; + char buf[128]; + // non-blocking from O_NONBLOCK above. + // Ignore EINTR because we already will poll this + // again later. + ssize_t bytesRead = read(fd_, &buf, 128); + return bytesRead > 0; + } + ~DumpPipe() { + if (fd_ != -1) { + close(fd_); } } -} - -// Get a key string from device -inline std::string getKeyFromDevice(at::Device& device) { - return std::to_string(device.index()); -} - -inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) { - // initialize the device index to -1, which is an invalid value. - int index = -1; - try { - index = std::stoi(deviceKey); - } catch (const std::invalid_argument& e) { - LOG(ERROR) << c10::str( - "Invalid deviceKey: ", deviceKey, ",", e.what(), "."); - } catch (const std::out_of_range& e) { - LOG(ERROR) << "Out of range: " << e.what(); + + private: + int fd_ = -1; +}; +#else +struct DumpPipe { + DumpPipe(int rank) {} + bool shouldDump() { + return false; } - return static_cast(index); -} - -std::string getKeySendRecv(int myRank, int peer) { - int lowRank = myRank < peer ? myRank : peer; - int highRank = myRank < peer ? peer : myRank; - std::string sendRecvPair = - std::to_string(lowRank) + ":" + std::to_string(highRank); - return sendRecvPair; -} - -// Get device from tensor -inline at::Device getDevice(at::Tensor& tensor) { - return tensor.device(); -} - -// [Sync Streams] Helper that lets the input ncclStreams to wait for the current -// stream. NCCL communications run on ncclStreams, but input tensors are -// allocated on different streams (i.e., current streams). Communications on -// ncclStreams cannot start before pending input tensor ops on current streams -// finish. Otherwise, ops on two streams might read/write same tensors -// concurrently. +}; +#endif + +// ProcessGroupNCCL implements NCCL bindings for c10d. // -// The synchronization above alone is not enough. We also need to make sure -// input tensors are not freed before their usages on ncclStreams finish. This -// can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream, -// which remembers the usage stream (ncclStream), creates an event on the usage -// stream when GC attempts to free the input tensor, and delays GC until that -// event is done. -void syncStream( - at::Device& device, - at::cuda::CUDAEvent& ncclEvent, - at::cuda::CUDAStream& ncclStream) { - ncclEvent.record(at::cuda::getCurrentCUDAStream(device.index())); - ncclEvent.block(ncclStream); -} - -// Given a ncclUniqueId, convert it to a string representation that can be put -// in the store. -std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { - const uint8_t* bytes = reinterpret_cast(&ncclID); - std::ostringstream oss; - for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { - oss << std::hex << static_cast(bytes[i]); - } - return oss.str(); -} - -std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) { - return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; -} - -// Returns exception's what() given an exception_ptr instance. -std::string getExceptionMsgFromExceptionPtr( - const std::exception_ptr& exceptionPtr) { - TORCH_CHECK(exceptionPtr != nullptr); - try { - std::rethrow_exception(exceptionPtr); - } catch (const std::exception& e) { - return e.what(); - } catch (...) { - return "Unknown exception type"; - } -} - -inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { - // parentheses avoid some compiler warnings - static const uint64_t min_version = - (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); - static const uint64_t cur_version = torch::cuda::nccl::version(); - if (cur_version < min_version) { - TORCH_CHECK_WITH( - NotImplementedError, - status == c10::cuda::CaptureStatus::None, - "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); - } -} - -} // namespace - -// Map from each communicator to its device index. -// This map is used when register/deregister cache segments from cache -// allocator. See design notes below: -// - Each segment should be registered only to the communicator on the -// same device. -// - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be -// ranks rather than device in point-to-point case. -// - This map has also to be maintained as global variable since the register -// hooks are called outside the scope of any PG, thus we need traverse -// communicators in all PGs. -static std::unordered_map, int> ncclCommDevIdxMap; -static std::mutex ncclCommDevIdxMapMutex; -static bool allocatorHooksAttached = false; - -std::atomic ProcessGroupNCCL::shouldDump_(false); - -void cacheAllocatorRegisterHook( - const c10::cuda::CUDACachingAllocator::TraceEntry& te) { - // Register after SEGMENT_ALLOC - if (te.action_ != - c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) { - return; - } +// All functions of the class are expected to be called in the same order +// across all processes in the process group. This is the only way that we +// can guarantee to match up the same calls among all processes. +// +// All NCCL functions provided by this class are asynchronous functions. More +// specifically, each NCCL call is scheduled on a separate CUDA stream that is +// different from the current CUDA stream. This is for the purpose of +// achieving potentially concurrency and better performance. As a result, +// it is the callers' responsibility to make sure that the CUDA stream their +// code works on needs to wait for the NCCL operation from +// this class. +// +// This can be done by calling: +// +// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same +// functionality and are synonyms. +// +// Also note that WorkNCCL::finishedGPUExecution() is a helper function only +// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has +// finished execution on the GPU (not just scheduled). +// +// Example on using the NCCL process group +// +// ProcessGroupNCCL pg(store, rank, size); +// std::shared_ptr work = pg.allreduce(tensors); +// +// // At this point, NCCL kernel has already by queued successfully +// // Now, let current stream wait for the NCCL to finish, this function is +// // async operation as well +// +// work->wait() +// +// // Now continue on other work in the current stream. +class TORCH_API ProcessGroupNCCL : public Backend { + public: + class WorkNCCL : public Work, public std::enable_shared_from_this { + public: + friend struct WorkInfo; - std::lock_guard lock(ncclCommDevIdxMapMutex); - for (auto& it : ncclCommDevIdxMap) { - auto& ncclComm = it.first; - auto& devIdx = it.second; - if (te.device_ == devIdx) { - ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); - } - } -} - -void cacheAllocatorDeregisterHook( - const c10::cuda::CUDACachingAllocator::TraceEntry& te) { - // deregister before SEGMENT_FREE - if (te.action_ != - c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) { - return; - } + // Constructor takes a list of CUDA devices + WorkNCCL( + const std::string& pgUID, + const std::string& pgDesc, + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle = nullptr, + const std::optional>& inputs = std::nullopt, + bool desyncDebug = false, + bool enableTiming = false, + bool cudaEventCacheEnabled = false, + DebugLevel distDebugLevel = DebugLevel::Off); + // Copy constructor doing partial copy without outputs_. Cleanup thread + // monitors and removes finished works. However it will deadlock when + // destructs outputs_ tensors who are view tensors in autograd graph. + WorkNCCL(const WorkNCCL& w); - std::lock_guard lock(ncclCommDevIdxMapMutex); - for (auto& it : ncclCommDevIdxMap) { - auto& ncclComm = it.first; - auto& devIdx = it.second; - if (te.device_ == devIdx) { - ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); - } - } -} - -std::unordered_map> -getNCCLCommDumpMap() { -#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) - std::unordered_map< - std::string /* ncclUniqueID */, - std::unordered_map /* dump from this comm */> - ncclDumpMap; - // dump_nccl_trace is only called from the default PG (local_id_=0), but we - // want to dump from all comms so we need to iterate over ncclCommDevIdxMap, - // which is static - std::vector> allNCCLComms; - // within the critical section, we don't want to dump while holding the lock - // as dump might hang - ncclCommDevIdxMapMutex.lock(); - for (auto& [ncclComm, _] : ncclCommDevIdxMap) { - allNCCLComms.push_back(ncclComm); - } - ncclCommDevIdxMapMutex.unlock(); - for (auto& ncclComm : allNCCLComms) { - std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); - ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); - } - return ncclDumpMap; -#else - return std::unordered_map< - std::string, - std::unordered_map>(); -#endif -} + ~WorkNCCL() override; -std::string dump_nccl_trace( - bool includeCollectives, - bool includeStackTraces, - bool onlyActive) { - auto ncclDumpMap = getNCCLCommDumpMap(); - return NCCLTraceBuffer::get()->dump( - ncclDumpMap, includeCollectives, includeStackTraces, onlyActive); -} - -std::string dump_nccl_trace_json(bool includeCollectives, bool onlyActive) { - auto ncclDumpMap = getNCCLCommDumpMap(); - return NCCLTraceBuffer::get()->dump_json( - ncclDumpMap, includeCollectives, onlyActive); -} - -std::optional)>>& -get_cpp_trace_dumper() { - static std::optional< - std::function)>> - dumper(std::nullopt); - return dumper; -} - -gil_checker_t& get_gil_checker() { - static gil_checker_t gil_checker = nullptr; - return gil_checker; -} - -std::future launchAsyncGilCheck() { - std::promise resultPromise; - std::future resultFuture = resultPromise.get_future(); - TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker"); - std::thread workerThread([promise = std::move(resultPromise)]() mutable { - c10::setThreadName("pt_nccl_gil_chk"); - - try { - auto& gil_checker = get_gil_checker(); - promise.set_value((*gil_checker)()); - } catch (...) { - promise.set_exception(std::current_exception()); - } - }); - - // Detach the thread to allow it to run independently - workerThread.detach(); - - return resultFuture; -} - -const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100; -constexpr int64_t kSynchronizeBusyWaitMillis = 10; -thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; - -std::ostream& operator<<( - std::ostream& output, - const ProcessGroupNCCL::WorkNCCL& workNCCL) { - std::string workInfo; - workInfo = c10::str( - "WorkNCCL(", - "SeqNum=", - workNCCL.seq_, - ", OpType=", - opTypeToString(workNCCL.opType_), - ", NumelIn=", - workNCCL.numelIn_, - ", NumelOut=", - workNCCL.numelOut_, - ", Timeout(ms)=", - workNCCL.opTimeout_.count(), - ")"); - return output << workInfo; -} - -ProcessGroupNCCL::WorkNCCL::WorkNCCL( - const std::string& pgUID, - const std::string& pgDesc, - at::Device& device, - int rank, - OpType opType, - uint64_t seq, - const char* profilingTitle, - const std::optional>& inputs, - bool desyncDebug, - bool enableTiming, - bool cudaEventCacheEnabled, - DebugLevel distDebugLevel) - : Work(rank, opType, profilingTitle, inputs), - pgUID_(pgUID), - pgDesc_(pgDesc), - device_(device), - workStartTime_(std::chrono::steady_clock::now()), - seq_(seq), - timingEnabled_(enableTiming), - distDebugLevel_(distDebugLevel) { - // Creates the CUDA event wrappers - // Note: The actual events are lazily created when first recorded to with - // DEFAULT_FLAGS = cudaEventDisableTiming. - if (cudaEventCacheEnabled) { - ncclStartEvent_ = enableTiming - ? ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming) - : nullptr; - ncclEndEvent_ = - ProcessGroupNCCL::CUDAEventCache::get().create(enableTiming); - } else { - ncclStartEvent_ = enableTiming - ? std::make_shared(cudaEventDefault) - : nullptr; - ncclEndEvent_ = std::make_shared( - enableTiming ? cudaEventDefault : cudaEventDisableTiming); - } -} - -ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) - : Work(w.rank_, w.opType_), - std::enable_shared_from_this(w), - pgUID_(w.pgUID_), - pgDesc_(w.pgDesc_), - device_(w.device_), - ncclStartEvent_(w.ncclStartEvent_), - ncclEndEvent_(w.ncclEndEvent_), - ncclComm_(w.ncclComm_), - blockingWait_(w.blockingWait_), - opTimeout_(w.opTimeout_), - ownedEphermeralTimeout_(w.ownedEphermeralTimeout_), - workStartTime_(w.workStartTime_), - seq_(w.seq_), - startTraceUpdated_(w.startTraceUpdated_), - numelIn_(w.numelIn_), - numelOut_(w.numelOut_), - store_(w.store_), - timingEnabled_(w.timingEnabled_), - trace_id_(w.trace_id_), - distDebugLevel_(w.distDebugLevel_) { - exception_ = w.exception_; -} - -ProcessGroupNCCL::WorkNCCL::~WorkNCCL() = default; - -bool ProcessGroupNCCL::WorkNCCL::isCompleted() { - if (!ncclComm_->isAborted()) { - checkAndSetException(); - } - return exception() || finishedGPUExecutionInternal(); -} + // Checks if the NCCL kernel has started to execute. + bool isStarted(); -bool ProcessGroupNCCL::WorkNCCL::isStarted() { - if (!ncclComm_->isAborted()) { - checkAndSetException(); - } - return exception() || startedGPUExecutionInternal(); -} + // Checks if request has completed. In this specific case of NCCL, it checks + // if the NCCL operation has completed on the GPU in its own NCCL stream. + // Non-blocking operation. + bool isCompleted() override; -bool ProcessGroupNCCL::WorkNCCL::isSuccess() const { - C10_THROW_ERROR(NotImplementedError, "WorkNCCL::isSuccess() is deprecated"); -} + bool isSuccess() const override; -void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { - if (exception()) { - // We already have an exception. - return; - } + // Same as calling synchronize() for NCCL work. + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; - auto exception_ptr = checkForNCCLErrors(); - std::unique_lock lock(mutex_); - exception_ = exception_ptr; - if (exception_) { - LOG(ERROR) << logPrefix() << "Collective " << *this - << " raised the following async exception: " - << getExceptionMsgFromExceptionPtr(exception_); - } -} - -const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { - static std::string prefix = c10::str("[Rank ", rank_, "] "); - return prefix; -} - -void ProcessGroupNCCL::WorkNCCL::setException( - std::exception_ptr exception_ptr) { - std::unique_lock lock(mutex_); - exception_ = exception_ptr; -} - -// Helper that checks if the NCCL kernels are completed on the GPUs -bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() { - checkAndSetException(); - return finishedGPUExecutionInternal(); -} - -bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const { - // if timing is disabled we won't have allocated start events - if (!timingEnabled_) { - return false; - } - // Checking the work's corresponding CUDA event's status - if (!ncclStartEvent_->query()) { - return false; - } - return true; -} - -bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { - // Checking the work's corresponding CUDA event's status - // It calls `cudaEventQuery` eventually. Although this seems to be a - // non-blocking call, but we did notice hangs in the past. It can - // hang if another thread is holding the CUDA global context lock. For - // example, when doing a `cudaDeviceSynchronize` or even - // `cudaStreamSynchronize`. - if (!ncclEndEvent_->query()) { - return false; - } - return true; -} - -bool ProcessGroupNCCL::WorkNCCL::checkTimeout( - std::optional timeout) { - STATIC_SCOPED_WAIT_COUNTER( - pytorch.wait_counter.ProcessGroupNCCL__checkTimeout); - auto currentTimepoint = std::chrono::steady_clock::now(); - auto timeElapsed = std::chrono::duration_cast( - currentTimepoint - workStartTime_); - auto workTimeout = timeout ? *timeout : opTimeout_; - - if (timeElapsed < workTimeout) - return false; + void abort() override; - // Timed out + // Let current stream wait on the completing of the NCCL work + // Throws on exceptions. Blocking operation, which will wait for work + // completion. + void synchronize() override; - // There is already an error, we don't override it - if (exception()) - return true; + // Synchronize streams by blocking each on the NCCL stream + void synchronizeStream(); - std::string exceptionMsg = c10::str( - logPrefix(), - "Watchdog caught collective operation timeout: ", - *this, - " ran for ", - timeElapsed.count(), - " milliseconds before timing out."); - - LOG(ERROR) << exceptionMsg; - std::exception_ptr exception_ptr = - std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exceptionMsg)); - setException(exception_ptr); - return true; -} - -void ProcessGroupNCCL::WorkNCCL::handleException( - ErrorHandlingMode errorHandling) { - if (exception_) { - auto exceptionMsg = c10::str( - "Some NCCL operations have failed or timed out. Due to the ", - "asynchronous nature of CUDA kernels, subsequent GPU operations ", - "might run on corrupted/incomplete data."); - LOG(ERROR) << logPrefix() << exceptionMsg; - C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException"); - - if (SHOULD_TEAR_DOWN(errorHandling)) { - auto tearDownMsg = c10::str( - "To avoid data inconsistency, we are taking the entire process down."); - LOG(ERROR) << logPrefix() << tearDownMsg; - std::rethrow_exception(exception_); - } - } -} + // Helper function to handle exception (throw if needed). + void handleException(ErrorHandlingMode asyncErrorHandling); -void ProcessGroupNCCL::WorkNCCL::synchronize() { - // Call Synchronize without a timeout. We use this method to avoid adding a - // timeout argument to the public synchronize API. - synchronizeInternal(kNoTimeout); -} + // Helper function that checks if the NCCL kernels have finished + // execution on the GPUs + bool finishedGPUExecution(); -void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { - auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); - // Block the current stream on the NCCL stream - ncclEndEvent_->block(currentStream); + // Get a Future object that will be marked as completed internally. + c10::intrusive_ptr getFuture() override; - if (avoidRecordStreams_) { - stashed_for_allocator_safety_->clear(); - } -} - -// Waiting on the work's corresponding CUDA events -void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( - std::chrono::milliseconds timeout) { - synchronizeStream(); - - // In case of blocking, wait for the operation to complete. - if (blockingWait_) { - while (!isCompleted()) { - bool timedOut = checkTimeout( - timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); - // Explicitly abort ncclComms here before throwing this timed out - // exception to users. - // If throwing timed out excepiton without aborting nccl communicators - // here, it was observed that CUDA GPU will have 100% utilization and - // can not run new events successfully. - if (timedOut) { - std::string exceptionMsg = c10::str( - logPrefix(), - "Work ", - (*this), - " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1)."); - LOG(ERROR) << exceptionMsg; - break; - } - // Yield - std::this_thread::sleep_for( - std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); - } - // exception() includes timeout and error during blocking wait - if (exception()) { - // Abort NCCL communicators - abort(); - // Throw exception (from main thread here) - handleException(TearDown); - } - } + float getDuration() const override; - // Device synchronize only after we've completed timeout checks. - if (barrierTensor_.defined()) { - // If we use the work to do barrier, we should block here - // `dist.barrier()` only requires all CPU processes to enter this - // function, hence we only need to make sure the dummy all-reduce has - // completed. So we would only need to sync the **current stream** back to - // host, and do not need to synchronize the entire device (which may have - // kernels running on other streams). - // Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can: - // - lower chance of hang; - // - CurrentCUDAStream is usually the context of the next operation in - // Python, thus blocking current stream would already block the next - // compute kernel; - // - achieve better barrier performance. - auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); - // CUDAStream wrapper will correctly use a DeviceGuard here - currentStream.synchronize(); - } -} - -// Same as calling synchronize(). -bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { - RECORD_PARAM_COMMS( - static_cast(this->seq_), // seq - std::make_tuple(pgUID_, pgDesc_), // PG name tuple - rank_, // rank - "wait", // collective name - 0, // inNelems - 0, // outNelems - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, - -1, - static_cast(1)); // number of device? - synchronizeInternal(timeout); - // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL - // upgrade. Once a NCCL version is qualified, this code should not be needed - // at runtime. -#ifdef PGNCCL_ENABLE_HASH - if (distDebugLevel_ >= DebugLevel::Detail) { - auto numel = getTensorsNumel(*outputs_); - auto hashValue = hashTensors(*outputs_); - PRINT_COLLECTIVE_HASH_SIGNATURE( - "output", opTypeToString(opType_), numel, hashValue); - } -#endif - // Always return true, because abort API is not implemented. - return true; -} - -void ProcessGroupNCCL::WorkNCCL::abort() { - // Abort all communicators of this work - ncclComm_->ncclCommAbort(); - - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.erase(ncclComm_); - ncclCommDevIdxMapMutex.unlock(); -} - -ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} - -// CUDA event is used to record the start/end of one Work. -// Instead of let the CUDA event gets destroyed, we now reuse it after the Work -// has been erased from workMetaList_. -// This is to avoid the potential deadlock caused by CudaEventDestroy. -std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( - bool timing) { - auto deleter = [this, timing](at::cuda::CUDAEvent* event) { - std::lock_guard lock(this->cacheMutex_); - this->eventsArray_[timing ? 1 : 0].push_back(event); - }; - at::cuda::CUDAEvent* event = nullptr; - { - std::lock_guard lock(cacheMutex_); - auto events = eventsArray_[timing ? 1 : 0]; - if (!events.empty()) { - event = events.back(); - events.pop_back(); - } - } - if (!event) { - event = new at::cuda::CUDAEvent( - timing ? cudaEventDefault : cudaEventDisableTiming); - } - return std::shared_ptr(event, std::move(deleter)); -} - -ProcessGroupNCCL::CUDAEventCache& ProcessGroupNCCL::CUDAEventCache::get() { - static ProcessGroupNCCL::CUDAEventCache cache; - return cache; -} - -static std::atomic process_group_id = 0; - -constexpr const char* MULTI_DEVICE_ERROR_MSG = - "Expecting one tensor only but got multiple. You are probably using multiple " - "devices under one thread. The support for such usage has been deprecated. " - "For details, please refer to " - "https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. " - "ProcessGroupNCCL continues supporting multi-process and multi-thread modes."; - -ProcessGroupNCCL::ProcessGroupNCCL( - const c10::intrusive_ptr& store, - int rank, - int size, - c10::intrusive_ptr options) - : Backend(rank, size), - store_(store), - options_(options), - ncclCommCounter_(0), - traceKeyStart_(getTraceStartKey("NCCL", rank)), - traceKeyEnd_(getTraceEndKey("NCCL", rank)), - terminateProcessGroup_(false), - terminateHeartbeatMonitorThread_(false), - collectiveDebugInfoMode_(false), - local_id_(process_group_id++), - intraNodeComm_(initIntraNodeComm()) { - TORCH_CHECK_WITH( - ValueError, - at::cuda::getNumGPUs() != 0, - "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); - - // getNcclVersion needs to get called before launching threads which can - // potentially call getenv. getNcclVersion internally calls setenv to set some - // environment variables from config file, which can race with getenv from - // other threads and cause segfaults. - const auto ncclVersion = getNcclVersion(); - this->setGroupUid(options_->group_name); - this->localDeviceCount_ = at::cuda::getNumGPUs(); - logPrefix_ = createLogPrefix(); - blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false); - asyncErrorHandling_ = static_cast( - getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); - desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || - (dist_debug_level_ >= DebugLevel::Detail); - rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true); - // TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT - // or change its name to reflect that dump happens on exception including - // both timeout and other errors. - dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) || - (dist_debug_level_ >= DebugLevel::Detail); - sleepAfterException_ = getCvarBool(TORCH_NCCL_SLEEP_AFTER_EXCEPTION, false); - // logging C++ stack isn't safe. Introduce a variable to control it. - logCppStackOnUncleanShutdown_ = - getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true); - enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false); - heartbeat_ = 1ULL; - monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true)); - cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, false)); - heartbeatTimeoutInSec_ = - getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/); - waitTimeoutDumpInMilSec_ = - getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/); - coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000); - ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0); - enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); - // store_ usually is wrapped with PrefixStore and the prefix is different - // across different ProcessGroupNCCL(PG) instances. We need to get the - // underlying non-PrefixStore for sharing global information shared across - // different PGs. - PrefixStore* prefixStore = dynamic_cast(store_.get()); - globalStore_ = - prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; -#ifdef ENABLE_NCCL_ERROR_CHECKING - enableTiming_.store( - getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); -#endif - avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false); -#ifdef NCCL_HAS_COMM_REGISTER - useTensorRegisterAllocatorHook_ = - getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); - if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: - expandable_segments()) { - useTensorRegisterAllocatorHook_ = false; - LOG(INFO) - << logPrefix() - << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; - } -#endif + uint64_t getSequencenumber() const override; - if (blockingWait_) { - if (asyncErrorHandling_ != NoHandling || desyncDebug_) { - LOG(INFO) - << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and " - << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG" - << "should not both be enabled. " - << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process."; - asyncErrorHandling_ = NoHandling; - desyncDebug_ = false; - } - } else { - if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { - LOG(INFO) - << logPrefix() - << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " - << "must both be enabled. " - << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING."; - asyncErrorHandling_ = SkipCleanUp; - } - } + const std::string& logPrefix() const; -#ifdef ENABLE_NCCL_ERROR_CHECKING - ncclCommWatchdogThread_ = - std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); -#endif + // Helper function that sets an exception_ptr on the WorkNCCL object. + void setException(std::exception_ptr exception_ptr); - init(); - const std::string OFF = "OFF"; - std::string torch_distributed_debug = - getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); - LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: " - << "size: " << size << ", global rank: " << globalRank() - << ", TIMEOUT(ms): " << options_->timeout.count() - << ", USE_HIGH_PRIORITY_STREAM: " - << options_->is_high_priority_stream - << ", SPLIT_FROM: " << options_->split_from - << ", SPLIT_COLOR: " << options_->split_color - << ", PG Name: " << options_->group_name; - - LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: " - << "NCCL version: " << ncclVersion - << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ - << ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeoutOrEx_ - << ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: " - << waitTimeoutDumpInMilSec_ - << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_ - << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load() - << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_ - << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug -#ifdef NCCL_HAS_COMM_REGISTER - << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " - << useTensorRegisterAllocatorHook_ -#endif - << ", TORCH_NCCL_ENABLE_MONITORING: " - << monitorThreadEnabled_.load() - << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_ - << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_ - << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_ - << ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_ - << ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_ - << ", TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: " - << logCppStackOnUncleanShutdown_; - - if (options_->global_ranks_in_group.empty()) { - this->globalRankStart = 0; - } else { - this->globalRankStart = options_->global_ranks_in_group[0]; - } + // Helper function that returns True if the WorkNCCL object has timed out + // and False otherwise. + // In case of timeout, set exception on the WorkNCCL object. + bool checkTimeout( + std::optional timeout = std::nullopt); + + std::vector result() override; - if (options_->global_ranks_in_group.empty()) { - this->globalRankStride = 1; - } else if (options_->global_ranks_in_group.size() == 1) { - this->globalRankStride = 0; - } else { - bool ranksAreStrided = true; - int startRank = options_->global_ranks_in_group[0]; - int stride = - options_->global_ranks_in_group[1] - options_->global_ranks_in_group[0]; - for (std::vector::size_type i = 0; - i < options_->global_ranks_in_group.size(); - i++) { - if (options_->global_ranks_in_group[i] != startRank + i * stride) { - ranksAreStrided = false; - break; - } - } + protected: + // The process group unique id + std::string pgUID_; + + // The process group description + std::string pgDesc_; + + // The cached list of CUDA devices to operate on + at::Device device_; + + // The start CUDA event of NCCL operator tracking this work item. These + // start CUDA events are needed by desync debugging if enabled. + std::shared_ptr ncclStartEvent_; + + // The end CUDA event of NCCL operator tracking this work item. + std::shared_ptr ncclEndEvent_; + + // The NCCL communicator used for this work item. + std::shared_ptr ncclComm_; + + // Tensors used for barrier op + at::Tensor barrierTensor_; + + // Clone of blockingWait_ from ProcessGroupNCCL. + bool blockingWait_ = false; + + // Clone of avoidRecordStreams_ from ProcessGroupNCCL. + bool avoidRecordStreams_ = false; - if (ranksAreStrided) { - this->globalRankStride = options_->global_ranks_in_group[1] - - options_->global_ranks_in_group[0]; - } else { - this->globalRankStride = -1; - } - } + // Clone of opTimeout_ from ProcessGroupNCCL. + std::chrono::milliseconds opTimeout_; - // Attach hooks to cache allocator to trigger the hooks whenever a traced - // action is called. In the following hooks, we register a newly allocated - // segment when SEGMENT_ALLOC action occurs, and deregister a segment when - // SEGMENT_FREE action occurs. - // We attach hooks only once at the first PG creation. - // Attaching hooks fails if CUDACachingAllocator is not initialized, so - // lazyInitCUDA is called (and is a no-op if CUDA is already initialized). - if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { - at::globalContext().lazyInitCUDA(); - c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( - &cacheAllocatorRegisterHook); - c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( - &cacheAllocatorDeregisterHook); - allocatorHooksAttached = true; - } -} - -void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { - const auto key = getKeyFromDevice(device); - LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device " - << device; - getNCCLComm(key, device, OpType::ALLREDUCE); -} - -void ProcessGroupNCCL::performNocolorSplit(at::Device device) { - // If our backend doesn't support splitting, this is a no-op for - // ranks not in the new subgroup (and ranks that would be in it will - // just use a new communicator rather than split). -#ifdef NCCL_HAS_COMM_SPLIT - const auto key = getKeyFromDevice(device); - LOG(INFO) << logPrefix() << "Performing nocolor split on backend device " - << device << ", key " << key << ", i am " << this; - auto comm = getNCCLComm(key, device, OpType::ALLREDUCE); - NCCLComm::split( - comm.get(), - NCCL_SPLIT_NOCOLOR, - rank_, - options_->config, - options_->global_ranks_in_group); -#endif -} + // Ephemeral timeouts are owned by exactly one work, + // and reset after that work completes. + // There may be more than one ephemeral timeout active at the same time, + // and this variable is used to track the ownership of ephemeral timeout. + std::chrono::milliseconds ownedEphermeralTimeout_ = + std::chrono::milliseconds(0); -c10::intrusive_ptr ProcessGroupNCCL:: - initIntraNodeComm() { - using IntraNodeComm = intra_node_comm::IntraNodeComm; - if (!IntraNodeComm::isEnabled()) { - return nullptr; - } - auto prefixStore = c10::make_intrusive("IntraNodeComm", store_); - auto comm = c10::make_intrusive(prefixStore, rank_, size_); - if (comm->rendezvous()) { - return comm; - } else { - return nullptr; - } -} - -void ProcessGroupNCCL::setSequenceNumberForGroup() { -} // NCCL just starts sequence numbers at 0. - -uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { - return seqCollective_; -} - -void ProcessGroupNCCL::registerOnCompletionHook( - std::function)>&& hook) { - TORCH_CHECK_WITH( - DistBackendError, - onCompletionHook_ == nullptr, - "ProcessGroupNCCL OnCompletion hook already registered"); - - TORCH_CHECK_WITH( - ValueError, - enableTiming_.load(), - "ProcessGroupNCCL OnCompletion hook requires recording start and end " - "events which require setting TORCH_NCCL_ENABLE_TIMING environment variable. " - "This is only available for NCCL version >= 2.4."); - onCompletionHook_ = std::move(hook); - onCompletionHookThread_ = std::thread(&ProcessGroupNCCL::runHookLoop, this); -} - -// must release GIL when calling this method -void ProcessGroupNCCL::waitForPendingWorks() { - // Reasoning about hook completion: - // 1. waitForPendingWorks should be called after user code has finished - // calling - // all collectives. This means, when we got here, all of the collectives - // are either in workMetaList_ or has been erased from workMetaList_. - // 2. The watchdog thread grabs both locks to move Work object from the - // workMetaList_ to the completedWorkList_, and the hook thread only erases - // a Work object after the hook is returned. Therefore, after user code - // calls a collective, its Work object is either in workMetaList_ or in - // completedWorkList_ before it finishes. - // 3. We have three threads and two locks. - // a. main thread (this function) grabs two locks atomically - // b. watchdog thread (watchdogHandler function) always grabs - // workMetaListMutex_ - // first and then grabs completedWorkListMutex_. - // c. hook thread (runHookLoop function) only grabs - // completedWorkListMutex_. Therefore, locks are always acquired in the - // same order and hence no deadlocks. - while (true) { - { - std::lock(workMetaListMutex_, completedWorkListMutex_); - std::lock_guard lockWork(workMetaListMutex_, std::adopt_lock); - std::lock_guard lockHook( - completedWorkListMutex_, std::adopt_lock); - - if (workMetaList_.empty() && completedWorkList_.empty()) { - return; - } - } + // Time point representing when the work started. + std::chrono::time_point workStartTime_; - std::this_thread::sleep_for( - std::chrono::milliseconds(kWatchdogThreadSleepMillis)); - } -} - -void ProcessGroupNCCL::enableCollectivesTiming() { - enableTiming_.store(true); -} - -void ProcessGroupNCCL::waitForFutureOrTimeout( - std::future& fut, - const std::chrono::milliseconds& timeOutMilSec, - const std::string& futDescription, - bool throwException, - bool log) { - std::string errorMsg; - - ::c10d::C10dLoggingData data; - if (log) { - data.integers["pg_id"] = local_id_; - data.integers["rank"] = rank_; - data.integers["global_rank"] = globalRank(); - data.strings["flight_recorder_version"] = c10d::version_val_str; - } + // Record the collective sequential number. + uint64_t seq_; - TORCH_CHECK(fut.valid(), "Expected a valid future"); - std::future_status status = fut.wait_for(timeOutMilSec); - if (status == std::future_status::ready) { - // Calling .get() will re-raise any exception from the future, and we don't - // care about the retval - try { - bool result = fut.get(); - if (result) { - LOG(INFO) << logPrefix() - << "future is successfully executed for: " << futDescription; - if (log) { - data.strings["status"] = "SUCCESS"; - } - } - } catch (const std::exception& e) { - errorMsg = c10::str( - logPrefix(), - "Exception thrown when waiting for future ", - futDescription, - ": ", - e.what()); - if (log) { - data.strings["status"] = "EXCEPTION"; - data.strings["exception"] = e.what(); - } - LOG(ERROR) << errorMsg; - } catch (...) { - errorMsg = c10::str( - logPrefix(), - "Unknown exception thrown when waiting for future ", - futDescription); - if (log) { - data.strings["status"] = "EXCEPTION"; - data.strings["exception"] = "Unknown exception"; - } - LOG(ERROR) << errorMsg; - } - } else { - errorMsg = c10::str( - logPrefix(), - "Future for ", - futDescription, - " timed out after ", - timeOutMilSec.count(), - " ms"); - data.strings["status"] = "TIMEOUT"; - LOG(ERROR) << errorMsg; - } - if (log) { - auto logger = c10d::C10dLogger::getLogger(); - if (logger) { - logger->log(data); - } - } - if (throwException && !errorMsg.empty()) { - C10_THROW_ERROR(DistBackendError, errorMsg); - } -} - -void ProcessGroupNCCL::abortCommsFromMap( - std::unordered_map>& ncclCommsMap, - std::optional abortReason) { - // The process may control multiple devices, loop through the communicators on - // each device - for (auto& it : ncclCommsMap) { - auto& devName = it.first; - auto& ncclComm = it.second; - at::cuda::OptionalCUDAGuard gpuGuard; - at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); - if (deviceIndex >= 0) { - // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in - // the map could be non deviceIndex, but rank to rank numbers. So we - // indeed need to check if deviceIndex >= 0 - // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` - gpuGuard.set_index(deviceIndex); - } - LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " - << ncclComm->ncclComm_ << " on CUDA device: " << devName; - ncclComm->ncclCommAbort(abortReason); - // Note that we don't remove the aborted communicators from the - // cache. The reason is that if we do remove the communicator - // from the cache, it is possible that a new collective operation - // calls `ncclCommInitRank` to create a new communicator whereas - // other ranks might have failed/timed out and didn't enter - // `ncclCommInitRank`. As a result, when there is a failure on - // a communicator the application receives an exception and its - // their responsibility to destroy the process group and recreate - // it to recover from errors. - - LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed " - << " communicator on CUDA device: " << devName; - } -} - -// Abort all communicators on this rank -bool ProcessGroupNCCL::abort(std::optional abortReason) { - // This will log counter for how long the abort actually takes. - STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); - // Remove record from global ncclCommDevIdxMapMutex before aboarting, - // so that a new cache segment would not register to already aborded - // communicators. Note that ncclCommDevIdxMap is a global container which may - // contain other PG's communicators, thus we need to only erase communicators - // for the current PG. - ncclCommDevIdxMapMutex.lock(); - for (auto& it : devNCCLCommMap_) { - auto& ncclComm = it.second; - ncclCommDevIdxMap.erase(ncclComm); - } - ncclCommDevIdxMapMutex.unlock(); - - std::lock_guard lock(mutex_); - abortCommsFromMap(devNCCLCommMap_, abortReason); - abortCommsFromMap(inInitializationCommMap_, abortReason); - return true; -} - -void ProcessGroupNCCL::shutdown(std::optional reason) { - // Don't join threads here since the purpose of this method is to abort all - // communicators and signal the threads to exit. Joining on the threads could - // potentially block and hence avoid it in this method. - terminateProcessGroup_.store(true); - workMetaListCV_.notify_one(); - - // lauch abort asynchrounously and wait for it to complete or timeout - LOG(INFO) << logPrefix() - << "Launching ProcessGroupNCCL abort asynchrounously."; - std::future fut = std::async( - std::launch::async, [this, &reason]() { return this->abort(reason); }); - - waitForFutureOrTimeout( - fut, options_->timeout, "ProcessGroup abort", true, false); - LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully."; - - // We need to wait for abort to finish before we can safely shut down - // heartbeat monitoring thread. - terminateHeartbeatMonitorThread_.store(true); - monitorWakeUpCV_.notify_one(); -} - -ProcessGroupNCCL::~ProcessGroupNCCL() { - LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered."; - - if (!terminateProcessGroup_.load()) { - if (rank_ % localDeviceCount_ == 0) { - TORCH_WARN_ONCE( - "WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. ", - "On normal program exit, the application should call destroy_process_group to ", - "ensure that any pending NCCL operations have finished in this process. " - "In rare cases this process can exit before this point and block the progress of " - "another member of the process group. This constraint has always been present, " - " but this warning has only been added since PyTorch 2.4"); - } - // If user haven't explicitly destroy/shutdown process group, destructor - // needs to do so - shutdown(); - } + // Indicates if the nccl start event has been updated to the store trace. + // This will be used by desync debug. + bool startTraceUpdated_{false}; - // Wait for all threads to finish before returning -#ifdef ENABLE_NCCL_ERROR_CHECKING - if (ncclCommWatchdogThread_.joinable()) { - ncclCommWatchdogThread_.join(); - LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; - } - if (ncclHeartbeatMonitorThread_.joinable()) { - ncclHeartbeatMonitorThread_.join(); - LOG(INFO) << logPrefix() - << "ProcessGroupNCCL heart beat monitor thread joined."; - } -#endif - if (onCompletionHookThread_.joinable()) { - onCompletionHookThread_.join(); - LOG(INFO) << logPrefix() - << "ProcessGroupNCCL onCompletionHookThread thread joined."; - } -} - -bool ProcessGroupNCCL::dumpDebuggingInfo() { - // Serialize all calls to this function to avoid corrupting data, but allow - // multiple calls in one runtime. User is responsible for preserving the - // output file from an earlier call before a later call overwrites it. - static std::mutex writeDebugInfoMutex; - std::lock_guard lock(writeDebugInfoMutex); - LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; - if (ncclTraceBufferSize_ > 0) { - // We dump nccl trace into local disk by default and users can register - // their customized writer by inheriting `DebugInfoWriter` via - // `registerDebugInfoWriter`. - auto ncclTrace = dump_nccl_trace(true, true, false); - DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); - LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to " - << writer.getWriterTarget(); - writer.write(ncclTrace); - return true; - } - return false; -} - -void ProcessGroupNCCL::terminateProcess(std::string errMsg) { - // Logging with `FATAL`, after errMsg printed, it calls `std::abort()` - // to terminate the program execution. - LOG(FATAL) << logPrefix() << errMsg; -} - -int computeDeltaMS( - std::chrono::time_point start, - std::chrono::time_point end) { - return std::chrono::duration_cast(end - start) - .count(); -} - -std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg( - const std::string& extraMsg) { - return c10::str( - logPrefix(), - "Received a dump signal due to a collective timeout from ", - extraMsg, - " and we will try our best to dump the debug info. ", - "Last enqueued NCCL work: ", - pgStatus_->lastEnqueuedSeq, - ", last completed NCCL work: ", - pgStatus_->lastCompletedSeq, - ".", - "This is most likely caused by incorrect usages of collectives, e.g., wrong ", - "sizes used across ranks, the order of collectives is not same for all ranks ", - "or the scheduled collective, for some reason, didn't run. Additionally, ", - "this can be caused by GIL deadlock or other reasons such as network errors or ", - "bugs in the communications library (e.g. NCCL), etc. "); -} - -std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg( - const std::string& exitReason) { - return c10::str( - logPrefix(), - "Terminating the process after attempting to dump debug info, due to ", - exitReason, - "."); -} - -void ProcessGroupNCCL::heartbeatMonitor() { - c10::setThreadName("pt_nccl_heartbt"); - - uint64_t heartBeatCounter = 0ULL; - std::string errorMsg; - std::string exitReason; - bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0); - int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_ - : heartbeatTimeoutInSec_ * 1000; - auto lastTimePollStore = std::chrono::steady_clock::now(); - auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now(); - std::optional dumpPipe = std::nullopt; - if (local_id_ == 0) { - // DumpPipe is one per-trainer process, and its convenient to name them - // after 'global' ranks in the system, So we assume processgroup (uid)==0 is - // the global PG and has globally unique rank ids across trainers. - dumpPipe.emplace(rank_); - } - while (true) { - // This won't have any lock since this lock is only used here. - // Please be aware that mutex `monitorMutex_` should not be used - // somewhere else to avoid the deadlock. - std::unique_lock lock(monitorMutex_); - if (monitorWakeUpCV_.wait_for( - lock, std::chrono::milliseconds(monitorPollInterval), [&] { - return terminateHeartbeatMonitorThread_.load(); - })) { - // For the normal complete or user interception, monitorWakeUpCV_ - // will get notified, we early return and exit heartbeatMonitor. - return; - } - auto currentTime = std::chrono::steady_clock::now(); - - // We put extra functionality in the thread for the default PG (aka, - // local_id_=0) because the signal is same across different PGs. We only - // need to run once per process to avoid duplicate things performed in too - // many separate threads. For example, we check a global flag on the - // TCPStore periodically to see if any PG on any rank observed a timeout and - // signaled peers to dump debugging info, and we avoid hammering the - // TCPStore from all PGs on the same rank. - if (checkDumpSignal) { - // There are two scenarios where monitor thread will dump on timeout: - // 1. The current rank is the first to observe a timeout in watchdog. - // (shouldDump_ was set to true by the watchdog thread). - // 2. Other ranks detected the timeout and signal the current rank to - // dump. In addtion, monitor threads will dump if watchdog threads has no - // heartbeat or dumpPipe is not empty. - if (shouldDump_.load()) { - errorMsg = getNCCLWatchdogTimeoutErrorMsg("this local rank"); - exitReason = "collective timeout or exception"; - break; - } - // We poll store to see if some ranks have flagged a timeout when - // we haven't polled for `heartbeat_timeout` seconds and there haven't - // any work added or removed for `watchdog_timeout` seconds. - if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >= - kWatchdogThreadSleepMillis && - computeDeltaMS(lastTimePollStore, currentTime) >= - coordCheckIntervalMilSec_) { - lastTimePollStore = currentTime; - // Wrap globalStore_->check() in a try-catch block to avoid crashing if - // the store is not available. - bool checkExceptionDump = false; - try { - checkExceptionDump = - globalStore_->check({std::string(EXCEPTION_DUMP)}); - } catch (const std::exception& e) { - LOG(WARNING) - << logPrefix() - << "Failed to check the \"should dump\" flag on TCPStore, " - << "(maybe TCPStore server has shut down too early), with error: " - << e.what(); - // We give up for now assuming TCPStore has been torn down. - return; - } - - if (checkExceptionDump) { - int timeOutRank = -1; - if (!shouldDump_.load()) { - LOG(ERROR) - << logPrefix() - << "Observed flight recorder dump signal from another rank via TCPStore."; - } - shouldDump_.store(true); - try { - auto vec = globalStore_->get(std::string(EXCEPTION_DUMP)); - TORCH_CHECK_WITH( - DistBackendError, - vec.size() == sizeof(int), - "Invalid size for the timeout rank ID"); - std::memcpy(&timeOutRank, vec.data(), vec.size()); - } catch (const std::exception& e) { - LOG(ERROR) << logPrefix() - << "Failed to get timeout rank ID from TCPStore." - << e.what(); - } - errorMsg = - getNCCLWatchdogTimeoutErrorMsg(c10::str(" rank ", timeOutRank)); - exitReason = "collective timeout or exception"; - break; - } - } - } + // Record collective sizes for debug. We only record the size on the first + // device as multi-device per process is deprecated + size_t numelIn_ = -1; + size_t numelOut_ = -1; - if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >= - heartbeatTimeoutInSec_ * 1000) { - // Check the heart beat of watchdog thread. - lastTimeHeartBeatCheck = currentTime; - auto heartbeat = heartbeat_.load(); - if (heartbeat != heartBeatCounter) { - heartBeatCounter = heartbeat; - } else { - shouldDump_.store(true); - // Watchdog heartbeat timeout. - errorMsg = c10::str( - logPrefix(), - "ProcessGroupNCCL's watchdog got stuck for ", - heartbeatTimeoutInSec_, - " seconds without making progress in monitoring enqueued collectives. ", - "This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, ", - "and could be triggered by another thread holding the GIL inside a ", - "CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.", - "If you suspect the watchdog is not actually stuck and a longer timeout would help, ", - "you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value " - "or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)." - "If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout " - "or false positive abort; otherwise, please attempt to debug the hang. "); - exitReason = "ProcessGroupNCCL watchdog hang"; - break; - } - } - // process a request to dump the trace. only PG uid 0 will respond to dump - // requests, but this is fine since all PG's feed into the same flight - // recorder and dump. After dump, the training should continue. - if (dumpPipe.has_value() && dumpPipe->shouldDump()) { - // best effort dump, not waiting for the dump here - std::future fut = std::async( - std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); - } - } - LOG(ERROR) << errorMsg; + // Wrapper method for the static checkForNCCLErrors which can be overridden + // for tests. + virtual std::exception_ptr checkForNCCLErrors(); - auto& cpp_dumper = get_cpp_trace_dumper(); - if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) { - LOG(INFO) << "Dumping c++ stacktraces:"; - cpp_dumper.value()([](const std::string& line) { LOG(INFO) << line; }); - } + friend std::ostream& operator<<( + std::ostream& output, + const WorkNCCL& workNCCL); - if (checkDumpSignal && shouldDump_.load()) { - // Store debug info to storage if no other thread does it. (By default to - // local disk) - std::future asyncDebugDump = std::async( - std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); - - // wait for the dump until timeout - log data - waitForFutureOrTimeout( - asyncDebugDump, - std::chrono::milliseconds(waitTimeoutDumpInMilSec_), - "Flight recorder dump in heartbeatMonitor", - false, - true); - } + private: + // Helper function for synchronize + void synchronizeInternal(std::chrono::milliseconds timeout); - if (get_gil_checker() != nullptr) { - auto fut = launchAsyncGilCheck(); - auto kGilCheckTimeout = std::chrono::milliseconds(300); - auto futStatus = fut.wait_for(kGilCheckTimeout); - if (futStatus != std::future_status::ready) { - TORCH_CHECK( - futStatus != std::future_status::deferred, - "Expected the future to have been launched eagerly."); - LOG(ERROR) - << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang"; - } - } else { - LOG(INFO) - << "GIL checker was not registered, perhaps this is a no-python build?"; - } + // Checks for NCCL errors and sets an appropriate exception_ptr. + void checkAndSetException(); - // There are two possible cases for the watchdog thread exit: - // Case one: desync report runs quickly, and it follows the step: - // collective timeout -> desync -> exception handling -> destructors - // -> set terminateHeartbeatMonitorThread_ -> notify monitorWakeUpCV_. - // So the code either early returns above or will skip the sleep below. - // Case two: desync might be slow or get stuck. Or we get stuck in - // destructors, we will sleep for some time before calling std::abort() to - // kill the whole process. - if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() || - shouldDump_.load()) && - !terminateHeartbeatMonitorThread_.load()) { - // Leave another two mins for desync report generation or process group - // destroy. - std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_)); - LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ - << " waiting for desync report or process group destroy."; - } + // Just checks whether GPU execution has started, without modifying + // exception_ptr. + bool startedGPUExecutionInternal() const; - // At this point, we either already sleep for another `heartbeatTimeoutInSec_` - // or the thread has finished. Because we don't want to block the monitor - // thread, so We mark the thread detach and the dump of debug info becomes - // "best effort". If the process exit normally, marking it detach also makes - // sense because we don't really care about dumping the debug info. - - // We already log completion inside the thread, so it may not be necessary to - // check the return value here. We mainly use a future so we can exit early - // if done. - - if (!terminateHeartbeatMonitorThread_.load()) { - // Create a error message reported from MonitorThread, so - // we throw exception and make the whole process to be killed. - // TODO(fduwjj): After having a hang debug wiki, we need to update the wiki - // url here. - if (monitorThreadEnabled_.load()) { - terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason)); - } else { - // Ideally we want to merge this one with the above one, but we are going - // to remove the kill switch for monitor thread soon, so we keep this one - // for now. - LOG(ERROR) - << logPrefix() - << "ProcessGroupNCCL monitor thread is disabled, but would have terminated the process" - << "after attempting to dump debug info, due to " << exitReason - << "."; - } - } -} - -void ProcessGroupNCCL::ncclCommWatchdog() { - c10::setThreadName("pt_nccl_watchdg"); - - try { - VLOG(2) << logPrefix() << "Process group watchdog thread started!"; - ncclHeartbeatMonitorThread_ = - std::thread(&ProcessGroupNCCL::heartbeatMonitor, this); - watchdogHandler(); - VLOG(2) << logPrefix() - << "Process group watchdog thread terminated normally"; - } catch (std::exception& e) { - if (std::string(e.what()).find("driver shutting down") != - std::string::npos) { - LOG(INFO) - << logPrefix() - << "main process destroyed cuda before watchdog loop exited, terminating watchdog." - << " (Watchdog caught exception: " << e.what(); - - } else { - // Append error message reported from watchdogHandler - const auto exitMsg = c10::str( - logPrefix(), - "Process group watchdog thread terminated with exception: ", - e.what()); - LOG(ERROR) << exitMsg; - if (C10_LIKELY(rethrowCUDAErrors_) || - !(std::string(e.what()).find("CUDA Error"))) { - // TODO(whc) clean up the rethrow - why is it stored in a class var and - // rethrown? - watchDogException_ = - std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); - std::rethrow_exception(watchDogException_); - } - } - } catch (...) { - const auto exitMsg = c10::str( - logPrefix(), - "Process group watchdog thread terminated with exception: unknown"); - LOG(ERROR) << exitMsg; - watchDogException_ = - std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); - std::rethrow_exception(watchDogException_); - } -} + // Just checks whether GPU execution has completed, without modifying + // exception_ptr. + bool finishedGPUExecutionInternal() const; -void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) { - if (work.startTraceUpdated_) - return; + // Reference to the store so that we can write aborted communicators + // to the store. + c10::intrusive_ptr store_; - if (terminateProcessGroup_.load() || storeError_) - return; + // Store a reference to NCCL collective's outputs, used by result and to + // give a more descriptive message when representing the Work as a string. + std::shared_ptr> outputs_; - work.startTraceUpdated_ = true; - storeError_ = !c10d::traceUpdate( - store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_)); -} + // TORCH_NCCL_AVOID_RECORD_STREAMS implementation helper. + // Stores references to participating non-output tensors (ie inputs, + // flattened intermediates). + // We'll clear this list in synchronizeStream, just after user-facing + // stream(s) are synced with the nccl work stream(s). + // By keeping these refs (as well as outputs_) alive until after the + // collective's work rejoins the user-facing streams, we achieve + // caching allocator safety without any recordStream calls. + // For in-place collectives, some refs stashed here may alias outputs_, + // but that doesn't do any harm. + std::shared_ptr> stashed_for_allocator_safety_; -void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) { - if (terminateProcessGroup_.load() || storeError_) - return; + // The future returned by getFuture. + c10::intrusive_ptr future_; - // In case the start of the work hasn't been logged - if (!work.startTraceUpdated_) { - logWorkStart(work); - } + bool timingEnabled_; + // unique id used to tell the trace buffer that this + // work has completed + std::optional trace_id_; + DebugLevel distDebugLevel_; + friend class ProcessGroupNCCL; + }; - storeError_ = !c10d::traceUpdate( - store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_)); -} - -std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() { - return retrieveDesyncReport(store_, "NCCL", rank_, size_); -} - -// We want to have both PG ID and global unique ID (guid) for the logging -// prefix. PG ID records how many ProcessGroupNCCL objects were created on a -// specific rank and is a stable index across ranks, which lets users reason -// about, for example, the second PG we initialized on this rank is for FSDP, -// and corresponds with PG ID = 1 on other ranks as well. Unlike PG ID, guid (or -// group name) is a global unique ID across ranks. The guid is either a hash of -// all the ranks in the group or a counter of how many times -// `_process_group_name` is called, essentially it means how many times we -// have PGs users have created. Before using split_group, even if -// we are creating a new sub-PG, all ranks have to call the API at the same -// time, and this makes `group_name` a unique identifier for a group (PG). -std::string ProcessGroupNCCL::createLogPrefix() const { - if (!pg_desc_.empty() && pg_desc_ != "undefined") { - return c10::str( - "[PG ID ", - local_id_, - " PG GUID ", - pg_uid_, - "(", - pg_desc_, - ") Rank ", - rank_, - "] "); - } - return c10::str( - "[PG ID ", local_id_, " PG GUID ", pg_uid_, " Rank ", rank_, "] "); -} - -const std::string& ProcessGroupNCCL::logPrefix() const { - return logPrefix_; -} - -const int& ProcessGroupNCCL::globalRank() const { - static int globalRank = rank_; - return globalRank; -} - -const std::vector& ProcessGroupNCCL::groupRanks() const { - if (options_->global_ranks_in_group.empty() && local_id_ == 0) { - static std::vector globalRanks(size_); - std::iota(globalRanks.begin(), globalRanks.end(), 0); - return globalRanks; - } - return options_->global_ranks_in_group; -} - -void ProcessGroupNCCL::addEphemeralTimeout( - const std::chrono::milliseconds& timeout) { - std::lock_guard timeoutLock(mtxTimeoutExtension_); - ephemeralTimeoutActive_ += timeout; -} - -bool ProcessGroupNCCL::verifyWorkTimeoutForTest( - const c10::intrusive_ptr work, - const std::chrono::milliseconds& timeout) { - // Since collective returns a c10d::Work, we need to cast it to WorkNCCL. - if (auto workNCCL = c10::dynamic_intrusive_pointer_cast(work)) { - // workNCCL is now a c10::intrusive_ptr - return workNCCL->opTimeout_ == timeout; - } - C10_THROW_ERROR( - DistBackendError, "Non c10d::WorkNCCL object returned from collective"); -} - -void ProcessGroupNCCL::watchdogHandler() { - bool done = false; - lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); - auto lastStatusUpdateTime = std::chrono::steady_clock::now(); - std::list completedWorkList; - - while (!done || !terminateProcessGroup_.load()) { - std::unique_lock lock(workMetaListMutex_); - // We busy-poll the work vector every kWatchdogThreadSleepMillis - // milliseconds as long as the atomic is True. - workMetaListCV_.wait_for( - lock, - std::chrono::milliseconds(kWatchdogThreadSleepMillis), - [&]() -> bool { return terminateProcessGroup_.load(); }); - // Bump up heart beat by one. - heartbeat_++; - -// Some versions of GLOG support less-spammy version of LOG_EVERY_MS -// in which case we don't want to spam the logs. -#ifdef LOG_EVERY_MS - // Log the progress of this PG periodically - C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str( - logPrefix(), - "NCCL Work update periodically: ", - "last enqueued NCCL work: ", - pgStatus_->lastEnqueuedSeq, - ", last completed NCCL work: ", - pgStatus_->lastCompletedSeq, - "."); -#endif - auto logger = ::c10d::C10dLogger::getLogger(); - if (logger && - computeDeltaMS( - lastStatusUpdateTime, std::chrono::steady_clock::now()) >= - kWorkStatusUpdatePeriodMs) { - ::c10d::C10dLoggingData data; - // logging integers - data.integers["pg_id"] = local_id_; - data.integers["rank"] = rank_; - data.integers["global_rank"] = globalRank(); - data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq; - data.integers["last_started_work"] = pgStatus_->lastStartedSeq; - data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq; - data.integers["last_enqueued_numel_in"] = pgStatus_->lastEnqueuedNumelIn; - data.integers["last_enqueued_numel_out"] = - pgStatus_->lastEnqueuedNumelOut; - data.integers["last_completed_numel_in"] = - pgStatus_->lastCompletedNumelIn; - data.integers["last_completed_numel_out"] = - pgStatus_->lastCompletedNumelOut; - // logging strings - data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName; - data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName; - data.strings["last_completed_work_name"] = - pgStatus_->lastCompletedWorkName; - data.strings["pg_name"] = pg_uid_; - data.strings["pg_desc"] = pg_desc_; - logger->log(data); - lastStatusUpdateTime = std::chrono::steady_clock::now(); - } + class CUDAEventCache { + public: + CUDAEventCache(); + std::shared_ptr create(bool timing); + static CUDAEventCache& get(); + + private: + std::mutex cacheMutex_; + // NOTE: We intentionaly store raw pointers so that + // we do not attempt to destroy the event objects on process exit, + // because cuda may be gone. + std::vector + eventsArray_[2]; // 0 for timing=false, 1 for timing=true + }; - for (auto it = workMetaList_.begin(); it != workMetaList_.end(); - /* no increment */) { - auto& work = *it; - // When terminateProcessGroup_ is true, communicators have already been - // aborted, So cannot check exception based on them. But watchdog needs to - // finish the check for the works that have already been enqueued to - // workMetaList_ - if (!terminateProcessGroup_.load()) { - work.checkAndSetException(); - } - bool timedOut = work.checkTimeout(); - - // If work hits an exception (either an error or timeout) - if (work.exception()) { - // log as soon as exception is detected - LOG(ERROR) << c10::str( - logPrefix(), - "Exception (either an error or timeout) detected by watchdog at work: ", - work.seq_, - ", last enqueued NCCL work: ", - pgStatus_->lastEnqueuedSeq, - ", last completed NCCL work: ", - pgStatus_->lastCompletedSeq, - "."); - // try to notify other ranks via global TCPStore to dump the flight - // recorder when a collective timeout or exception happens. Flight - // recorder behavior is independent of desync Debug. - if (dumpOnTimeoutOrEx_) { - try { - auto rank = globalRank(); - auto vec = std::vector( - reinterpret_cast(&rank), - reinterpret_cast(&rank) + sizeof(rank)); - globalStore_->set(std::string(EXCEPTION_DUMP), vec); - if (!shouldDump_.load()) { - LOG(ERROR) - << logPrefix() - << "Broadcasting flight-recorder dump signal to other processes via TCPStore."; - } - // signal the monitor thread on PG0 to start dumping - shouldDump_.store(true); - if (sleepAfterException_) { - // This sleep is used to give time for dumping before throwing - // exception - std::this_thread::sleep_for( - std::chrono::seconds(heartbeatTimeoutInSec_)); - LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ - << " giving time for flight recorder dumps to finish."; - } - } catch (const std::exception& e) { - LOG(ERROR) << logPrefix() - << "Failed to set dump signal in tcpstore. " - << "Error: " << e.what(); - } - } - - if (SHOULD_CLEAN_UP(asyncErrorHandling_)) { - // Abort work and corresponding communicators - work.abort(); - // PG level abort, which would abort all other communicators on this - // rank - abort(); - } - - // Report desync state in case of timeout - if (timedOut) { - LOG(ERROR) << c10::str( - logPrefix(), - "Timeout at NCCL work: ", - work.seq_, - ", last enqueued NCCL work: ", - pgStatus_->lastEnqueuedSeq, - ", last completed NCCL work: ", - pgStatus_->lastCompletedSeq, - "."); - if (desyncDebug_) { - try { - collectiveDebugInfoMode_.store(true); - auto desyncMsg = getNCCLWatchdogDebugInfo(); - LOG(ERROR) << logPrefix() << desyncMsg; - } catch (const std::exception& e) { - LOG(ERROR) - << logPrefix() - << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " - << " Please file an issue. Error: " << e.what(); - } catch (...) { - LOG(ERROR) - << logPrefix() - << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." - << " Please file an issue."; - } - } - } - // Throw exception - work.handleException(asyncErrorHandling_); - } - - // Work status logging for desync debug - if (desyncDebug_) { - if (work.isStarted()) { - logWorkStart(work); - } - if (work.isCompleted()) { - logWorkEnd(work); - } - } - - // a work could be started but not completed, so we should not update - // lastStartedSeq and lastStartedOpName if the work state is checked - // multiple times after the start - if (pgStatus_->lastStartedSeq < static_cast(work.seq_) && - work.isStarted()) { - pgStatus_->lastStartedSeq = work.seq_; - pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); - } - - // Clean up completed work - if (work.isCompleted()) { - { - // Reset the timeout and first work if the work is completed. - std::lock_guard timeoutLock(mtxTimeoutExtension_); - if (work.ownedEphermeralTimeout_.count() > 0) { - ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; - ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; - } - } - pgStatus_->lastCompletedSeq = work.seq_; - pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); - pgStatus_->lastCompletedNumelIn = work.numelIn_; - pgStatus_->lastCompletedNumelOut = work.numelOut_; - NCCLTraceBuffer::get()->retire_id(work.trace_id_, true); - if (onCompletionHook_) { - // Move Work object to completedWorkList_ to be consumed by the hook - // thread - { - const std::lock_guard lock(completedWorkListMutex_); - completedWorkList_.splice( - completedWorkList_.end(), workMetaList_, it++); - } - completedWorkListCV_.notify_one(); - } else { - it = workMetaList_.erase(it); - lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); - } - at::cuda::CUDAGraph::dec_pending_event_queries(); - } else { - // Increment the iterator if the current WorkNCCL object is not - // completed. - ++it; - } - // Increment heartbeat after each work processed, - // in case processing is slowed down (but not hung) by cuda api contention - heartbeat_++; - } - done = workMetaList_.empty(); - } -} - -void ProcessGroupNCCL::runHookLoop() { - c10::setThreadName("pt_nccl_runhook"); - - bool done = false; - while (!done || !terminateProcessGroup_.load()) { - std::unique_lock lock(completedWorkListMutex_); - // We busy-poll the work vector every kWatchdogThreadSleepMillis - // milliseconds as long as the atomic is True. - completedWorkListCV_.wait_for( - lock, - std::chrono::milliseconds(kWatchdogThreadSleepMillis), - [&]() -> bool { - return !completedWorkList_.empty() || terminateProcessGroup_.load(); - }); - - try { - for (auto it = completedWorkList_.begin(); it != completedWorkList_.end(); - /* no increment */) { - const WorkNCCL& work = *it; - // Hook might grab GIL, unlock first to prevent deadlock - lock.unlock(); - - auto timeStarted = - std::chrono::system_clock::now() + - std::chrono::duration_cast( - work.workStartTime_ - std::chrono::steady_clock::now()); - onCompletionHook_(std::make_shared( - work.retrieveOpType(), // OpType - work.getSequencenumber(), // seq - timeStarted, // timeStarted - std::chrono::system_clock::now(), // timeFinished - std::chrono::duration( - work.getDuration()) // activeDuration - )); - - lock.lock(); - it = completedWorkList_.erase(it); - } - } catch (std::exception& e) { - if (std::string(e.what()).find("driver shutting down") != - std::string::npos) { - LOG(INFO) - << logPrefix() - << "main process destroyed cuda before runHookLoop exited, terminating runHookLoop." - << " (runHookLoop caught exception: " << e.what(); - - } else { - // PythonOnCompletionHook has already extracted Python exception message - // and wrapped it with a cpp one. So we no longer need to acquire GIL - // here. - const auto errorStr = c10::str( - "Caught exception on rank ", - rank_, - " while running onCompletion hook for ProcessGroupNCCL: ", - e.what(), - ". Aborting all communicators."); - - // No need to call abort() on WorkNCCL here as that collective has - // already finished successfully at this point. We just need to abort - // the process Abort all NCCL Communicators on this ProcessGroupNCCL - // instance. - abort(errorStr); - } + struct Options : Backend::Options { + // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for + // operations. This is only used when blockingWait_ is enabled. + explicit Options(bool is_high_priority_stream = false); + + // return intrusive_ptr of the object + static c10::intrusive_ptr create( + bool is_high_priority_stream = false) { + return c10::make_intrusive(is_high_priority_stream); } - // Lock is still acquired at this point - done = completedWorkList_.empty(); - } -} - -std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors() { - return checkForNCCLErrorsInternal(ncclComm_); -} - -std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors( - std::shared_ptr& ncclComm) { - return checkForNCCLErrorsInternal(ncclComm); -} - -std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( - std::shared_ptr& ncclComm) { - // Prioritize commFailureReason over checkForNcclError() result if - // commFailureReason is set. - auto commFailureReason = ncclComm->getNcclCommFailureReason(); - if (commFailureReason != std::nullopt) { - return std::make_exception_ptr(C10_BUILD_ERROR( - DistBackendError, - c10::str( - "NCCL communicator encountered error set by ProcessGroupNCCL: ", - *commFailureReason))); - } - ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError(); - // When nonblocking mode is enabled by TORCH_NCCL_USE_COMM_NONBLOCKING, - // ncclInProgress could be returned when there are pending NCCL calls. - // In this case, no exception should be thrown + // Schedule NCCL operations on high priority CUDA streams + bool is_high_priority_stream; + #ifdef NCCL_HAS_COMM_NONBLOCKING - // ncclInProgress is defined only if NCCL_HAS_COMM_NONBLOCKING is defined - if (ncclAsyncErr != ncclSuccess && ncclAsyncErr != ncclInProgress) { -#else - if (ncclAsyncErr != ncclSuccess) { + // Configure ranks + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; #endif - return std::make_exception_ptr(C10_BUILD_ERROR( - DistBackendError, - "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" + - getNcclErrorDetailStr(ncclAsyncErr))); - } - return nullptr; -} + // Optional "parent" backend and color to create communicators from + // via `ncclCommSplit` + std::shared_ptr split_from; + int64_t split_color{0}; + std::vector global_ranks_in_group; + std::string group_name; + }; -void ProcessGroupNCCL::broadcastUniqueNCCLID( - ncclUniqueId* ncclID, - bool isSingleP2POp, - const std::string& p2pKey, - int p2pRank) { - // For collective operations: - // For every NCCL communicator that we create we need to broadcast - // a unique ID from rank 0 to all other ranks. This broadcast is - // done by rank 0 setting a key in the store and all other ranks - // retrieving the contents of that key. A single process group - // may create multiple NCCL communicators, so we use a sequence - // number to differentiate between them. - // For single point-to-point operations: - // The sequence number will only be increased on 2 out of all the - // processes in a Process Group. So all following collective - // operations will see different sequence numbers which will cause - // runtime errors. To avoid that, use the src:target pair instead - // of sequence number for p2p communications. - - std::string storeKey; - if (!isSingleP2POp) { - storeKey = std::to_string(ncclCommCounter_++); - } else { - storeKey = p2pKey; - } - if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) { - auto vec = std::vector( - reinterpret_cast(ncclID), - reinterpret_cast(ncclID) + NCCL_UNIQUE_ID_BYTES); - store_->set(storeKey, vec); - } else { - try { - auto vec = store_->get(storeKey); - TORCH_CHECK_WITH( - DistBackendError, - vec.size() == NCCL_UNIQUE_ID_BYTES, - "Invalid size for ncclUniqueId"); - std::memcpy(ncclID, vec.data(), vec.size()); - } catch (const std::exception& e) { - std::string exceptionMsg = c10::str( - "[", - rank_, - "] is setting up NCCL communicator and " - "retrieving ncclUniqueId from [0] via c10d key-value store by key '", - storeKey, - "', but store->get('", - storeKey, - "') got error: "); - C10_THROW_ERROR( - DistBackendError, - exceptionMsg + e.what() + - ". This may indicate a possible application crash on rank 0 or a network set up issue."); - } catch (...) { - C10_THROW_ERROR( - DistBackendError, - c10::str( - "Unknown exception while [", - rank_, - "] is setting up NCCL communicator and " - "retrieving ncclUniqueId from [0] via c10d key-value store by key '", - storeKey, - "'", - ". This may indicate a possible application crash on rank 0 or a network set up issue.")); - } - } -} - -void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { - std::lock_guard lock(mutex_); - if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { - TORCH_INTERNAL_ASSERT( - false, - "Expected to find key ", - devNCCLCommMapKey, - " in NCCL communicator map."); - } - std::shared_ptr& ncclComm = devNCCLCommMap_[devNCCLCommMapKey]; - // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being - // destroyed, so using ncclCommAbort here. - ncclComm->ncclCommAbort(); - // Remove communicators from the cache. - devNCCLCommMap_.erase(devNCCLCommMapKey); - // Clear used device indices. - usedDeviceIdxs_.clear(); - - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.erase(ncclComm); - ncclCommDevIdxMapMutex.unlock(); -} - -std::shared_ptr ProcessGroupNCCL::getNCCLComm( - const std::string& deviceKey, - at::Device& device, - OpType opType, - int p2pRank, - bool isSendRecvSelf) { - // Sanity check - if (deviceKey.empty()) { - C10_THROW_ERROR( - DistBackendError, - "Not able to create/get the NCCL Communicator since " - "the GPU devices are not known"); - } - if (bound_device_id_) { - if (*bound_device_id_ != device) { - LOG(ERROR) << logPrefix() << "Tensor found on device " << device - << " but backend constrained to " << *bound_device_id_; - C10_THROW_ERROR( - DistBackendError, - "Attempt to perform collective on tensor not on device passed to init_process_group"); - } - } + // If you wish to create multiple process groups, each with a potentially + // different rank and size, you can do so by passing a new store instance + // to each one. If you have only a single store object, you can + // use the `c10d::PrefixStore` to derive scoped instances. + // This is also what the Python API in torch.distributed does. + // + // The process group instance keeps a reference to the store because + // it may be used long after the constructor runs. In fact, the constructor + // doesn't create any NCCL communicators. A single NCCL communicator can + // only be used on a specific set of devices, and are therefore created + // on-demand when a collective runs. If another collective is executed later, + // against a different set of devices, the process group creates another NCCL + // communicator. These NCCL communicators are cached and reused if possible. + // + ProcessGroupNCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + c10::intrusive_ptr options = Options::create()); - usedDeviceIdxs_.insert(device.index()); + // This constructor includes the deprecated `groupName` argument. + // If you have existing code that uses the `groupName`, you can replace + // it by specifying a `c10d::PrefixStore(groupName, store)` for store. + C10_DEPRECATED ProcessGroupNCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + const std::string& groupName, + c10::intrusive_ptr options = Options::create()) + : ProcessGroupNCCL(store, rank, size, options) {} - { - std::lock_guard lock(mutex_); - if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { - // Reuse the cached communicator if there is one. - return devNCCLCommMap_[deviceKey]; - } + ~ProcessGroupNCCL() override; + + // This function returns a local uid for ProcessGroupNCCL. + uint64_t getUid() { + return static_cast(local_id_); } - // NCCL communicator not cached, create a new entry - std::shared_ptr ncclComm; + c10::intrusive_ptr getOptions() { + return options_; + } - // Create the unique NCCL ID and broadcast it - ncclUniqueId ncclID; + const std::string getBackendName() const override { + return std::string(NCCL_BACKEND_NAME); + } - // reset log prefix to include group_desc - logPrefix_ = createLogPrefix(); + bool supportsSplitting() const override { + return true; + } -#ifdef NCCL_COMM_DESCRIPTION - // Pass process group name and description to NCCL communicator - std::string commDesc = pg_desc_ + ':' + pg_uid_; - options_->config.commDesc = strdup(commDesc.c_str()); -#endif + void startCoalescing() override; - // For batch_isend_irecv, ncclGroupStart() would be called upfront - bool batchP2P = ncclActiveGroupCounter_ > 0; - bool singleP2POp = isP2POp(opType, batchP2P); - - // Get the device index - auto deviceIndex = device.index(); - at::cuda::OptionalCUDAGuard gpuGuard(device); - - // [Group Start/End Note] This is used to ensure that nccl communicator will - // be created before communication primitives are called. Let's look at this - // example: Using the batch_isend_irecv to send a tensor to a target process. - // On the sender side, the corresponding underlying NCCL calls will look like - // ncclGroupStart() // This is in batch_isend_irecv - // ncclCommInitRank() // Inside NCCLComm::create - // ncclSend() - // ncclGroupEnd() // This is in batch_isend_irecv - // With this pattern, the nccl communicator will be created in the last - // ncclGroupEnd which means when ncclSend is processed, the passed - // communicator argument is NULL which will lead to runtime error. So we need - // to "close" all active nccl groups to ensure nccl communicator is actually - // created before encountering any communication calls. This is why we need - // the following for loop. - for (const auto i : c10::irange(ncclActiveGroupCounter_)) { - (void)i; - // comms have not been initiated yet, so can only check in blocking-way - C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); - } + c10::intrusive_ptr endCoalescing() override; - // GPU world size and GPU rank - int numRanks, rank; - - if (!singleP2POp) { - // Collective, all-to-all, or batch P2P - numRanks = getSize(); - rank = getRank(); - } else if (isSendRecvSelf) { - // Same process send and recv. - numRanks = 1; - rank = 0; - } else { - // For single point-to-point operation, there are only 2 processes - // involved so the GPU rank is either 0 or 1. - numRanks = 2; - rank = p2pRank; - } + // For specifying a composite optype, such as ALLGATHER and REDUCE_SCATTER + c10::intrusive_ptr endCoalescing(OpType optype); -#ifdef NCCL_HAS_COMM_SPLIT - if (options_->split_from) { - TORCH_CHECK( - options_->split_color != 0, - "Must specify a non-zero color when splitting"); - // Find a valid, healthy communicator to split from if possible. - std::lock_guard lock(options_->split_from->mutex_); - auto& other_comms = options_->split_from->devNCCLCommMap_; - auto dit = other_comms.find(getKeyFromDevice(device)); - if (dit != other_comms.end()) { - auto& parentComm = dit->second; - if (parentComm != nullptr && !parentComm->isAborted()) { - ncclComm = NCCLComm::split( - parentComm.get(), - options_->split_color, - rank, - options_->config, - options_->global_ranks_in_group); - } - } - } -#endif + c10::intrusive_ptr broadcast( + std::vector& tensors, + const BroadcastOptions& opts = BroadcastOptions()) override; - // To simplify conditional nesting, just create the ncclComms[i] - // entry if it hasn't been yet rather than untangling the - // conditions that might have resulted in a split above. - if (!ncclComm) { - if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) { - // For point-to-point communication, lower rank of the two will get unique - // id. - if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { - C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt); - } - - // Broadcast so that each process can have a unique NCCL ID - auto timeStarted = std::chrono::steady_clock::now(); - broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank); - auto timerDeltaMs = - std::chrono::duration_cast>( - std::chrono::steady_clock::now() - timeStarted) - .count() * - 1000; - LOG(INFO) << logPrefix() - << "ProcessGroupNCCL broadcast unique ID through store took " - << timerDeltaMs << " ms"; - } + c10::intrusive_ptr _broadcast_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const BroadcastOptions& opts = BroadcastOptions()); -#ifdef NCCL_HAS_COMM_NONBLOCKING - ncclComm = NCCLComm::create(numRanks, rank, ncclID, options_->config); -#else - ncclComm = NCCLComm::create(numRanks, rank, ncclID); -#endif - } + c10::intrusive_ptr allreduce_sparse( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; - // Creates the NCCL streams - bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); - auto streamVal = at::cuda::getStreamFromPool( - options_->is_high_priority_stream || force_high); + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; - { - std::lock_guard lock(mutex_); - inInitializationCommMap_.emplace(deviceKey, ncclComm); - } + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override; - NCCLTraceBuffer::get()->record_pg_ranks( - std::make_tuple(pg_uid_, pg_desc_), groupRanks()); - - RECORD_PARAM_COMMS( - 0, // seq - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - rank, // rank - "init", // collective name - 0, // inNelems - 0, // outNelems - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - size_); // worldSize - - LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ " - << ncclComm->ncclComm_ << " on CUDA device: " << deviceIndex; - - // At this point NCCL should have been initialized, hence we can accurately - // get the env value even if NCCL sets it by reading from nccl.conf file - LOG(INFO) << logPrefix() - << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A"); - - // See [Group Start/End Note] - for (const auto i : c10::irange(ncclActiveGroupCounter_)) { - (void)i; - C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); - } + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override; - ncclStreams_.emplace(deviceKey, std::move(streamVal)); - - // Note: these events are created with the (default) cudaEventDisableTiming - // flag This flag provides the best performance when used with - // cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't measure the - // performance using cudaEvent, this should be set. - // TODO(kwen2501): is ncclEvents_ used anywhere else? - ncclEvents_.emplace(deviceKey, at::cuda::CUDAEvent(cudaEventDisableTiming)); - - // Move the NCCL resource to cache - auto it = inInitializationCommMap_.find(deviceKey); - // A previous thread could've already removed devicesKey from - // inInitializationCommMap_ and added it to devNCCLCommMap_ - if (it != inInitializationCommMap_.end()) { - devNCCLCommMap_.emplace(deviceKey, std::move(it->second)); - inInitializationCommMap_.erase(deviceKey); - - // Now ncclComms are fully initialized. - // Register all active CUDA memory segments in cache allocator to - // the new NCCL communicators - if (useTensorRegisterAllocatorHook_) { - auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); - // Register the segment to a new NCCL communicator if on the same device - for (const auto& segmentInfo : snapshot.segments) { - TORCH_INTERNAL_ASSERT( - segmentInfo.device == device.index(), - "Mismatch between CUDA memory segment device and current device"); - ncclComm->registerSegment( - reinterpret_cast(segmentInfo.address), - segmentInfo.total_size); - } - } - // Record the mapping between ncclComm and device index so that later - // register hook can register a newly allocated segment to communicators - // on the same device. - // NOTE: we need remove the communicator from this map when it is - // destroyed, otherwise may register onto an invalid communicator. - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.emplace(ncclComm, device.index()); - ncclCommDevIdxMapMutex.unlock(); - } + c10::intrusive_ptr _reduce_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const ReduceOptions& opts = ReduceOptions()); - it = devNCCLCommMap_.find(deviceKey); - TORCH_INTERNAL_ASSERT( - it != devNCCLCommMap_.end(), "Communicators not populated in cache!"); + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; - return it->second; -} + c10::intrusive_ptr _allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; -uint64_t ProcessGroupNCCL::getCommSplitCounter() const { - uint64_t ret = 0; - for (const auto& i : devNCCLCommMap_) { - auto& ncclComm = i.second; - ret += ncclComm->getCommSplitCounter(); - } - return ret; -} - -namespace { - -// Check validity of tensor -void check_gpu_single_tensor( - const at::Tensor& tensor, - const bool p2p = false // whether operation is a P2P operation -) { - if (!tensor.is_cuda() || tensor.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); - } - // Skip the following requirements for P2P operations - if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - if (p2p) { - TORCH_WARN_ONCE( - "Detected non-contiguous tensor in P2P operations. It is user " - "responsibility to guarantee that source and destination tensors have " - "the same contiguity format."); - } else { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); - } - } -} - -// Checks that all `tensors' have the same type and shape and reside on the same -// GPU. -// TODO: test_c10d_nccl.py should consider adding tests for the error conditions -// here, ie, that deliberately pass invalid tensors and check the right -// exception is thrown. The "Expected list of tensors on the same device" -// condition may be a challenge because the test would need to pass tensors on -// different devices in the same process. -int64_t check_gpu_tensors_same_device(const std::vector& tensors) { - if (tensors.size() == 0) { - C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); - } + c10::intrusive_ptr allgather_coalesced( + std::vector>& outputTensorLists, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; - const auto& first = tensors.front(); + c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts = AllgatherOptions()) override; - int64_t total_numel = 0; - for (const auto& t : tensors) { - if (!t.is_cuda() || t.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); - } - if (t.scalar_type() != first.scalar_type()) { - C10_THROW_ERROR(TypeError, "Tensors must have identical type"); - } - if (!t.is_non_overlapping_and_dense()) { - C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); - } - // If we're in this function, the user called a _coalesced collective - // on a set of tensors with potentially different sizes and strides. - // Therefore, we don't check for matching sizes and strides, - // but we do double-check tensors are on the same device. - TORCH_CHECK_WITH( - ValueError, - t.get_device() == tensors[0].get_device(), - "Expected list of tensors on the same device"); - total_numel += t.numel(); - } + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - return total_numel; -} + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; -bool check_same_size(const std::vector& input_tensors) { - for (const auto& input_tensor : input_tensors) { - if (!input_tensors[0].is_same_size(input_tensor)) { - return false; - } - } - return true; -} - -} // namespace - -c10::intrusive_ptr ProcessGroupNCCL::initWork( - at::Device& device, - int rank, - OpType opType, - const char* profilingTitle, - const std::vector& inputs, - const std::vector& outputs, // TODO(kwen2501): necessary? - bool record) { - auto r = c10::make_intrusive( - pg_uid_, - pg_desc_, - device, - rank, - opType, - seqCollective_, - profilingTitle, - profilingTitle != nullptr ? std::optional>(inputs) - : std::nullopt, - desyncDebug_, - enableTiming_.load(), - cudaEventCacheEnabled_.load(), - dist_debug_level_); - if (record) { - bool isP2P = isP2POp(opType); - // Ideally record every work that we enqueue, rather than every work we - // create. - // - at the time of this PR we do not currently enqueue every created work - // - but it is unsafe to steal refs to start/end cuda events from Works that - // may go out of scope before flight recorder has retired them, - // so we must ensure that any work that is initialized via initWork will - // be enqueued - // - initially, moved record() into workEnqueue(), but found that makes it - // hard to get access to profilingTitle, - // inputs, and outputs for metadata recording, and we don't want to attach - // these objects to the Work becuase it has implications for keeping those - // tensors alive longer and adds overhead when copying Work objects - // between threads - r->trace_id_ = NCCLTraceBuffer::get()->record( - local_id_, - std::make_tuple(pg_uid_, pg_desc_), - seqCollective_, - seqP2P_, - op_id_, - profilingTitle ? profilingTitle : "", - inputs, - outputs, - r->ncclStartEvent_.get(), - r->ncclEndEvent_.get(), - options_->timeout, - pgStatus_, - isP2P); - } - return r; -} - -// TODO(kwen2501): deprecate -std::vector ProcessGroupNCCL::WorkNCCL::result() { - return *outputs_; -} - -c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: - getFuture() { - return future_; -} - -float ProcessGroupNCCL::WorkNCCL::getDuration() const { - TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled"); - TORCH_CHECK( - ncclStartEvent_, - "getDuration only works if ncclStartEvents_ is populated, true if timing enabled"); - TORCH_CHECK( - ncclEndEvent_, - "getDuration only works if ncclEndEvents_ is populated, which should always be true"); - return ncclStartEvent_->elapsed_time(*ncclEndEvent_); -} - -uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const { - return seq_; -} - -void ProcessGroupNCCL::assignTimeoutToWork( - const c10::intrusive_ptr& work, - const c10::intrusive_ptr& option) { - std::chrono::milliseconds timeout = option->timeout; - std::lock_guard timeoutLock(mtxTimeoutExtension_); - if (ephemeralTimeoutActive_.count() > 0) { - timeout += ephemeralTimeoutActive_; - } - work->opTimeout_ = timeout; - work->ownedEphermeralTimeout_ = - ephemeralTimeoutActive_ - ephemeralTimeoutInflight_; - ephemeralTimeoutInflight_ = ephemeralTimeoutActive_; -} - -void ProcessGroupNCCL::workEnqueue( - c10::intrusive_ptr work) { - if (!terminateProcessGroup_.load()) { - std::lock_guard lock(workMetaListMutex_); - // Avoid view tensors to be processed in cleanup thread. - // View tensors' destruction invokes autograd_meta, which - // needs to be destructed in user thread. Otherwise will - // get deadlock. Here we enqueue work without outputs_. - workMetaList_.emplace_back(*work); - // update the PG status related to the last enqueued work - pgStatus_->lastEnqueuedSeq = work->seq_; - pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_); - pgStatus_->lastEnqueuedNumelIn = work->numelIn_; - pgStatus_->lastEnqueuedNumelOut = work->numelOut_; - lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); - } -} - -ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) - : Backend::Options(NCCL_BACKEND_NAME, kProcessGroupNCCLDefaultTimeout), - is_high_priority_stream(is_high_priority_stream) {} - -static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; - -void ProcessGroupNCCL::startCoalescing() { - // Other collective ops bump seq_ before creating a work. Thus, if coalesced - // ops bump seq_ only after initing a work they will collide with (reuse) the - // seq_ of the last non-coalesced collective. Previously, seq_ was bumped - // inside endCoalescing, but before initWork. Since we now record individual - // ops from a coalesce group into the flight recorder, we want to have the - // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during - // start, which has one minor downside- we burn a seq_ if someone ever does a - // 'start' and 'end' coalescing region without doing an operation inbetween. - - // Don't bump op_id_ here, because startCoalescing isn't a logical operation. - // Bump it for each logical op inside the coalescing group. - if (coalescing_state_ & CoalP2P) { - seqP2P_++; - } else { - seqCollective_++; - } + c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override; + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override; + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override; + + void groupStart(); + + void groupEnd(); + + void groupEndNonblocking(std::shared_ptr comm); + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override; + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override; + + // Unsupported Ops + c10::intrusive_ptr recvAnysource( + std::vector& tensors, + int tag) override; + + // Agrees on an initial sequence number for the whole group by having rank 0 + // create it and broadcast it to other ranks using the store. + void setSequenceNumberForGroup() override; + + // Retrieves the current sequence number for the whole group, which should be + // in sync. If the returned number is not consistent across the group, it + // may indicate that there is some sort of collective desynchronization. + uint64_t getSequenceNumberForGroup() override; + + // Return the total number of splits the communicators held by this process + // group have performed. Counts ncclCommCreateFromRanks() for ncclx v2.21.5+ + uint64_t getCommSplitCounter() const; + + void registerOnCompletionHook( + std::function)>&& hook) override; + void waitForPendingWorks() override; + + void enableCollectivesTiming() override; + + // Helper function for iteratively aborting communicators in the provided map + void abortCommsFromMap( + std::unordered_map>& ncclCommsMap, + std::optional abortReason); + + c10::intrusive_ptr initIntraNodeComm(); + + // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) + // instead of relying on ProcessGroupNCCL destructor. + // return true if abort is successful, otherwise false + bool abort(std::optional abortReason = std::nullopt); + + void shutdown(std::optional reason = std::nullopt); + + void eagerConnectSingleDevice(at::Device device) override; + + void performNocolorSplit(at::Device device); + + // This method adds a temporary extension for the timeout period, + // applying to all collectives between the calling of this API and + // the completion of the first collective on the GPU. While this feature + // provides flexibility in specific scenarios, it introduces statefulness + // to timeout setting. Therefore, it is advisable to use this API sparingly + // and consider alternative approaches, such as directly setting the timeout + // or utilizing a barrier collective (one can set any timeout to the barrier), + // whenever feasible. + void addEphemeralTimeout(const std::chrono::milliseconds& timeout); + + // This function is only intended for testing purposes because we don't + // want to expose the `WorkNCCL` via pybind. It verifies whether the + // `opTimeout_` of the provided WorkNCCL instance is the same as the specified + // timeout. + bool verifyWorkTimeoutForTest( + const c10::intrusive_ptr work, + const std::chrono::milliseconds& timeout); + + protected: + // Helper that broadcasts nccl unique ID to all ranks through the store + void broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + bool isSingleP2POp, + const std::string& devicesKey, + int p2pRank); + + // Helper that either looks up the cached NCCL communicators or creates + // a new set of NCCL communicators as a cache entry + std::shared_ptr getNCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank = 0, + bool isSendRecvSelf = false); + + // Wrapper method which can be overridden for tests. + virtual std::exception_ptr checkForNCCLErrors( + std::shared_ptr& ncclComm); + + // Ensure thaht if record is True, the work obj will be enqueued via + // workEnqueue + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle = nullptr, + const std::vector& inputs = {}, + const std::vector& outputs = {}, + bool record = false); + + // In the timeout case and we will dump debug info such as the NCCL flight + // recorder to storage. Down the road, if we have more complicated or blocking + // operations, we might need to use a side thread to do it. + bool dumpDebuggingInfo(); + + private: + int globalRankStart; + int globalRankStride; + + // Helper that encapsulates work shared across all collective communication + // primitives. The callbacks have the following signatures: + // + // ncclResult_t fn(at::Tensor& input, at::Tensor& output, + // ncclComm_t, at::cuda::CUDAStream&); + // void {pre,post}(std::vector); + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, + bool nanCheck = true); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, + bool nanCheck = true); + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, + bool nanCheck = true); + + template + c10::intrusive_ptr collectiveCoalesced( + std::vector& input, + std::vector& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false); + + // Helper that encapsulates work shared across point-to-point communication + // primitives. It is the same structure as the helper used for collective + // communication primitives. + template + c10::intrusive_ptr pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle = nullptr); + + template + c10::intrusive_ptr pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + PreProcess pre, + PostProcess post, + const char* profilingTitle); + + c10::intrusive_ptr allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts = AllreduceOptions()); + + // Checks for NCCL errors on each of the communicators and returns an + // appropriate exception_ptr (nullptr if no errors). + static std::exception_ptr checkForNCCLErrorsInternal( + std::shared_ptr& ncclComm); + + // Function that runs as part of a separate thread and checks for errors on + // NCCL communicators. We need a separate thread to check for NCCL errors + // since we can't rely on the user calling certain methods like wait(), + // isCompleted() etc. to detect and remediate errors. In addition to this, we + // need a mechanism to safely abort and remove NCCL communicators from our + // cache. This can be done cleanly by having a thread for the ProcessGroupNCCL + // class. Attempting to modify the communicator cache from the WorkNCCL class + // might run into issues with object lifetime since the ProcessGroupNCCL + // object might get destroyed before the WorkNCCL object. + void ncclCommWatchdog(); + + // Return the CUDA device most likely associated with this backend. + // If we aren't bound to a specific device, there is no strict + // guarantee that this heuristic is the correct assignment of ranks + // to GPUs that Python layers use, but in practice it tends to be. + // Fortunately we don't rely on this for correctness of any tensor + // operations, just for ancillary uses like barriers. + at::Device guessDeviceForRank() const; + + // Destroys initialized NCCL communicators in devNCCLComMap_ given by input + // key. Throws if there are no communicators to destroy. Also removes + // communicators from the cache and clears used device indices. + void destroyNCCLComms(const std::string& devNCCLCommMapKey); + + // Watchdog's inside loop. + // Takes care of cleaning up completed work, and aborting upon failure or + // timeout. + void watchdogHandler(); + + void runHookLoop(); + + // Desync debug helper + void logWorkStart(WorkNCCL& work); + + // Desync debug helper + void logWorkEnd(WorkNCCL& work); + + // Generates a prefix that is unique to this process group and rank, for + // disambiguating logs + std::string createLogPrefix() const; + + // Returns the unique prefix created in createLogPrefix + const std::string& logPrefix() const; + + // Returns the global rank of the device. This function assumes that users + // always create a default global process group(PG) which includes all + // devices. It is called in the constructor of ProcessGroupNCCL, so it always + // return the rank_ of the the very first PG created, aka, default global PG. + const int& globalRank() const; + + // Returns the global ranks of a PG. + const std::vector& groupRanks() const; + + // Util function to assign timeout to each work. + void assignTimeoutToWork( + const c10::intrusive_ptr& work, + const c10::intrusive_ptr& option); + + protected: + // Function that runs as part of a separate thread aside from watchdog + // thread because we need to check the heartbeat from watchdog thread + // so that when we get stuck in some NCCL/CUDA calls, + // we can dump the debugging information and abort the process. + virtual void heartbeatMonitor(); + + // Function that directly trigger std::abort so that the whole process + // gets terminated. + virtual void terminateProcess(std::string errMsg); + + // A helper function to wait for a future to complete or timeout. + void waitForFutureOrTimeout( + std::future& fut, + const std::chrono::milliseconds& timeOutMilSec, + const std::string& futDescription, + bool throwException = false, + bool log = false); + + // When watchdog timeout, this function will be called and return debug info + // for users. For now we only get information from retrieveDesyncReport. + // We are working on enabling more useful debug information for watchdog + // timeout. + virtual std::string getNCCLWatchdogDebugInfo(); + + std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg); + + std::string getNCCLWatchdogTimeoutExitMsg(const std::string& exitReason); + + static const int64_t kWatchdogThreadSleepMillis; + + // The store is used to broadcast the NCCL unique ID of rank 0. This store + // comes with prefix and it is different across ProcessGroup NCCL instances + // (aka, different ProcessGroups). + c10::intrusive_ptr store_; + + // Reference to the store without prefix so that keys are same across all + // ProcessGroup NCCL instances and (key, value) pairs written to the store are + // global. + c10::intrusive_ptr globalStore_; + + bool storeError_{false}; + + // The lock which protects the write/read of + // ephemeralTimeoutActive_/ephemeralTimeoutInflight_. + // TODO(fduwjj): We need to have an audit on all mutexes we are adding here. + // And consolidate them if possible. + std::mutex mtxTimeoutExtension_; + + // The ephemeral timeout added on top of existing timeout for works issued + // before first work finishes. + std::chrono::milliseconds ephemeralTimeoutActive_ = + std::chrono::milliseconds(0); + + // The ephemeral timeout addition which has been already applied to work. + std::chrono::milliseconds ephemeralTimeoutInflight_ = + std::chrono::milliseconds(0); + + const c10::intrusive_ptr options_; + + // The number of NCCL communicators that have been created during + // the lifetime of this process group. This sequence number is + // used to scope keys used in the store. + uint64_t ncclCommCounter_{0}; + + // The store keys to trace the last NCCL collective kernel CUDA events - start + // event and end event respectively. These are used to do desync root cause + // analysis. + const std::string traceKeyStart_; + const std::string traceKeyEnd_; - coalescedDevice_.set_index(-1); - coalescedComm_ = nullptr; - coalescing_state_ |= CoalActive; - groupStart(); -} - -// `optype` is for specifying a composite optype, such as ALLGATHER and -// REDUCE_SCATTER -c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { - if (coalescedComm_ == nullptr) { - // There is no actual work being coalesced, return here - groupEnd(); - coalescing_state_ = 0; - return nullptr; - } - TORCH_CHECK( - coalescedDevice_.index() >= 0, - "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); - - // `coalescedComm_` should have same set of comms across collectives - auto comm = coalescedComm_; - // `coalescedDevice_` should have same set of devices across collectives - auto device = coalescedDevice_; - - // `getKeyFromDevice` is how we get keys for both collectives and batch P2P - const auto key = getKeyFromDevice(device); - auto ncclStream = ncclStreams_.at(key); - - // Create Work object - c10::cuda::CaptureStatus capture_status = - c10::cuda::currentStreamCaptureStatusMayInitCtx(); - bool enqueue = - (coalescing_state_) && capture_status == c10::cuda::CaptureStatus::None; - auto work = - initWork(device, rank_, optype, "nccl:coalesced", {}, {}, enqueue); - work->ncclComm_ = comm; - work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams_; - work->store_ = store_; - assignTimeoutToWork(work, options_); - - // Record start before ncclGroupEnd - if (work->timingEnabled_) { - work->ncclStartEvent_->record(ncclStream); - } + // The NCCL communicator that the process group has cached. + // + // For collective operations: + // The key is a list of GPU devices that an operation is operating on + // The GPU devices are stored in a device sequence and the cache NCCL + // communicator is associated with this GPU device sequence + // + // e.g. If the process group op only uses device 0, then the value of + // the used device string stored (value of the hashmap) would be "0". + // + // If the process group op uses device 0 - 7 and the each tensor of the + // input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately, + // then the value of the used device string (key) stored would be + // "0,1,2,3,4,5,6,7" + // + // If the process group op uses device 0 - 7 and the each tensor of the + // input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately, + // then the value of the used device string stored would be + // "0,4,5,6,7,1,2,3" + // + // Note that the order of the device for the tensor list matters. + // + // For point-to-point operations: + // The key is a string of my current rank and the peer process rank. + // e.g. If process 1 and process 2 are involved in a point-to-point + // communication, the key will be "1:2" on both processes. Note: this is for + // the scenario where there is only 1 GPU per process. When it comes to + // multiple GPUs per process, this part may need to redesigned. + // TODO: we probably need a separte map for P2P comms + std::unordered_map> devNCCLCommMap_; - if (nccl_use_nonblocking()) { - groupEndNonblocking(comm); - } else { - groupEnd(); - } + // The NCCL communicators currently in process of being initialized. + std::unordered_map> + inInitializationCommMap_; - // Record end after ncclGroupEnd - // TODO(eqy): is this still necessary if avoidRecordStreams_ is set? - work->ncclEndEvent_->record(ncclStream); + // Mutex to guard maps like devNCCLCommMap_. + std::mutex mutex_; - if (avoidRecordStreams_) { - // other functions expect an initialized ptr if avoidRecordStreams_ is set - work->stashed_for_allocator_safety_ = - std::make_shared>(); - } + // Heartbeat of watchdog thread. + std::atomic_uint64_t heartbeat_; - // Notify graphs before we check the capture status preemptively - at::cuda::CUDAGraph::inc_pending_event_queries(); + // The time interval used for deciding whether there is no watchdog heartbeat. + int heartbeatTimeoutInSec_; - if (enqueue) { - workEnqueue(work); - } else { - at::cuda::CUDAGraph::dec_pending_event_queries(); - } + // timeout for the dump to finish. + int waitTimeoutDumpInMilSec_; - coalescing_state_ = 0; - coalescedComm_ = nullptr; - return work; -} - -c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { - // Default OpType to COALESCED if not specified - return endCoalescing(OpType::COALESCED); -} - -template -c10::intrusive_ptr ProcessGroupNCCL::collective( - std::vector& inputs, - std::vector& outputs, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType, - const char* profilingTitle, - bool avoidRecordStreams, - bool nanCheck) { - // Environment setting by the user may add onto collective call's option - avoidRecordStreams |= avoidRecordStreams_; - nanCheck &= enableNanCheck_; - - c10::cuda::CaptureStatus capture_status = - c10::cuda::currentStreamCaptureStatusMayInitCtx(); - errorIfCapturingNonCapturableNCCL(capture_status); - - // Bump collective counter - seqCollective_++; - op_id_++; - - auto device = getDevice(inputs[0]); - const auto key = getKeyFromDevice(device); - auto ncclComm = getNCCLComm(key, device, opType); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalColl; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = ncclComm; - } else { - TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); - } - } + // Interval of check coordinated signals in ProcessGroupNCCL from other ranks + // e.g., trigger the dump of the debugging info for timeout when notified. + int coordCheckIntervalMilSec_; - // Used many times below, so we stash the unordered_map lookup - auto ncclStream = ncclStreams_.at(key); + // Size of ring buffer where we store NCCL Traces for debugging. + int ncclTraceBufferSize_; - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); + // We gate the heartbeat monitor thread so that we can roll it out gradually. + std::atomic monitorThreadEnabled_; - bool enqueue = - !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; - auto work = - initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue); + // We gate the cudaEventCache so that we can roll it out gradually. + std::atomic cudaEventCacheEnabled_; - // Store references to outputs to be used by WorkNCCL::result and operator<<. - work->outputs_ = std::make_shared>(outputs); + // Monitor thread which checks the heartbeat of Watchdog thread. + // If the monitor thread finds there is no heartbeat, it will dump debug info + // and then kill the watchdog thread to avoid hang. + std::thread ncclHeartbeatMonitorThread_; - if (avoidRecordStreams) { - work->stashed_for_allocator_safety_ = - std::make_shared>(inputs); - } + // Watchdog thread which looks for errors on the cached NCCL communicators. + std::thread ncclCommWatchdogThread_; - at::cuda::OptionalCUDAGuard gpuGuard(device); + std::thread onCompletionHookThread_; - if (nanCheck) { - for (const auto& input : inputs) { - checkForNan(input, ncclStream); - } - } + // Whether or not we should terminate the watchdog and workCleanup threads. + std::atomic terminateProcessGroup_; - // Start event should only be recorded before the ncclGroupStart() - if (work->timingEnabled_) { - work->ncclStartEvent_->record(ncclStream); - } + // Whether or not we should terminate the heartbeat monitoring threads. + std::atomic terminateHeartbeatMonitorThread_; - pre(ncclStream, work); + // Whether we are in the shutdown mode when we are trying to get debug info, + // such as desync report. + std::atomic collectiveDebugInfoMode_; - ncclComm_t comm = ncclComm->getNcclComm(); + // Whether there are hooks pending to be fired + std::atomic hasPendingHooks_; - // Both `inputs' and `outputs' are created on a worker stream and used in - // different ncclStreams. Hence, both must record the ncclStream to - // prevent being freed before the collective finishes. + // This is the signal from watchdog threads to indicate whether the monitor + // thread should dump. Making it static so that it is accessiable from all the + // PGs. With this flag, monitor thread would dump debug info under any one of + // the three conditions: // - // We only record `inputs' here, and leave recording `outputs' to `fn' for - // operations where `inputs' and `outputs' are not the same. + // 1: watchdog thread of any PG detects a collective timeout. + // 2: timeout signal is received from other ranks through tcpstore. + // 3: current PG's watchdog heartbeat timeout occurs. // - // See [Sync Streams]. - if (!avoidRecordStreams) { - for (const auto& input : inputs) { - if (!input.is_sparse()) { - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), ncclStream); - } else { - // for sparse input case record streams on both index and value - // tensors - c10::cuda::CUDACachingAllocator::recordStream( - input.values().storage().data_ptr(), ncclStream); - c10::cuda::CUDACachingAllocator::recordStream( - input.indices().storage().data_ptr(), ncclStream); - } - } - } + // Note that only the monitor thread from PG0 will dump the debug info for + // case one and two so that the debug info is only dumped once. + static std::atomic shouldDump_; -// Not all collectives have the same signature, e.g, all-reduce take in a Tensor -// as the input and output while all-to-all take in a vector of Tensors as input -// and output. Because we define the signature of the fn to take only single -// tensor as input and output, we need to do a hack to get the first element in -// the vector and pass it to fn. -// TODO: we should clean up this in future (by either entirely removing lambda's -// or removing input and output from lambda's signature). -#ifndef NCCL_HAS_COMM_NONBLOCKING - C10D_NCCL_CHECK( - fn(inputs[0], outputs[0], comm, ncclStream), - ncclComm->getNcclCommFailureReason()); -#else - C10D_NCCL_CHECK_TIMEOUT( - fn(inputs[0], outputs[0], comm, ncclStream), - comm, - ncclComm->getNcclCommFailureReason()); -#endif + // Mutex to Guard workMetaList_ + std::mutex workMetaListMutex_; - post(ncclStream, work); + // Mutex to Guard monitorWakeUpCV_ + std::mutex monitorMutex_; - // End event should only be recorded after the ncclGroupEnd() - if (!coalescing_state_) { - work->ncclEndEvent_->record(ncclStream); - } - work->ncclComm_ = ncclComm; - - { - c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - - // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA - // future blocks the stream this callback runs on the corresponding - // ncclEndEvents_ ensuring appropriate synchronization. - if (work->recordFunctionEndCallback_) { - work->future_->addCallback( - [work](at::ivalue::Future& /* unused */) { - work->recordFunctionEndCallback_(); - }, - // uses_future = false allows us to skip synchronization in - // ivalue::Future, but is only valid as long as the lambda doesn't use - // the "Future" argument. - /*uses_future=*/false); - } - work->future_->markCompleted(at::IValue(*work->outputs_)); - } + bool writeDebugInfo_ = false; - // Set appropriate work parameters. - work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams; - work->store_ = store_; - assignTimeoutToWork(work, options_); - // Record size info for debug. We only record the size on the first device as - // multi-device per process is deprecated - work->numelIn_ = 0; - work->numelOut_ = 0; - for (const auto& input : inputs) { - work->numelIn_ += input.numel(); - } - for (const auto& output : outputs) { - work->numelOut_ += output.numel(); - } + // Condition Variable for watchdog thread sleep + std::condition_variable workMetaListCV_; - // Notify graphs before we check the capture status preemptively - at::cuda::CUDAGraph::inc_pending_event_queries(); - if (enqueue) { - workEnqueue(work); - } else { - at::cuda::CUDAGraph::dec_pending_event_queries(); - } + // Condition Variable for monitor thread to wake up early + std::condition_variable monitorWakeUpCV_; - return work; -} - -template -c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( - std::vector& inputs, - std::vector& outputs, - Fn fn, - OpType opType, - const char* profilingTitle, - bool avoidRecordStreams) { - // Environment setting by the user may add onto collective call's option - avoidRecordStreams |= avoidRecordStreams_; - c10::cuda::CaptureStatus capture_status = - c10::cuda::currentStreamCaptureStatusMayInitCtx(); - errorIfCapturingNonCapturableNCCL(capture_status); - - // Bump collective counter - seqCollective_++; - - // For coalescingManager collectives, there is no individual c++ call per - // collective so there is no flight record and we increment seq*_ and op_id_ - // together. Compare this to startCoalesing/endCoalescing flow where we - // increment seq_ once per group and increment op_id_ once per indvidual - // operation within the group - op_id_++; - - // Currently, the API permits one scenario where inputs.size() and - // outputs.size() are > 0. - // 1. If the call was a _coalesced call, all inputs must be on the same - // device. - // The group of nccl calls applies the collective separately to each input, - // but the group as a whole should be efficient, and might even execute as - // a single fused kernel. - auto device = getDevice(inputs[0]); - const auto key = getKeyFromDevice(device); - auto ncclComm = getNCCLComm(key, device, opType); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalColl; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = ncclComm; - } else { - TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); - } - } + // Vector to Store WorkNCCL pointers + std::list workMetaList_; - // Used many times below, so we stash the unordered_map lookup - auto ncclStream = ncclStreams_.at(key); + std::chrono::time_point lastWorkListUpdateTime_; - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); + // Mutex to Guard workMetaList_ + std::mutex completedWorkListMutex_; - auto work = initWork( - device, rank_, opType, profilingTitle, inputs, outputs, /*record=*/true); + // Condition Variable for watchdog thread sleep + std::condition_variable completedWorkListCV_; - // Store references to outputs to be used by WorkNCCL::result and operator<<. - work->outputs_ = std::make_shared>(outputs); + std::list completedWorkList_; - if (avoidRecordStreams) { - work->stashed_for_allocator_safety_ = - std::make_shared>(inputs); - } + // Add Work Pointer to workVector + void workEnqueue(c10::intrusive_ptr); - at::cuda::OptionalCUDAGuard gpuGuard(device); + // The CUDA streams used by NCCL kernels + std::unordered_map ncclStreams_; - // Start event should only be recorded before the ncclGroupStart() (which - // happens inside AutoNcclGroup guard below) - if (work->timingEnabled_) { - work->ncclStartEvent_->record(ncclStream); - } + // The CUDA events used to sync NCCL streams + std::unordered_map ncclEvents_; - ncclComm_t comm = ncclComm->getNcclComm(); - -// TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL -// upgrade. Once a NCCL version is qualified, this code should not be needed at -// runtime. -#ifdef PGNCCL_ENABLE_HASH - if (enableCollecticeHashDebug_.load()) { - auto numel = getTensorsNumel(inputs); - auto hashValue = hashTensors(inputs); - PRINT_COLLECTIVE_HASH_SIGNATURE( - "input", opTypeToString(opType), numel, hashValue); - } -#endif + // Device Indexes used for all collectives in this group + std::set usedDeviceIdxs_; - { - torch::cuda::nccl::AutoNcclGroup nccl_group_guard( - comm, nccl_use_nonblocking()); - for (const auto i : c10::irange(inputs.size())) { - // Both `inputs' and `outputs' are created on a worker stream and used in - // different ncclStreams. Hence, both must record the ncclStream to - // prevent being freed before the collective finishes. - // - // We only record `inputs' here, and leave recording `outputs' to `fn' for - // operations where `inputs' and `outputs' are not the same. - // - // See [Sync Streams]. - if (!avoidRecordStreams) { - if (!inputs[i].is_sparse()) { - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].storage().data_ptr(), ncclStream); - } else { - // for sparse input case record streams on both index and value - // tensors - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].values().storage().data_ptr(), ncclStream); - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].indices().storage().data_ptr(), ncclStream); - } - } -#ifndef NCCL_HAS_COMM_NONBLOCKING - C10D_NCCL_CHECK( - fn(inputs[i], outputs[i], comm, ncclStream), - ncclComm->getNcclCommFailureReason()); -#else - C10D_NCCL_CHECK_TIMEOUT( - fn(inputs[i], outputs[i], comm, ncclStream), - comm, - ncclComm->getNcclCommFailureReason()); -#endif - } - } + // Flag to denote if a coalescing groupStart/groupEnd block is active + int coalescing_state_ = 0; - work->ncclEndEvent_->record(ncclStream); - work->ncclComm_ = ncclComm; - - { - c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - - // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA - // future blocks the stream this callback runs on the corresponding - // ncclEndEvents_ ensuring appropriate synchronization. - if (work->recordFunctionEndCallback_) { - work->future_->addCallback( - [work](at::ivalue::Future& /* unused */) { - work->recordFunctionEndCallback_(); - }, - // uses_future = false allows us to skip synchronization in - // ivalue::Future, but is only valid as long as the lambda doesn't use - // the "Future" argument. - /*uses_future=*/false); - } - work->future_->markCompleted(at::IValue(*work->outputs_)); - } + // Stores device indexes for all collectives run inside a coalescing block + at::Device coalescedDevice_ = at::Device("cuda"); - // Set appropriate work parameters. - work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams; - work->store_ = store_; - assignTimeoutToWork(work, options_); - // Record size info for debug. We only record the size on the first device as - // multi-device per process is deprecated - work->numelIn_ = inputs[0].numel(); - work->numelOut_ = outputs[0].numel(); - - /* Note [cuda graph capture and workEnqueue] - - Normal behavior of the C10D watchdog is to query cuda events on work objects - periodically, but when cuda graph recording is active these event queries - would crash or mess up the recording. - - To ensure we do not enqueue a work object to the watchdog when cuda graph - capture is active, we use a one-way sync. We increment a flag pre-emptively, - indicating our intent to enqueue a work object. Then we check capture_status - to see if (a) capturing is already in progress (we cannot enqueue in this - case), (b) capturing hasn't started yet, so we can trust that no capture will - start (since a pre-condition of starting a capture is to check the event query - count is 0). - - If we are not able to enqueue the work due to capture-in-progress, we finally - decrement the counter. - - For this reason we cannot easily move the increment inside workEnqueue unless - we also change the semantic of workEnqueue to 'maybeWorkEnqueue'. - - TODO: - - Is our design for flight recorder safe in this context? are we recording - any FR events during cudagraph capture? if so, they won't be safe to poll for - completion status. - */ - at::cuda::CUDAGraph::inc_pending_event_queries(); - if (capture_status == c10::cuda::CaptureStatus::None) { - workEnqueue(work); - } else { - at::cuda::CUDAGraph::dec_pending_event_queries(); - } - // TODO(whc) if the work isn't enqueued, I don't feel great about returning - // it, since interactions with it by usercode won't behave normally - they - // won't observe work completion, for instance. Will this lead to silent - // problems during capture? - return work; -} - -template -c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType, - PreProcess pre, - PostProcess post, - const char* profilingTitle) { - // avoidRecordStreams_ note: - // send, recv, and irecv should be ok with avoidRecordStreams, - // However, for isend, I don't think the API requires the user - // to wait() on the returned handle, so ProcessGroupNCCL can't know - // when it's safe to release the input back to the allocator, - // and the present call has no way to know it's not an isend. - // Therefore, we warn and fall back to the typical recordStream logic: - if (avoidRecordStreams_) { - TORCH_WARN_ONCE( - "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " - "collectives."); - } + // Stores communicators for all collectives run inside a coalescing block + std::shared_ptr coalescedComm_ = nullptr; - auto device = getDevice(tensor); - std::string key; - int p2pRank = 0, p2pTargetRank = 0; - bool isSendRecvSelf = false; - // For batch_isend_irecv, ncclGroupStart() would be called upfront - bool batchP2P = ncclActiveGroupCounter_ > 0; - if (batchP2P) { - // For batch P2P, we need to treat it like a collective when selecting - // communicator, because other ranks can call into this batch other than my - // rank and my peer - key = getKeyFromDevice(device); - p2pRank = rank_; - p2pTargetRank = peer; - } else { - // For single P2P, preserve the old two-rank behavior (to avoid perf diff) - key = getKeySendRecv(rank_, peer); - p2pRank = rank_ <= peer ? 0 : 1; - isSendRecvSelf = rank_ == peer; - p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; - - if (!coalescing_state_) { - // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be - // bumped in `startCoalescing`. - seqP2P_++; - } - } + // map from the key: "group name + pg counter (ID)" to the + // unique NCCL ID count. This needs to be group and pg specific + // + // For each process group, we need a uniform unique NCCL ID counter to ensure + // that NCCL operation in this process group can be completed successfully. + // Since each process group ID belongs to a group name, the key to this map + // is a combination of group name and ProcessGroupNCCL ID. + static std::unordered_map pgUniqueNCCLIDCnt_; - // Bump the logical operation counter regardless of whether this op is - // coalesced or individual - op_id_++; + // map from group name to the pg counter (ID) within that group + // + // For each group with the "group name" (which is the key), we need to + // keep track of a unique process group ID when creating a new + // ProcessGroupNCCL for this "group name". Therefore, the value of this + // map keeps the unique ProcessGroupNCCL's ID for a specific group with + // the "group name". The reason we need a per-group process group ID counter + // is that different group can have different ranks and we need ensure that + // each group has its own uniform process group ID for all its ranks. + static std::unordered_map processGroupCounterMap_; - auto ncclComm = getNCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + // Whether or not wait() and synchronize() are blocking operations that wait + // for the operation to complete. + bool blockingWait_ = false; - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalP2P; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = ncclComm; - } else { - TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); - } - } + // Whether or not to hook the cache allocator to register all allocated + // tensors + bool useTensorRegisterAllocatorHook_ = false; - // Used many times below, so we stash the unordered_map lookup - auto ncclStream = ncclStreams_.at(key); - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); - - // Work itself will create the CUDA events on all GPUs of tensors - c10::intrusive_ptr work; - if (coalescing_state_) { - // When coalescing, we record events per op that lack timing/state - // information becuase there is no 'work' associated with them, and then - // later in endCoalescing we record a 'coalesced' Work which has - // timing/state updates via watchdog thread, but lacks op metadata such as - // input/output sizes and profilingTitle per-op in the group. - auto trace_id = NCCLTraceBuffer::get()->record( - local_id_, - std::make_tuple(pg_uid_, pg_desc_), - seqCollective_, - seqP2P_, - op_id_, - profilingTitle, - {tensor}, - {tensor}, - nullptr, - nullptr, - options_->timeout, - pgStatus_, - /*isP2P=*/true); - // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get - // their timings/states updated by proxy when the Work obj representing the - // coalesce group gets its update, we could accumulate these trace_ids - // together and ask FlightRecorder to take the update from one Work and - // apply it to multiple entries - (void)trace_id; - } else { - // Store references to outputs to be used by WorkNCCL::result and - // operator<<. Note that these outputs are only valid for recv(), as send() - // does not modify the inputs but we still create these outputs for use - // cases such as profiling. - - work = initWork( - device, rank_, opType, profilingTitle, {tensor}, {}, /*record=*/false); - // This bypasses something in Work() that crashes if {tensor} is given as - // output, not sure what - work->outputs_ = std::make_shared>(); - work->outputs_->push_back(tensor); - // TODO(whc) because we don't pass output {tensor} to initWork, we tell - // initWork to not record, and then we manually call record passing all the - // information it wants. - work->trace_id_ = NCCLTraceBuffer::get()->record( - local_id_, - std::make_tuple(pg_uid_, pg_desc_), - seqCollective_, - seqP2P_, - op_id_, - profilingTitle, - {tensor}, - {tensor}, - work->ncclStartEvent_.get(), - work->ncclEndEvent_.get(), - options_->timeout, - pgStatus_, - /*isP2P=*/true); - } + // Whether or not the workCleanupThread is used to perform async error + // handling. + ErrorHandlingMode asyncErrorHandling_ = NoHandling; - // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard(device); + // Whether or not to enable timeout root cause analysis. + bool desyncDebug_; - // Only check for NaN for send ops, for recv ops `tensor` can be a random - // placeholder - if (enableNanCheck_ && opType == OpType::SEND) { - checkForNan(tensor, ncclStream); - } + // Whether or not to dump debug info on exception including both watchdog + // timeout and nccl errors. + bool dumpOnTimeoutOrEx_; - if (!coalescing_state_) { - // Start event should only be recorded before the ncclGroupStart() - if (work->timingEnabled_) { - work->ncclStartEvent_->record(ncclStream); - } + // Whether or not to sleep after an exception is thrown in the watchdog. + bool sleepAfterException_; - pre(ncclStream, work); - } + // Whether or not to enable nan check for input tensors to collectives. + bool enableNanCheck_; - // Both send tensor and recv tensor are created on a worker stream and used - // in different ncclStreams. Hence, both must record the ncclStream to - // prevent being freed before the collective finishes. - // - // See [Sync Streams]. - c10::cuda::CUDACachingAllocator::recordStream( - tensor.storage().data_ptr(), ncclStream); + // Whether or not to print C++ stack traces to logs on unclean shutdown. + bool logCppStackOnUncleanShutdown_; - // This part seems common to both p2p and coalesced-p2p usage? - ncclComm_t comm_ = ncclComm->getNcclComm(); + // Whether or not to create start CUDAEvent and enable timing for start + // and end events. Note that enableTiming_ is always true if desyncDebug_ + // is set to true. + std::atomic enableTiming_; -#ifndef NCCL_HAS_COMM_NONBLOCKING - C10D_NCCL_CHECK( - fn(tensor, comm_, ncclStream, p2pTargetRank), - ncclComm->getNcclCommFailureReason()); -#else - C10D_NCCL_CHECK_TIMEOUT( - fn(tensor, comm_, ncclStream, p2pTargetRank), - ncclComm->getNcclComm(), - ncclComm->getNcclCommFailureReason()); -#endif + // Flag to enable the print of hash value of input/output of collectives for + // verification. + std::atomic enableCollecticeHashDebug_; - if (!coalescing_state_) { - post(ncclStream); - - // End event should only be recorded after the ncclGroupEnd() - work->ncclEndEvent_->record(ncclStream); - work->ncclComm_ = ncclComm; - work->blockingWait_ = blockingWait_; - work->store_ = store_; - assignTimeoutToWork(work, options_); - // Record size info for debug. We only record the size on the first device - // as multi-device per process is deprecated - work->numelIn_ = work->numelOut_ = tensor.numel(); - - // Future only needs to be created and marked completed with outputs for - // recv(), but still create future for use cases such as profiling even for - // send(). - { - c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - } + // Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set + bool avoidRecordStreams_ = false; - // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA - // future blocks the stream this callback runs on the corresponding - // ncclEndEvents_ ensuring appropriate synchronization. - if (work->recordFunctionEndCallback_) { - work->future_->addCallback( - [work](at::ivalue::Future& /* unused */) { - work->recordFunctionEndCallback_(); - }, - // uses_future = false allows us to skip synchronization in - // ivalue::Future, but is only valid as long as the lambda doesn't use - // the "Future" argument. - /*uses_future=*/false); - } - } + // Whether the NCCL watchdog should rethrow CUDA errors. + bool rethrowCUDAErrors_ = false; - // Enqueue P2P op so that it can be cancelled by NCCL watchdog - c10::cuda::CaptureStatus capture_status = - c10::cuda::currentStreamCaptureStatusMayInitCtx(); + // Set of communicators that this process group has aborted and their + // ncclUniqueId has been written to the store. We don't need a lock + // for this map since only the watchdog thread accesses this set. The + // set contains the string representation of ncclUniqueId. + std::unordered_set abortedComms_; - // Notify graphs before we check the capture status preemptively - at::cuda::CUDAGraph::inc_pending_event_queries(); + // The number of active ncclGroupStart() calls. This counter will be increased + // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd() + // is called. + static thread_local uint64_t ncclActiveGroupCounter_; - if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) { - workEnqueue(work); - return work; - } else { - at::cuda::CUDAGraph::dec_pending_event_queries(); - return nullptr; - } -} - -template -c10::intrusive_ptr ProcessGroupNCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType, - const char* profilingTitle, - bool avoidRecordStreams, - bool nanCheck) { - auto inputs = std::vector{input}; - auto outputs = std::vector{output}; - return collective( - inputs, - outputs, - fn, - pre, - post, - opType, - profilingTitle, - avoidRecordStreams, - nanCheck); -} - -template -c10::intrusive_ptr ProcessGroupNCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - OpType opType, - const char* profilingTitle, - bool avoidRecordStreams, - bool nanCheck) { - auto inputs = std::vector{input}; - auto outputs = std::vector{output}; - return collective( - inputs, - outputs, - fn, - [](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - [](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - opType, - profilingTitle, - avoidRecordStreams, - nanCheck); -} - -template -c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType, - const char* profilingTitle) { - return pointToPoint( - tensor, - fn, - peer, - opType, - [](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - [](at::cuda::CUDAStream&) {}, - profilingTitle); -} - -c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( - std::vector& tensors, - const AllreduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for NCCL reductions"); -#ifdef IS_NCCLX - tensor = tensor.coalesce(); - at::Tensor outputTensor = - torch::zeros(tensor.sizes(), tensor.options().layout(torch::kStrided)); - auto work = collective( - tensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - auto ncclDataType = getNcclDataType(input.scalar_type()); - auto ncclReduceOp = - getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); - - size_t num_elements = output.numel(); - auto indices = input.indices(); - auto sizes = input.sizes(); - int colSize = sizes[1]; - auto rows = indices[0]; - size_t blockCount = rows.sizes()[0]; - auto recvIndices = indices[0] * colSize; - - // prevent output and recvIndices from being freed - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - c10::cuda::CUDACachingAllocator::recordStream( - recvIndices.storage().data_ptr(), stream); - auto result = ncclAllReduceSparseBlock( - input._values().data_ptr(), // sendbuff - recvIndices.data_ptr(), // recv_indices - blockCount, // block_count - colSize, // block_length - output.data_ptr(), // recvbuff - output.numel(), // recv_count - ncclDataType, - ncclReduceOp, - comm, - stream.stream()); - return result; - }, - [](at::cuda::CUDAStream& ncclStream, - c10::intrusive_ptr& work) {}, - [&](at::cuda::CUDAStream& ncclStream, - c10::intrusive_ptr& work) { - // Convert output tensors to sparse and back into tensors. - at::cuda::CUDAStreamGuard guard(ncclStream); - if (opts.sparseIndices.has_value()) { - tensor = at::sparse_coo_tensor( - opts.sparseIndices.value(), outputTensor, tensor.sizes()); - } else { - tensor = outputTensor.to_sparse(); - } - }, - OpType::_ALLREDUCE_SPARSE, - "nccl:all_reduce_sparse"); - return work; -#else - // If the nccl branch is not "exp" then we just error - C10_THROW_ERROR( - Error, - "NCCL does not support all_reduce with sparse tensors. Please use dense tensors instead."); -#endif -} - -c10::intrusive_ptr ProcessGroupNCCL::allreduce_impl( - at::Tensor& tensor, - const AllreduceOptions& opts) { - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - auto ncclDataType = getNcclDataType(input.scalar_type()); - auto ncclReduceOp = - getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); - return ncclAllReduce( - input.data_ptr(), - output.data_ptr(), - input.numel(), - ncclDataType, - ncclReduceOp, - comm, - stream.stream()); - }, - OpType::ALLREDUCE, - "nccl:all_reduce"); -} - -c10::intrusive_ptr ProcessGroupNCCL::allreduce( - std::vector& tensors, - const AllreduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "all_reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } - check_gpu_single_tensor(tensor); - - if (intraNodeComm_ != nullptr && opts.reduceOp == ReduceOp::SUM) { - using namespace intra_node_comm; - auto algo = intraNodeComm_->selectAllReduceAlgo(tensor); - if (algo != intra_node_comm::AllReduceAlgo::NONE) { - intraNodeComm_->allReduce(tensor, algo); - return c10::make_intrusive(); - } - } - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for NCCL reductions"); - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash tensors. - return allreduce_impl(tensor, opts); -} - -c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts) { - auto total_numel = check_gpu_tensors_same_device(tensors); - TORCH_CHECK( - !isFloat8Type(tensors.back().scalar_type()), - "Float8 dtypes are not currenlty supported for NCCL reductions"); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce_coalesced", // collective name - total_numel, // inNelems - total_numel, // outNelems - tensors[0].scalar_type(), // dType - // I'm not sure what in,outSplitSizes mean here. - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash tensors. - return collectiveCoalesced( - tensors, - tensors, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - auto ncclDataType = getNcclDataType(input.scalar_type()); - auto ncclReduceOp = - getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); - return ncclAllReduce( - input.data_ptr(), - output.data_ptr(), - input.numel(), - ncclDataType, - ncclReduceOp, - comm, - stream.stream()); - }, - OpType::COALESCED, - "nccl:allreduce_coalesced"); -} - -c10::intrusive_ptr ProcessGroupNCCL::broadcast( - std::vector& tensors, - const BroadcastOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - if (tensor.is_complex()) { - tensor = at::view_as_real(tensor); - } - check_gpu_single_tensor(tensor); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - opts.rootRank, // root rank - "broadcast", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash tensors. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - - const auto root = opts.rootRank + opts.rootTensor; - bool nanCheck = (root == rank_); - - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - return ncclBcast( - input.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - root, - comm, - stream.stream()); - }, - OpType::BROADCAST, - "nccl:broadcast", - avoidRecordStreams, - nanCheck); -} - -// _broadcast_oop adds an out-of-place broadcast in PGNCCL -// Custom collectives may be implemented by coalescing broadcast operations -// One use-case is implementing a vector all_gather (all_gather_v) -// where unevenly sized inputs are gathered among participating ranks -// Since all_gather provides an out-of-place API, an all_gather_v -// semantic implemented inside pg_nccl.all_gather also needs to support -// out-of-place, for which an out-of-place broadcast is required to be added -c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const BroadcastOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _broadcast_oop must have the same number of elements "); - } - const auto root = opts.rootRank + opts.rootTensor; - bool nanCheck = (root == rank_); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - return ncclBroadcast( - input.data_ptr(), - output.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - root, - comm, - stream.stream()); - }, - OpType::BROADCAST, - "nccl:_broadcast_oop", - /*avoidRecordStreams=*/false, - nanCheck); -} - -c10::intrusive_ptr ProcessGroupNCCL::reduce( - std::vector& tensors, - const ReduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } - check_gpu_single_tensor(tensor); - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - opts.rootRank, // root rank - "reduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash tensors. - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; - auto ncclDataType = getNcclDataType(input.scalar_type()); - auto ncclReduceOp = - getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); - return ncclReduce( - input.data_ptr(), - output.data_ptr(), - input.numel(), - ncclDataType, - ncclReduceOp, - root, - comm, - stream.stream()); - }, - OpType::REDUCE, - "nccl:reduce"); -} - -// _reduce_oop exposes an out-of-place reduce from PGNCCL -// Custom collectives may be implemented by coalescing reduce operations -// One use-case is implementing a vector reduce_scatter (reduce_scatter_v) -// where inputs are reduced and scattered unevenly among participating ranks -// Since reduce_scatter provides an out-of-place API, a reduce_scatter_v -// semantic implemented inside pg_nccl.reduce_scatter also needs to support -// out-of-place, for which an out-of-place reduce is required to be added -c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _reduce_oop must have the same number of elements "); - } - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; - const auto ncclDataType = getNcclDataType(input.scalar_type()); - const auto ncclReduceOp = - getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); - return ncclReduce( - input.data_ptr(), - output.data_ptr(), - input.numel(), - ncclDataType, - ncclReduceOp, - (int)root, - comm, - stream.stream()); - }, - OpType::REDUCE, - "nccl:_reduce_oop"); -} - -c10::intrusive_ptr ProcessGroupNCCL::allgather( - std::vector>& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts) { - TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); - check_gpu_single_tensor(inputTensor); - // @lint-ignore CLANGTIDY - auto outputTensors_ = outputTensors.back(); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_gather", // collective name - inputTensor.numel(), // inNelems - inputTensor.numel() * // outNelems - this->getSize(), - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - bool same_size = check_same_size(outputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor outputFlattened = newLikeFlat(outputTensors_); - - return collective( - inputTensor, - outputFlattened, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - return ncclAllGather( - input.data_ptr(), - output.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - comm, - stream.stream()); - }, - [](at::cuda::CUDAStream& ncclStream, - c10::intrusive_ptr& work) { - // avoidRecordStreams_ note: We actually don't need to stash anything - // here. - // - inputTensors is stashed onto work->stashed_for_allocator_safety_ - // in collective(). - // - outputFlattened is stashed onto work->outputs_ in collective(). - // - User-facing outputTensors should be held by the user until after - // waiting on work_, or the call makes no sense. - // So all participating tensors are accounted for, and won't be - // released back to their allocation streams until after work_ is - // waited on. - }, - [&](at::cuda::CUDAStream& ncclStream, - c10::intrusive_ptr& work) { - // Copy the flattened output tensors to the outputs. - at::cuda::CUDAStreamGuard guard(ncclStream); - for (const auto j : c10::irange(outputTensors_.size())) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - outputTensors_[j].storage().data_ptr(), ncclStream); - } - outputTensors_[j].copy_(outputFlattened[j], true); - } - }, - OpType::ALLGATHER, - "nccl:all_gather"); - } else { - const auto num_reduces = outputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& output = outputTensors_[i]; - auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{ - static_cast(i), static_cast(0), opts.timeout}; - _broadcast_oop(output, input, broadcastOpts); - } - auto work = endCoalescing(OpType::ALLGATHER); - return work; - } -} - -c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( - std::vector>& /* unused */, - std::vector& /* unused */, - const AllgatherOptions& /* unused */) { - C10_THROW_ERROR( - NotImplementedError, - "ProcessGroupNCCL does not support allgather_coalesced"); -} - -c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const AllgatherOptions& opts) { - return collectiveCoalesced( - inputs, - outputs, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - return ncclAllGather( - input.data_ptr(), - output.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - comm, - stream.stream()); - }, - OpType::COALESCED, - "nccl:all_gather_into_tensor_coalesced"); -} - -c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ReduceScatterOptions& opts) { - TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto outputTensor = outputTensors.back(); - check_gpu_single_tensor(outputTensor); - // @lint-ignore CLANGTIDY - auto inputTensors_ = inputTensors.back(); - TORCH_CHECK( - !isFloat8Type(outputTensor.scalar_type()), - "Float8 dtypes are not currenlty supported for NCCL reductions"); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "reduce_scatter", // collective name - outputTensor.numel() * this->getSize(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - bool same_size = check_same_size(inputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor inputFlattened = newLikeFlat(inputTensors_); - - return collective( - inputFlattened, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - const auto ncclDataType = getNcclDataType(input.scalar_type()); - const auto ncclReduceOp = - getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); - return ncclReduceScatter( - input.data_ptr(), - output.data_ptr(), - output.numel(), - ncclDataType, - ncclReduceOp, - comm, - stream.stream()); - }, - [&](at::cuda::CUDAStream& ncclStream, - c10::intrusive_ptr& work) { - if (avoidRecordStreams_) { - // We only need to stash inputTensors. - // - inputFlattened is stashed onto - // work->stashed_for_allocator_safety_ - // in collective(). - // - User-facing outputTensors is stashed onto work->outputs_ in - // collective(), - // and should also be held by the user until after waiting on - // work_. - auto& v = work->stashed_for_allocator_safety_; - v->insert(v->end(), inputTensors_.begin(), inputTensors_.end()); - } - - // Copy the input tensors to the flattened inputs. - at::cuda::CUDAStreamGuard guard(ncclStream); - for (const auto j : c10::irange(inputTensors_.size())) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - inputTensors_[j].storage().data_ptr(), ncclStream); - } - inputFlattened[j].copy_(inputTensors_[j], true); - } - }, - [&](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - OpType::REDUCE_SCATTER, - "nccl:reduce_scatter"); - } else { - const auto num_reduces = inputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& input = inputTensors_[i]; - auto& output = (i == rank_) ? outputTensor : input; - auto reduceOpts = ReduceOptions{ - opts.reduceOp, - static_cast(i), - static_cast(0), - opts.timeout}; - _reduce_oop(output, input, reduceOpts); - } - auto work = endCoalescing(OpType::REDUCE_SCATTER); - return work; - } -} - -c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts) { - if (inputTensor.dtype() != outputTensor.dtype()) { - C10_THROW_ERROR( - TypeError, "input tensor must be the same type as the output tensor."); - } + // Counting for the sequential number of NCCL collective call. + // (specifically, how many actual kernels we launched, which differs from + // op_id_ when coalescing is enabled) + uint64_t seqCollective_{0}; - if (inputTensor.numel() != outputTensor.numel() * size_) { - C10_THROW_ERROR( - ValueError, - "input tensor must be the same size as output size times world size"); - } + // Counting for the sequential number of NCCL P2P calls. + uint64_t seqP2P_{0}; - // @lint-ignore CLANGTIDY - const auto& tensor = outputTensor; - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for NCCL reductions"); - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "_reduce_scatter_base", // collective name - inputTensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dtype - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash inputs and outputs. - // Note 2: for asyncOp = false, we don't want to record streams because we - // know that the NCCL stream will join back to the "current" stream right - // after this op. So we might just as well keep the stream ownership of the - // input/output tensors unchanged. The benefit would be that the - // allocation/free of the tensors would look deterministic to the "current" - // stream so that the caching allocator can reuse memory pool for this stream - // in a clever way. This setting is added for libraries like FSDP which uses - // `reduce_scatter_tensor`. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - auto ncclDataType = getNcclDataType(input.scalar_type()); - auto ncclReduceOp = - getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); - return ncclReduceScatter( - input.data_ptr(), - output.data_ptr(), - output.numel(), - ncclDataType, - ncclReduceOp, - comm, - stream.stream()); - }, - OpType::_REDUCE_SCATTER_BASE, - "nccl:_reduce_scatter_base", - avoidRecordStreams); -} - -c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const ReduceScatterOptions& opts) { - TORCH_CHECK( - !isFloat8Type(inputs.back().scalar_type()), - "Float8 dtypes are not currenlty supported for NCCL reductions"); - return collectiveCoalesced( - inputs, - outputs, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - auto ncclDataType = getNcclDataType(input.scalar_type()); - auto ncclReduceOp = - getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); - return ncclReduceScatter( - input.data_ptr(), - output.data_ptr(), - output.numel(), - ncclDataType, - ncclReduceOp, - comm, - stream.stream()); - }, - OpType::COALESCED, - "nccl:reduce_scatter_tensor_coalesced"); -} - -c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { - RECORD_PARAM_COMMS( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - rank_, // rank - "barrier", // collective name - 0, // inNelems - 0, // outNelems - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // Device to use for barrier - int barDevIdx = -1; - - // Select device to use for barrier - // 1st choice: Use user defined GPU device ids if provided - if (!opts.device_ids.empty()) { - // Use the first device id because PG NCCL is single-device now - barDevIdx = opts.device_ids[0]; - } else if (getBoundDeviceId()) { - // 2nd choice: Use the bound GPU device id if available. - // Bounded device id can be passed to `init_process_group`. - barDevIdx = (*getBoundDeviceId()).index(); - } else if (!usedDeviceIdxs_.empty()) { - // 3rd choice: infer the device id from the used device ids. - barDevIdx = *usedDeviceIdxs_.begin(); - } else { - // This means there is not yet a NCCL collective being called - // Here we have to use the best guesses and will use a single GPU to call - // allreduce to achieve barrier. - // In case the multiple processes fall into the same node, we use rank to - // ensure that each process is on a different GPU - // Note: it is better to use global rank because the group-local rank can be - // offset wrt the device id if intra-node GPUs are sharded into multiple - // dimensions. - barDevIdx = static_cast(globalRank() % localDeviceCount_); - LOG(WARNING) - << logPrefix() - << c10::str( - " using GPU ", - barDevIdx, - " to perform barrier as devices used by this process are currently unknown. ", - "This can potentially cause a hang if this rank to GPU mapping is incorrect.", - "Specify device_ids in barrier() to force use of a particular device,", - "or call init_process_group() with a device_id."); - } + // Incrementing counter for logical operations (collective or p2p) issued on + // the ProcessGroup + uint64_t op_id_{0}; - TORCH_CHECK_WITH( - ValueError, - barDevIdx >= 0, - "Failed to infer a GPU device id to perform barrier. "); - auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx); - - // Create a dummy tensor on the device - // Note: we use zeros() instead of empty() to prevent barrier from triggering - // alarm when NaN checker is enabled. - at::Tensor barrierTensor = - at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); - - // All reduce to achieve the barrier - auto work = allreduce_impl(barrierTensor); - - // Work will take over barrierTensors - auto ncclWork = dynamic_cast(work.get()); - TORCH_CHECK(ncclWork); - ncclWork->barrierTensor_ = std::move(barrierTensor); - return work; -} - -c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& /* unused */) { - check_gpu_single_tensor(outputTensor, true); - check_gpu_single_tensor(inputTensor, true); - if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_all", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash inputTensors and - // outputTensors. - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - torch::cuda::nccl::all2all_single_equal_split( - input, output, this->getSize(), comm, stream); - return ncclSuccess; - }, - OpType::ALLTOALL_BASE, - "nccl:all_to_all"); - } else { - c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); - c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_allv", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - inputSplitSizes, // inSplitSizes - outputSplitSizes, // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash inputTensors and - // outputTensors. - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - std::vector send_lengths(size_); - std::vector recv_lengths(size_); - std::vector send_offsets(size_); - std::vector recv_offsets(size_); - c10d::computeLengthsAndOffsets( - inputSplitSizes, input, &send_lengths, &send_offsets); - c10d::computeLengthsAndOffsets( - outputSplitSizes, output, &recv_lengths, &recv_offsets); - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - torch::cuda::nccl::all2all_single_unequal_split( - input.data_ptr(), - send_lengths.data(), - send_offsets.data(), - output.data_ptr(), - recv_lengths.data(), - recv_offsets.data(), - input.element_size(), - input.scalar_type(), - comm, - stream); - return ncclSuccess; - }, - OpType::ALLTOALL_BASE, - "nccl:all_to_all"); - } -} - -c10::intrusive_ptr ProcessGroupNCCL::alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& /* unused */) { - std::vector inSplitSizes; - std::vector outSplitSizes; - int64_t total_numel = 0; - - auto device = outputTensors[0].device(); - for (const auto r : c10::irange(outputTensors.size())) { - check_gpu_single_tensor(outputTensors[r], true); - check_gpu_single_tensor(inputTensors[r], true); - TORCH_CHECK( - device == outputTensors[r].device() && - device == inputTensors[r].device(), - "Tensors must be on the same device") - inSplitSizes.push_back(inputTensors[r].numel()); - outSplitSizes.push_back(outputTensors[r].numel()); - total_numel += inputTensors[r].numel(); - } + std::exception_ptr watchDogException_ = nullptr; - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_to_all", // collective name - total_numel, // inNelems - total_numel, // outNelems - inputTensors.front().scalar_type(), // dType - inSplitSizes, // inSplitSizes - outSplitSizes, // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensors, - outputTensors, - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream); - return ncclSuccess; - }, - [&](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) { - if (avoidRecordStreams_) { - // inputTensor0 and outputTensor0 are stashed redundantly by - // collective(), but that's ok. - auto& v = work->stashed_for_allocator_safety_; - v->insert(v->end(), inputTensors.begin(), inputTensors.end()); - v->insert(v->end(), outputTensors.begin(), outputTensors.end()); - } - }, - [](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - OpType::ALLTOALL, - "nccl:all_to_all"); -} - -c10::intrusive_ptr ProcessGroupNCCL::send( - std::vector& tensors, - int dstRank, - int /* unused */) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - check_gpu_single_tensor(tensor, true); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - dstRank, // dst rank - "send", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - auto ret = pointToPoint( - tensor, - [&](at::Tensor& input, - ncclComm_t comm, - at::cuda::CUDAStream& stream, - int dst) { - torch::cuda::nccl::send(input, comm, stream, dst); - return ncclSuccess; - }, - dstRank, - OpType::SEND, - c10::str("nccl:send ", rank_, "->", dstRank).c_str()); - return ret; -} - -c10::intrusive_ptr ProcessGroupNCCL::recv( - std::vector& tensors, - int srcRank, - int /* unused */) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - check_gpu_single_tensor(tensor, true); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - srcRank, // src rank - "recv", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - auto ret = pointToPoint( - tensor, - [&](at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream, - int src) { - torch::cuda::nccl::recv(output, comm, stream, src); - return ncclSuccess; - }, - srcRank, - OpType::RECV, - c10::str("nccl:recv ", rank_, "<-", srcRank).c_str()); - return ret; -} - -void ProcessGroupNCCL::groupStart() { - C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); - ++ncclActiveGroupCounter_; -} - -void ProcessGroupNCCL::groupEnd() { - C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); - --ncclActiveGroupCounter_; -} - -void ProcessGroupNCCL::groupEndNonblocking(std::shared_ptr comm) { -#ifndef NCCL_HAS_COMM_NONBLOCKING - C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); -#else - if (!nccl_use_nonblocking()) { - C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); - } else { - C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt); - } -#endif - --ncclActiveGroupCounter_; -} - -c10::intrusive_ptr ProcessGroupNCCL::gather( - std::vector>& outputTensors, - std::vector& inputTensors, - const GatherOptions& opts) { - static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::gather: " + msg); - }; + // The number of ProcessGroupNCCL created on the current rank. + size_t local_id_; - assertRootRank(invalidArgument, opts.rootRank, size_); - - TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); - - std::vector outputs; - - if (getRank() == opts.rootRank) { - if (outputTensors.size() != 1) { - std::stringstream ss; - ss << "requires a single-element output list containing a list with " - << getSize() << " tensors."; - invalidArgument(ss.str()); - } else if (outputTensors[0].size() != static_cast(getSize())) { - std::stringstream ss; - ss << "Incorrect output list size " << outputTensors[0].size() - << ". Output list size should be " << getSize() - << ", same as size of the process group."; - invalidArgument(ss.str()); - } + std::string logPrefix_; - const auto& options = inputTensor.options(); - const auto& sizes = inputTensor.sizes(); - assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); - outputs = outputTensors[0]; - } else { - // if not in the root rank, initialize outputs as empty list - if (outputTensors.size() != 0) { - invalidArgument("requires empty output on non-root"); - } - outputs = {}; - // append a empty tensor to the list, we don't use it but the - // `collective` template function requires it to invoke its function - outputs.emplace_back(); - } + c10::intrusive_ptr intraNodeComm_; - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - opts.rootRank, // root rank - "gather", // collective name - inputTensor.numel(), // inNelems - inputTensor.numel() * this->getSize(), // outNelems - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash inputTensors and - // outputs, which == outputTensors[0] on the root rank where it matters. - - auto inputs = std::vector{inputTensor}; - return collective( - inputs, - outputs, // just to fit the collective interface - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank; - if (getRank() == root) { - if (!avoidRecordStreams_) { - for (auto output : outputs) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - } - } - torch::cuda::nccl::gather(inputTensor, outputs, comm, stream, root); - return ncclSuccess; - }, - [](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - [](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - OpType::GATHER, - "nccl:gather"); -} - -c10::intrusive_ptr ProcessGroupNCCL::scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ScatterOptions& opts) { - static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupNCCL::scatter: " + msg); - }; + // Number of devices on this node. + int localDeviceCount_{0}; - assertRootRank(invalidArgument, opts.rootRank, size_); - - TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto outputTensor = outputTensors.back(); - - std::vector inputs; - - if (getRank() == opts.rootRank) { - if (inputTensors.size() != 1) { - std::stringstream ss; - ss << "requires a single-element input list containing a list with " - << getSize() << " tensors."; - invalidArgument(ss.str()); - } else if (inputTensors[0].size() != static_cast(getSize())) { - std::stringstream ss; - ss << "Incorrect input list size " << inputTensors[0].size() - << ". Input list size should be " << getSize() - << ", same as size of the process group."; - invalidArgument(ss.str()); - } + std::shared_ptr pgStatus_ = + std::make_shared(); +}; - const auto& options = outputTensor.options(); - const auto& sizes = outputTensor.sizes(); - assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); - inputs = inputTensors[0]; - } else { - // if not in the root rank, initialize inputTensors as empty place holder - // with an empty list - if (inputTensors.size() != 0) { - invalidArgument("requires empty input on non-root"); - } - inputs = {}; - // append a empty tensor to the list, we don't use it but the - // `collective` template function requires it to invoke its function - inputs.emplace_back(); - } +// Dumps the NCCL comm traces and additional information about the Process +// Group. +TORCH_API std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive); - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - opts.rootRank, // root rank - "scatter", // collective name - outputTensor.numel() * this->getSize(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash outputTensors and - // inputs, which == inputTensors[0] on the root rank where it matters. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - - const auto root = opts.rootRank; - bool nanCheck = (rank_ == root); - - auto outputs = std::vector{outputTensor}; - return collective( - outputs, - inputs, // just to fit the collective interface - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - if (getRank() == root) { - if (!avoidRecordStreams) { - for (auto input : inputs) { - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), stream); - } - } - } - torch::cuda::nccl::scatter(inputs, outputTensor, comm, stream, root); - return ncclSuccess; - }, - [](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - [](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, - OpType::SCATTER, - "nccl:scatter", - avoidRecordStreams, - nanCheck); -} - -c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( - std::vector& /* unused */, - int /* unused */) { - C10_THROW_ERROR( - NotImplementedError, "ProcessGroupNCCL does not support recvAnysource"); -} - -c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( - at::Tensor& output_tensor, - at::Tensor& input_tensor, - const AllgatherOptions& opts) { - check_gpu_single_tensor(input_tensor); - check_gpu_single_tensor(output_tensor); - - if (input_tensor.dtype() != output_tensor.dtype()) { - C10_THROW_ERROR( - TypeError, "output tensor must have the same type as input tensor"); - } +// Dumps the NCCL comm traces and additional information about the Process +// Group in JSON formatted string. +// We don't include stack traces in JSON format as it is far too much data. +TORCH_API std::string dump_nccl_trace_json( + bool includeCollectives, + bool onlyActive); - if (input_tensor.numel() * size_ != output_tensor.numel()) { - C10_THROW_ERROR( - ValueError, - "output tensor size must be equal to world_size times input tensor size"); - } +// Gets a mutable reference to a global optional function.Heartbeat Monitor +// will use this function to dump traces, if available. Inside fbcode, we +// store a function here that uses an internal tool for process tracing +TORCH_API std::optional< + std::function)>>& +get_cpp_trace_dumper(); - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - input_tensor, // inputTensors - output_tensor, // outputTensors - rank_, // rank - "_allgather_base", // collective name - input_tensor.numel(), // inNelems - output_tensor.numel(), // outNelems - output_tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - globalRankStart, // globalRankStart - globalRankStride, // globalRankStride - this->getSize()); // worldSize - - // avoidRecordStreams_ note: collective() will stash inputs and outputs. - // Note 2: for asyncOp = false, we don't want to record streams because we - // know that the NCCL stream will join back to the "current" stream right - // after this op. So we might just as well keep the stream ownership of the - // input/output tensors unchanged. The benefit would be that the - // allocation/free of the tensors would look deterministic to the "current" - // stream so that the caching allocator can reuse memory pool for this stream - // in a clever way. This setting is added for libraries like FSDP which uses - // `all_gather_into_tensor`. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - - return collective( - input_tensor, - output_tensor, - [&](at::Tensor& input, - at::Tensor& output, - ncclComm_t comm, - at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - return ncclAllGather( - input.data_ptr(), - output.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - comm, - stream.stream()); - }, - OpType::_ALLGATHER_BASE, - "nccl:_all_gather_base", - avoidRecordStreams); -} +// Similar to get_cpp_trace_dumper, this stores a function defined in +// torch-python layer that lets us check whether the GIL can be acquired, +// helpful for instrumenting in cases where a hang was observed. +typedef bool (*gil_checker_t)(); +TORCH_API gil_checker_t& get_gil_checker(); } // namespace c10d #endif // USE_C10D_NCCL \ No newline at end of file diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 60d2a9680e814e..0f7792e64e5faa 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1813,8 +1813,7 @@ communication mechanism. py::init< const c10::intrusive_ptr<::c10d::Store>&, int, - int, - c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(), + int>(), py::call_guard()) .def("rank", &::c10d::ProcessGroup::getRank) .def("size", &::c10d::ProcessGroup::getSize) @@ -1824,7 +1823,6 @@ communication mechanism. "_backend_id", &::c10d::ProcessGroup::getBackendID, py::arg("backend_type")) - .def_property_readonly("options", &::c10d::ProcessGroup::getOptions) .def( "broadcast", &::c10d::ProcessGroup::broadcast, @@ -2134,6 +2132,14 @@ communication mechanism. }, py::arg("device"), py::call_guard()) + .def( + "_set_default_backend", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const ::c10d::ProcessGroup::BackendType& backendType) { + return self->setDefaultBackend(backendType); + }, + py::arg("backend_type"), + py::call_guard()) .def( "_register_on_completion_hook", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -2237,27 +2243,6 @@ The hook must have the following signature: .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) .export_values(); - // base ProcessGroup::Options binding - auto processGroupOptions = - intrusive_ptr_class_<::c10d::ProcessGroup::Options>( - processGroup, - "Options", - R"( -Base class for all processes group options implementations, such as the nccl -options :class:`~torch.distributed.ProcessGroupNCCL.Options`). -)") - .def( - py::init([](const std::string& backend, - const std::chrono::milliseconds& timeout) { - return c10::make_intrusive<::c10d::ProcessGroup::Options>( - backend, timeout); - }), - py::arg("backend"), - py::arg("timeout") = kProcessGroupDefaultTimeout, - py::call_guard()) - .def_readonly("backend", &::c10d::ProcessGroup::Options::backend) - .def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout); - // TODO: The collection definitions handles direct instantiation of // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported // and should be removed once all tests are transitioned @@ -2557,6 +2542,29 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). &::c10d::Backend::endCoalescing, py::call_guard()); + // base Backend::Options binding + // TODO: Maybe we can consider how to merge this with + // `DistributedBackendOptions`. + auto backendOptions = + intrusive_ptr_class_<::c10d::Backend::Options>( + backend, + "Options", + R"( +Base class for all backend options implementations, such as the nccl +options :class:`~torch.distributed.ProcessGroupNCCL.Options`). +)") + .def( + py::init([](const std::string& backend, + const std::chrono::milliseconds& timeout) { + return c10::make_intrusive<::c10d::Backend::Options>( + backend, timeout); + }), + py::arg("backend"), + py::arg("timeout") = kProcessGroupDefaultTimeout, + py::call_guard()) + .def_readonly("backend", &::c10d::Backend::Options::backend) + .def_readwrite("_timeout", &::c10d::Backend::Options::timeout); + #ifdef USE_C10D_GLOO static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; @@ -2568,7 +2576,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device"); intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>( - processGroupGloo, "_Options", processGroupOptions) + processGroupGloo, "_Options", backendOptions) .def(py::init<>()) .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices) .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads); @@ -2795,7 +2803,7 @@ for details. intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( processGroupNCCL, "Options", - processGroupOptions, + backendOptions, R"( ProcessGroup options for the NCCL backend diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index d8d33ab7ebfbea..d08bcfefc50e5f 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -36,6 +36,7 @@ def _init_device_mesh_stub(): else: + from torch._C._distributed_c10d import Backend as C10dBackend from torch.distributed.distributed_c10d import ( _find_pg_by_ranks_and_tag, _get_default_group, @@ -66,7 +67,7 @@ def __init__(self) -> None: self.mesh_stack: List[DeviceMesh] = [] self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {} self.mesh_dim_group_options: Dict[ - int, Tuple[str, Optional[ProcessGroup.Options]] + int, Tuple[str, Optional[C10dBackend.Options]] ] = {} self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {} # Record flatten mesh name to its mesh dim index in root mesh. @@ -279,7 +280,7 @@ def _set_mesh_dim_group_options( self, dim: int, backend: str, - pg_options: Optional[ProcessGroup.Options] = None, + pg_options: Optional[C10dBackend.Options] = None, ) -> None: self.mesh_dim_group_options[dim] = (backend, pg_options) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index bf3566338fd3ea..2d9357bbd15a44 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -280,6 +280,7 @@ class Backend(str): NCCL: ProcessGroup.BackendType.NCCL, XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, + MPI: ProcessGroup.BackendType.MPI, } def __new__(cls, name: str): @@ -1554,7 +1555,7 @@ def init_process_group( backend, store, group_name, - pg_options=pg_options, + backend_options=pg_options, timeout=timeout, device_id=device_id, group_desc="default_pg", @@ -1651,7 +1652,7 @@ def _new_process_group_helper( backend, store, group_name, - pg_options=None, + backend_options=None, timeout=None, pg_tag=None, device_id=None, @@ -1730,11 +1731,17 @@ def _new_process_group_helper( return GroupMember.NON_GROUP_MEMBER, None prefix_store = PrefixStore(f"{group_name}/", store) - base_pg_options = ProcessGroup.Options(backend=str(backend)) - base_pg_options._timeout = timeout + # The backend for PG will be set later based on what's inside BackendConfig + # and timeout are set in each backend's option. pg: ProcessGroup = ProcessGroup( - prefix_store, group_rank, group_size, base_pg_options + prefix_store, + group_rank, + group_size, ) + # Set the default backend when only single backend is passed in. + if "," not in str(backend) and ":" not in str(backend): + assert backend in Backend.backend_type_map, f"Unknown backend type {backend}" + pg._set_default_backend(Backend.backend_type_map[backend]) if device_id: pg.bound_device_id = device_id backend_config = BackendConfig(backend) @@ -1761,8 +1768,8 @@ def _new_process_group_helper( backend_prefix_store, backend_class.rank(), backend_class.size(), - base_pg_options, ) + pg._set_default_backend(backend_type) elif backend_str == Backend.GLOO: # TODO: remove this check after lazy initialization is supported # if pg_options is not None: @@ -1774,28 +1781,30 @@ def _new_process_group_helper( elif backend_str == Backend.NCCL: if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL built in") - if pg_options is not None: + if backend_options is not None: assert isinstance( - pg_options, ProcessGroupNCCL.Options - ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" - if pg_options._timeout != timeout: + backend_options, ProcessGroupNCCL.Options + ), "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + if backend_options._timeout != timeout: warnings.warn( - "pg_options._timeout was specified, " + "backend_options._timeout was specified, " "but timeout kwarg has a default value that will always override it. " ) else: - # default pg_options for NCCL - pg_options = ProcessGroupNCCL.Options() - pg_options.is_high_priority_stream = False - pg_options._timeout = timeout + # default backend_options for NCCL + backend_options = ProcessGroupNCCL.Options() + backend_options.is_high_priority_stream = False + backend_options._timeout = timeout if split_from: - pg_options.split_from = split_from - pg_options.split_color = _process_group_color(global_ranks_in_group) - pg_options.global_ranks_in_group = global_ranks_in_group - pg_options.group_name = group_name + backend_options.split_from = split_from + backend_options.split_color = _process_group_color( + global_ranks_in_group + ) + backend_options.global_ranks_in_group = global_ranks_in_group + backend_options.group_name = group_name backend_class = ProcessGroupNCCL( - backend_prefix_store, group_rank, group_size, pg_options + backend_prefix_store, group_rank, group_size, backend_options ) backend_type = ProcessGroup.BackendType.NCCL elif backend_str == Backend.UCC and is_ucc_available(): @@ -1841,7 +1850,7 @@ def _new_process_group_helper( dist_backend_opts.group_id = group_name dist_backend_opts.global_ranks_in_group = global_ranks_in_group - backend_class = creator_fn(dist_backend_opts, pg_options) + backend_class = creator_fn(dist_backend_opts, backend_options) # Set sequence numbers for gloo and nccl backends. if backend_str == Backend.GLOO: @@ -4476,12 +4485,15 @@ def split_group( global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] prefix_store = PrefixStore(f"{group_name}/", default_store) - base_pg_options = ProcessGroup.Options(backend=str(backend)) - base_pg_options._timeout = timeout + # We register the backend after initializing and timeout is set in pg_options. pg: ProcessGroup = ProcessGroup( - prefix_store, group_rank, len(my_group), base_pg_options + prefix_store, + group_rank, + len(my_group), ) + backend_type = ProcessGroup.BackendType.NCCL pg.bound_device_id = device_id + pg._set_default_backend(backend_type) pg_options._timeout = timeout pg_options.split_from = parent_backend @@ -4491,7 +4503,6 @@ def split_group( backend_class = ProcessGroupNCCL( prefix_store, group_rank, len(my_group), pg_options ) - backend_type = ProcessGroup.BackendType.NCCL backend_class._set_sequence_number_for_group() pg._register_backend(torch.device("cuda"), backend_type, backend_class) @@ -4614,7 +4625,7 @@ def _new_group_with_tag( ranks=None, timeout=None, backend=None, - pg_options=None, + backend_options=None, pg_tag=None, use_local_synchronization=False, group_desc=None, @@ -4689,7 +4700,7 @@ def _new_group_with_tag( backend, default_store, group_name, - pg_options=pg_options, + backend_options=backend_options, timeout=timeout, pg_tag=pg_tag, device_id=device_id, From 0d8d3be7d759c3642b77a4243e7298a522adf9ab Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Mon, 16 Sep 2024 20:07:29 +0000 Subject: [PATCH 0949/1018] [torch/multiprocessing] Use multiprocessing.reduction.register ForkingPickler.register to register custom tensor and storage reductions (#135030) Right now `multiprocessing.reduction.register()` is simply an alias to `multiprocessing.reduction.ForkingPickler.register()` https://github.com/python/cpython/blame/main/Lib/multiprocessing/reduction.py#L56, but the top-level `register()` function exposes less of the internal details of `multiprocessing.reduction` module. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135030 Approved by: https://github.com/albanD --- torch/multiprocessing/reductions.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index c02ad3dc2362ba..4e4539396f8323 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -2,7 +2,7 @@ import multiprocessing import os import threading -from multiprocessing.reduction import ForkingPickler +from multiprocessing import reduction from multiprocessing.util import register_after_fork from typing import Union @@ -626,22 +626,22 @@ def reduce_storage(storage): def init_reductions(): - ForkingPickler.register(torch.cuda.Event, reduce_event) + reduction.register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: if t.__name__ == "UntypedStorage": - ForkingPickler.register(t, reduce_storage) + reduction.register(t, reduce_storage) else: - ForkingPickler.register(t, reduce_typed_storage_child) + reduction.register(t, reduce_typed_storage_child) - ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + reduction.register(torch.storage.TypedStorage, reduce_typed_storage) for t in torch._tensor_classes: - ForkingPickler.register(t, reduce_tensor) + reduction.register(t, reduce_tensor) # TODO: Maybe this should be in tensor_classes? :) - ForkingPickler.register(torch.Tensor, reduce_tensor) + reduction.register(torch.Tensor, reduce_tensor) from torch.nn.parameter import Parameter - ForkingPickler.register(Parameter, reduce_tensor) + reduction.register(Parameter, reduce_tensor) From 52bad417664534af27a3bc11d5c8a7d5bd57fb69 Mon Sep 17 00:00:00 2001 From: Dan Johnson Date: Mon, 16 Sep 2024 09:52:47 -0700 Subject: [PATCH 0950/1018] Use ncclAlltoAllv and ncclAlltoAll API when supported (#134499) NCCL does not have an api for ncclAllToAll and ncclAllToAllv, so PyTorch does point to point send/recv. Expose this API if it is supported. Differential Revision: [D61683836](https://our.internmc.facebook.com/intern/diff/D61683836/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134499 Approved by: https://github.com/shuqiangzhang, https://github.com/eqy --- torch/csrc/cuda/nccl.cpp | 64 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index afc904f2562d1b..22bcb44109a49d 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -832,7 +832,10 @@ void all2all_single_equal_split( const auto* sendbuff = reinterpret_cast(input.const_data_ptr()); auto* recvbuff = reinterpret_cast(output.data_ptr()); auto comm = to_nccl_comm(_comm); -#if defined(USE_ROCM) +#if defined(USE_ROCM) || defined(NCCL_ALLTOALL_SUPPORTED) + // NCCL_ALLTOALL_SUPPORTED is used so NCCL can differentiate send/recv + // operations issued as a part of the collective (e.g. alltoall) vs those + // inside traditional p2p operations. NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream)); #else NCCL_CHECK(ncclCommCount(comm, &numranks)); @@ -877,6 +880,21 @@ void all2all_single_unequal_split( auto type = to_nccl_data_type(_type); auto comm = to_nccl_comm(_comm); +#ifdef NCCL_ALLTOALLV_SUPPORTED + // NCCL_ALLTOALLV_SUPPORTED is used so NCCL can differentiate send/recv + // operations issued as a part of the collective (e.g. alltoallv) vs those + // inside traditional p2p operations. + NCCL_CHECK(ncclAllToAllv( + sendbuff, + sendcounts, + senddispls, + recvbuff, + recvcounts, + recvdispls, + type, + comm, + stream.stream())); +#else int numranks; NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclGroupStart()); @@ -905,6 +923,7 @@ void all2all_single_unequal_split( #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif +#endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif @@ -924,6 +943,48 @@ void all2all( using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); +#ifdef NCCL_ALLTOALLV_SUPPORTED + // NCCL_ALLTOALLV_SUPPORTED is used so NCCL can differentiate send/recv + // operations issued as a part of the collective (e.g. alltoallv) vs those + // inside traditional p2p operations. + TORCH_INTERNAL_ASSERT( + outputTensors.size() == inputTensors.size(), + "number of input tensors is not equal to number of output tensors"); + std::vector sendCounts(inputTensors.size()); + std::vector sendDisps(inputTensors.size()); + std::vector recvCounts(outputTensors.size()); + std::vector recvDisps(outputTensors.size()); + uintptr_t sendBase = reinterpret_cast(inputTensors[0].data_ptr()); + uintptr_t recvBase = reinterpret_cast(outputTensors[0].data_ptr()); + size_t dtypeSize = inputTensors.front().element_size(); + + for (const auto r : c10::irange(outputTensors.size())) { + sendCounts[r] = inputTensors[r].numel(); + auto sendOffset = + reinterpret_cast(inputTensors[r].data_ptr()) - sendBase; + TORCH_INTERNAL_ASSERT( + sendOffset % dtypeSize == 0, + "sendOffset is not divisible by dtypeSize"); + sendDisps[r] = sendOffset / dtypeSize; + recvCounts[r] = outputTensors[r].numel(); + auto recvOffset = + reinterpret_cast(outputTensors[r].data_ptr()) - recvBase; + TORCH_INTERNAL_ASSERT( + recvOffset % dtypeSize == 0, + "recvOffset is not divisible by dtypeSize"); + recvDisps[r] = recvOffset / dtypeSize; + } + NCCL_CHECK(ncclAllToAllv( + inputTensors[0].data_ptr(), + sendCounts.data(), + sendDisps.data(), + outputTensors[0].data_ptr(), + recvCounts.data(), + recvDisps.data(), + to_nccl_data_type(inputTensors.front()), + comm, + stream.stream())); +#else NCCL_CHECK(ncclGroupStart()); for (const auto r : c10::irange(outputTensors.size())) { at::Tensor& input = inputTensors[r]; @@ -953,6 +1014,7 @@ void all2all( #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif +#endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif From 3b683d0e963c5079590b349f18eed991c36a92b5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 16 Sep 2024 20:26:35 +0000 Subject: [PATCH 0951/1018] Revert "[dynamo] Fix support for classmethod(property(...)) (#134968)" This reverts commit c64ae601ba9eb3ad2cd3402a14f6ac83c0ab7eba. Reverted https://github.com/pytorch/pytorch/pull/134968 on behalf of https://github.com/jeanschmidt due to Breaking internal signals, we need to skip the new tests on py3.10 ([comment](https://github.com/pytorch/pytorch/pull/134968#issuecomment-2353909010)) --- test/dynamo/test_repros.py | 79 +------------------------ torch/_dynamo/variables/constant.py | 2 - torch/_dynamo/variables/user_defined.py | 4 -- 3 files changed, 1 insertion(+), 84 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index bd661990402622..e1e1dd1837bef4 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -14,14 +14,13 @@ import itertools import os import random -import sys import unittest import warnings import weakref from abc import ABC from collections import namedtuple from copy import deepcopy -from enum import Enum, IntEnum +from enum import Enum from functools import wraps from typing import Any, Dict, Iterator, List, Tuple from unittest import mock @@ -4547,82 +4546,6 @@ def f(*args): f(*args) self.assertEqual(num_compiles, 1) - @unittest.skipIf(sys.version_info < (3, 9), "requires python 3.9+") - def test_issue134451(self): - class BoundingBox2DIndex(IntEnum): - _X = 0 - _Y = 1 - _HEADING = 2 - _LENGTH = 3 - _WIDTH = 4 - - @classmethod - def size(cls): - return 5 - - @classmethod - @property - def X(cls): - return cls._X - - @classmethod - @property - def Y(cls): - return cls._Y - - @classmethod - @property - def HEADING(cls): - return cls._HEADING - - @classmethod - @property - def LENGTH(cls): - return cls._LENGTH - - @classmethod - @property - def WIDTH(cls): - return cls._WIDTH - - @classmethod - @property - def POINT(cls): - # assumes X, Y have subsequent indices - return slice(cls._X, cls._Y + 1) - - @classmethod - @property - def STATE_SE2(cls): - # assumes X, Y, HEADING have subsequent indices - return slice(cls._X, cls._HEADING + 1) - - class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self._mlp_states = nn.Sequential( - nn.Linear(10, 20), - nn.ReLU(), - nn.Linear(20, BoundingBox2DIndex.size()), - ) - - def forward(self, x): - agent_states = self._mlp_states(x) - agent_states[..., BoundingBox2DIndex.POINT] = ( - agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 - ) - agent_states[..., BoundingBox2DIndex.HEADING] = ( - agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi - ) - return agent_states - - model = SimpleModel().eval() - input_tensor = torch.randn(1, 10, dtype=torch.float32) - opt = torch.compile(model.eval(), backend="eager", fullgraph=True) - actual = opt(input_tensor) - expected = model(input_tensor) - self.assertEqual(actual, expected) - def test_invalid_seq_unpack(self): def myfn(arg): (a, b) = arg diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index de357cf8094f3a..65de0aab6ef3c7 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -222,8 +222,6 @@ def create(cls, cls_type, value_vt, options): unimplemented("Enum variable is constructed with non constant values") def as_proxy(self): - if isinstance(self.value, int): - return int(self.value) # convert IntEnum to a normal int return self.value def __str__(self) -> str: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 32057c838d74c8..14499b4d2e4bf7 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -193,10 +193,6 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke else: return SourcelessBuilder.create(tx, func) elif isinstance(obj, classmethod): - if isinstance(obj.__func__, property): - return variables.UserFunctionVariable(obj.__func__.fget).call_function( - tx, [self], {} - ) return variables.UserMethodVariable(obj.__func__, self, source=source) elif isinstance(obj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static(dict, "fromkeys") From 97975290729371d95b5dd31c7e81084a5dcb3c9d Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Mon, 16 Sep 2024 14:51:47 +0300 Subject: [PATCH 0952/1018] Add scaling arguments to bsr_dense_addmm (#136104) As in the title. Tackles https://github.com/pytorch/ao/pull/821/files#r1759821413 The PR assumes that the existing tuning parameters are good also when using scaling arguments. This needs to be verified as a follow-up task. Also, this PR redefines triton-contiguous tensors: the tensor must have strides not larger than 1. This will now allow zero strides that previously triggered `contiguous` call although the underlying memory buffer was contiguous. Re: "a considerable slow-down occurs because tensor data is copied element-wise rather than chunk-wise" - this note should refer to a code (torch or triton?) that implements the element/chunk-wise copy so that we could verify that allowing zero strides indeed would not trigger element-wise copies. Atm, the performance increase in ViT-H benchmarks (that involve using 0 strides) is an evidence that allowing zero strides does not lead to slow-downs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136104 Approved by: https://github.com/cpuhrsch --- test/test_sparse_csr.py | 45 +++++++--- torch/sparse/_triton_ops.py | 150 +++++++++++++++++++++++++------ torch/sparse/_triton_ops_meta.py | 20 ++++- 3 files changed, 173 insertions(+), 42 deletions(-) diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 903b609c4fb236..f897fd041889f8 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -4027,6 +4027,7 @@ def test_TensorAsKey(self, device): @skipIfRocm @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8) + @precisionOverride({torch.float16: 6e-1}) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") def test_triton_kernel(self, op, device, dtype, blocksize): from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm @@ -4039,15 +4040,24 @@ def bsr_dense_linear(input, weights, bias=None): operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op] - def reference(input, mat1, mat2, beta=1, alpha=1, op=op): + def reference(input, mat1, mat2, beta=1, alpha=1, left_alpha=None, right_alpha=None, op=op): assert mat1.layout is torch.strided assert mat2.layout is torch.strided if dtype is torch.int8: if op == '_int_bsr_dense_addmm': - return beta * input + alpha * torch._int_mm(mat1, mat2) - # workaround RuntimeError: "addmm_cuda" not implemented for 'Char' - return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8) - return beta * input + alpha * (mat1 @ mat2) + mat12 = torch._int_mm(mat1, mat2) + else: + # workaround RuntimeError: "addmm_cuda" not implemented for 'Char' + mat12 = torch._int_mm(mat1, mat2).to(torch.int8) + else: + mat12 = mat1 @ mat2 + if alpha != 1: + mat12 *= alpha + if left_alpha is not None: + mat12 = left_alpha.reshape(*left_alpha.shape[:-1], -1, 1) * mat12 + if right_alpha is not None: + mat12 = mat12 * right_alpha.reshape(*right_alpha.shape[:-1], 1, -1) + return beta * input + mat12 if op == '_int_bsr_dense_addmm': # _int_bsr_dense_addmm is same as bsr_dense_addmm except @@ -4056,6 +4066,8 @@ def reference(input, mat1, mat2, beta=1, alpha=1, op=op): # definitions above and all other definitions below are # identical between _int_bsr_dense_addmm and # bsr_dense_addmm. + if dtype.is_floating_point or dtype.is_complex: + self.skipTest(f"Redundant test: {op} on {dtype} tensors") op = 'bsr_dense_addmm' def nc_copy(t, axes=(-1,)): @@ -4101,14 +4113,21 @@ def nc_copy(t, axes=(-1,)): blocks_per_row_lst = [1, 2] blocks_per_col_lst = [1, 2] result_cols_lst = [16, 32, 64] - for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product( - beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst): + has_left_alpha_lst = dict(bsr_dense_addmm=[False, True], bsr_dense_mm=[False], bsr_dense_linear=[False])[op] + has_right_alpha_lst = dict(bsr_dense_addmm=[False, True], bsr_dense_mm=[False], bsr_dense_linear=[False])[op] + high = 1.5 + int(dtype is torch.int8) + for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N, has_left_alpha, has_right_alpha in itertools.product( + beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst, + has_left_alpha_lst, has_right_alpha_lst): M = BM * blocks_per_row K = BK * blocks_per_col mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device) bsr = mat1.to_sparse_bsr((BM, BK)) - mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5) - input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5) + mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=high) + input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=high) + + left_alpha = make_tensor(M, dtype=dtype, device=device, low=0.5, high=high) if has_left_alpha else None + right_alpha = make_tensor(N, dtype=dtype, device=device, low=0.5, high=high) if has_right_alpha else None if 0 and op == "bsr_dense_addmm": # Find optimal kernel parameters, the speed-up is @@ -4121,12 +4140,12 @@ def nc_copy(t, axes=(-1,)): meta = get_meta(op, key, version=(0, dtype, 0.5)) if meta is None: optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5) - meta = get_meta(op, key, version=(0, dtype, 0.5)) assert meta is not None dump() # this will update torch/sparse/_triton_ops_meta.py - expected = reference(input, mat1, mat2, beta=beta, alpha=alpha) - kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm={}, + expected = reference(input, mat1, mat2, beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha) + kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, + left_alpha=left_alpha, right_alpha=right_alpha), bsr_dense_mm={}, bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op] args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2), @@ -4156,7 +4175,7 @@ def nc_copy(t, axes=(-1,)): if op in {'bsr_dense_addmm', 'bsr_dense_linear'}: args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2), bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op] - kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), + kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha), bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op] result = operation(*args, **kwargs) self.assertEqual(result, expected) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index 091e91d37f604a..f919718bb5dc66 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -89,14 +89,14 @@ def make_triton_contiguous(t): """Return input as a triton-contiguous tensor. A triton-contiguous tensor is defined as a tensor that has strides - with minimal value equal to 1. + with minimal value smaller than or equal to 1. While triton kernels support triton-non-contiguous tensors (all - strides being greater than 1 or having 0 strides) arguments, a - considerable slow-down occurs because tensor data is copied - element-wise rather than chunk-wise. + strides being greater than 1) arguments, a considerable slow-down + occurs because tensor data is copied element-wise rather than + chunk-wise. Zero strides is assumed to not have this defect. """ - if min(t.stride()) != 1: + if min(t.stride()) > 1: # TODO: investigate if contiguity along other axes than the # last one can be beneficial for performance return t.contiguous() @@ -1097,6 +1097,8 @@ def _int_bsr_dense_addmm( *, beta=1, alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, skip_checks: bool = False, max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, @@ -1120,6 +1122,8 @@ def _int_bsr_dense_addmm( dense, beta=beta, alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, out=out, skip_checks=skip_checks, max_grid=max_grid, @@ -1134,11 +1138,21 @@ def bsr_dense_addmm( *, beta=1, alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, skip_checks: bool = False, max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, meta: Optional[dict] = None, ): + """Compute + + out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1) + + where left_alpha, right_alpha are (* + 1)-D tensors when + specified, otherwise, these are treated as tensors filled with + ones. + """ f_name = "bsr_dense_addmm" values = bsr.values() crow_indices = bsr.crow_indices() @@ -1150,8 +1164,8 @@ def bsr_dense_addmm( # todo: implement checks + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) if out is None: - original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) out = dense.new_empty(original_batch_dims_broadcasted + (M, N)) if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0: @@ -1163,6 +1177,30 @@ def bsr_dense_addmm( out.mul_(beta) return out + left_alpha_is_one = False + right_alpha_is_one = False + if left_alpha is None: + left_alpha_is_one = True + left_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand( + *original_batch_dims_broadcasted, M, N + ) + + if right_alpha is None: + right_alpha_is_one = True + right_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand( + *original_batch_dims_broadcasted, M, N + ) + assert left_alpha.stride()[-1] == 0 + assert right_alpha.stride()[-2] == 0 + if meta is None: sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) meta = bsr_dense_addmm_meta( @@ -1178,9 +1216,16 @@ def bsr_dense_addmm( ) out_backup = out - crow_indices, col_indices, values, input, dense, out = prepare_inputs( - bsr, input, dense, out - ) + ( + crow_indices, + col_indices, + values, + input, + dense, + left_alpha, + right_alpha, + out, + ) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out) BM, BK = blocksize SPLIT_N = meta.get("SPLIT_N", N // BM) @@ -1191,6 +1236,9 @@ def bsr_dense_addmm( dense = tile_to_blocksize(dense, (BK, BN)) input = tile_to_blocksize(input, (BM, BN)) + left_alpha = tile_to_blocksize(left_alpha, (BM, BN)) + right_alpha = tile_to_blocksize(right_alpha, (BM, BN)) + dot_out_dtype = { torch.float16: tl.float32, torch.bfloat16: tl.float32, @@ -1216,6 +1264,8 @@ def bsr_dense_addmm( col_indices: (0, None, None), input: (0, -3, -4), dense: (0, -3, None), + left_alpha: (0, -3, -4), + right_alpha: (0, -3, -4), out: (0, -3, -4), } @@ -1229,6 +1279,8 @@ def kernel(grid, *sliced_tensors): beta_is_one=beta == 1, beta_is_nonzero=beta != 0, alpha_is_one=alpha == 1, + left_alpha_is_one=left_alpha_is_one, + right_alpha_is_one=right_alpha_is_one, BLOCKSIZE_ROW=BM, BLOCKSIZE_INNER=BK, BLOCKSIZE_COL=BN, @@ -2278,6 +2330,22 @@ def _bsr_strided_addmm_kernel( dense_row_block_stride, dense_col_block_stride, # dense epilogue + # left_alpha prologue + left_alpha_ptr, + left_alpha_batch_stride, + left_alpha_tiled_row_stride, + left_alpha_tiled_col_stride: tl.constexpr, + left_alpha_row_block_stride, + left_alpha_col_block_stride: tl.constexpr, + # left_alpha epilogue + # right_alpha prologue + right_alpha_ptr, + right_alpha_batch_stride, + right_alpha_tiled_row_stride: tl.constexpr, + right_alpha_tiled_col_stride, + right_alpha_row_block_stride: tl.constexpr, + right_alpha_col_block_stride, + # right_alpha epilogue # output prologue output_ptr, output_batch_stride, @@ -2291,6 +2359,8 @@ def _bsr_strided_addmm_kernel( beta_is_one: tl.constexpr, beta_is_nonzero: tl.constexpr, alpha_is_one: tl.constexpr, + left_alpha_is_one: tl.constexpr, + right_alpha_is_one: tl.constexpr, BLOCKSIZE_ROW: tl.constexpr, BLOCKSIZE_COL: tl.constexpr, BLOCKSIZE_INNER: tl.constexpr, @@ -2299,6 +2369,12 @@ def _bsr_strided_addmm_kernel( GROUP_SIZE_ROW: tl.constexpr, SPLIT_N: tl.constexpr, ): + # left/right_alpha tensors are originally (* + 1)-dimensional + assert left_alpha_tiled_col_stride == 0 + assert left_alpha_col_block_stride == 0 + assert right_alpha_tiled_row_stride == 0 + assert right_alpha_row_block_stride == 0 + batch_pid = tl.program_id(axis=2) row_block_pid = tl.program_id(axis=0) col_block_pid = tl.program_id(axis=1) @@ -2324,17 +2400,6 @@ def _bsr_strided_addmm_kernel( inner_block_arange = tl.arange(0, BLOCKSIZE_INNER) col_block_arange = tl.arange(0, BLOCKSIZE_COL) - if beta_is_nonzero: - # Pointers are set to exact write-to locations - input_ptrs = ( - input_ptr - + input_batch_stride * batch_pid - + input_tiled_row_stride * row_block_pid - + input_tiled_col_stride * col_block_pid - + input_row_block_stride * row_block_arange[:, None] - + input_col_block_stride * col_block_arange[None, :] - ) - # Pointers are set to the first block of the current row. values_block_ptrs = ( values_ptr @@ -2371,14 +2436,7 @@ def _bsr_strided_addmm_kernel( + col_indices_stride * nnz_offset ) - # alpha is never 0 - if beta_is_nonzero: - output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined] - if not (beta_is_one and alpha_is_one): - beta_alpha = beta / alpha - output_acc_block *= beta_alpha - else: - output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) + output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) for _ in range(row_nnz): values_block = tl.load(values_block_ptrs) @@ -2402,6 +2460,42 @@ def _bsr_strided_addmm_kernel( if not alpha_is_one: output_acc_block *= alpha + if not left_alpha_is_one: + left_alpha_ptrs = ( + left_alpha_ptr + + left_alpha_batch_stride * batch_pid + + left_alpha_tiled_row_stride * row_block_pid + + left_alpha_tiled_col_stride * col_block_pid + + left_alpha_row_block_stride * row_block_arange[:, None] + + left_alpha_col_block_stride * col_block_arange[None, :] + ) + output_acc_block *= tl.load(left_alpha_ptrs) + + if not right_alpha_is_one: + right_alpha_ptrs = ( + right_alpha_ptr + + right_alpha_batch_stride * batch_pid + + right_alpha_tiled_row_stride * row_block_pid + + right_alpha_tiled_col_stride * col_block_pid + + right_alpha_row_block_stride * row_block_arange[:, None] + + right_alpha_col_block_stride * col_block_arange[None, :] + ) + output_acc_block *= tl.load(right_alpha_ptrs) + + if beta_is_nonzero: + input_ptrs = ( + input_ptr + + input_batch_stride * batch_pid + + input_tiled_row_stride * row_block_pid + + input_tiled_col_stride * col_block_pid + + input_row_block_stride * row_block_arange[:, None] + + input_col_block_stride * col_block_arange[None, :] + ) + if beta_is_one: + output_acc_block += tl.load(input_ptrs) + else: + output_acc_block += beta * tl.load(input_ptrs) + # write back the result tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index c7234641886234..3d353c00cc21bd 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -599,6 +599,8 @@ def tune_bsr_dense_addmm( *, beta=1, alpha=1, + left_alpha=None, + right_alpha=None, out=None, store=False, verbose=False, @@ -658,7 +660,15 @@ def tune_bsr_dense_addmm( def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out): def test_func(): return bsr_dense_addmm( - input, bsr, dense, beta=beta, alpha=alpha, meta=meta, out=out + input, + bsr, + dense, + beta=beta, + alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, + meta=meta, + out=out, ) return triton.testing.do_bench(test_func, warmup=500, rep=100) @@ -723,6 +733,8 @@ def optimize_bsr_dense_addmm( bk, beta=1, alpha=1, + use_left_alpha=False, + use_right_alpha=False, dtype=torch.float16, device="cuda", sparsity=0.5, @@ -736,12 +748,18 @@ def optimize_bsr_dense_addmm( ).to_sparse_bsr((bm, bk)) dense = make_tensor(k, n, dtype=dtype, device=device) input = make_tensor(m, n, dtype=dtype, device=device) + left_alpha = make_tensor(m, dtype=dtype, device=device) if use_left_alpha else None + right_alpha = ( + make_tensor(n, dtype=dtype, device=device) if use_right_alpha else None + ) tune_bsr_dense_addmm( input, bsr, dense, beta=beta, alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, store=True, force=force, verbose=verbose, From 163d433534656715998d6e13fadbb4b42461e8d9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 16 Sep 2024 21:28:54 +0000 Subject: [PATCH 0953/1018] [ONNX] Treat CompositeImplicitAutograd ops as normal ops in decomp (#136153) Since https://github.com/pytorch/pytorch/pull/135080, the CompositeImplicitAutograd (CIA) ops are only decomposed when a decomp function is provided in a table. There is no longer a need to distinguish CIA ops like Upsample and preserve them explicitly. On the ONNX Script torchlib side I will unregister some ops from the following list to make sure some CIA ops are still decomposed. ``` , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136153 Approved by: https://github.com/xadupre, https://github.com/gramalingam --- torch/onnx/_internal/exporter/_decomp.py | 45 ++++-------------------- 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/torch/onnx/_internal/exporter/_decomp.py b/torch/onnx/_internal/exporter/_decomp.py index 87e4925b06b6a0..de0dd0c0bcb73b 100644 --- a/torch/onnx/_internal/exporter/_decomp.py +++ b/torch/onnx/_internal/exporter/_decomp.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations +import itertools from typing import Callable, TYPE_CHECKING import torch @@ -38,33 +39,6 @@ def get_onnx_implemented_overloads( return registered_ops -def get_preserve_ops() -> set[torch._ops.OpOverload]: - """Return a set of CompositeImplicitAutograd ops that should be preserved.""" - aten = torch.ops.aten - # NOTE: Keep this list sorted - # NOTE: Do _not_ retain aten.linear as its decomposition is addmm, which is Gemm and is preferable for accuracy - return { - aten._upsample_bilinear2d_aa.default, - aten._upsample_nearest_exact1d.vec, - aten._upsample_nearest_exact2d.vec, - aten._upsample_nearest_exact3d.vec, - aten.group_norm.default, - aten.instance_norm.default, - aten.upsample_bilinear2d.default, - aten.upsample_bilinear2d.vec, - aten.upsample_linear1d.default, - aten.upsample_linear1d.vec, - aten.upsample_nearest1d.default, - aten.upsample_nearest1d.vec, - aten.upsample_nearest2d.default, - aten.upsample_nearest2d.vec, - aten.upsample_nearest3d.default, - aten.upsample_nearest3d.vec, - aten.upsample_trilinear3d.default, - aten.upsample_trilinear3d.vec, - } - - def create_onnx_friendly_decomposition_table( onnx_registered_ops: set[torch._ops.OperatorBase], ) -> dict[torch._ops.OperatorBase, Callable]: @@ -81,19 +55,12 @@ def create_onnx_friendly_decomposition_table( Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding decomposition functions. """ - # This table contains all CIA decomps, so we should filter out ONNX supported CIAs from it - decomposition_table: dict[torch._ops.OperatorBase, Callable] = ( - torch._decomp._decomp_table_to_post_autograd_aten() - ) - can_preserve = get_preserve_ops().intersection(onnx_registered_ops) - for op in list(decomposition_table.keys()): - if op in can_preserve: - del decomposition_table[op] + decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} - # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single - # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your - # definitions in a single TORCH_LIBRARY block. - for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defined] + for op_overload, decomp_fn in itertools.chain( + torch._decomp._decomp_table_to_post_autograd_aten().items(), # type: ignore[attr-defined] + torch._decomp.decomposition_table.items(), # type: ignore[attr-defined] + ): # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX # symbolic function. # NOTE: Do not skip torch._refs decomps. They are fine because otherwise the model is From 4aead15f6ca16cbce1241a803f42f6489ff4a5c8 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sun, 15 Sep 2024 18:59:19 -0700 Subject: [PATCH 0954/1018] use csv extention for test report in order for it to be uploaded to s3 (#136128) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136128 Approved by: https://github.com/clee2000 --- .ci/pytorch/test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9d2f2246ad7509..50aff1d55cd8dd 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -401,9 +401,9 @@ pr_time_benchmarks() { TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" - PYTHONPATH=$(pwd)/benchmarks/dynamo/pr_time_benchmarks source benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh "$TEST_REPORTS_DIR/pr_time_benchmarks_after.txt" "benchmarks/dynamo/pr_time_benchmarks/benchmarks" + PYTHONPATH=$(pwd)/benchmarks/dynamo/pr_time_benchmarks source benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh "$TEST_REPORTS_DIR/pr_time_benchmarks_results.csv" "benchmarks/dynamo/pr_time_benchmarks/benchmarks" echo "benchmark results on current PR: " - cat "$TEST_REPORTS_DIR/pr_time_benchmarks_after.txt" + cat "$TEST_REPORTS_DIR/pr_time_benchmarks_results.csv" } From e6aae779ddebae5836ad8ab7360f30b3ad430f3c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 16 Sep 2024 23:59:56 +0000 Subject: [PATCH 0955/1018] Revert "fix compiled_autograd deadlock throw (#135795)" This reverts commit 00dc7d435652ad66e9d2feb2660928b632281a98. Reverted https://github.com/pytorch/pytorch/pull/135795 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/135795#issuecomment-2354233619)) --- test/inductor/test_compiled_autograd.py | 5 ++++- torch/_dynamo/compiled_autograd.py | 4 ---- torch/autograd/graph.py | 7 +++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 3e4fb849494936..cbe6a268c52275 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2686,7 +2686,10 @@ def fn(x): inp = torch.rand(10, 10, requires_grad=True) out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True) - with torch._dynamo.compiled_autograd.enable(torch.compile): + with self.assertRaisesRegex( + RuntimeError, + r"\(e.g. reentrant checkpointing\), this is not supported yet\.", + ), torch._dynamo.compiled_autograd.enable(torch.compile): out.backward() diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 675f99e46fb78d..e7c5d2414f6e2b 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -525,10 +525,6 @@ def disable(): torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) -def maybe_disable_compiled_autograd(): - return disable() if in_compiled_autograd_region else contextlib.nullcontext() - - # return to starting state of a new process def reset() -> None: compiled_autograd_enable = False diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 991b00e8b50cfb..4792ced32b2703 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -822,10 +822,9 @@ def _engine_run_backward( if attach_logging_hooks: unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) try: - with torch._dynamo.compiled_autograd.maybe_disable_compiled_autograd(): - return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass + return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass finally: if attach_logging_hooks: unregister_hooks() # type: ignore[possibly-undefined] From a2d3e64a1334281a8e61beb27f415d4236d03b31 Mon Sep 17 00:00:00 2001 From: "Zhijing Li (Accelerator Enablement)" Date: Tue, 17 Sep 2024 01:06:10 +0000 Subject: [PATCH 0956/1018] Back out "Flip triton kernel default layout constraint to "needs_fixed_stride_order" (#135581)" (#136160) Test Plan: make train-hstu-cint-publish-bf16-tgif-local Differential Revision: D62766335 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136160 Approved by: https://github.com/muchulee8 --- torch/_inductor/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 99a25b13d2fd2a..04595ba4f0c738 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -74,7 +74,7 @@ def autotune_remote_cache_default() -> Optional[bool]: # The default layout constraint for user-defined triton kernels. # See "The default layout constraint for custom operators" for options. -triton_kernel_default_layout_constraint = "needs_fixed_stride_order" +triton_kernel_default_layout_constraint = "flexible_layout" # use cpp wrapper instead of python wrapper cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" From 865ef8fa85d496be9f3a534f8788d0b34ecab320 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 13 Sep 2024 21:45:45 -0700 Subject: [PATCH 0957/1018] inductor: dont use default_dtype during rng functionalization (#136041) Fixes https://github.com/pytorch/pytorch/issues/119162 See context at https://github.com/pytorch/pytorch/issues/119162#issuecomment-2349849469 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136041 Approved by: https://github.com/eellison --- test/dynamo/test_repros.py | 16 ++++++++++++++++ torch/_inductor/inductor_prims.py | 9 ++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e1e1dd1837bef4..1dc42c4405df7a 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5927,6 +5927,22 @@ def outer_func(x): res = compile_outer(x) self.assertEqual(ref, res) + # https://github.com/pytorch/pytorch/issues/119162 + def test_inductor_rng_default_dtype(self) -> None: + @torch.compile + def fn(): + tmp = torch.randn(4, 4, dtype=torch.bfloat16) + return tmp + + try: + old = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + out = fn() + finally: + torch.set_default_dtype(old) + # output dtype should be float32 + self.assertEqual(out.dtype, torch.bfloat16) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 82a23d3e60cf10..caba77371aac00 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -68,9 +68,16 @@ def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: lambda seeds, index: seeds[index], doc="Extract a single seed from the result of inductor_seeds()", ) +# inductor_random() doesn't accept a dtype. +# instead, its lowering always burns in float32, and conversions to a different type +# are explicit in the graph. We therefore need this impl (used during tracing) to hardcoded +# the dtype, so it always faithfully produces a float32 tensor during tracing, +# even if the default dtype is set to something else. random = make_prim( "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", - lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device), + lambda size, seed, mode: getattr(torch, mode)( + size, device=seed.device, dtype=torch.float32 + ), doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", ) randint = make_prim( From aa3b05219bc829927a105ed1f5a767e825dd52ae Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 13 Sep 2024 21:45:46 -0700 Subject: [PATCH 0958/1018] make view.dtype always return an alias (#136074) Fixes https://github.com/pytorch/pytorch/issues/136064 In the linked repro, this issue was that there was some code like this: ``` # x has dtype torch.float32 def f(x): y = x.view(torch.float32) y.copy_(...) ``` Where because `view.dtype` is implemented today to potentially directly return its input, we would end up directly clobbering the proxy for our graph input (replacing its FX proxy value from `arg0_1` to `view_1`). This is not desirable, because we have careful assertions in AOTDispatcher that mutations only ever happen on graph inputs - but this clobbering caused the mutation to appear, from the perspective of the FX graph, like it was happening on a view of the input. Why is this normally not a problem? Ordinarily, the `ADInplaceOrView` kernel for `view.dtype` will take the output of the view kernel, [and detach() it](https://github.com/pytorch/pytorch/blob/main/tools/autograd/gen_inplace_or_view_type.py#L466) (properly creating a fresh `TensorImpl`). This does **not** happen, though, if you are executing the kernel from with a `__torch_dispatch__` region: the `ADInplaceOrView` logic has already run above you, so that key will be in the TLS exclude set. This PR changes eager behavior - at first I considered trying to only change behavior under compile. But this problem isn't technically specific to PT2: if you ever rely on tensor identity from inside of a __torch_dispatch__ call, then we need to make sure the raw `view.dtype` kernel doesn't directly return the input. I am also making the assumption that "`view.dtype` no-op'ing when the dtype is the same" is not a case worth optimizing in eager mode, and that the overhead of the `TensorImpl` creation is relatively negligible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136074 Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/albanD ghstack dependencies: #136041 --- aten/src/ATen/native/TensorConversions.cpp | 3 --- test/test_python_dispatch.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index dc0c1054d16b05..22a576408bfbbb 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -772,9 +772,6 @@ inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_str } Tensor view_dtype(const Tensor& self, ScalarType dtype) { - if (self.scalar_type() == dtype) { - return self; - } const auto type_meta = c10::scalarTypeToTypeMeta(dtype); TORCH_CHECK(!self.is_conj(), "torch.Tensor.view is not supported for conjugate view tensors when converting to a different dtype."); diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 7c938e302118c3..00e86b1f9977b2 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -1793,6 +1793,23 @@ def test_tolist_numpy_with_torch_dispatch_mode(self) -> None: with self.assertRaises(AssertionError): self.assertEqual(x, None) + # See https://github.com/pytorch/pytorch/issues/136064 + def test_view_returns_alias_under_torch_dispatch(self): + class MyMode(TorchDispatchMode): + def __init__(self, testcase): + self.testcase = testcase + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + out = func(*args, **kwargs) + if func == torch.ops.aten.view.dtype: + # view should return a fresh TensorImpl + self.testcase.assertTrue(out is not args[0]) + return out + + with MyMode(self): + x = torch.ones(4, dtype=torch.float32) + out = x.view(torch.float32) + def test_record_stream(self) -> None: class TestMode(TorchDispatchMode): def __init__(self, testcase): From bec366977909cb62ac23b1a4351ca7df5d71b868 Mon Sep 17 00:00:00 2001 From: wz337 Date: Mon, 16 Sep 2024 13:10:13 -0700 Subject: [PATCH 0959/1018] [DSD][EZ] Minor update in _state_dict_utils.py (#136165) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136165 Approved by: https://github.com/kwen2501 ghstack dependencies: #135725, #135763 --- torch/distributed/_state_dict_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 9520b195f9ed5c..be71f88cd52b8e 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -556,7 +556,10 @@ def _distribute_tensors( shape, offset = compute_local_shape_and_global_offset( full_tensor.shape, local_state.device_mesh, local_state.placements ) - slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))] + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] local_tensor = full_tensor[slices] local_state_dict[key] = DTensor.from_local( local_tensor, local_state.device_mesh, local_state.placements From 44340844ae3b3a81835896c48238cd39bad802f2 Mon Sep 17 00:00:00 2001 From: ankurneog Date: Tue, 17 Sep 2024 04:39:08 +0000 Subject: [PATCH 0960/1018] Update real device in FSDP state_dict_utils (#134994) ## Motivation The default device for tensor.device both for sharded as well as non sharded is set to cuda by default. Hence while checking the FSDP UTs we see the following errors. This change updates the actual device type based on the created tensor. ``` [rank3] File "/root/repos/pytorch-training-tests/tests/pytorch/v2.4.0/distributed_hpu/fsdp/test_fsdp_dtensor_state_dict.py", line 143, in test_dtensor_sharded_tensor_state_dict_identical [rank3] sharded_tensor_sd = ref_model.state_dict() [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1944, in state_dict [rank3] hook_result = hook(self, destination, prefix, local_metadata) [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank3] return func(*args, **kwargs) [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 752, in _post_state_dict_hook [rank3] tensor.device, [rank3] File "/usr/local/lib/python3.10/dist-packages/typing_extensions.py", line 2853, in wrapper [rank3] return arg(*args, **kwargs) [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1152, in __torch_function__ [rank3] return dispatch(st_instance, func) [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1134, in dispatch [rank3] return _SHARDED_OPS[func](types, args, kwargs, st._process_group) [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/op_registry_utils.py", line 33, in wrapper [rank3] return wrapped_func(types, args, kwargs, process_group) [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py", line 52, in tensor_device [rank3] dev = torch.device(torch.cuda.current_device()) [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 878, in current_device [rank3] _lazy_init() [rank3] File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 305, in _lazy_init [rank3] raise AssertionError("Torch not compiled with CUDA enabled") [rank3] AssertionError: Torch not compiled with CUDA enabled ```` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134994 Approved by: https://github.com/fegin --- torch/distributed/fsdp/_state_dict_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 06dfb00d71e898..f96872bfa6e7cf 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -730,13 +730,18 @@ def _post_state_dict_hook( for key, tensor in sorted(processed_state_dict.items()): if key.startswith(prefix) and isinstance(tensor, torch.Tensor): local_shape = tensor.shape + device = None if isinstance(tensor, ShardedTensor): local_shape = None shards = tensor.local_shards() if shards: local_shape = shards[0].tensor.shape + device = shards[0].tensor.device elif isinstance(tensor, DTensor): local_shape = tensor.to_local().shape + device = tensor.device + else: + device = tensor.device logger.info( "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s", key, @@ -744,7 +749,7 @@ def _post_state_dict_hook( tensor.shape, local_shape, tensor.dtype, - tensor.device, + device, ) return processed_state_dict From 70427fb8de725843664a482779ad27313020101f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 17 Sep 2024 13:00:01 +0000 Subject: [PATCH 0961/1018] Revert "Optimize dict reconstruct to not codegen untouched values (#134876)" This reverts commit a1a57a424dc992f4dc2d44bdc1e4e7e500881a9c. Reverted https://github.com/pytorch/pytorch/pull/134876 on behalf of https://github.com/jeanschmidt due to new introduced test test_reconstruct.py::ReconstructTest::test_functional_call_reconstruct is breaking internally. @zou3519 may you help get those changes merged back to main? ([comment](https://github.com/pytorch/pytorch/pull/134876#issuecomment-2355697685)) --- test/dynamo/test_reconstruct.py | 257 ----------------------------- torch/_dynamo/side_effects.py | 23 +-- torch/_dynamo/variables/builder.py | 4 +- torch/_dynamo/variables/dicts.py | 41 +---- 4 files changed, 11 insertions(+), 314 deletions(-) delete mode 100644 test/dynamo/test_reconstruct.py diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py deleted file mode 100644 index 18a18b6151b527..00000000000000 --- a/test/dynamo/test_reconstruct.py +++ /dev/null @@ -1,257 +0,0 @@ -# Owner(s): ["module: dynamo"] - -import contextlib -import dis -import unittest -from typing import List - -import torch -import torch._dynamo.test_case - - -def _filter_instructions(instructions, opname): - return list(filter(lambda x: x.opname == opname, instructions)) - - -class ReconstructTest(torch._dynamo.test_case.TestCase): - @contextlib.contextmanager - def register_bytecode_hook(self, check_fn): - def hook(code, out_code): - check_fn(list(dis.get_instructions(out_code))) - return code - - torch._dynamo.reset() - handle = torch._dynamo.convert_frame.register_bytecode_hook(hook) - try: - yield - finally: - handle.remove() - - def test_ConstDict_optimize_reconstruct(self): - """ - Emit code to reconstruct only the key that changed - """ - - def hook(instructions: List[dis.Instruction]): - build_map = _filter_instructions(instructions, "BUILD_MAP") - self.assertEqual(len(build_map), 1) - # reconstruct only d[40] - self.assertEqual(build_map[0].argval, 1) - - def f(d, t): - d[1] = t - d[40] = t + 1 - - t = torch.randn(3, 4) - d = {1: t} - d_opt = d.copy() - f(d, t) - - with self.register_bytecode_hook(hook): - opt_f = torch._dynamo.optimize("eager", nopython=True)(f) - opt_f(d_opt, t) - self.assertEqual(d, d_opt) - - def test_ConstDict_pop_reconstruct(self): - """ - If something is pop'ed from the dict, we reconstruct everything - """ - - def hook(instructions: List[dis.Instruction]): - build_map = _filter_instructions(instructions, "BUILD_MAP") - self.assertEqual(len(build_map), 1) - # reconstruct everything - self.assertEqual(build_map[0].argval, 2) - - def f(d, t): - d.pop(2) - d[40] = t + 1 - - t = torch.randn(3, 4) - d = {1: t, 2: t + 1} - d_opt = d.copy() - - f(d, t) - - with self.register_bytecode_hook(hook): - opt_f = torch._dynamo.optimize("eager", nopython=True)(f) - opt_f(d_opt, t) - self.assertEqual(d, d_opt) - - @unittest.expectedFailure - def test_ConstDict_popitem_reconstruct(self): - """ - If something is pop'ed from the dict, we reconstruct everything - """ - - def hook(instructions: List[dis.Instruction]): - build_map = _filter_instructions(instructions, "BUILD_MAP") - self.assertEqual(len(build_map), 1) - # reconstruct everything - self.assertEqual(build_map[0].argval, 1) - - def f(d, t): - d.popitem() - - t = torch.randn(3, 4) - d = {1: t, 2: t + 1} - d_opt = d.copy() - - f(d, t) - - with self.register_bytecode_hook(hook): - opt_f = torch._dynamo.optimize("eager", nopython=True)(f) - opt_f(d_opt, t) - self.assertEqual(d, d_opt) - - def test_ConstDict_popitem_reconstruct_graph_break(self): - """ - If something is pop'ed from the dict, we reconstruct everything. - Calling dict.popitem will graph break. - """ - - def f(d, t): - d.popitem() - - t = torch.randn(3, 4) - d = {1: t, 2: t + 1} - d_opt = d.copy() - - f(d, t) - - opt_f = torch.compile(backend="eager")(f) - opt_f(d_opt, t) - self.assertEqual(d, d_opt) - - def test_ConstDict_del_reconstruct(self): - """ - If something is deleted from the dict, we reconstruct everything - """ - - def hook(instructions: List[dis.Instruction]): - build_map = _filter_instructions(instructions, "BUILD_MAP") - self.assertEqual(len(build_map), 1) - # reconstruct everything - self.assertEqual(build_map[0].argval, 2) - - def f(d, t): - del d[2] - d[40] = t + 1 - - t = torch.randn(3, 4) - d = {1: t, 2: t + 1} - d_opt = d.copy() - - f(d, t) - - with self.register_bytecode_hook(hook): - opt_f = torch._dynamo.optimize("eager", nopython=True)(f) - opt_f(d_opt, t) - self.assertEqual(d, d_opt) - - def test_ConstDict_get_reconstruct(self): - """ - dict.get shouldn't affect anything - """ - - def hook(instructions: List[dis.Instruction]): - build_map = _filter_instructions(instructions, "BUILD_MAP") - self.assertEqual(len(build_map), 1) - self.assertEqual(build_map[0].argval, 1) - load_const = _filter_instructions(instructions, "LOAD_CONST") - self.assertNotIn(123, load_const) - - def f(d, t): - d[456] = d.get(456) + t - - t = torch.randn(3, 4) - d = {123: t, 456: t + 1} - d_opt = d.copy() - - f(d, t) - - with self.register_bytecode_hook(hook): - opt_f = torch._dynamo.optimize("eager", nopython=True)(f) - opt_f(d_opt, t) - self.assertEqual(d, d_opt) - - def test_ConstDict_clear_reconstruct(self): - """ - If dict.clear() is used, we reconstruct everything - """ - - def hook(instructions: List[dis.Instruction]): - build_map = _filter_instructions(instructions, "BUILD_MAP") - self.assertEqual(len(build_map), 1) - # reconstruct everything - self.assertEqual(build_map[0].argval, 1) - - def f(d, t): - d.clear() - d[3] = t + 3 - - t = torch.randn(3, 4) - d = {1: t, 2: t + 1} - d_opt = d.copy() - - f(d, t) - - with self.register_bytecode_hook(hook): - opt_f = torch._dynamo.optimize("eager", nopython=True)(f) - opt_f(d_opt, t) - self.assertEqual(d, d_opt) - - def test_create_dict_reconstruct(self): - """ - If dict is created inside a function, everything needs to be reconstructed - """ - - def hook(instructions: List[dis.Instruction]): - build_map = _filter_instructions(instructions, "BUILD_MAP") - self.assertEqual(len(build_map), 1) - # reconstruct everything - self.assertEqual(build_map[0].argval, 2) - - def f(t): - return {1: t, 2: t + 1} - - t = torch.randn(3, 4) - d = f(t) - - with self.register_bytecode_hook(hook): - opt_f = torch._dynamo.optimize("eager", nopython=True)(f) - d_opt = opt_f(t) - self.assertEqual(d, d_opt) - - def test_functional_call_reconstruct(self): - """ - PyTorch shouldn't codegen any key/value when functional_call is used - """ - - def hook(instructions: List[dis.Instruction]): - build_map = _filter_instructions(instructions, "BUILD_MAP") - self.assertEqual(len(build_map), 1) - # don't reconstruct anything - self.assertEqual(build_map[0].argval, 0) - - m = torch.nn.Linear(3, 3) - new_bias = torch.randn(3) - new_weight = torch.randn(3, 3) - - def fn(new_weight, new_bias, x): - return torch.func.functional_call( - m, {"weight": new_weight, "bias": new_bias}, x - ) - - x = torch.randn(2, 3) - expected = torch.nn.functional.linear(x, new_weight, new_bias) - with self.register_bytecode_hook(hook): - opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) - got = opt_fn(new_weight, new_bias, x) - self.assertEqual(expected, got) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 3b83eb17d407f4..01891dc9e196db 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -590,36 +590,21 @@ def codegen_update_mutated(self, cg: PyCodegen): ) elif isinstance(var, variables.ConstDictVariable): - # Reconstruct works as follow: - # (1) codegen(...) each pair of key/value - # (2) create a new dictionary with the pairs of key/values above - # (3) clear the original dictionary - # + only if a key was removed from the input dict - # (4) update the original dictionary with the dict created in (2) - cg(var.mutable_local.source) # type: ignore[attr-defined] cg.load_method("update") cg(var, allow_cache=False) - if var.should_reconstruct_all: - cg(var.mutable_local.source) # type: ignore[attr-defined] - cg.load_method("clear") + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.load_method("clear") suffixes.append( [ + *create_call_method(0), # clear + create_instruction("POP_TOP"), *create_call_method(1), # update create_instruction("POP_TOP"), ] ) - - if var.should_reconstruct_all: - suffixes.append( - [ - *create_call_method(0), # clear - create_instruction("POP_TOP"), - ] - ) - elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 41c5f4c004df98..193e227267ec3b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -640,9 +640,7 @@ def build_key_value(i, k, v): source=self.source, ) else: - result = ConstDictVariable( - result, user_cls=type(value), source=self.source - ) + result = ConstDictVariable(result, type(value), source=self.source) return self.set_source_and_track_mutable(value, result) elif isinstance(value, torch.nn.Module): diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 4626f68bb44c06..e8323ec6f70d76 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -14,7 +14,7 @@ from ..eval_frame import skip_code from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard -from ..source import AttrSource, GetItemSource, is_from_local_source +from ..source import AttrSource, GetItemSource from ..utils import dict_keys, dict_values, istype, specialize_symnode from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -128,18 +128,8 @@ def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: return Hashable._eq_impl(self.underlying_value, other) def __init__( - self, - items: Dict[VariableTracker, VariableTracker], - user_cls=dict, - **kwargs, + self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs ) -> None: - # .clone() pass these arguments in kwargs but they're recreated a few - # lines below - if "original_items" in kwargs: - kwargs.pop("original_items") - if "should_reconstruct_all" in kwargs: - kwargs.pop("should_reconstruct_all") - super().__init__(**kwargs) Hashable = ConstDictVariable._HashableTracker @@ -155,10 +145,6 @@ def make_hashable(key): return key if isinstance(key, Hashable) else Hashable(key) self.items = {make_hashable(x): v for x, v in items.items()} - # need to reconstruct everything if the dictionary is an intermediate value - # or if a pop/delitem was executed - self.should_reconstruct_all = not is_from_local_source(self.source) - self.original_items = items.copy() self.user_cls = user_cls def as_proxy(self): @@ -203,9 +189,6 @@ def len(self): ] ) - def _maybe_realize(self, item): - return item.realize() if item else item - def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: @@ -218,29 +201,20 @@ def reconstruct(self, codegen): ) ) # instructions to build the dict keys and values - num_args = 0 for key, value in self.items.items(): - # We can safely call realize() here as it won't introduce any new guards - is_new_item = ( - self._maybe_realize(self.original_items.get(key.vt)) != value.realize() - ) - - if is_new_item or self.should_reconstruct_all: - codegen(key.vt) - codegen(value) - num_args += 1 - + codegen(key.vt) + codegen(value) # BUILD_MAP and calling collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: codegen.extend_output( [ - create_instruction("BUILD_MAP", arg=num_args), + create_instruction("BUILD_MAP", arg=len(self.items)), *create_call_function(1, False), ] ) # BUILD_MAP only if user_cls is dict else: - codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) + codegen.append_output(create_instruction("BUILD_MAP", arg=len(self.items))) def getitem_const_raise_exception_if_absent( self, tx: "InstructionTranslator", arg: VariableTracker @@ -314,7 +288,6 @@ def call_method( self.items[Hashable(args[0])] = args[1] return ConstantVariable.create(None) elif name == "__delitem__" and arg_hashable and self.mutable_local: - self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.__delitem__(Hashable(args[0])) return ConstantVariable.create(None) @@ -325,11 +298,9 @@ def call_method( else: return args[1] elif name == "pop" and arg_hashable and self.mutable_local: - self.should_reconstruct_all = True tx.output.side_effects.mutation(self) return self.items.pop(Hashable(args[0])) elif name == "clear": - self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.clear() return ConstantVariable.create(None) From b20e5aac5f2512b087eb32e773a483aef8147073 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 17 Sep 2024 13:39:07 +0000 Subject: [PATCH 0962/1018] Revert "Add decomposition for squeeze_copy (#130941)" This reverts commit c33b0580e6a702be0cd5be691b3b465da012aa34. Reverted https://github.com/pytorch/pytorch/pull/130941 on behalf of https://github.com/jeanschmidt due to Need to revert in order to be able to revert https://github.com/pytorch/pytorch/pull/130944, after fixing any merge conflicts, feel free to merge it back ([comment](https://github.com/pytorch/pytorch/pull/130941#issuecomment-2355831480)) --- test/distributed/_tensor/test_dtensor_ops.py | 1 - ...asDecompTest.test_has_decomposition.expect | 6 +++++ test/functorch/test_ops.py | 3 --- test/functorch/test_vmap.py | 1 - test/test_mps.py | 1 - tools/autograd/gen_variable_type.py | 1 - torch/_decomp/__init__.py | 1 - torch/_refs/__init__.py | 2 -- .../_internal/common_methods_invocations.py | 25 ------------------- 9 files changed, 6 insertions(+), 35 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index d76fccee216d1f..c6fd0e01e5ffd9 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -427,7 +427,6 @@ def wrapped(fn): xfail("special.xlog1py"), xfail("special.zeta"), xfail("squeeze", "multiple"), - xfail("squeeze_copy"), xfail("signal.windows.bartlett"), xfail("signal.windows.blackman"), xfail("signal.windows.cosine"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 2ecc1aef4c3954..357a5afefd4d3e 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1285,6 +1285,12 @@ aten::split_copy.Tensor_out aten::squeeze_ aten::squeeze_.dim aten::squeeze_.dims +aten::squeeze_copy +aten::squeeze_copy.dim +aten::squeeze_copy.dim_out +aten::squeeze_copy.dims +aten::squeeze_copy.dims_out +aten::squeeze_copy.out aten::sspaddmm.out aten::t_ aten::to_mkldnn diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 1c94fae00c98ce..9b04fb618ce616 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1417,7 +1417,6 @@ def test_vmapjvpall(self, device, dtype, op): xfail("masked.cumprod", ""), xfail("permute_copy"), xfail("renorm"), # hit vmap fallback, which is disabled - xfail("squeeze_copy"), xfail("t_copy"), xfail("transpose_copy"), xfail("unsqueeze_copy"), @@ -1485,7 +1484,6 @@ def test(): xfail("put"), xfail("quantile"), xfail("renorm"), - xfail("squeeze_copy"), xfail("take"), xfail("tensor_split"), xfail("to_sparse"), @@ -1546,7 +1544,6 @@ def test(): xfail( "index_fill" ), # aten::_unique hit the vmap fallback which is currently disabled - xfail("squeeze_copy"), xfail("t_copy"), xfail("transpose_copy"), xfail("unsqueeze_copy"), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 3f369226183a9f..329ae0013bf293 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4439,7 +4439,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("put"), xfail("quantile"), xfail("renorm"), - xfail("squeeze_copy"), xfail("resize_as_"), xfail("take"), xfail("tensor_split"), diff --git a/test/test_mps.py b/test/test_mps.py index d6e19e9eeb92e2..20dae361ba8fcf 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -337,7 +337,6 @@ def mps_ops_modifier(ops): 'split_with_sizes_copy', 'splitlist_args', 'squeeze', - 'squeeze_copy', 'squeezemultiple', 'sub', 'svd', diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index c456b127168baa..d1e9245cd72f9c 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -201,7 +201,6 @@ "permute", "permute_copy", "squeeze", - "squeeze_copy", "unsqueeze", "unsqueeze_copy", "resize", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 34e8935be1f22c..c21ee837471178 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -621,7 +621,6 @@ def _core_aten_decompositions_post_autograd() -> ( aten.special_xlog1py, aten.split.Tensor, aten.split_with_sizes_copy, - aten.squeeze_copy, aten.squeeze.default, aten.squeeze.dim, aten.std, diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index d6ebf7e55c2394..41b9031a304227 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -285,7 +285,6 @@ "stack", "swap_axes", # alias for transpose "squeeze", - "squeeze_copy", "t", "t_copy", "T", @@ -6338,7 +6337,6 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) # no sparse support. See narrow_copy_sparse in core. narrow_copy = _make_copy_from_view(aten.narrow) permute_copy = _make_copy_from_view(aten.permute) -squeeze_copy = _make_copy_from_view(aten.squeeze) t_copy = _make_copy_from_view(aten.t) transpose_copy = _make_copy_from_view(aten.transpose) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cc61f8bbc69863..55b3cd3c542171 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -19428,26 +19428,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # https://github.com/pytorch/pytorch/issues/66357 check_batched_forward_grad=False, sample_inputs_func=sample_inputs_squeeze_multiple), - OpInfo('squeeze_copy', - ref=_squeeze_ref, - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), - supports_out=True, - assert_autodiffed=True, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - # vmap does not support inplace views - check_inplace_batched_forward_grad=False, - # https://github.com/pytorch/pytorch/issues/66357 - check_batched_forward_grad=False, - sample_inputs_func=sample_inputs_squeeze, - skips=( - DecorateInfo( - unittest.expectedFailure, - 'TestJit', - 'test_variant_consistency_jit', - dtypes=(torch.float32,), - ), - )), UnaryUfuncInfo( 'fill', ref=_fill_np, @@ -23842,11 +23822,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.squeeze", torch_opinfo_name="squeeze", ), - PythonRefInfo( - "_refs.squeeze_copy", - torch_opinfo_name="squeeze_copy", - supports_out=True, - ), PythonRefInfo( "_refs.squeeze", torch_opinfo_name="squeeze", From da39dbe4459efe866b892f7dbe7354f06db47754 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 17 Sep 2024 13:42:55 +0000 Subject: [PATCH 0963/1018] Revert "Add decomposition for permute_copy (#130944)" This reverts commit ab9a7eadd34aee59fc67e29237610b7562cc4ff0. Reverted https://github.com/pytorch/pytorch/pull/130944 on behalf of https://github.com/jeanschmidt due to Broke internal signal executorch.backends.xnnpack.test.ops.permute.TestPermute, more details on D62737086. @eellison could you please help get this PR merged to main? ([comment](https://github.com/pytorch/pytorch/pull/130944#issuecomment-2355846394)) --- test/distributed/_tensor/test_dtensor_ops.py | 1 - ...HasDecompTest.test_has_decomposition.expect | 2 ++ test/functorch/test_ops.py | 2 -- test/functorch/test_vmap.py | 1 - test/test_mps.py | 1 - tools/autograd/gen_variable_type.py | 1 - torch/_decomp/__init__.py | 1 - torch/_refs/__init__.py | 2 -- .../_internal/common_methods_invocations.py | 18 ------------------ 9 files changed, 2 insertions(+), 27 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index c6fd0e01e5ffd9..532bd3facae554 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -370,7 +370,6 @@ def wrapped(fn): xfail("ormqr"), xfail("ones"), xfail("pca_lowrank"), - xfail("permute_copy"), xfail("pinverse"), xfail("polar"), xfail("put"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 357a5afefd4d3e..a759903d065591 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1014,6 +1014,8 @@ aten::ones.names_out aten::ones.out aten::ormqr aten::ormqr.out +aten::permute_copy +aten::permute_copy.out aten::poisson aten::poisson.out aten::polygamma diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 9b04fb618ce616..5cd18a072d8daf 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1415,7 +1415,6 @@ def test_vmapjvpall(self, device, dtype, op): xfail("nn.functional.dropout3d", ""), xfail("as_strided_scatter", ""), xfail("masked.cumprod", ""), - xfail("permute_copy"), xfail("renorm"), # hit vmap fallback, which is disabled xfail("t_copy"), xfail("transpose_copy"), @@ -1480,7 +1479,6 @@ def test(): xfail("masked_select"), xfail("nanquantile"), xfail("ormqr"), - xfail("permute_copy"), xfail("put"), xfail("quantile"), xfail("renorm"), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 329ae0013bf293..1bfc31fe521bb5 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4470,7 +4470,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("histc"), xfail("as_strided"), xfail("as_strided_copy"), - xfail("permute_copy"), xfail("t_copy"), xfail("unsqueeze_copy"), xfail("istft"), diff --git a/test/test_mps.py b/test/test_mps.py index 20dae361ba8fcf..32adede8dddca0 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -318,7 +318,6 @@ def mps_ops_modifier(ops): 'ones', 'outer', 'permute', - 'permute_copy', 'positive', 'randn', 'ravel', diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index d1e9245cd72f9c..fa6c578dea04a2 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -199,7 +199,6 @@ "transpose", "transpose_copy", "permute", - "permute_copy", "squeeze", "unsqueeze", "unsqueeze_copy", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index c21ee837471178..a22289b75c4012 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -569,7 +569,6 @@ def _core_aten_decompositions_post_autograd() -> ( aten.norm, aten.ones, aten.ones_like, - aten.permute_copy, aten.pixel_shuffle, aten.pixel_unshuffle, aten._prelu_kernel, diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 41b9031a304227..9b4d10cd5a6ad4 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -274,7 +274,6 @@ "native_group_norm", "native_layer_norm", "permute", - "permute_copy", "ravel", "repeat", "reshape", @@ -6336,7 +6335,6 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) # TODO: This must return a sparse tensor if the input is sparse, but refs have # no sparse support. See narrow_copy_sparse in core. narrow_copy = _make_copy_from_view(aten.narrow) -permute_copy = _make_copy_from_view(aten.permute) t_copy = _make_copy_from_view(aten.t) transpose_copy = _make_copy_from_view(aten.transpose) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 55b3cd3c542171..34ab66cd8077a5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -16851,19 +16851,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_varargs=True, sample_inputs_func=sample_inputs_permute, reference_inputs_func=reference_inputs_permute), - OpInfo('permute_copy', - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), - supports_out=True, - assert_autodiffed=True, - assert_jit_shape_analysis=True, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - supports_varargs=False, # torch.permute is also not varargs - sample_inputs_func=sample_inputs_permute, - reference_inputs_func=reference_inputs_permute, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), - )), BinaryUfuncInfo('pow', dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), @@ -23767,11 +23754,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.permute", torch_opinfo_name="permute", ), - PythonRefInfo( - "_refs.permute_copy", - torch_opinfo_name="permute_copy", - supports_out=True, - ), ElementwiseUnaryPythonRefInfo( "_refs.rad2deg", torch_opinfo_name="rad2deg", From df1bf35ffe84e1743334b8ac4a1997582680725a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 16 Sep 2024 20:32:17 -0700 Subject: [PATCH 0964/1018] Don't run reshape pattern match on dynamic shape size tensor (#136100) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/136100 Approved by: https://github.com/mengluy0125 --- test/inductor/test_split_cat_fx_passes.py | 43 ++++++++++++++--------- torch/_inductor/fx_passes/split_cat.py | 6 ++++ 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index 2c32b01be2d60f..3e775ef2de8e4b 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -27,7 +27,12 @@ def patch(f): class TestSplitCatFxPasses(TestCase): - @patch + @torch._inductor.config.patch( + pre_grad_fusion_options={ + "normalization_pass": {}, + }, + post_grad_fusion_options={}, + ) def test_split_normalization(self): def arg_only(x): return [torch.relu(s) for s in torch.split(x, 2, 1)] @@ -76,27 +81,31 @@ def unequal_split_cm(x): def cm_with_list(x): return [torch.relu(s) for s in x.split([16, 16], dim=-1)] + def normalize_reshape_with_dynamic_shape(x): + return x.reshape(4, 16) + args = [ torch.randn(2, 32), ] - for fn, expected_split_norm_count in [ - (arg_only, 1), - (arg_only_dim0, 1), - (kwarg1, 1), - (kwarg2, 1), - (kwarg3, 1), - (list_replace, 0), - (multi_split, 17), - (unequal_split, 1), - (arg_only_cm, 1), - (kwarg1_cm, 1), - (kwarg2_cm, 1), - (multi_split_cm, 17), - (unequal_split_cm, 1), - (cm_with_list, 1), + for fn, dynamic, expected_split_norm_count in [ + (arg_only, False, 1), + (arg_only_dim0, False, 1), + (kwarg1, False, 1), + (kwarg2, False, 1), + (kwarg3, False, 1), + (list_replace, False, 0), + (multi_split, False, 17), + (unequal_split, False, 1), + (arg_only_cm, False, 1), + (kwarg1_cm, False, 1), + (kwarg2_cm, False, 1), + (multi_split_cm, False, 17), + (unequal_split_cm, False, 1), + (cm_with_list, False, 1), + (normalize_reshape_with_dynamic_shape, True, 0), ]: expected = fn(*args) - actual = torch.compile(fn)(*args) + actual = torch.compile(fn, dynamic=dynamic)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index b098781677a84e..f850ecf6008c94 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -449,6 +449,12 @@ def normalize_reshape_default(match: Match, *args, **kwargs): return reshape_input = get_arg_value(reshape_node, 0) + from torch.fx.experimental.symbolic_shapes import free_symbols + + if free_symbols(reshape_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", reshape_node) + return + with match.graph.inserting_after(reshape_node): new_reshape_node = match.graph.call_function( torch.reshape, From 97108d037c623d5b81221f8273cddf4e19193ab7 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:45:15 +0000 Subject: [PATCH 0965/1018] Add back optim type hints that were lost when *.pyi files were removed (#136185) When stub files (`*.pyi`) were removed from `optim` (#125556, #125452), some types that existed are no longer available. This pull request adds them back. Just for reference, these types are used in `pytorch-lightning`'s `LightningCLI`. Command line interfaces are created automatically, and having type hints make them nicer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136185 Approved by: https://github.com/janeyx99 --- torch/optim/lr_scheduler.py | 89 ++++++++++++++++++++----------------- torch/optim/rmsprop.py | 4 +- torch/optim/sgd.py | 5 ++- 3 files changed, 53 insertions(+), 45 deletions(-) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 820a57f3e0b756..f0a10efefd12f5 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -92,7 +92,10 @@ class LRScheduler: _get_lr_called_within_step: bool = False def __init__( - self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated" + self, + optimizer: Optimizer, + last_epoch: int = -1, + verbose="deprecated", ): # noqa: D107 # Attach optimizer if not isinstance(optimizer, Optimizer): @@ -319,7 +322,7 @@ def __init__( self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], - last_epoch=-1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.optimizer = optimizer @@ -419,7 +422,7 @@ def __init__( self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], - last_epoch=-1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.optimizer = optimizer @@ -523,8 +526,8 @@ def __init__( self, optimizer: Optimizer, step_size: int, - gamma=0.1, - last_epoch=-1, + gamma: float = 0.1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.step_size = step_size @@ -582,8 +585,8 @@ def __init__( self, optimizer: Optimizer, milestones: Iterable[int], - gamma=0.1, - last_epoch=-1, + gamma: float = 0.1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.milestones = Counter(milestones) @@ -648,9 +651,9 @@ class ConstantLR(LRScheduler): def __init__( self, optimizer: Optimizer, - factor=1.0 / 3, - total_iters=5, - last_epoch=-1, + factor: float = 1.0 / 3, + total_iters: int = 5, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 if factor > 1.0 or factor < 0: @@ -726,10 +729,10 @@ class LinearLR(LRScheduler): def __init__( self, optimizer: Optimizer, - start_factor=1.0 / 3, - end_factor=1.0, - total_iters=5, - last_epoch=-1, + start_factor: float = 1.0 / 3, + end_factor: float = 1.0, + total_iters: int = 5, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 if start_factor > 1.0 or start_factor <= 0: @@ -803,7 +806,11 @@ class ExponentialLR(LRScheduler): """ def __init__( - self, optimizer: Optimizer, gamma: float, last_epoch=-1, verbose="deprecated" + self, + optimizer: Optimizer, + gamma: float, + last_epoch: int = -1, + verbose="deprecated", ): # noqa: D107 self.gamma = gamma super().__init__(optimizer, last_epoch, verbose) @@ -859,7 +866,7 @@ def __init__( optimizer: Optimizer, schedulers: List[LRScheduler], milestones: List[int], - last_epoch=-1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 if len(schedulers) < 1: @@ -992,9 +999,9 @@ class PolynomialLR(LRScheduler): def __init__( self, optimizer: Optimizer, - total_iters=5, - power=1.0, - last_epoch=-1, + total_iters: int = 5, + power: float = 1.0, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.total_iters = total_iters @@ -1074,8 +1081,8 @@ def __init__( self, optimizer: Optimizer, T_max: int, - eta_min=0.0, - last_epoch=-1, + eta_min: float = 0.0, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.T_max = T_max @@ -1288,13 +1295,13 @@ def __init__( self, optimizer: Optimizer, mode: Literal["min", "max"] = "min", - factor=0.1, - patience=10, - threshold=1e-4, + factor: float = 0.1, + patience: int = 10, + threshold: float = 1e-4, threshold_mode: Literal["rel", "abs"] = "rel", - cooldown=0, + cooldown: int = 0, min_lr: Union[List[float], float] = 0, - eps=1e-8, + eps: float = 1e-8, verbose="deprecated", ): # noqa: D107 if factor >= 1.0: @@ -1525,16 +1532,16 @@ def __init__( optimizer: Optimizer, base_lr: Union[float, List[float]], max_lr: Union[float, List[float]], - step_size_up=2000, + step_size_up: int = 2000, step_size_down: Optional[int] = None, mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", - gamma=1.0, + gamma: float = 1.0, scale_fn: Optional[Callable[[float], float]] = None, scale_mode: Literal["cycle", "iterations"] = "cycle", - cycle_momentum=True, - base_momentum=0.8, - max_momentum=0.9, - last_epoch=-1, + cycle_momentum: bool = True, + base_momentum: float = 0.8, + max_momentum: float = 0.9, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 # Attach optimizer @@ -1740,9 +1747,9 @@ def __init__( self, optimizer: Optimizer, T_0: int, - T_mult=1, - eta_min=0.0, - last_epoch=-1, + T_mult: int = 1, + eta_min: float = 0.0, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 if T_0 <= 0 or not isinstance(T_0, int): @@ -1959,15 +1966,15 @@ def __init__( total_steps: Optional[int] = None, epochs: Optional[int] = None, steps_per_epoch: Optional[int] = None, - pct_start=0.3, + pct_start: float = 0.3, anneal_strategy: Literal["cos", "linear"] = "cos", - cycle_momentum=True, + cycle_momentum: bool = True, base_momentum: Union[float, List[float]] = 0.85, max_momentum: Union[float, List[float]] = 0.95, - div_factor=25.0, - final_div_factor=1e4, - three_phase=False, - last_epoch=-1, + div_factor: float = 25.0, + final_div_factor: float = 1e4, + three_phase: bool = False, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 # Validate optimizer diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 9b77ad7fe3eeaf..876f4e1d697bf2 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -34,8 +34,8 @@ def __init__( eps: float = 1e-8, weight_decay: float = 0, momentum: float = 0, - centered=False, - capturable=False, + centered: bool = False, + capturable: bool = False, foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index b270afce4d60a2..46af5ae77537ec 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -15,6 +15,7 @@ _use_grad_for_differentiable, DeviceDict, Optimizer, + ParamsT, ) @@ -24,12 +25,12 @@ class SGD(Optimizer): # noqa: D101 def __init__( self, - params, + params: ParamsT, lr: Union[float, Tensor] = 1e-3, momentum: float = 0, dampening: float = 0, weight_decay: float = 0, - nesterov=False, + nesterov: bool = False, *, maximize: bool = False, foreach: Optional[bool] = None, From 7c1ed259ff1943b55eff89fe14f94883fcdd46d3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 17 Sep 2024 15:50:59 +0000 Subject: [PATCH 0966/1018] [ONNX] Fix numpy method to return the correct type (#136162) Previous implementation of the `numpy()` method returns `fp64` when the tensor is `fp32`. This is unexpected but seems to be caused by calling `__array__(dtype=None)` on the numpy array. I updated the implementation to implement the `numpy()` method explicitly and added tests to guard the behavior. This needs to be cherry-picked into torch 2.5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136162 Approved by: https://github.com/gramalingam, https://github.com/xadupre --- test/onnx/exporter/test_core.py | 75 ++++++++++++++++++++++++++ torch/onnx/_internal/exporter/_core.py | 18 +++++-- 2 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 test/onnx/exporter/test_core.py diff --git a/test/onnx/exporter/test_core.py b/test/onnx/exporter/test_core.py new file mode 100644 index 00000000000000..fc776ecc673b5b --- /dev/null +++ b/test/onnx/exporter/test_core.py @@ -0,0 +1,75 @@ +# Owner(s): ["module: onnx"] +"""Unit tests for the _core module.""" + +from __future__ import annotations + +import numpy as np + +import torch +from torch.onnx._internal.exporter import _core +from torch.testing._internal import common_utils + + +@common_utils.instantiate_parametrized_tests +class TorchTensorTest(common_utils.TestCase): + @common_utils.parametrize( + "dtype, np_dtype", + [ + (torch.bfloat16, np.uint16), + (torch.bool, np.bool_), + (torch.complex128, np.complex128), + (torch.complex64, np.complex64), + (torch.float16, np.float16), + (torch.float32, np.float32), + (torch.float64, np.float64), + (torch.float8_e4m3fn, np.uint8), + (torch.float8_e4m3fnuz, np.uint8), + (torch.float8_e5m2, np.uint8), + (torch.float8_e5m2fnuz, np.uint8), + (torch.int16, np.int16), + (torch.int32, np.int32), + (torch.int64, np.int64), + (torch.int8, np.int8), + (torch.uint16, np.uint16), + (torch.uint32, np.uint32), + (torch.uint64, np.uint64), + (torch.uint8, np.uint8), + ], + ) + def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype): + tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype)) + self.assertEqual(tensor.numpy().dtype, np_dtype) + self.assertEqual(tensor.__array__().dtype, np_dtype) + self.assertEqual(np.array(tensor).dtype, np_dtype) + + @common_utils.parametrize( + "dtype", + [ + (torch.bfloat16), + (torch.bool), + (torch.complex128), + (torch.complex64), + (torch.float16), + (torch.float32), + (torch.float64), + (torch.float8_e4m3fn), + (torch.float8_e4m3fnuz), + (torch.float8_e5m2), + (torch.float8_e5m2fnuz), + (torch.int16), + (torch.int32), + (torch.int64), + (torch.int8), + (torch.uint16), + (torch.uint32), + (torch.uint64), + (torch.uint8), + ], + ) + def test_tobytes(self, dtype: torch.dtype): + tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype)) + self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes()) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 3f64b34816e187..7d49a654a9c00c 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -63,6 +63,9 @@ torch.int64: ir.DataType.INT64, torch.int8: ir.DataType.INT8, torch.uint8: ir.DataType.UINT8, + torch.uint16: ir.DataType.UINT16, + torch.uint32: ir.DataType.UINT32, + torch.uint64: ir.DataType.UINT64, } _BLUE = "\033[96m" _END = "\033[0m" @@ -97,11 +100,10 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None): tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype), name=name ) - def __array__(self, dtype: Any = None) -> np.ndarray: - # numpy() calls __array__ in ir.Tensor + def numpy(self) -> np.ndarray: self.raw: torch.Tensor if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).numpy(force=True).__array__(dtype) + return self.raw.view(torch.uint16).numpy(force=True) if self.dtype in { ir.DataType.FLOAT8E4M3FN, ir.DataType.FLOAT8E4M3FNUZ, @@ -109,8 +111,14 @@ def __array__(self, dtype: Any = None) -> np.ndarray: ir.DataType.FLOAT8E5M2FNUZ, }: # TODO: Use ml_dtypes - return self.raw.view(torch.uint8).numpy(force=True).__array__(dtype) - return self.raw.numpy(force=True).__array__(dtype) + return self.raw.view(torch.uint8).numpy(force=True) + return self.raw.numpy(force=True) + + def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray: + del copy # Unused, but needed for the signature + if dtype is None: + return self.numpy() + return self.numpy().__array__(dtype) def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 From e5f33a221b1f87de97feedffb80ca7579e8c7c43 Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Tue, 17 Sep 2024 16:50:17 +0000 Subject: [PATCH 0967/1018] [Fix]: Update CPUINFO submodule to fix support for NON-SVE ARM Hardware (#135857) Regression PR : https://github.com/pytorch/cpuinfo/pull/255 Change-Id: I56cec061072be11ec33ccb661114360b979fc7aa Pull Request resolved: https://github.com/pytorch/pytorch/pull/135857 Approved by: https://github.com/digantdesai, https://github.com/malfet --- third_party/cpuinfo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cpuinfo b/third_party/cpuinfo index fa1c679da8d19e..a5ff6df40ce528 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit fa1c679da8d19e1d87f20175ae1ec10995cd3dd3 +Subproject commit a5ff6df40ce528721cfc310c7ed43946d77404d5 From 564b00022ba99a8b9bc50d4a2b6dab0c52863bba Mon Sep 17 00:00:00 2001 From: Xintong Hu Date: Tue, 17 Sep 2024 17:26:53 +0000 Subject: [PATCH 0968/1018] [PT2] Port merge_concats_pass to PT2 pre_grad passes (#135527) Summary: as title Test Plan: new UT Differential Revision: D62398390 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135527 Approved by: https://github.com/frank-wei --- torch/_inductor/fx_passes/pre_grad.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index eef058ab0b5aa2..bca3361962b07c 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -100,6 +100,10 @@ def stack_to_unsqueeze_pass(graph): return None +def merge_concats_pass(graph): + return None + + @init_once_fakemode def lazy_init(): from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401 @@ -180,6 +184,12 @@ def shape_prop(mod) -> None: example_inputs, "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass", ) + pass_execution_and_save( + merge_concats_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply merge_concats_pass", + ) pass_execution_and_save( fuse_split_linear_add_pass.apply, gm, From 7541614c663ae8f7fa3053da98049d171a5bef1e Mon Sep 17 00:00:00 2001 From: Trung Truong Date: Tue, 17 Sep 2024 17:42:56 +0000 Subject: [PATCH 0969/1018] [MTIA] Support torch.cuda.get_device_capability equivalent API on MTIA (#135889) Summary: Mirror `get_device_capability` on MTIA per https://fburl.com/gdoc/p4lo5avn At the moment, both the major and minor version are just 0 Test Plan: Unit test: `buck2 test //mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api` https://www.internalfb.com/intern/testinfra/testconsole/testrun/1688850109958190/ Differential Revision: D62595296 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135889 Approved by: https://github.com/egienvalue --- aten/src/ATen/detail/MTIAHooksInterface.h | 5 +++++ docs/source/mtia.rst | 1 + torch/_C/__init__.pyi.in | 1 + torch/csrc/mtia/Module.cpp | 6 ++++++ torch/mtia/__init__.py | 12 ++++++++++++ 5 files changed, 25 insertions(+) diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index cee4463c818504..1480436fb4f1d8 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -104,6 +104,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); return nullptr; } + + virtual PyObject* getDeviceCapability(DeviceIndex device) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } }; struct TORCH_API MTIAHooksArgs {}; diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst index 4b0e1a52269afc..548c6463d08c83 100644 --- a/docs/source/mtia.rst +++ b/docs/source/mtia.rst @@ -19,6 +19,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined is_available is_initialized memory_stats + get_device_capability set_device set_stream stream diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 338f616bff60aa..cb61d8dbb70170 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1784,6 +1784,7 @@ def _mtia_getCurrentStream(device: _int) -> Stream: ... def _mtia_setCurrentStream(stream: Stream) -> None: ... def _mtia_getDefaultStream(device: _int) -> Stream: ... def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ... +def _mtia_getDeviceCapability(device: _int) -> Tuple[_int, _int]: ... # Defined in torch/csrc/mps/Module.cpp diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index cf2af60a1572df..0de566f3cf10be 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -80,6 +80,12 @@ void initModule(PyObject* module) { at::detail::getMTIAHooks().memoryStats(device_index); return py::reinterpret_steal(raw_pyobject); }); + + m.def("_mtia_getDeviceCapability", [](c10::DeviceIndex device_index) { + PyObject* raw_pyobject = + at::detail::getMTIAHooks().getDeviceCapability(device_index); + return py::reinterpret_steal(raw_pyobject); + }); } } // namespace mtia diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index eaa107e973f6d9..af711d49260600 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -166,6 +166,17 @@ def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]: return torch._C._mtia_memoryStats(_get_device_index(device, optional=True)) +def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]: + r"""Return capability of a given device as a tuple of (major version, minor version). + + Args: + device (torch.device or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True)) + + def set_stream(stream: Stream): r"""Set the current stream.This is a wrapper API to set the stream. Usage of this function is discouraged in favor of the ``stream`` @@ -323,6 +334,7 @@ def set_rng_state( "current_stream", "default_stream", "memory_stats", + "get_device_capability", "set_device", "set_stream", "stream", From 2e928a87f70033e39d407ba0ee9ab8fa6900adaa Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:51:56 +0000 Subject: [PATCH 0970/1018] Delete links to non-existing `run_plan_mpi.cc` (#136204) That were deleted by https://github.com/pytorch/pytorch/pull/125092 Fixes https://github.com/pytorch/pytorch/issues/136199 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136204 Approved by: https://github.com/albanD, https://github.com/seemethere --- binaries/CMakeLists.txt | 6 ------ 1 file changed, 6 deletions(-) diff --git a/binaries/CMakeLists.txt b/binaries/CMakeLists.txt index 273353128baafa..3405b1defb5c55 100644 --- a/binaries/CMakeLists.txt +++ b/binaries/CMakeLists.txt @@ -35,12 +35,6 @@ if(USE_ROCM) endif() -if(USE_MPI) - caffe2_binary_target("run_plan_mpi.cc") - target_link_libraries(run_plan_mpi ${MPI_CXX_LIBRARIES}) -endif() - - caffe2_binary_target("dump_operator_names.cc") caffe2_binary_target("optimize_for_mobile.cc") From 0d33cd9ea72349deb3e364ad1d111dc546e1f9c6 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 17 Sep 2024 10:36:16 -0400 Subject: [PATCH 0971/1018] Support rms_norm() for NJT (#135872) `rms_norm()` is a nice-to-have for ViT :) This PR: * SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp. * Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side. * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135872 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #125947 --- .../functorch/BatchRulesDecompositions.cpp | 2 +- aten/src/ATen/native/layer_norm.cpp | 9 ++--- aten/src/ATen/native/layer_norm.h | 39 ++++++++++++++++++- aten/src/ATen/native/native_functions.yaml | 4 +- test/test_nestedtensor.py | 22 ++++++++--- torch/nested/_internal/ops.py | 36 ++++++++++++++--- .../_internal/opinfo/definitions/nested.py | 26 +++++++++++++ 7 files changed, 118 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 21f17cff1f18f3..cca20e9e553e5c 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -230,7 +230,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { m.impl("reshape", native::reshape_symint); OP_DECOMPOSE(resolve_conj); OP_DECOMPOSE(resolve_neg); - OP_DECOMPOSE(rms_norm); + m.impl("rms_norm", native::rms_norm_symint); OP_DECOMPOSE(row_stack); OP_DECOMPOSE(rrelu); OP_DECOMPOSE(rrelu_); diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index 6e0abb14f8be91..c739547af9c1ab 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -264,18 +264,15 @@ std::tuple math_native_layer_norm( return std::make_tuple(out, mean, rstd); } -Tensor rms_norm( +Tensor rms_norm_symint( const Tensor& input, - IntArrayRef normalized_shape, + c10::SymIntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { - // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - auto bias_opt = std::optional(); - const Tensor& bias = *at::borrow_from_optional_tensor(bias_opt); - (void) _check_layer_norm_inputs(input, normalized_shape, weight, bias); + _check_rms_norm_inputs_symint(input, normalized_shape, weight); std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index e35ccf8634bccb..ba2b356c0b0459 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -8,6 +8,41 @@ namespace at::native { namespace { +C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const Tensor& weight /* optional */) { + + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sym_sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sym_sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_ndim = input.dim(); + const auto input_shape = input.sym_sizes(); + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + AT_ERROR(ss.str()); + } +} + C10_ALWAYS_INLINE std::pair _check_layer_norm_inputs( const Tensor& input, IntArrayRef normalized_shape, @@ -71,9 +106,9 @@ void layer_norm_cpu_out( int64_t M, int64_t N); -Tensor rms_norm( +Tensor rms_norm_symint( const Tensor& input, - IntArrayRef normalized_shape, + c10::SymIntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e05aa43b17417f..83d04c4a14c958 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3289,7 +3289,9 @@ autogen: native_layer_norm_backward.out tags: core -- func: rms_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor +- func: rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + dispatch: + CompositeImplicitAutograd: rms_norm_symint - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 408c9b4559dd2c..c1b25313dbc24c 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -6,6 +6,7 @@ import math import sys import unittest +from contextlib import nullcontext from functools import partial from typing import Optional, Tuple @@ -5074,11 +5075,12 @@ def test_mean_dim_reduce_multiple_dims( ): """ Mean on NestedTensor fails when trying to reduce across multiple dimensions + only if the batch or ragged dims are included """ tensor_lists = self._get_example_tensor_lists( include_list_of_lists=False, include_requires_grad=components_require_grad ) - reduce_dims = ((0, 1), (2, 3), (2, 3, 4)) + reduce_dims = ((0, 1), (2, 3), (2, 3, 4), (0, 3), (1, 2)) for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): nt = torch.nested.nested_tensor( @@ -5090,10 +5092,20 @@ def test_mean_dim_reduce_multiple_dims( ) if nt.dim() > reduce_dim[-1]: - with self.assertRaisesRegex( - RuntimeError, - "not supported across multiple dimensions for NestedTensor", - ): + ragged_or_batch_included = ( + nt._ragged_idx in reduce_dim or 0 in reduce_dim + ) + + context = ( + self.assertRaisesRegex( + RuntimeError, + "not supported across multiple dimensions for NestedTensor", + ) + if ragged_or_batch_included + else nullcontext() + ) + + with context: out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim) @dtypes(torch.float32) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 9b7544ec2fc06d..14f227e0886230 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -346,6 +346,29 @@ def _flatten_sig(input, start_dim=0, end_dim=-1): return inp.reshape(*new_shape) + # Handle nested-specific input validation for CompositeImplicit rms_norm + if func.__name__ == "rms_norm": + + def _rms_norm_sig(input, normalized_shape, weight=None, eps=None): + pass + + _, new_kwargs = normalize_function( # type: ignore[misc] + _rms_norm_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + normalized_shape = new_kwargs.pop("normalized_shape") + + # can't normalize over the ragged dim (yet) + max_normalizable = inp.dim() - inp._ragged_idx - 1 + if len(normalized_shape) > max_normalizable: + raise ValueError( + "rms_norm(): Normalization over the ragged dim not supported for nested tensors" + ) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + raise NotImplementedError(func) @@ -395,7 +418,7 @@ def prim_layout_default(func, *args, **kwargs): def tensor_attr_unsupported_getter(func, *args, **kwargs): if func == torch.ops.aten.size.default: raise RuntimeError( - "NestedTensors does not support directly calling torch.ops.aten.size " + "NestedTensor does not support directly calling torch.ops.aten.size; " "please use `nested_tensor.size()` instead." ) @@ -1434,13 +1457,16 @@ def mean_dim(func, *args, **kwargs): func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) - if len(new_kwargs["dim"]) > 1: + inp = new_kwargs.pop("input") + + if len(new_kwargs["dim"]) > 1 and ( + inp._ragged_idx in new_kwargs["dim"] or 0 in new_kwargs["dim"] + ): raise RuntimeError( - "mean(): not supported across multiple dimensions for NestedTensor" + "mean(): not supported across multiple dimensions for NestedTensor " + "when either the batch dim or ragged dim is included" ) - inp = new_kwargs.pop("input") - ( new_kwargs["dim"], reduce_on_batch, diff --git a/torch/testing/_internal/opinfo/definitions/nested.py b/torch/testing/_internal/opinfo/definitions/nested.py index ea678c2e4f8750..9654037fa70ad1 100644 --- a/torch/testing/_internal/opinfo/definitions/nested.py +++ b/torch/testing/_internal/opinfo/definitions/nested.py @@ -249,6 +249,31 @@ def sample_inputs_masked_select( ) +def sample_inputs_nn_functional_rms_norm( + op_info, device, dtype, requires_grad, **kwargs +): + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] + ): + # normalize over non-ragged dims + for start_dim in range(2, njt.dim()): + normalized_shape = njt.shape[start_dim:] + weight = torch.randn( + normalized_shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + yield SampleInput( + njt, + kwargs={ + "normalized_shape": normalized_shape, + "weight": weight, + }, + ) + + sample_inputs_nn_functional_threshold = partial( sample_inputs_elementwise_njt_unary, op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9}, @@ -264,6 +289,7 @@ def sample_inputs_masked_select( njt_sample_inputs = { "clone": sample_inputs_clone, **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)}, + "nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm, "nn.functional.threshold": sample_inputs_nn_functional_threshold, **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)}, "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0), From 6c711deb0562fec4b91330c0ddd4382aa463ffd8 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 17 Sep 2024 18:13:14 +0000 Subject: [PATCH 0972/1018] [export] Deserialize args with python keyword names (#136036) Currently when we deserialize inputs to nodes, we deserialize arguments with default values as kwargs. So deserializing `aten.uniform`, which has the signature `uniform(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)`, will get become `uniform(x, from=0, to=1)`. However, this fails when running in python because `from` is a python keyword. So the solution here is to not deserialize it as a kwarg. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136036 Approved by: https://github.com/zhxchen17 --- test/export/test_serialize.py | 21 +++++++++++++++++++++ torch/_export/serde/serialize.py | 10 +++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 1c31f64d2e0a2e..6f9fa464a86637 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -7,6 +7,7 @@ # Owner(s): ["oncall: export"] import copy import io +import math import tempfile import unittest import zipfile @@ -808,6 +809,26 @@ def f(x, y): self.check_graph(M(), inputs) + def test_arg_from(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("compress_weight", torch.ones((10, 10))) + self.register_buffer("compress_bias", torch.ones(10)) + + def forward(self) -> None: + if self.compress_weight is None or self.compress_bias is None: + return + torch.nn.init.kaiming_uniform_(self.compress_weight, a=math.sqrt(5)) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( + self.compress_weight + ) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + torch.nn.init.uniform_(self.compress_bias, -bound, bound) + + with torch.no_grad(): + self.check_graph(M(), ()) + def test_map(self): from functorch.experimental import control_flow diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 44153ccc78eb42..33a08032479e33 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -7,6 +7,7 @@ import inspect import io import json +import keyword import logging import math import operator @@ -14,6 +15,7 @@ import typing import traceback +from collections import OrderedDict from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum @@ -1946,13 +1948,19 @@ def deserialize_inputs(self, target, serialized_node: Node): for input in serialized_node.inputs } args = [] - kwargs = {} + kwargs: OrderedDict[str, Any] = OrderedDict() for schema_arg in schema_args: is_positional = ( not schema_arg.has_default_value() and not schema_arg.kwarg_only ) if is_positional: args.append(actual_args[schema_arg.name]) + elif keyword.iskeyword(schema_arg.name): + assert not schema_arg.kwarg_only + if len(kwargs) > 0: + kwargs = OrderedDict() + args.extend(list(kwargs.values())) + args.append(actual_args[schema_arg.name]) else: if schema_arg.name in actual_args: kwargs[schema_arg.name] = actual_args[schema_arg.name] From 6fbb26bfc870d42c63f01aac4698635c2e29d82d Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 17 Sep 2024 18:42:56 +0000 Subject: [PATCH 0973/1018] [TorchRec][PT2 IR][APF] short circuit the flatten/unflatten between EBC and KTRegroupAsDict modules (#136045) Summary: # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Test Plan: # run test * serializer ``` buck2 run fbcode//mode/opt fbcode//torchrec/ir/tests:test_serializer ``` * apf ``` buck2 run fbcode//mode/opt fbcode//aps_models/ads/gmp/tests/ne/e2e_deterministic_tests:gmp_e2e_ne_tests -- --filter-text 'test_mtml_instagram_model_562438350_single_gpu_with_ir' ``` * local mp run ``` ==== Finished E2E deterministic test for mtml_instagram_model_gmp_474023725_non_kjt_unary ==== finished test_mtml_instagram_model_562438350_single_gpu_with_ir Imports took: 6.0s! Profile with --import-profiler. --_ |""---__ Executed 1 example in 203.1s: |'.| || . """| Successful: 1 | || || /|\""-. | Failed: 0 | || || | | | Skipped: 0 | || || | \|/ | Not executed: 8 |."| || --"" '__| https://testslide.readthedocs.io/ --" |__---""" ``` Differential Revision: D62606738 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136045 Approved by: https://github.com/angelayi --- torch/export/unflatten.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index a655c48fd9f6ca..2aa12e43770ee8 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -202,6 +202,7 @@ def __init__( export_graph = deepcopy(export_module.graph) self.graph_signature = deepcopy(export_module.graph_signature) self.graph = torch.fx.Graph() + self.graph.owning_module = self self.module_call_graph = deepcopy(export_module.module_call_graph) self.flat_args_adapter = flat_args_adapter # Flag to indicate whether args have been adapted. From abfa305230d4077ca00f65463e8dbdacdb7e9894 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 17 Sep 2024 18:50:09 +0000 Subject: [PATCH 0974/1018] [NCCL] Don't override `waitUntilInitialized`'s setting of `comm->initialized_` (#136155) #133630 sets `initialized_` to `true` which causes previous wait codepaths to skip necessary waits, see also #https://github.com/pytorch/pytorch/issues/136151 CC @shuqiangzhang @wconstab Pull Request resolved: https://github.com/pytorch/pytorch/pull/136155 Approved by: https://github.com/fduwjj, https://github.com/kwen2501, https://github.com/c-p-i-o, https://github.com/shuqiangzhang --- torch/csrc/distributed/c10d/NCCLUtils.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index b11728e0ba8be7..47ace12db6c3fc 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -84,7 +84,9 @@ std::shared_ptr NCCLComm::split( std::nullopt); ++source->ncclCommSplitCounter_; comm->rank_ = rank; - comm->initialized_ = true; + if (!nccl_use_nonblocking()) { + comm->initialized_ = true; + } return comm; } #endif From fb7108e5ed7ac8a1ac8df61924e467936cdd7f1a Mon Sep 17 00:00:00 2001 From: Banit Agrawal Date: Tue, 17 Sep 2024 19:08:44 +0000 Subject: [PATCH 0975/1018] [PyTorch CUDA Allocator] Allow reuse of non-split blocks with better rounding (#136174) Summary: This diff adds an option to round the non-split blocks in caching allocator so that they can be reused without causing lots of fragmentation for large memory segments. For example, if we specify max_split memory size as 400MB, then all allocations more than 400MB will not be split. Lets say, we allocated some 1024MB blocks and these are cached in the allocator blocks. If we request a new 500MB block, we round it to nearest power-2-division, thats 512MB, we add default kLargeBuffer of 20MB, that will be 532MB and since 532MB is less than existing 1024MB block, the 1024MB will not be used for this allocation, instead a new 512MB block will be created. In this diff, we provide an option to cofigure the kLargeBuffer for rounding and expose as a configurable option, so 512MB + max_non_split_rounding_size and if thats greater than 1024MB, we will use te 1024MB and we wont create a new 512MB block using cudaMalloc. This option is added so that we can pre-allocate some large blocks so that we can reuse them as much as possible and we dont stall on calling cudaMalloc. Differential Revision: D62758758 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136174 Approved by: https://github.com/zyan0 --- c10/cuda/CUDAAllocatorConfig.cpp | 25 +++++++++++++++++++++++++ c10/cuda/CUDAAllocatorConfig.h | 8 ++++++++ c10/cuda/CUDACachingAllocator.cpp | 3 ++- docs/source/notes/cuda.rst | 7 +++++++ 4 files changed, 42 insertions(+), 1 deletion(-) diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 19aedb2cbb02f7..7b410bdd6ef557 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -12,6 +12,7 @@ constexpr size_t kRoundUpPowerOfTwoIntervals = 16; CUDAAllocatorConfig::CUDAAllocatorConfig() : m_max_split_size(std::numeric_limits::max()), + m_max_non_split_rounding_size(kLargeBuffer), m_garbage_collection_threshold(0), m_pinned_num_register_threads(1), m_expandable_segments(false), @@ -94,6 +95,27 @@ size_t CUDAAllocatorConfig::parseMaxSplitSize( return i; } +size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_non_split_rounding_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_non_split_rounding_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", ""); + } + return i; +} + size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold( const std::vector& config, size_t i) { @@ -258,6 +280,9 @@ void CUDAAllocatorConfig::parseArgs(const char* env) { if (config_item_view == "max_split_size_mb") { i = parseMaxSplitSize(config, i); used_native_specific_option = true; + } else if (config_item_view == "max_non_split_rounding_mb") { + i = parseMaxNonSplitRoundingSize(config, i); + used_native_specific_option = true; } else if (config_item_view == "garbage_collection_threshold") { i = parseGarbageCollectionThreshold(config, i); used_native_specific_option = true; diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 3106fc1b46baee..38adc4732e3d6d 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -63,6 +63,10 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_roundup_power2_divisions; } + static size_t max_non_split_rounding_size() { + return instance().m_max_non_split_rounding_size; + } + static std::string last_allocator_settings() { std::lock_guard lock( instance().m_last_allocator_settings_mutex); @@ -90,6 +94,9 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t i, const char c); size_t parseMaxSplitSize(const std::vector& config, size_t i); + size_t parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i); size_t parseGarbageCollectionThreshold( const std::vector& config, size_t i); @@ -108,6 +115,7 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t i); std::atomic m_max_split_size; + std::atomic m_max_non_split_rounding_size; std::vector m_roundup_power2_divisions; std::atomic m_garbage_collection_threshold; std::atomic m_pinned_num_register_threads; diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 3da3c6d4f5d05a..a67a720717bb7f 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -2527,7 +2527,8 @@ class DeviceCachingAllocator { return false; // Allow oversized block size to be rounded up but within a limit if ((p.size() >= CUDAAllocatorConfig::max_split_size()) && - ((*it)->size >= p.size() + kLargeBuffer)) + ((*it)->size >= + p.size() + CUDAAllocatorConfig::max_non_split_rounding_size())) return false; p.block = *it; pool.blocks.erase(it); diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 7d434bbbba64ce..c0b82adc7e073a 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -471,6 +471,13 @@ Available options: set the knob value to: [256:1,512:2,1024:4,>:8]. ``roundup_power2_divisions`` is only meaningful with ``backend:native``. With ``backend:cudaMallocAsync``, ``roundup_power2_divisions`` is ignored. +* ``max_non_split_rounding_mb`` will allow non-split blocks for better reuse, eg, + a 1024MB cached block can be re-used for a 512MB allocation request. In the default + case, we only allow up to 20MB of rounding of non-split blocks, so a 512MB block + can only be served with between 512-532 MB size block. If we set the value of this + option to 1024, it will alow 512-1536 MB size blocks to be used for a 512MB block + which increases reuse of larger blocks. This will also help in reducing the stalls + in avoiding expensive cudaMalloc calls. * ``garbage_collection_threshold`` helps actively reclaiming unused GPU memory to avoid triggering expensive sync-and-reclaim-all operation (release_cached_blocks), which can be unfavorable to latency-critical GPU applications (e.g., servers). From f6b55c51ba86b584b3f05b594b8a047861407aae Mon Sep 17 00:00:00 2001 From: Banit Agrawal Date: Tue, 17 Sep 2024 21:08:10 +0000 Subject: [PATCH 0976/1018] [PyTorch Pinned Allocator] Add support of background thread to process events (#135524) Summary: Currently we process events in the regular allocation path and we call cudaEventQuery to check on the events and this path can take some locks in libcuda driver. Its not entirely needed to do process events in the allocation path, we could move this to a background thread and keep processing events regularly and put the freed block to the free list. Differential Revision: D62396585 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135524 Approved by: https://github.com/zyan0 --- aten/src/ATen/core/CachingHostAllocator.h | 153 +++++++++++++++----- aten/src/ATen/cuda/CachingHostAllocator.cpp | 5 + c10/cuda/CUDAAllocatorConfig.cpp | 20 +++ c10/cuda/CUDAAllocatorConfig.h | 8 + docs/source/notes/cuda.rst | 4 + 5 files changed, 154 insertions(+), 36 deletions(-) diff --git a/aten/src/ATen/core/CachingHostAllocator.h b/aten/src/ATen/core/CachingHostAllocator.h index 2ca90c914722c1..1d5fbacdcb8476 100644 --- a/aten/src/ATen/core/CachingHostAllocator.h +++ b/aten/src/ATen/core/CachingHostAllocator.h @@ -1,4 +1,6 @@ #include +#include +#include #include #include #include @@ -109,6 +111,17 @@ template < typename E, typename B = HostBlock> struct CachingHostAllocatorImpl { + CachingHostAllocatorImpl() { + // Launch the background thread and process events in a loop. + if (pinned_use_background_threads()) { + getBackgroundThreadPool()->run([&]() { + while (true) { + process_events(); + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + }); + } + } virtual ~CachingHostAllocatorImpl() = default; public: @@ -118,17 +131,34 @@ struct CachingHostAllocatorImpl { return {nullptr, nullptr}; } - process_events(); + // If we are using background threads, we can process events in the + // background. + if (!pinned_use_background_threads()) { + process_events(); + } + + // Round up the allocation to the nearest power of two to improve reuse. + // These power of two sizes are also used to index into the free list. + size_t roundSize = c10::llvm::PowerOf2Ceil(size); // First, try to allocate from the free list - auto* block = get_free_block(size); + auto* block = get_free_block(roundSize); if (block) { return {block->ptr_, reinterpret_cast(block)}; } - // Round up the allocation to the nearest power of two to improve reuse. - // These power of two sizes are also used to index into the free list. - size_t roundSize = c10::llvm::PowerOf2Ceil(size); + // Check in the recently freed blocks with pending events to see if we + // can reuse them. Call get_free_block again after processing events + if (pinned_use_background_threads()) { + process_events_for_specific_size(roundSize); + block = get_free_block(roundSize); + if (block) { + return {block->ptr_, reinterpret_cast(block)}; + } + } + + // Slow path: if we can't allocate from the cached free list, we need + // to create a new block. void* ptr = nullptr; allocate_host_memory(roundSize, &ptr); @@ -237,6 +267,10 @@ struct CachingHostAllocatorImpl { return c10::llvm::Log2_64_Ceil(size); } + virtual bool pinned_use_background_threads() { + return false; + } + virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const { TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data"); } @@ -261,6 +295,21 @@ struct CachingHostAllocatorImpl { } virtual void process_events() { + // process all events until the last unready event, not for specific size. + process_events_for_specific_size(-1); + } + + // If size is -1, process all events from backwards until the last unready + // event. Otherwise, process events for a specific size and on first ready block + // is found, add it to the free list and return. + virtual void process_events_for_specific_size(int64_t size) { + size_t event_count = 0; + size_t max_events = 0; + { + std::lock_guard g(events_mutex_); + max_events = events_.size(); + } + while (true) { // Avoid calling cudaEventDestroy while holding a mutex, so move // intermediate events out of the lock into this object. @@ -278,6 +327,25 @@ struct CachingHostAllocatorImpl { return; } + if (size != -1) { + if (event_count++ > max_events) { + { + std::lock_guard g(events_mutex_); + events_.push_front(std::move(*processed)); + } + return; + } + if (size != (int64_t)processed->second->size_) { + // if we are processing a specific size, and the size of the block + // doesn't match, we can't use it. + { + std::lock_guard g(events_mutex_); + events_.push_front(std::move(*processed)); + } + continue; + } + } + // otherwise, query the event { // now, see if we can handle this element @@ -286,9 +354,14 @@ struct CachingHostAllocatorImpl { // push the event onto the back if it's not ready. { std::lock_guard g(events_mutex_); - events_.push_back(std::move(*processed)); + if (size == -1) { + events_.push_back(std::move(*processed)); + return; + } else { + events_.push_front(std::move(*processed)); + continue; + } } - return; } } @@ -309,46 +382,54 @@ struct CachingHostAllocatorImpl { auto index = size_index(block->size_); std::lock_guard g(free_list_[index].mutex_); free_list_[index].list_.push_back(block); + if (size != -1) { + return; + } } } } - /* These following functions are runtime-related. */ - - // Allocate page-locked memory on the host. - virtual void allocate_host_memory(size_t size, void** ptr) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "Not implemented for allocate_host_memory"); + TaskThreadPool* getBackgroundThreadPool() { + static TaskThreadPool* pool = new TaskThreadPool(1); + return pool; } - // Free block and release the pointer contained in block. - virtual void free_block(B* block) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block"); - } + /* These following functions are runtime-related. */ - // Record an event on stream and store event into events. - virtual void record_stream(std::optional>& events, S stream) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream"); - } + // Allocate page-locked memory on the host. + virtual void allocate_host_memory(size_t size, void** ptr) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "Not implemented for allocate_host_memory"); + } - // Query event if it is completed. - virtual bool query_event(E& event) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); - } + // Free block and release the pointer contained in block. + virtual void free_block(B* block) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block"); + } + + // Record an event on stream and store event into events. + virtual void record_stream(std::optional>& events, S stream) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream"); + } - alignas(64) std::mutex blocks_mutex_; - ska::flat_hash_set blocks_; // block list - ska::flat_hash_map ptr_to_block_; + // Query event if it is completed. + virtual bool query_event(E& event) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); + } - // We keep free list as a vector of free lists, one for each power of two - // size. This allows us to quickly find a free block of the right size. - // We use deque to store per size free list and guard the list with its own - // mutex. - alignas(64) std::vector> free_list_ = std::vector>(MAX_SIZE_INDEX); + alignas(64) std::mutex blocks_mutex_; + ska::flat_hash_set blocks_; // block list + ska::flat_hash_map ptr_to_block_; - alignas(64) std::mutex events_mutex_; - std::deque> events_; // event queue paired with block -}; + // We keep free list as a vector of free lists, one for each power of two + // size. This allows us to quickly find a free block of the right size. + // We use deque to store per size free list and guard the list with its own + // mutex. + alignas(64) std::vector> free_list_ = std::vector>(MAX_SIZE_INDEX); + + alignas(64) std::mutex events_mutex_; + std::deque> events_; // event queue paired with block + }; template struct CachingHostAllocatorInterface : public at::Allocator { diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index b430fb984a426d..511a0c2884587f 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -123,6 +123,11 @@ struct CUDACachingHostAllocatorImpl return true; } + bool pinned_use_background_threads() override { + return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + pinned_use_background_threads(); + } + EventPool::Event create_event_internal(DeviceIndex idx) { // Leak the event pool to avoid shutdown issue. static auto* event_pool = new EventPool(); diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 7b410bdd6ef557..7c1c1e02644b25 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -18,6 +18,7 @@ CUDAAllocatorConfig::CUDAAllocatorConfig() m_expandable_segments(false), m_release_lock_on_cudamalloc(false), m_pinned_use_cuda_host_register(false), + m_pinned_use_background_threads(false), m_last_allocator_settings("") { m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); } @@ -331,6 +332,9 @@ void CUDAAllocatorConfig::parseArgs(const char* env) { } else if (config_item_view == "pinned_num_register_threads") { i = parsePinnedNumRegisterThreads(config, i); used_native_specific_option = true; + } else if (config_item_view == "pinned_use_background_threads") { + i = parsePinnedUseBackgroundThreads(config, i); + used_native_specific_option = true; } else { TORCH_CHECK( false, "Unrecognized CachingAllocator option: ", config_item_view); @@ -388,6 +392,22 @@ size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( return i; } +size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_background_threads"); + m_pinned_use_background_threads = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_background_threads value", ""); + } + return i; +} + // General caching allocator utilities void setAllocatorSettings(const std::string& env) { CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str()); diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 38adc4732e3d6d..876bffcf98527a 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -46,6 +46,10 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_pinned_num_register_threads; } + static bool pinned_use_background_threads() { + return instance().m_pinned_use_background_threads; + } + static size_t pinned_max_register_threads() { // Based on the benchmark results, we see better allocation performance // with 8 threads. However on future systems, we may need more threads @@ -113,6 +117,9 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t parsePinnedNumRegisterThreads( const std::vector& config, size_t i); + size_t parsePinnedUseBackgroundThreads( + const std::vector& config, + size_t i); std::atomic m_max_split_size; std::atomic m_max_non_split_rounding_size; @@ -122,6 +129,7 @@ class C10_CUDA_API CUDAAllocatorConfig { std::atomic m_expandable_segments; std::atomic m_release_lock_on_cudamalloc; std::atomic m_pinned_use_cuda_host_register; + std::atomic m_pinned_use_background_threads; std::string m_last_allocator_settings; std::mutex m_last_allocator_settings_mutex; }; diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index c0b82adc7e073a..74d0c89387fa7e 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -534,6 +534,10 @@ Available options: allocation time of pinned memory. A good value for this option is 8 based on benchmarking results. + `pinned_use_background_threads` option is a boolean flag to enable background thread + for processing events. This avoids any slow path associated with querying/processing of + events in the fast allocation path. This feature is disabled by default. + .. note:: Some stats reported by the From 943303f9f78b162e99d7313e5196e13f976e9235 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 17 Sep 2024 21:53:31 +0000 Subject: [PATCH 0977/1018] [MPS] Fix 5D+ reductions over negative dimentions (#136198) This fixes bug introduced by https://github.com/pytorch/pytorch/pull/99856 that attempts to speed-up reduction for 5D+ tensor if trailing dimensions are all ones, but introduces crashes/off-by-one errors for wrapped dimensions Added regresion test case to `TestMPS.test_sum` Fixes https://github.com/pytorch/pytorch/issues/136132 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136198 Approved by: https://github.com/albanD --- aten/src/ATen/native/mps/operations/ReduceOps.mm | 5 ++++- test/test_mps.py | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 6d372f96b7f90f..cd0e75d84dc3bc 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -153,13 +153,16 @@ static void reduction_out_mps(const Tensor& input_t, const std::string& func_name) { bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name); + // NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor bool canSqueezeLastDim = true; IntArrayRef input_shape = input_t.sizes(); if (opt_dim.has_value()) { IntArrayRef dim = opt_dim.value(); for (const auto dim_val : dim) { auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size()); - if (wrap_dim >= 4) { + // canSqueeze logic is broken when dim is negative, it introduces off-by-one-erros or crashes + // See https://github.com/pytorch/pytorch/issues/136132#issuecomment-2354482608 + if (wrap_dim >= 4 || dim_val < 0) { canSqueezeLastDim = false; } TORCH_CHECK( diff --git a/test/test_mps.py b/test/test_mps.py index 32adede8dddca0..a342fa3425853e 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -5728,6 +5728,10 @@ def helper(n, c, h, w, dtype=torch.float32): helper(2, 8, 4, 5, dtype=torch.int32) helper(2, 8, 4, 5, dtype=torch.int64) helper(2, 8, 4, 5, dtype=torch.bool) + # Regression test for https://github.com/pytorch/pytorch/issues/136132 + x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2) + self.assertEqual(x.numel(), 8) + self.assertEqual(x.max().item(), 30.0) # Test forward prod def test_prod(self): From 4d8f1e340203caf8d8555045303aceb9f3d8b819 Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Wed, 18 Sep 2024 00:55:01 +0000 Subject: [PATCH 0978/1018] [c10d] remove sleep from watchdogHandler (#135760) Summary: Remove sleep from the `watchdogHandler` function. This sleep unnecessary slows things down during a NCCL timeout. Flight recorder is configured to take a minute, at most, to dump out it's buffer. This sleep ends up waiting for `8` minutes before destroy is called. Test Plan: Unit tests. Differential Revision: D62529875 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135760 Approved by: https://github.com/fduwjj, https://github.com/shuqiangzhang --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 15 +++++++++------ torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp | 10 ++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d8bd66db1c0d47..a75c989d4d56a3 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -834,6 +834,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( // both timeout and other errors. dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) || (dist_debug_level_ >= DebugLevel::Detail); + sleepAfterException_ = getCvarBool(TORCH_NCCL_SLEEP_AFTER_EXCEPTION, false); // logging C++ stack isn't safe. Introduce a variable to control it. logCppStackOnUncleanShutdown_ = getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true); @@ -1808,12 +1809,14 @@ void ProcessGroupNCCL::watchdogHandler() { } // signal the monitor thread on PG0 to start dumping shouldDump_.store(true); - // This sleep is used to give time for dumping before throwing - // exception - std::this_thread::sleep_for( - std::chrono::seconds(heartbeatTimeoutInSec_)); - LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ - << " giving time for flight recorder dumps to finish."; + if (sleepAfterException_) { + // This sleep is used to give time for dumping before throwing + // exception + std::this_thread::sleep_for( + std::chrono::seconds(heartbeatTimeoutInSec_)); + LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ + << " giving time for flight recorder dumps to finish."; + } } catch (const std::exception& e) { LOG(ERROR) << logPrefix() << "Failed to set dump signal in tcpstore. " diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 39717696b0746e..2bca8992af445c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -63,6 +63,13 @@ static std::vector TORCH_NCCL_ASYNC_ERROR_HANDLING = { static std::vector TORCH_NCCL_DUMP_ON_TIMEOUT = { "TORCH_NCCL_DUMP_ON_TIMEOUT"}; +// TODO: remove this change after a safe rollout. +// Control whether we sleep after an exception is thrown. +// This change is temporary and is used to safely remove the current sleep that +// exists after an exception is thrown. +static std::vector TORCH_NCCL_SLEEP_AFTER_EXCEPTION = { + "TORCH_NCCL_SLEEP_AFTER_EXCEPTION"}; + // Control whether Desync Debug is enabled. This variable must be set // together with TORCH_NCCL_ASYNC_ERROR_HANDLING. static std::vector TORCH_NCCL_DESYNC_DEBUG = { @@ -1140,6 +1147,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // timeout and nccl errors. bool dumpOnTimeoutOrEx_; + // Whether or not to sleep after an exception is thrown in the watchdog. + bool sleepAfterException_; + // Whether or not to enable nan check for input tensors to collectives. bool enableNanCheck_; From 5127c4beaa9c69938cda63d1a650cddae0d4c87c Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 12 Sep 2024 22:47:03 -0700 Subject: [PATCH 0979/1018] [AO][Inductor] Enable WOQ fusion pattern with permute (#135928) **Summary** Fix https://github.com/pytorch/pytorch/issues/135831 and https://github.com/pytorch/ao/issues/890. The root cause of the numerical failure was that the customized woq-int8 kernel was not triggered due to changes in the pattern. After re-adding the fusion pattern, the accuracy check now passes. I will open a separate TorchAO PR to enable these unit tests in TorchAO. **Test Plan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_woq_int8 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135928 Approved by: https://github.com/jgong5, https://github.com/eellison --- test/inductor/test_mkldnn_pattern_matcher.py | 25 ++++++++++++++++---- torch/_inductor/constant_folding.py | 24 +++++++++++++++---- torch/_inductor/fx_passes/quantization.py | 22 +++++++++++++++++ 3 files changed, 62 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 418f743c650108..259975996bd90e 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -2585,19 +2585,36 @@ def forward(self, x): @skipIfNoDynamoSupport def test_woq_int8(self): class M(torch.nn.Module): + def __init__(self, is_permute): + super().__init__() + self.is_permute = is_permute + def forward(self, x, weight, scales): - return torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales + if self.is_permute: + weight = weight.t() + m = torch.mm( + x.reshape(-1, x.shape[-1]), + weight.to(x.dtype), + ) + y = m * scales.to(m.dtype) + y = y.reshape(*x.shape[:-1], y.shape[-1]) + return y + else: + return ( + torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales + ) - mod = M().eval() x_shape = (1, 1, 256) - w_shape = (12, 256) s_shape = 12 x_strides = [ (256, 256, 1), # linear dispatching to mm (256, 32, 1), # linear dispatching to bmm ] - for x_stride in x_strides: + is_permutes = [False, True] + for x_stride, is_permute in itertools.product(x_strides, is_permutes): + mod = M(is_permute=is_permute).eval() x = torch.randn(x_shape, dtype=torch.bfloat16).as_strided(x_shape, x_stride) + w_shape = (12, 256) w = torch.randint(-128, 127, w_shape, dtype=torch.int8) s = torch.randn(s_shape, dtype=torch.bfloat16) diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 25d4b872d7737f..09abe579b52045 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -86,13 +86,27 @@ def _deduce_value(self, node: torch.fx.Node) -> Any: return super().run_node(node) def is_impure(self, node: torch.fx.node.Node) -> bool: + def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: + return ( + node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value] + and isinstance(node.args[0], torch.fx.Node) + and "val" in node.args[0].meta + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ) + if ( - node.target == torch.ops.prims.convert_element_type.default - and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type] - and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] - and node.args[1] == torch.bfloat16 + is_woq_int8_pattern(node) + or ( + node.target == torch.ops.aten.permute.default + and len(node.users) == 1 + and is_woq_int8_pattern(next(iter(node.users))) + ) + ) and is_const_source( + node.args[0], self.lifted_constants # type: ignore[arg-type] ): - # For int8_weight -> dq -> bf16_weight + # Case 1: int8_weight -> dq -> bf16_weight + # Case 2: int8_weight -> permute -> dq -> bf16_weight return True quant_registered = ( diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 6e17e2ff456fc6..3c918d480704e9 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -1561,6 +1561,27 @@ def _register_woq_mm_int8_pattern3(): _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) +def _register_woq_mm_int8_pattern4(): + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg("weight"), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + def _register_quantization_lowerings(): _register_quantization_unary_fusion() _register_quantization_binary_fusion() @@ -1573,6 +1594,7 @@ def _register_woq_lowerings(): _register_woq_mm_int8_pattern1() _register_woq_mm_int8_pattern2() _register_woq_mm_int8_pattern3() + _register_woq_mm_int8_pattern4() def _is_valid_dequant_promotion_pattern(dtype=torch.float32): From 2781cb8c8c83d2d92e6ede39f2a97a03ec8e647e Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 18 Sep 2024 01:24:05 +0000 Subject: [PATCH 0980/1018] [BE] Make `NestedTensorTransformerFunctions.cu` compilable without warnings (#136222) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before the change compilation produced following warnings: ``` /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu: In function ‘std::tuple > at::native::check_shape_and_partition_(const at::Tensor&, const std::vector&, const at::Tensor&)’: /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu:584:22: warning: comparison of integer expressions of different signedness: ‘const int’ and ‘const size_t’ {aka ‘const long unsigned int’} [-Wsign-compare] 584 | TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims); | ~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~ /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu: In lambda function: /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu:1224:1061: warning: comparison of integer expressions of different signedness: ‘long unsigned int’ and ‘int’ [-Wsign-compare] 1224 | AT_DISPATCH_INDEX_TYPES( | ^ /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu: In lambda function: /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu:1224:1985: warning: comparison of integer expressions of different signedness: ‘long unsigned int’ and ‘int’ [-Wsign-compare] 1224 | AT_DISPATCH_INDEX_TYPES( | ^ /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu: In instantiation of ‘void at::native::jagged_dense_elementwise_jagged_output_opt_(const at::Tensor&, const std::vector&, const at::Tensor&, const at::Tensor&, F) [with scalar_t = c10::Half; F = __nv_dl_wrapper_t<__nv_dl_trailing_return_tag, std::optional), at::native::_fbgemm_dense_to_jagged_forward_symint, c10::Half, 1> >]’: /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu:1515:1: required from here /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu:1336:2006: warning: comparison of integer expressions of different signedness: ‘size_t’ {aka ‘long unsigned int’} and ‘int’ [-Wsign-compare] 1336 | AT_DISPATCH_INDEX_TYPES( | ^ /home/nshulga/git/pytorch/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu:1336:2113: warning: comparison of integer expressions of different signedness: ‘size_t’ {aka ‘long unsigned int’} and ‘int’ [-Wsign-compare] 1336 | AT_DISPATCH_INDEX_TYPES( | ^ ``` after it compiled without a warning Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/136222 Approved by: https://github.com/PaliC, https://github.com/kit1980 --- .../native/nested/cuda/NestedTensorTransformerFunctions.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu index 16ed9195c08884..3cd2b6836d0668 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu @@ -581,7 +581,7 @@ inline std::tuple> check_shape_and_partition_( StackArray jagged_dims_tensor; const int num_jagged_dim = dense_tensor.dim() - 2; - TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims); + TORCH_CHECK(num_jagged_dim <= static_cast(kStackArrayMaxDims)); jagged_dims_tensor.ndim = num_jagged_dim; std::memcpy( &(jagged_dims_tensor.vals[0]), @@ -1220,7 +1220,7 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt( // MI100 has independent shared mem and L1 int used_shared_kb = shared_kb; #endif - int used_shared_bytes = used_shared_kb << 10; + auto used_shared_bytes = static_cast(used_shared_kb << 10); AT_DISPATCH_INDEX_TYPES( x_offsets[0].scalar_type(), "check_shared_memory", [&] { auto B = y_0_reshaped.size(0); @@ -1355,7 +1355,7 @@ void jagged_dense_elementwise_jagged_output_opt_( int used_shared_bytes = calc_used_shared_bytes(y_reshaped.get_device()); set_max_dynamic_shared_mem_size_for_opt_search_kernel(used_shared_bytes); C10_CUDA_KERNEL_LAUNCH_CHECK(); - TORCH_CHECK(dynamic_smem_size <= used_shared_bytes); + TORCH_CHECK(dynamic_smem_size <= static_cast(used_shared_bytes)); } dim3 threads_bs = dim3(1024, 1, 1); dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); From ab9a171ef5d675066265b1742df8be2bb75e063e Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Wed, 18 Sep 2024 02:11:22 +0000 Subject: [PATCH 0981/1018] [torch/numpy][numpy2.0 compat] Additional changes for tests to run under numpy-2.0 (#136152) Continuation of https://github.com/pytorch/pytorch/pull/131909. This PR makes numpy tests compatible with numpy>=2.0.0. Specifically it deals with APIs that have been removed from numpy-2.0. Changes in this PR: 1. Use `numpy.exceptions.ComplexWarning` if `numpy.exceptions` namespace is present. In numpy-2.0 `numpy.ComplexWarning` has been removed in favor of using `numpy.exceptions.ComplexWarning` (see [numpy-2.0 migration guide](https://numpy.org/devdocs/numpy_2_0_migration_guide.html#changes-to-namespaces)). Note that `numpy.exceptions` was introduced in numpy-1.25.0 hence does not exist in numpy<=1.24.x. 2. Do the same for `numpy.exceptions.VisibleDeprecationWarning` 3. Use `np.sort(...,axis=0)` over `np.msort()`(`np.msort()` removed in numpy-2.0) 4. Use `np.pad()` over `np.lib.pad()` (`np.lib` removed in numpy-2.0) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136152 Approved by: https://github.com/atalman --- test/test_sort_and_select.py | 4 ++-- test/torch_np/numpy_tests/core/test_einsum.py | 2 -- .../numpy_tests/core/test_indexing.py | 21 +++++++++++++++---- .../numpy_tests/core/test_multiarray.py | 13 ++++++++---- .../torch_np/numpy_tests/lib/test_arraypad.py | 2 +- 5 files changed, 29 insertions(+), 13 deletions(-) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index df847543600ed2..aebfdaec0cb854 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -402,10 +402,10 @@ def test(shape): if tensor.size() != torch.Size([]): if dtype is torch.bfloat16: expected = torch.from_numpy( - np.msort(tensor.float().cpu().numpy()) + np.sort(tensor.float().cpu().numpy(), axis=0) ).bfloat16() else: - expected = torch.from_numpy(np.msort(tensor.cpu().numpy())) + expected = torch.from_numpy(np.sort(tensor.cpu().numpy(), axis=0)) else: expected = tensor # numpy.msort() does not support empty shapes tensor diff --git a/test/torch_np/numpy_tests/core/test_einsum.py b/test/torch_np/numpy_tests/core/test_einsum.py index 029e10cead4ab1..5432fb63d18000 100644 --- a/test/torch_np/numpy_tests/core/test_einsum.py +++ b/test/torch_np/numpy_tests/core/test_einsum.py @@ -440,8 +440,6 @@ def check_einsum_sums(self, dtype, do_opt=False): # Suppress the complex warnings for the 'as f8' tests with suppress_warnings() as sup: - # sup.filter(np.ComplexWarning) - # matvec(a,b) / a.dot(b) where a is matrix, b is vector for n in range(1, 17): a = np.arange(4 * n, dtype=dtype).reshape(4, n) diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index c5823dc30ff37c..d47f62dca505c9 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -602,16 +602,21 @@ def test_boolean_index_cast_assign(self): zero_array[bool_index] = np.array([1]) assert_equal(zero_array[0, 1], 1) + # np.ComplexWarning moved to np.exceptions in numpy>=2.0.0 + # np.exceptions only available in numpy>=1.25.0 + has_exceptions_ns = hasattr(np, "exceptions") + ComplexWarning = ( + np.exceptions.ComplexWarning if has_exceptions_ns else np.ComplexWarning + ) + # Fancy indexing works, although we get a cast warning. assert_warns( - np.ComplexWarning, zero_array.__setitem__, ([0], [1]), np.array([2 + 1j]) + ComplexWarning, zero_array.__setitem__, ([0], [1]), np.array([2 + 1j]) ) assert_equal(zero_array[0, 1], 2) # No complex part # Cast complex to float, throwing away the imaginary portion. - assert_warns( - np.ComplexWarning, zero_array.__setitem__, bool_index, np.array([1j]) - ) + assert_warns(ComplexWarning, zero_array.__setitem__, bool_index, np.array([1j])) assert_equal(zero_array[0, 1], 0) @@ -1012,6 +1017,14 @@ def test_multidim(self): # This is so that np.array(True) is not accepted in a full integer # index, when running the file separately. warnings.filterwarnings("error", "", DeprecationWarning) + # np.VisibleDeprecationWarning moved to np.exceptions in numpy>=2.0.0 + # np.exceptions only available in numpy>=1.25.0 + has_exceptions_ns = hasattr(np, "exceptions") + VisibleDeprecationWarning = ( + np.exceptions.VisibleDeprecationWarning + if has_exceptions_ns + else np.VisibleDeprecationWarning + ) warnings.filterwarnings("error", "", np.VisibleDeprecationWarning) def isskip(idx): diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index baed25a9b021c3..f0ef1f79beb542 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -5469,8 +5469,6 @@ def test_out_arg(self): out = np.zeros((5, 2), dtype=np.complex128) c = self.matmul(a, b, out=out) assert_(c is out) - # with suppress_warnings() as sup: - # sup.filter(np.ComplexWarning, '') c = c.astype(tgt.dtype) assert_array_equal(c, tgt) @@ -5852,9 +5850,16 @@ def test_complex_warning(self): x = np.array([1, 2]) y = np.array([1 - 2j, 1 + 2j]) + # np.ComplexWarning moved to np.exceptions in numpy>=2.0.0 + # np.exceptions only available in numpy>=1.25.0 + has_exceptions_ns = hasattr(np, "exceptions") + ComplexWarning = ( + np.exceptions.ComplexWarning if has_exceptions_ns else np.ComplexWarning + ) + with warnings.catch_warnings(): - warnings.simplefilter("error", np.ComplexWarning) - assert_raises(np.ComplexWarning, x.__setitem__, slice(None), y) + warnings.simplefilter("error", ComplexWarning) + assert_raises(ComplexWarning, x.__setitem__, slice(None), y) assert_equal(x, [1, 2]) diff --git a/test/torch_np/numpy_tests/lib/test_arraypad.py b/test/torch_np/numpy_tests/lib/test_arraypad.py index befa9d76ac467a..9b16e105db62f7 100644 --- a/test/torch_np/numpy_tests/lib/test_arraypad.py +++ b/test/torch_np/numpy_tests/lib/test_arraypad.py @@ -543,7 +543,7 @@ def test_check_constant_odd_pad_amount(self): @xpassIfTorchDynamo # (reason="tuple values") def test_check_constant_pad_2d(self): arr = np.arange(4).reshape(2, 2) - test = np.lib.pad( + test = np.pad( arr, ((1, 2), (1, 3)), mode="constant", constant_values=((1, 2), (3, 4)) ) expected = np.array( From 591bb41d61fa37fec936dd6ae009cd3b56866b15 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 18 Sep 2024 03:10:21 +0000 Subject: [PATCH 0982/1018] [CI] Make linux-aarch64 shards actually running different tests (#136208) Non-functional sharding was introduced in https://github.com/pytorch/pytorch/pull/125255 but each shard in that case were running the same tests... Pull Request resolved: https://github.com/pytorch/pytorch/pull/136208 Approved by: https://github.com/seemethere, https://github.com/ZainRizvi, https://github.com/atalman --- .ci/pytorch/test.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 50aff1d55cd8dd..b866ee8162e0ea 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1383,14 +1383,16 @@ test_executorch() { assert_git_not_dirty } -test_linux_aarch64(){ +test_linux_aarch64() { python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \ - test_transformers test_multiprocessing test_numpy_interop --verbose + test_transformers test_multiprocessing test_numpy_interop \ + --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose # Dynamo tests python test/run_test.py --include dynamo/test_compile dynamo/test_backends dynamo/test_comptime dynamo/test_config \ dynamo/test_functions dynamo/test_fx_passes_pre_grad dynamo/test_interop dynamo/test_model_output dynamo/test_modules \ - dynamo/test_optimizers dynamo/test_recompile_ux dynamo/test_recompiles --verbose + dynamo/test_optimizers dynamo/test_recompile_ux dynamo/test_recompiles \ + --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose # Inductor tests python test/run_test.py --include inductor/test_torchinductor inductor/test_benchmark_fusion inductor/test_codecache \ @@ -1400,7 +1402,8 @@ test_linux_aarch64(){ inductor/test_max_autotune inductor/test_memory_planning inductor/test_metrics inductor/test_multi_kernel inductor/test_pad_mm \ inductor/test_pattern_matcher inductor/test_perf inductor/test_profiler inductor/test_select_algorithm inductor/test_smoke \ inductor/test_split_cat_fx_passes inductor/test_standalone_compile inductor/test_torchinductor \ - inductor/test_torchinductor_codegen_dynamic_shapes inductor/test_torchinductor_dynamic_shapes --verbose + inductor/test_torchinductor_codegen_dynamic_shapes inductor/test_torchinductor_dynamic_shapes \ + --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose } if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then From 1450dcd67fc1be5e8ec0cbe0393dc30ec24320b6 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 16 Sep 2024 21:14:36 -0700 Subject: [PATCH 0983/1018] [dynamo] Fix support for classmethod(property(...)) (#134968) Fixes #134451 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134968 Approved by: https://github.com/yanboliang --- test/dynamo/test_repros.py | 82 ++++++++++++++++++++++++- torch/_dynamo/variables/constant.py | 2 + torch/_dynamo/variables/user_defined.py | 4 ++ 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 1dc42c4405df7a..338da69b20107d 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -20,7 +20,7 @@ from abc import ABC from collections import namedtuple from copy import deepcopy -from enum import Enum +from enum import Enum, IntEnum from functools import wraps from typing import Any, Dict, Iterator, List, Tuple from unittest import mock @@ -457,7 +457,7 @@ def forward( if past_key_value is not None: assert ( len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" real_seq_length += ( past_key_value[0].shape[2] if query_length is None else query_length ) @@ -4546,6 +4546,84 @@ def f(*args): f(*args) self.assertEqual(num_compiles, 1) + def test_issue134451(self): + class BoundingBox2DIndex(IntEnum): + _X = 0 + _Y = 1 + _HEADING = 2 + _LENGTH = 3 + _WIDTH = 4 + + @classmethod + def size(cls): + return 5 + + @classmethod + @property + def X(cls): + return cls._X + + @classmethod + @property + def Y(cls): + return cls._Y + + @classmethod + @property + def HEADING(cls): + return cls._HEADING + + @classmethod + @property + def LENGTH(cls): + return cls._LENGTH + + @classmethod + @property + def WIDTH(cls): + return cls._WIDTH + + @classmethod + @property + def POINT(cls): + # assumes X, Y have subsequent indices + return slice(cls._X, cls._Y + 1) + + @classmethod + @property + def STATE_SE2(cls): + # assumes X, Y, HEADING have subsequent indices + return slice(cls._X, cls._HEADING + 1) + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self._mlp_states = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, BoundingBox2DIndex.size()), + ) + + def forward(self, x): + agent_states = self._mlp_states(x) + agent_states[..., BoundingBox2DIndex.POINT] = ( + agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 + ) + agent_states[..., BoundingBox2DIndex.HEADING] = ( + agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi + ) + return agent_states + + model = SimpleModel().eval() + input_tensor = torch.randn(1, 10, dtype=torch.float32) + opt = torch.compile(model.eval(), backend="eager", fullgraph=True) + actual = opt(input_tensor) + try: + expected = model(input_tensor) + except Exception as e: + raise unittest.SkipTest("eager failed, requires Python>=3.12") from e + self.assertEqual(actual, expected) + def test_invalid_seq_unpack(self): def myfn(arg): (a, b) = arg diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 65de0aab6ef3c7..de357cf8094f3a 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -222,6 +222,8 @@ def create(cls, cls_type, value_vt, options): unimplemented("Enum variable is constructed with non constant values") def as_proxy(self): + if isinstance(self.value, int): + return int(self.value) # convert IntEnum to a normal int return self.value def __str__(self) -> str: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 14499b4d2e4bf7..32057c838d74c8 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -193,6 +193,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke else: return SourcelessBuilder.create(tx, func) elif isinstance(obj, classmethod): + if isinstance(obj.__func__, property): + return variables.UserFunctionVariable(obj.__func__.fget).call_function( + tx, [self], {} + ) return variables.UserMethodVariable(obj.__func__, self, source=source) elif isinstance(obj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static(dict, "fromkeys") From d81663c882d9290579b129b856b7025cee8b6d91 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Wed, 18 Sep 2024 07:33:41 +0000 Subject: [PATCH 0984/1018] Reland D62220158 (#136213) Summary: We fix the unit test test_pad_mm and reland the diff Test Plan: See in D62220158 Differential Revision: D62891584 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136213 Approved by: https://github.com/dshi7 --- test/inductor/test_pad_mm.py | 35 ++++++++++++++++++++++++++ torch/_inductor/fx_passes/pad_mm.py | 23 +++++++++++++++++ torch/_inductor/fx_passes/split_cat.py | 1 + 3 files changed, 59 insertions(+) diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index dbe1170994fdbe..0b34e174cc6688 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -9,6 +9,7 @@ get_pad_cache, get_padded_length, should_pad_common, + should_pad_mm_bf16, ) from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code @@ -449,6 +450,40 @@ def mm(inps, b): repr(get_pad_cache().get_local_cache()) ) + @fresh_inductor_cache() + @inductor_config.patch( + post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}} + ) + def test_pad_mm_bf16(self): + m = 2 + n = 13 + k = 15691904 + mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16) + mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16) + expected_alignment = get_alignment_size(mat1) + + assert expected_alignment == 8, "Alignment for bfloat16 should be 8" + assert should_pad_common( + mat1, mat2 + ), "This should pass the common padding criteria" + if torch.cuda.get_device_capability() < (9, 0): + assert should_pad_mm_bf16( + mat1.dtype, m, n, k + ), "This should pass the should_pad_mm_bf16 padding criteria" + + @torch.compile() + def mm(mat1, mat2): + return torch.mm(mat1, mat2) + + res2, (code,) = run_and_get_code(mm, mat1, mat2) + mm_expected_result = torch.mm(mat1, mat2) + # in call code, expect to see a single pad per input, and then we should see padded allocation for output + FileCheck().check("del async_compile").check_count( + ".run(", 2, exactly=True + ).check("empty_strided_cuda((8, 16)").run(code) + + assert torch.allclose(res2, mm_expected_result), "MM results are not identical" + if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 87450b34e7ee27..a00bad39747918 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -364,6 +364,23 @@ def should_pad(key: str, ori_time, pad_time) -> bool: return should_pad +def should_pad_mm_bf16(dtype, M, N, K): + # always force pad for mm with bf16 when the following are satisfied to avoid perf regression + large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ + "pad_aten_mm_pass" + ].get("k_threshold_to_pad", 8388608) + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and N % 2 == 1 + and K >= large_k_threshold_to_pad + and torch.cuda.get_device_capability() < (9, 0) + ): # doesnt repro on h100s: + return True + return False + + def should_pad_bench( match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None ) -> bool: @@ -410,6 +427,12 @@ def realize_symbols(ds): if torch._inductor.config.force_shape_pad: return True + if ( + "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options + and should_pad_mm_bf16(mat1.dtype, m, n, k) + ): + return True + if not has_triton(): return False diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index f850ecf6008c94..194d1d6dbaa798 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -65,6 +65,7 @@ "decompose_mm_pass", "unbind_stack_aten_pass", "shape_padding_multiplier", + "pad_aten_mm_pass", ] for pass_name in pre_grad_pass_names: From feee7dd3216f57f0254c1e5434ef218ca7ba9182 Mon Sep 17 00:00:00 2001 From: Prachi Gupta Date: Wed, 18 Sep 2024 11:01:23 +0000 Subject: [PATCH 0985/1018] Fix ROCm skip decorator for test_ddp_tp and multiprocess UTs (#136161) skip_if_rocm is used only in multiprocess case (when UT test class is a child of MultiProcessTestCase). Each individual process can exit with a skip code. If used for single process UT, it will cause the UT to fail as the process returns a non-zero exit code. Use skipIfRocm in single process UTs. To avoid the above confusion, this PR renamed skip_if_rocm to skip_if_rocm_multiprocess. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/136161 Approved by: https://github.com/jithunnair-amd, https://github.com/kwen2501, https://github.com/fegin --- .../test_replicate_with_compiler.py | 16 +++++++------- .../quantization/test_quantization.py | 6 ++--- .../optim/test_zero_redundancy_optimizer.py | 8 +++---- test/distributed/test_c10d_nccl.py | 22 +++++++++---------- torch/testing/_internal/common_distributed.py | 4 ++-- .../ddp_under_dist_autograd_test.py | 4 ++-- .../_internal/distributed/distributed_test.py | 10 ++++----- 7 files changed, 35 insertions(+), 35 deletions(-) diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 14d9cc47271f59..e0df5aaff27f84 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -29,9 +29,9 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, - skip_if_rocm, + skip_if_rocm_multiprocess, ) -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, skipIfRocm from torch.testing._internal.distributed.fake_pg import FakeStore from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint @@ -217,21 +217,21 @@ def test_compile_cpu_no_sync(self): self._test_compile(use_gpu=False, no_sync=True) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) @torch._inductor.config.patch(reorder_for_locality=False) def test_compile_gpu(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=False) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) @torch._inductor.config.patch(reorder_for_locality=False) def test_compile_gpu_ac(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=True) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_bf16(self): def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -245,7 +245,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: self._test_compile(use_gpu=True, no_sync=False, setup_func=setup) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_fp16(self): def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -262,7 +262,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: ) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_backward_only(self): self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True) @@ -386,7 +386,7 @@ def tearDown(self): dist.destroy_process_group() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skipIfRocm def test_ddp_tp(self): ref_model = Net() compiled_replicate_model = deepcopy(ref_model) diff --git a/test/distributed/algorithms/quantization/test_quantization.py b/test/distributed/algorithms/quantization/test_quantization.py index 0915b103d241cd..94a1c763474501 100644 --- a/test/distributed/algorithms/quantization/test_quantization.py +++ b/test/distributed/algorithms/quantization/test_quantization.py @@ -14,7 +14,7 @@ requires_gloo, requires_nccl, skip_if_lt_x_gpu, - skip_if_rocm, + skip_if_rocm_multiprocess, ) from torch.testing._internal.common_utils import ( NO_MULTIPROCESSING_SPAWN, @@ -112,7 +112,7 @@ def test_all_gather_bfp16(self): BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16" ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_fp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( @@ -137,7 +137,7 @@ def test_all_to_all_fp16(self): BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16" ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_bfp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index a17fd7cfb263a3..67edb211b9f1ed 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -403,7 +403,7 @@ def _check_same_model_params( ) @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm + @common_distributed.skip_if_rocm_multiprocess def test_step(self): """Check that ZeroRedundancyOptimizer properly exposes the ``step()`` interface.""" @@ -443,7 +443,7 @@ def test_step(self): self.assertEqual(m.bias, m_zero.bias) @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm + @common_distributed.skip_if_rocm_multiprocess def test_step_with_closure(self): """Check that ZeroRedundancyOptimizer properly exposes the ``step(closure)`` interface.""" @@ -663,7 +663,7 @@ def test_multiple_param_groups(self): torch.testing.assert_close(layer1.bias, layer3.bias) @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm + @common_distributed.skip_if_rocm_multiprocess def test_collect_shards(self): """Check the state consolidation mechanism and the state dict exposed by ZeroRedundancyOptimizer.""" @@ -1383,7 +1383,7 @@ def _test_ddp_zero_overlap( @common_distributed.skip_if_win32() @common_distributed.requires_nccl() @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm + @common_distributed.skip_if_rocm_multiprocess @parametrize( "use_gpu", [True], diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 3b1cd8e56d7e8c..78688b0e6a70cc 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -48,7 +48,7 @@ requires_nccl, requires_nccl_version, skip_if_lt_x_gpu, - skip_if_rocm, + skip_if_rocm_multiprocess, TEST_SKIPS, with_dist_debug_levels, with_nccl_blocking_wait, @@ -360,7 +360,7 @@ def test_close_pg(self): torch.float8_e5m2, ], ) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nan_assert(self, type): # Expecting a device-side error when NaN is detected os.environ["TORCH_NCCL_NAN_CHECK"] = "1" @@ -1724,7 +1724,7 @@ def test_grad_layout_1devicemodule_1replicaperprocess(self): @requires_nccl() @skip_if_lt_x_gpu(4) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_grad_layout_2devicemodule(self): int_devices = gpus_for_rank(self.world_size)[self.rank][:2] dev0 = torch.device("cuda:" + str(int_devices[0])) @@ -2496,7 +2496,7 @@ def _run_all_reduce(self, pg): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_but_pass_in_sandcastle("Test does not pass when run locally") def test_nccl_errors_nonblocking(self): # Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test @@ -2558,7 +2558,7 @@ def _test_nccl_errors_blocking(self, func): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_blocking_clean_exit(self): self._test_nccl_errors_blocking(lambda: sys.exit(0)) @@ -2566,7 +2566,7 @@ def test_nccl_errors_blocking_clean_exit(self): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_blocking_nonzero_exit(self): self._test_nccl_errors_blocking(lambda: sys.exit(1)) @@ -2574,7 +2574,7 @@ def test_nccl_errors_blocking_nonzero_exit(self): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_but_pass_in_sandcastle( "Frequently times out see https://github.com/pytorch/pytorch/issues/58920" ) @@ -2585,7 +2585,7 @@ def test_nccl_errors_blocking_abort(self): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_blocking_sigkill(self): self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL)) @@ -2593,7 +2593,7 @@ def test_nccl_errors_blocking_sigkill(self): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_blocking_sigterm(self): self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM)) @@ -2809,7 +2809,7 @@ def test_all_reduce_coalesced_manager_nccl(self): @requires_nccl() @skip_if_lt_x_gpu(2) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_intra_node_comm_all_reduce(self): from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter from torch.testing._internal.common_cuda import SM80OrLater @@ -4344,7 +4344,7 @@ def _check_return_codes(self, elapsed_time): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(2) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_dump(self): os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000" diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 26bdcce6103120..98b0f22d60bfd3 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -340,9 +340,9 @@ def requires_mpi(): ) -def skip_if_rocm(func): +def skip_if_rocm_multiprocess(func): """Skips a test for ROCm""" - func.skip_if_rocm = True + func.skip_if_rocm_multiprocess = True @wraps(func) def wrapper(*args, **kwargs): diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index de8b51e3defb98..00a67b20a73c7c 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -18,7 +18,7 @@ requires_gloo, requires_nccl, skip_if_lt_x_gpu, - skip_if_rocm, + skip_if_rocm_multiprocess, ) from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( @@ -662,7 +662,7 @@ class CudaDdpComparisonTest(CommonDdpComparisonTest): @skip_if_lt_x_gpu(NUM_TRAINERS) @requires_nccl() @dist_init - @skip_if_rocm + @skip_if_rocm_multiprocess def test_ddp_dist_autograd_local_vs_remote_gpu(self): # Each trainer uses a different random seed. Otherwise, they are going # to have exactly the same initial model parameters, input, and diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index b20165cd98fac4..981c8e59580649 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -62,7 +62,7 @@ initialize_temp_directories, cleanup_temp_dir, simple_sparse_reduce_tests, - skip_if_rocm, + skip_if_rocm_multiprocess, skip_if_small_worldsize, skip_if_odd_worldsize, skip_if_lt_x_gpu, @@ -3936,7 +3936,7 @@ def test_all_to_all(self): @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" ) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_cuda(self): group, group_id, rank = self._init_global_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) @@ -3952,7 +3952,7 @@ def test_all_to_all_complex(self): @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" ) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_cuda_complex(self): group, group_id, rank = self._init_global_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) @@ -4020,7 +4020,7 @@ def test_all_to_all_group(self): BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" ) @skip_if_small_worldsize - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_group_cuda(self): group, group_id, rank = self._init_group_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) @@ -4080,7 +4080,7 @@ def test_all_to_all_full_group(self): @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" ) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_full_group_cuda(self): group, group_id, rank = self._init_full_group_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) From 4dd3a18ebbfc5e2acbf0acae8e8bd0077ff9c771 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Thu, 12 Sep 2024 19:37:25 -0700 Subject: [PATCH 0986/1018] [Inductor] Increase multiplier to 3 for Inductor AMP FP16 benchmark correctness check (#135932) Fix https://github.com/pytorch/pytorch/issues/135657. Aligned with AMP BF16, using multiplier 3 for Inductor AMP FP16 benchmark correctness check Pull Request resolved: https://github.com/pytorch/pytorch/pull/135932 Approved by: https://github.com/CaoE, https://github.com/jgong5, https://github.com/jansel --- torch/_dynamo/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index bed3564ec15ce0..e086f11e174cd7 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1787,7 +1787,9 @@ def to_tensor(t): # accuracy when comparing AMP with FP32 is within a difference of less than 0.1%. # Thus, it's possible that the correctness check failures for these models are # false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms. - multiplier = 3.0 if res.dtype == torch.bfloat16 else 2.0 + multiplier = ( + 3.0 if res.dtype in (torch.float16, torch.bfloat16) else 2.0 + ) if use_larger_multiplier_for_smaller_tensor and ( fp64_ref.numel() <= 10 and tol >= 4 * 1e-2 From 091cb37e8ae7695051cae35c7e8a8444a2d099a3 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 16 Sep 2024 19:30:16 +0000 Subject: [PATCH 0987/1018] Fix fast_expand recursion error (#136163) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136163 Approved by: https://github.com/ezyang --- torch/fx/experimental/symbolic_shapes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 580d3ed741e3e4..51abd7f70a6869 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1444,7 +1444,7 @@ def _fast_expand(expr: sympy.Expr) -> sympy.Expr: if expr.is_Pow: base, exp = expr.args - if exp.is_Integer: + if exp.is_Integer and base.is_Add: if exp > 1: return sympy.expand_multinomial(expr, deep=False) elif exp < 0: From 79f2421bf7dca98e32dfba657d4f69b519da54e8 Mon Sep 17 00:00:00 2001 From: CaoE Date: Wed, 18 Sep 2024 16:31:27 +0000 Subject: [PATCH 0988/1018] Add _addmm_activation to lower precision cast policy on AutocastCPU (#135936) Fixes #132613. Add `_addmm_activation` to lower precision cast policy on AutocastCPU. `_addmm_activation` https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/transformer.cpp#L39 of `transformer_encoder_layer_forward` may throw `RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float` when autocast is enabled, as `_native_multi_head_attention` is put in lower data type cast policy https://github.com/pytorch/pytorch/pull/107674 and `_addmm_activation` may encounter mixed data types. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135936 Approved by: https://github.com/jgong5, https://github.com/ezyang --- aten/src/ATen/autocast_mode.cpp | 1 + torch/testing/_internal/autocast_test_lists.py | 1 + 2 files changed, 2 insertions(+) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 4766476d23377e..8ae66a30dcaf0f 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -336,6 +336,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(linalg_vecdot, lower_precision_fp) KERNEL_CPU(baddbmm, lower_precision_fp) KERNEL_CPU(addmm, lower_precision_fp) + KERNEL_CPU(_addmm_activation, lower_precision_fp) KERNEL_CPU(addbmm, lower_precision_fp) KERNEL_CPU(linear, lower_precision_fp) KERNEL_CPU(_convolution, deprecated, lower_precision_fp) diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index c9789f19c3dcff..2d18da71ec2bda 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -320,6 +320,7 @@ def __init__(self, dev): torch.randn((n, n, n), device=dev, dtype=torch.float32), torch.randn((n, n, n), device=dev, dtype=torch.float32))), ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("_addmm_activation", mat1_fp32 + mat2_fp32 + mat3_fp32, {"beta": 1, "alpha": 1, "use_gelu": True}), ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), torch.randn((n, n, n), device=dev, dtype=torch.float32))), ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), From 1c33f7f3dbf524496a2744274cb20eb381081e96 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 17 Sep 2024 09:18:28 -0700 Subject: [PATCH 0989/1018] [PyTorch] Remove unnecessary include of c10/util/Exception.h in irange.h (#136202) Manually audited and can't figure out why this would be needed. Differential Revision: [D62879500](https://our.internmc.facebook.com/intern/diff/D62879500/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136202 Approved by: https://github.com/malfet --- c10/mobile/CPUProfilingAllocator.cpp | 1 + c10/util/irange.h | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/c10/mobile/CPUProfilingAllocator.cpp b/c10/mobile/CPUProfilingAllocator.cpp index c655d7953a0b4f..2fc569135e267c 100644 --- a/c10/mobile/CPUProfilingAllocator.cpp +++ b/c10/mobile/CPUProfilingAllocator.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include diff --git a/c10/util/irange.h b/c10/util/irange.h index 3d7c607a1e9033..2719a82075cc96 100644 --- a/c10/util/irange.h +++ b/c10/util/irange.h @@ -2,7 +2,6 @@ #pragma once -#include #include #include From 645c4549125dbecdf7d7615222f97e258382ebb6 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 17 Sep 2024 08:40:41 -0700 Subject: [PATCH 0990/1018] [c10d] Make test compatible for new pytest (#136158) Temporary fix to the issue in https://github.com/pytorch/pytorch/issues/127517. Short-term fix following CPython: https://github.com/python/cpython/blob/51aefc5bf907ddffaaf083ded0de773adcdf08c8/Lib/unittest/case.py#L419-L426 Differential Revision: [D62878083](https://our.internmc.facebook.com/intern/diff/D62878083) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136158 Approved by: https://github.com/fegin --- torch/testing/_internal/common_distributed.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 98b0f22d60bfd3..9ec38c9ca671c2 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -561,8 +561,14 @@ def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> if methodName != "runTest": method_name = methodName super().__init__(method_name) - fn = getattr(self, method_name) - setattr(self, method_name, self.join_or_run(fn)) + try: + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + except AttributeError as e: + if methodName != 'runTest': + # we allow instantiation with no explicit method name + # but not an *incorrect* or missing method name + raise ValueError(f"no such test method in {self.__class__}: {methodName}") from e def setUp(self) -> None: super().setUp() @@ -1014,8 +1020,14 @@ def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> if methodName != "runTest": method_name = methodName super().__init__(method_name) - fn = getattr(self, method_name) - setattr(self, method_name, self.join_or_run(fn)) + try: + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + except AttributeError as e: + if methodName != 'runTest': + # we allow instantiation with no explicit method name + # but not an *incorrect* or missing method name + raise ValueError(f"no such test method in {self.__class__}: {methodName}") from e def perThreadSetUp(self): # super().setUp() # TestCase.setUp() calls torch.manual_seed() From 5ef2ad1b25e1045ac30aad534b8a1e73e812eea7 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Wed, 18 Sep 2024 17:39:34 +0000 Subject: [PATCH 0991/1018] [ROCm] upgrade ROCm CI builds to py3.10 (#134108) Upgrade ROCm CI builds to py3.10 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134108 Approved by: https://github.com/jeffdaily, https://github.com/jithunnair-amd, https://github.com/atalman --- .ci/docker/build.sh | 4 ++-- .github/workflows/inductor-rocm.yml | 18 +++++++++--------- .github/workflows/periodic.yml | 18 +++++++++--------- .github/workflows/pull.yml | 6 +++--- .github/workflows/rocm.yml | 18 +++++++++--------- .github/workflows/slow.yml | 18 +++++++++--------- .github/workflows/trunk.yml | 18 +++++++++--------- 7 files changed, 50 insertions(+), 50 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 61f7acfe7633fe..dedf591db38d01 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -286,7 +286,7 @@ case "$image" in TRITON=yes ;; pytorch-linux-focal-rocm-n-1-py3) - ANACONDA_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes DB=yes @@ -297,7 +297,7 @@ case "$image" in TRITON=yes ;; pytorch-linux-focal-rocm-n-py3) - ANACONDA_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes DB=yes diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index 01789ea8942fe5..dd26a0a70fe312 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -31,13 +31,13 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_1-py3_8-inductor-build: - name: rocm6.1-py3.8-inductor + linux-focal-rocm6_1-py3_10-inductor-build: + name: rocm6.1-py3.10-inductor uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.1-py3.8 + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -45,14 +45,14 @@ jobs: { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2" }, ]} - linux-focal-rocm6_1-py3_8-inductor-test: + linux-focal-rocm6_1-py3_10-inductor-test: permissions: id-token: write contents: read - name: rocm6.1-py3.8-inductor + name: rocm6.1-py3.10-inductor uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_1-py3_8-inductor-build + needs: linux-focal-rocm6_1-py3_10-inductor-build with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-inductor-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-inductor-build.outputs.test-matrix }} diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 57a16a5d3b9fec..5fe1784e59f6d0 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -297,13 +297,13 @@ jobs: docker-image: ${{ needs.linux-vulkan-focal-py3_11-clang10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-vulkan-focal-py3_11-clang10-build.outputs.test-matrix }} - linux-focal-rocm6_1-py3_8-build: - name: linux-focal-rocm6.1-py3.8 + linux-focal-rocm6_1-py3_10-build: + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.1-py3.8 + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -312,19 +312,19 @@ jobs: { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu" }, ]} - linux-focal-rocm6_1-py3_8-test: + linux-focal-rocm6_1-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_1-py3_8-build + - linux-focal-rocm6_1-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.test-matrix }} linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build: name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 92c5b043358e6c..a7c17117a8c477 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -503,15 +503,15 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_1-py3_8-build: + linux-focal-rocm6_1-py3_10-build: # don't run build twice on main if: github.event_name == 'pull_request' - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.1-py3.8 + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm.yml index 781f2878b6eae6..76b42498333a9d 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm.yml @@ -25,11 +25,11 @@ jobs: id-token: write contents: read - linux-focal-rocm6_1-py3_8-build: - name: linux-focal-rocm6.1-py3.8 + linux-focal-rocm6_1-py3_10-build: + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm6.1-py3.8 + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -42,16 +42,16 @@ jobs: { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" }, ]} - linux-focal-rocm6_1-py3_8-test: + linux-focal-rocm6_1-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_1-py3_8-build + - linux-focal-rocm6_1-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.test-matrix }} diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index ffad7e8ca3637a..426a4473cc034b 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -130,13 +130,13 @@ jobs: docker-image: ${{ needs.linux-focal-py3_9-clang10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} - linux-focal-rocm6_1-py3_8-build: - name: linux-focal-rocm6.1-py3.8 + linux-focal-rocm6_1-py3_10-build: + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.1-py3.8 + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -144,19 +144,19 @@ jobs: { config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, ]} - linux-focal-rocm6_1-py3_8-test: + linux-focal-rocm6_1-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_1-py3_8-build + - linux-focal-rocm6_1-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.test-matrix }} linux-jammy-py3_10-clang15-asan-build: name: linux-jammy-py3.10-clang15-asan diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 2b31568aae5733..f5a812abcac551 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -223,13 +223,13 @@ jobs: cuda-version: "12.1" runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" - linux-focal-rocm6_1-py3_8-build: - name: linux-focal-rocm6.1-py3.8 + linux-focal-rocm6_1-py3_10-build: + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.1-py3.8 + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -240,19 +240,19 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_1-py3_8-test: + linux-focal-rocm6_1-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_1-py3_8-build + - linux-focal-rocm6_1-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.test-matrix }} tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build: From 68da9bbc81bdd14af76ff5c214bde6df5e92bf71 Mon Sep 17 00:00:00 2001 From: maajidkhann Date: Wed, 18 Sep 2024 18:59:10 +0000 Subject: [PATCH 0992/1018] Extending the Pytorch vec backend for SVE (ARM) (#119571) **Motivation:** In Pytorch, Aten vectorization supports multiple platforms, including x86 and Arm, as well as multiple data types. It provides a generic implementation of Vector (Vec) type that allows the programmer to write code packing various primitives (such as floats) within 256bit & 512bits registers. It can be extended to support other ISAs easily by adding more VecISA sub-classes. **Reference Link:** https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cpu/vec **This PR:** * Our goal with this contribution is to add support for SVE backend for Vec in the Aten vectorization for CPU backend which can be benefitted by any ARM architecture supported CPU's that supports SVE. * More about SVE ISA for ARM: [https://developer.arm.com/Architectures/Scalable Vector Extensions](https://developer.arm.com/Architectures/Scalable%20Vector%20Extensions) * We are using the ARM C Language Extensions for SVE (https://developer.arm.com/documentation/102699/0100/Optimizing-with-intrinsics ) to accelerate performance for various operators in the SVE backend for Vec. * Currently we are adding support only for SVE ISA with the vector length of 256 bits (SVE 256). In future, we plan to extend this SVE support for other vector lengths as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119571 Approved by: https://github.com/malfet, https://github.com/snadampal Co-authored-by: Divya Kotadiya --- aten/src/ATen/CMakeLists.txt | 2 +- aten/src/ATen/Version.cpp | 5 + aten/src/ATen/cpu/vec/functional_base.h | 2 +- aten/src/ATen/cpu/vec/intrinsics.h | 8 + aten/src/ATen/cpu/vec/sve/sve_helper.h | 63 ++ aten/src/ATen/cpu/vec/sve/vec_common_sve.h | 176 ++++++ aten/src/ATen/cpu/vec/sve/vec_double.h | 505 ++++++++++++++++ aten/src/ATen/cpu/vec/sve/vec_float.h | 570 ++++++++++++++++++ aten/src/ATen/cpu/vec/sve/vec_int.h | 410 +++++++++++++ aten/src/ATen/cpu/vec/sve/vec_qint.h | 567 +++++++++++++++++ aten/src/ATen/cpu/vec/vec256/vec256.h | 6 +- .../src/ATen/cpu/vec/vec256/vec256_bfloat16.h | 2 +- aten/src/ATen/cpu/vec/vec256/vec256_qint.h | 2 +- aten/src/ATen/cpu/vec/vec_base.h | 2 +- .../ATen/native/BatchLinearAlgebraKernel.cpp | 16 + aten/src/ATen/native/DispatchStub.cpp | 70 ++- aten/src/ATen/native/DispatchStub.h | 34 +- aten/src/ATen/native/SegmentReduce.cpp | 8 + aten/src/ATen/native/cpu/LerpKernel.cpp | 2 +- aten/src/ATen/native/mkl/SpectralOps.cpp | 1 + .../native/sparse/FlattenIndicesKernel.cpp | 1 + .../SparseBinaryOpIntersectionKernel.cpp | 3 + .../ATen/native/transformers/attention.cpp | 1 + aten/src/ATen/test/vec_test_all_types.cpp | 12 + aten/src/ATen/test/vec_test_all_types.h | 3 + cmake/Codegen.cmake | 12 + cmake/Modules/FindARM.cmake | 78 +++ setup.py | 1 + torch/backends/cpu/__init__.py | 1 + 29 files changed, 2554 insertions(+), 9 deletions(-) create mode 100644 aten/src/ATen/cpu/vec/sve/sve_helper.h create mode 100644 aten/src/ATen/cpu/vec/sve/vec_common_sve.h create mode 100644 aten/src/ATen/cpu/vec/sve/vec_double.h create mode 100644 aten/src/ATen/cpu/vec/sve/vec_float.h create mode 100644 aten/src/ATen/cpu/vec/sve/vec_int.h create mode 100644 aten/src/ATen/cpu/vec/sve/vec_qint.h diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6d9152a4d07df2..1896530c0af6bc 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER) endif() EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS}) -file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") +file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp") file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h") file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp") diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index cf33d89e0814ea..f9a1ab0e1de8a7 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -105,6 +105,11 @@ std::string get_cpu_capability() { return "DEFAULT"; case native::CPUCapability::ZVECTOR: return "Z VECTOR"; +#elif defined(HAVE_SVE_CPU_DEFINITION) + case native::CPUCapability::DEFAULT: + return "DEFAULT"; + case native::CPUCapability::SVE256: + return "SVE256"; #else case native::CPUCapability::DEFAULT: return "NO AVX"; diff --git a/aten/src/ATen/cpu/vec/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h index 48d44dc42c33ce..e54440ed6eedd0 100644 --- a/aten/src/ATen/cpu/vec/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -78,7 +78,7 @@ struct VecReduceAllSIMD { #endif // defined(CPU_CAPABILITY_AVX512) #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE) template struct VecReduceAllSIMD { static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { diff --git a/aten/src/ATen/cpu/vec/intrinsics.h b/aten/src/ATen/cpu/vec/intrinsics.h index a82a8ef1a69457..48b18793b079e7 100644 --- a/aten/src/ATen/cpu/vec/intrinsics.h +++ b/aten/src/ATen/cpu/vec/intrinsics.h @@ -5,6 +5,10 @@ #elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* Clang-compatible compiler, targeting arm neon */ #include +#if defined(__ARM_FEATURE_SVE) +/* CLANG-compatible compiler, targeting ARM with SVE */ +#include +#endif #elif defined(_MSC_VER) /* Microsoft C/C++-compatible compiler */ #include @@ -17,6 +21,10 @@ #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* GCC-compatible compiler, targeting ARM with NEON */ #include +#if defined(__ARM_FEATURE_SVE) +/* GCC-compatible compiler, targeting ARM with SVE */ +#include +#endif #if defined (MISSING_ARM_VLD1) #include #elif defined (MISSING_ARM_VST1) diff --git a/aten/src/ATen/cpu/vec/sve/sve_helper.h b/aten/src/ATen/cpu/vec/sve/sve_helper.h new file mode 100644 index 00000000000000..e511ebb52b2e90 --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/sve_helper.h @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include + +#if defined(CPU_CAPABILITY_SVE) + +// Define the data type of VLS(vector-length specific). +typedef svbool_t vls_pred_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint8_t vls_int8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint16_t vls_int16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint32_t vls_int32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint64_t vls_int64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint8_t vls_uint8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); + +#define ptrue svptrue_b8() +#define ZERO_S8 svdup_n_s8(0) +#define ZERO_S16 svdup_n_s16(0) +#define ZERO_S32 svdup_n_s32(0) +#define ZERO_S64 svdup_n_s64(0) +#define ZERO_U8 svdup_n_u8(0) +#define ZERO_U16 svdup_n_u16(0) +#define ZERO_U32 svdup_n_u32(0) +#define ZERO_U64 svdup_n_u64(0) +#define ZERO_F16 svdup_n_f16(0.f) +#define ZERO_F32 svdup_n_f32(0.f) +#define ZERO_F64 svdup_n_f64(0.0) +#define ONE_S8 svdup_n_s8(1) +#define ONE_S16 svdup_n_s16(1) +#define ONE_S32 svdup_n_s32(1) +#define ONE_S64 svdup_n_s64(1) +#define ONE_U8 svdup_n_u8(1) +#define ONE_U16 svdup_n_u16(1) +#define ONE_U32 svdup_n_u32(1) +#define ONE_U64 svdup_n_u64(1) +#define ONE_F16 svdup_n_f16(1.f) +#define ONE_F32 svdup_n_f32(1.f) +#define ONE_F64 svdup_n_f64(1.0) +#define ALL_S8_TRUE_MASK svdup_n_s8(0xff) +#define ALL_S8_FALSE_MASK svdup_n_s8(0x0) +#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff) +#define ALL_S16_FALSE_MASK svdup_n_s16(0x0) +#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff) +#define ALL_S32_FALSE_MASK svdup_n_s32(0x0) +#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff) +#define ALL_S64_FALSE_MASK svdup_n_s64(0x0) +#define ALL_U8_TRUE_MASK svdup_n_u8(0x01) +#define ALL_U8_FALSE_MASK svdup_n_u8(0x00) +#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK) +#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK) +#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK) +#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK) +#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK) +#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK) + +#endif // defined(CPU_CAPABILITY_SVE) diff --git a/aten/src/ATen/cpu/vec/sve/vec_common_sve.h b/aten/src/ATen/cpu/vec/sve/vec_common_sve.h new file mode 100644 index 00000000000000..6f572e16a4c1fa --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_common_sve.h @@ -0,0 +1,176 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include + +#include +#include + +#if defined(CPU_CAPABILITY_SVE) +#include +#include +#include +#include +#endif + +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +inline Vectorized cast(const Vectorized& src) { + return svreinterpret_f32_f64(src); +} + +template<> +inline Vectorized cast(const Vectorized& src) { + return svreinterpret_f64_f32(src); +} + +#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \ +template<> \ +inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_s##int_bit##_f##float_bit(src); \ +} \ +template<> \ +inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_f##float_bit##_s##int_bit(src); \ +} + +DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64) +DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64) +DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64) +DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32) +DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32) +DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t> +inline gather(const double* base_addr, const Vectorized& vindex_) { + svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svld1_gather_s64index_f64(ptrue, base_addr, vindex); +} + +template +std::enable_if_t> +inline gather(const float* base_addr, const Vectorized& vindex_) { + svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svld1_gather_s32index_f32(ptrue, base_addr, vindex); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t> +inline mask_gather(const Vectorized& src, const double* base_addr, + const Vectorized& vindex_, const Vectorized& mask_) { + svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), + ALL_S64_TRUE_MASK); + svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); +} + +template +std::enable_if_t> +inline mask_gather(const Vectorized& src, const float* base_addr, + const Vectorized& vindex_, const Vectorized& mask_) { + svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), + ALL_S32_TRUE_MASK); + svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Only works for inputs in the range: [-2^51, 2^51] +// From: https://stackoverflow.com/a/41148578 +template<> +Vectorized +inline convert_to_int_of_same_size(const Vectorized &src) { + svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); + return svsub_s64_x(ptrue, + svreinterpret_s64_f64(x), + svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); +} + +template<> +Vectorized +inline convert_to_int_of_same_size(const Vectorized &src) { + return svcvt_s32_f32_x(ptrue, src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> +inline interleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3} + // b = {b0, b1, b2, b3} + // group cols crossing lanes: + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair(Vectorized(svzip1_f64(a, b)), + Vectorized(svzip2_f64(a, b))); +} + +template <> +std::pair, Vectorized> +inline interleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + return std::make_pair(Vectorized(svzip1_f32(a, b)), + Vectorized(svzip2_f32(a, b))); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> +inline deinterleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair(Vectorized(svuzp1_f64(a, b)), + Vectorized(svuzp2_f64(a, b))); +} + +template <> +std::pair, Vectorized> +inline deinterleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair(Vectorized(svuzp1_f32(a, b)), + Vectorized(svuzp2_f32(a, b))); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/sve/vec_double.h b/aten/src/ATen/cpu/vec/sve/vec_double.h new file mode 100644 index 00000000000000..911e69da90d4c9 --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_double.h @@ -0,0 +1,505 @@ +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> class Vectorized { +private: + vls_float64_t values; +public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(double); + } + Vectorized() {} + Vectorized(svfloat64_t v) : values(v) {} + Vectorized(double val) { + values = svdup_n_f64(val); + } + template> + Vectorized(Args... vals) { + __at_align__ double buffer[size()] = { vals... }; + values = svld1_f64(ptrue, buffer); + } + operator svfloat64_t() const { + return values; + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), + ALL_S64_TRUE_MASK); + return svsel_f64(mask, b, a); + } + template + static Vectorized arange(double base = 0., step_t step = static_cast(1)) { + __at_align__ double buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f64(ptrue, buffer); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f64(svwhilelt_b64(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f64(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b64(0ull, count); + return svld1_f64(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f64(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b64(0ull, count); + svst1_f64(pg, reinterpret_cast(ptr), values); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int64_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f64(ptrue, values, ZERO_F64); + svst1_s64(ptrue, mask_array, svsel_s64(svbool_mask, + ALL_S64_TRUE_MASK, + ALL_S64_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f64(ptrue, values, ZERO_F64); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any(ptrue, svcmpuo_f64(ptrue, svsub_f64_x(ptrue, values, values), ZERO_F64)); + } + Vectorized map(double (*f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f64_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f64(NAN); + const auto nan_mask = svcmpuo_f64(ptrue, values, ZERO_F64); + const auto pi = svdup_n_f64(c10::pi); + + const auto neg_mask = svcmplt_f64(ptrue, values, ZERO_F64); + auto angle = svsel_f64(neg_mask, pi, ZERO_F64); + angle = svsel_f64(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return USE_SLEEF(Vectorized(Sleef_acosdx_u10sve(values)),map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( Vectorized(Sleef_acoshdx_u10sve(values)),map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF(Vectorized(Sleef_asindx_u10sve(values)),map(std::asin)); + } + Vectorized atan() const { + return USE_SLEEF(Vectorized(Sleef_atandx_u10sve(values)),map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF(Vectorized(Sleef_atanhdx_u10sve(values)),map(std::atanh)); + } + Vectorized atan2(const Vectorized &b) const { + USE_SLEEF({return Vectorized(Sleef_atan2dx_u10sve(values, b));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized copysign(const Vectorized &sign) const { + USE_SLEEF( {return Vectorized(Sleef_copysigndx_sve(values, sign));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + } + ) + } + Vectorized erf() const { + return USE_SLEEF(Vectorized(Sleef_erfdx_u10sve(values)),map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF(Vectorized(Sleef_erfcdx_u15sve(values)),map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF(Vectorized(Sleef_expdx_u10sve(values)),map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF(Vectorized(Sleef_exp2dx_u10sve(values)),map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF(Vectorized(Sleef_expm1dx_u10sve(values)),map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + USE_SLEEF({return Vectorized(Sleef_fmoddx_sve(values, q));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + } + ) + } + Vectorized hypot(const Vectorized &b) const { + USE_SLEEF({return Vectorized(Sleef_hypotdx_u05sve(values, b));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + }) + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized &x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized &x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized &b) const { + USE_SLEEF( + { + return Vectorized(Sleef_nextafterfx_sve(values, b)); + }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized log() const { + return USE_SLEEF(Vectorized(Sleef_logdx_u10sve(values)),map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF(Vectorized(Sleef_log2dx_u10sve(values)),map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF(Vectorized(Sleef_log10dx_u10sve(values)),map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF(Vectorized(Sleef_log1pdx_u10sve(values)),map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( Vectorized(Sleef_sindx_u10sve(values)),map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF(Vectorized(Sleef_sinhdx_u10sve(values)),map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF(Vectorized(Sleef_cosdx_u10sve(values)),map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( Vectorized(Sleef_coshdx_u10sve(values)),map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f64_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f64_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f64_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f64_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF( Vectorized(Sleef_tandx_u10sve(values)),map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF( Vectorized(Sleef_tanhdx_u10sve(values)),map(std::tanh)); + } + Vectorized trunc() const { + return svrintz_f64_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF( Vectorized(Sleef_lgammadx_u10sve(values)),map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f64_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f64_x(ptrue, values, ONE_F64); + } + Vectorized rsqrt() const { + return svdivr_f64_x(ptrue, svsqrt_f64_x(ptrue, values), ONE_F64); + } + Vectorized pow(const Vectorized &b) const { + USE_SLEEF( {return Vectorized(Sleef_powdx_u10sve(values, b));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return svadd_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return svsub_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return svmul_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return svdiv_f64_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return svmax_f64_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return svmin_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return svmin_f64_x(ptrue, max, svmax_f64_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return svmin_f64_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return svmax_f64_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f64_s64(svand_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f64_s64(svorr_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f64_s64(sveor_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +Vectorized inline Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f64(ptrue, dst + i, svldnt1_f64(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b64(i, n); + svst1_f64(pg, dst + i, svldnt1_f64(pg, src + i)); + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return svmad_f64_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/sve/vec_float.h b/aten/src/ATen/cpu/vec/sve/vec_float.h new file mode 100644 index 00000000000000..4da7cc53710006 --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_float.h @@ -0,0 +1,570 @@ +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> class Vectorized { +private: + vls_float32_t values; +public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(float); + } + Vectorized() {} + Vectorized(svfloat32_t v) : values(v) {} + Vectorized(float val) { + values = svdup_n_f32(val); + } + template> + Vectorized(Args... vals) { + __at_align__ float buffer[size()] = { vals... }; + values = svld1_f32(ptrue, buffer); + } + operator svfloat32_t() const { + return values; + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), + ALL_S32_TRUE_MASK); + return svsel_f32(mask, b, a); + } + template + static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) { + __at_align__ float buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f32(ptrue, buffer); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f32(svwhilelt_b32(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_f32(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f32(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b32(0ull, count); + svst1_f32(pg, reinterpret_cast(ptr), values); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int32_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32); + svst1_s32(ptrue, mask_array, svsel_s32(svbool_mask, + ALL_S32_TRUE_MASK, + ALL_S32_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any(ptrue, svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32)); + } + Vectorized map(float (*f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f32_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f32(NAN); + const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32); + const auto pi = svdup_n_f32(c10::pi); + + const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32); + auto angle = svsel_f32(neg_mask, pi, ZERO_F32); + angle = svsel_f32(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return values; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return values; + } + Vectorized acos() const { + return USE_SLEEF(Vectorized(Sleef_acosfx_u10sve(values)),map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF(Vectorized(Sleef_acoshfx_u10sve(values)),map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF(Vectorized(Sleef_asinfx_u10sve(values)),map(std::asin)); + } + Vectorized atan() const { + return USE_SLEEF(Vectorized(Sleef_atanfx_u10sve(values)),map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF(Vectorized(Sleef_atanhfx_u10sve(values)),map(std::atanh)); + } + Vectorized atan2(const Vectorized &b) const { + USE_SLEEF({return Vectorized(Sleef_atan2fx_u10sve(values, b));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++){ + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized copysign(const Vectorized &sign) const { + + USE_SLEEF({return Vectorized(Sleef_copysignfx_sve(values, sign));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + }) + } + Vectorized erf() const { + return USE_SLEEF(Vectorized(Sleef_erffx_u10sve(values)),map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF(Vectorized(Sleef_erfcfx_u15sve(values)),map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF(Vectorized(Sleef_expfx_u10sve(values)),map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF(Vectorized(Sleef_exp2fx_u10sve(values)),map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF(Vectorized(Sleef_expm1fx_u10sve(values)),map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + USE_SLEEF({return Vectorized(Sleef_fmodfx_sve(values, q));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + }) + } + Vectorized hypot(const Vectorized &b) const { + USE_SLEEF( {return Vectorized(Sleef_hypotfx_u05sve(values, b));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized &b) const { + USE_SLEEF( + { + return Vectorized(Sleef_nextafterfx_sve(values, b)); + }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized log() const { + return USE_SLEEF(Vectorized(Sleef_logfx_u10sve(values)),map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF(Vectorized(Sleef_log2fx_u10sve(values)),map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF(Vectorized(Sleef_log10fx_u10sve(values)),map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF(Vectorized(Sleef_log1pfx_u10sve(values)),map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF(Vectorized(Sleef_sinfx_u10sve(values)),map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF(Vectorized(Sleef_sinhfx_u10sve(values)),map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF(Vectorized(Sleef_cosfx_u10sve(values)),map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF(Vectorized(Sleef_coshfx_u10sve(values)),map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f32_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f32_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f32_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f32_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF(Vectorized(Sleef_tanfx_u10sve(values)),map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF(Vectorized(Sleef_tanhfx_u10sve(values)),map(std::tanh)); + } + Vectorized trunc() const { + return svrintz_f32_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF(Vectorized(Sleef_lgammafx_u10sve(values)),map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f32_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f32_x(ptrue, values, ONE_F32); + } + Vectorized rsqrt() const { + return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32); + } + Vectorized pow(const Vectorized &b) const { + USE_SLEEF( {return Vectorized(Sleef_powfx_u10sve(values, b));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return svadd_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return svsub_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return svmul_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return svdiv_f32_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return svmax_f32_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return svmin_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return svmin_f32_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return svmax_f32_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f32_s32(svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f32_s32(svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f32_s32(sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +Vectorized inline Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b32(i, n); + svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i)); + } +} + +template <> +inline void convert(const float *src, at::Half *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svuzp1_f16(svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), + ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svuzp1_f16(svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), + ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +} + +template <> +inline void convert(const at::Half *src, float *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svzip1_f16(svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svzip1_f16(svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +} + +template <> +inline void convert(const bool *src, float *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return svmad_f32_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/sve/vec_int.h b/aten/src/ATen/cpu/vec/sve/vec_int.h new file mode 100644 index 00000000000000..6a081bd00a7514 --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_int.h @@ -0,0 +1,410 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +#define VEC_INT_SVE_TEMPLATE(vl, bit) \ +template <> class Vectorized { \ +private: \ + vls_int##bit##_t values; \ +public: \ + using value_type = int##bit##_t; \ + using size_type = int; \ + static constexpr size_type size() { \ + return vl; \ + } \ + Vectorized() {} \ + Vectorized(svint##bit##_t v) : values(v) {} \ + Vectorized(int##bit##_t val) { \ + values = svdup_n_s##bit(val); \ + } \ + template> \ + Vectorized(Args... vals) { \ + __at_align__ int##bit##_t buffer[size()] = { vals... }; \ + values = svld1_s##bit(ptrue, buffer); \ + } \ + operator svint##bit##_t() const { \ + return values; \ + } \ + static Vectorized blendv(const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& mask_) { \ + svbool_t mask = svcmpeq_s##bit(ptrue, mask_, ALL_S##bit##_TRUE_MASK); \ + return svsel_s##bit(mask, b, a); \ + } \ + /* step sometimes requires a higher precision type (e.g., T=int, step_t=double) */ \ + template \ + static Vectorized arange(int##bit##_t base = 0, step_t step = static_cast(1)) { \ + __at_align__ int##bit##_t buffer[size()]; \ + for (int64_t i = 0; i < size(); i++) { \ + buffer[i] = base + i * step; \ + } \ + return svld1_s##bit(ptrue, buffer); \ + } \ + static Vectorized set(const Vectorized& a, \ + const Vectorized& b, \ + int##bit##_t count = size()) { \ + if (count == 0) { \ + return a; \ + } else if (count < size()) { \ + return svsel_s##bit(svwhilelt_b##bit(0ull, count), b, a); \ + } \ + return b; \ + } \ + static Vectorized loadu(const void* ptr, int64_t count = size()) { \ + if (count == size()) \ + return svld1_s##bit(ptrue, reinterpret_cast(ptr)); \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + return svld1_s##bit(pg, reinterpret_cast(ptr)); \ + } \ + void store(void* ptr, int64_t count = size()) const { \ + if (count == size()) { \ + svst1_s##bit(ptrue, reinterpret_cast(ptr), values); \ + } else { \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + svst1_s##bit(pg, reinterpret_cast(ptr), values); \ + } \ + } \ + const int##bit##_t& operator[](int idx) const = delete; \ + int##bit##_t& operator[](int idx) = delete; \ + Vectorized abs() const { \ + return svabs_s##bit##_x(ptrue, values); \ + } \ + Vectorized real() const { \ + return values; \ + } \ + Vectorized imag() const { \ + return svdup_n_s##bit(0); \ + } \ + Vectorized conj() const { \ + return values; \ + } \ + Vectorized frac() const; \ + Vectorized neg() const { \ + return svneg_s##bit##_x(ptrue, values); \ + } \ + Vectorized operator==(const Vectorized& other) const { \ + svbool_t mask = svcmpeq_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator!=(const Vectorized& other) const { \ + svbool_t mask = svcmpne_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<(const Vectorized& other) const { \ + svbool_t mask = svcmplt_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<=(const Vectorized& other) const { \ + svbool_t mask = svcmple_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>(const Vectorized& other) const { \ + svbool_t mask = svcmpgt_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>=(const Vectorized& other) const { \ + svbool_t mask = svcmpge_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized eq(const Vectorized& other) const; \ + Vectorized ne(const Vectorized& other) const; \ + Vectorized gt(const Vectorized& other) const; \ + Vectorized ge(const Vectorized& other) const; \ + Vectorized lt(const Vectorized& other) const; \ + Vectorized le(const Vectorized& other) const; \ +}; \ +template <> \ +Vectorized inline operator+(const Vectorized& a, \ + const Vectorized& b) { \ + return svadd_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline operator-(const Vectorized& a, \ + const Vectorized& b) { \ + return svsub_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline operator*(const Vectorized& a, \ + const Vectorized& b) { \ + return svmul_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline maximum(const Vectorized& a, \ + const Vectorized& b) { \ + return svmax_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline minimum(const Vectorized& a, \ + const Vectorized& b) { \ + return svmin_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline clamp(const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, svmax_s##bit##_x(ptrue, min, a)); \ +} \ +template <> \ +Vectorized inline clamp_max(const Vectorized& a, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, a); \ +} \ +template <> \ +Vectorized inline clamp_min(const Vectorized& a, \ + const Vectorized& min) { \ + return svmax_s##bit##_x(ptrue, min, a); \ +} \ +template <> \ +Vectorized inline operator&(const Vectorized& a, \ + const Vectorized& b) { \ + return svand_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline operator|(const Vectorized& a, \ + const Vectorized& b) { \ + return svorr_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline operator^(const Vectorized& a, \ + const Vectorized& b) { \ + return sveor_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +inline Vectorized operator~(const Vectorized& a) { \ + return sveor_s##bit##_x(ptrue, a, svdup_n_s##bit(-1)); \ +} \ +Vectorized inline Vectorized::eq(const Vectorized& other) const { \ + return (*this == other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::ne(const Vectorized& other) const { \ + return (*this != other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::gt(const Vectorized& other) const { \ + return (*this > other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::ge(const Vectorized& other) const { \ + return (*this >= other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::lt(const Vectorized& other) const { \ + return (*this < other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::le(const Vectorized& other) const { \ + return (*this <= other) & Vectorized(1); \ +} + +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int64_t), 64) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int32_t), 32) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int16_t), 16) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int8_t), 8) + +template +Vectorized inline intdiv_nosve(const Vectorized& a, const Vectorized& b) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] /= values_b[i]; + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return svdiv_s64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return svdiv_s32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +inline void convert(const int32_t *src, int64_t *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); + } +} + +template <> +inline void convert(const int64_t *src, float *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +} + +template <> +inline void convert(const int32_t *src, float *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b32(i, n); + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +} + +template <> +inline void convert(const bool *src, int64_t *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_64 = svwhilelt_b64(i, n); + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +} + +template <> +inline void convert(const bool *src, int32_t *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +} + +template <> +inline void convert(const uint8_t *src, bool *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b8(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8(pg, reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b8(i, n); + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8(pg, reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return svlsl_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return svlsl_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return svlsl_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return svlsl_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return svasr_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return svasr_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return svasr_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/sve/vec_qint.h b/aten/src/ATen/cpu/vec/sve/vec_qint.h new file mode 100644 index 00000000000000..7c49c041ddf2ff --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_qint.h @@ -0,0 +1,567 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include +#include +#include +#include +#include +#include + +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vectorized::float_num_vecs +// iterations. + +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// NOTE: These are low-performance implementations that we fall back on +// if we are not building with SVE. This may not be an issue, because +// currently for quantization we assume the user has at least SVE +// installed, so these can simply act as a reference implementation. +// +// If in the future we relax this requirement (SVE+), we should probably +// revisit these implementations + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + using size_type = int; + static constexpr size_type size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / Vectorized::size(); + } + + static constexpr int int_num_vecs() { + return size() / Vectorized::size(); + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (size_t i = 0; i < size(); ++i) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = + at::native::dequantize_val(tmp_scale[j], tmp_zero_point[j], T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = + at::native::dequantize_val(tmp_scale[j], tmp_zero_point[j], T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_s32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_s32(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * Vectorized::size()], Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (size_t i = 0; i < size(); ++i) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = + nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * Vectorized::size()], Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_u8(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b8(0ull, count); + return svld1_u8(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * Vectorized::size()], Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index da38b9d26d2188..68367b81bd8a03 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -7,9 +7,13 @@ #include #if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR)) -#include +#if defined(CPU_CAPABILITY_SVE256) +#include +#else #include #include +#endif +#include #include #include #include diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h index e567c1925be840..12c11abb748dea 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -1097,7 +1097,7 @@ inline Vectorized convert_float_##name(const Vectorized& a, const V return Vectorized::loadu(arr2); \ } CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16); -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE256) inline std::tuple, Vectorized> convert_half_float(const Vectorized& a) { static_assert(Vectorized::size() == 2 * Vectorized::size()); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 8659fbb20f9665..3430e654d7f1f8 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -843,7 +843,7 @@ Vectorized inline maximum(const Vectorized& a, const V return a.maximum(b); } -#else +#elif !defined(CPU_CAPABILITY_SVE256) // NOTE: These are low-performance implementations that we fall back on // if we are not building with AVX2. This may not be an issue, because diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index c4d8b1ccf5de49..ba7865cb522f26 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -990,7 +990,7 @@ inline mask_gather(const Vectorized& src, T const* base_addr, buffer[i] = src_arr[i]; } } - mask = Vectorized(); // "zero out" mask + mask = Vectorized(static_cast(0)); // "zero out" mask return Vectorized::loadu(static_cast(buffer)); } diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 5c3ce52df87d44..d61c9870f4c522 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -1140,87 +1140,103 @@ REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel); +REGISTER_SVE256_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl); REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel); REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_SVE256_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel); REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_SVE256_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel); REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_SVE256_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl); REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_SVE256_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel); REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_SVE256_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel); REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_SVE256_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel); REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_SVE256_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel); REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_SVE256_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel); REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); +REGISTER_SVE256_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel); REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); +REGISTER_SVE256_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); + REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel); REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel); +REGISTER_SVE256_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel); REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel); REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel); REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel); REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel); +REGISTER_SVE256_DISPATCH(svd_stub, &svd_kernel); REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel); REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); +REGISTER_SVE256_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 699f2e0fe2e27c..b57649c2632592 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -34,6 +34,17 @@ static CPUCapability compute_cpu_capability() { if (strcmp(envar, "zvector") == 0) { return CPUCapability::ZVECTOR; } +#elif defined(HAVE_SVE_CPU_DEFINITION) + int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW. +#ifdef HAVE_SVE256_CPU_DEFINITION + if (strcmp(envar, "sve256") == 0) { + if (sve_vl == 256) { + return CPUCapability::SVE256; + } + TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT"); + return CPUCapability::DEFAULT; + } +#endif #else #ifdef HAVE_AVX512_CPU_DEFINITION if (strcmp(envar, "avx512") == 0) { @@ -52,7 +63,7 @@ static CPUCapability compute_cpu_capability() { TORCH_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar); } -#if !defined(__powerpc__) && !defined(__s390x__) +#if !defined(__powerpc__) && !defined(__s390x__) && !defined(HAVE_SVE_CPU_DEFINITION) if (cpuinfo_initialize()) { #if defined(HAVE_AVX512_CPU_DEFINITION) // GCC supports some AVX512 intrinsics such as _mm512_set_epi16 only in @@ -79,6 +90,23 @@ static CPUCapability compute_cpu_capability() { } #endif +#if defined(__linux__) && defined(HAVE_SVE_CPU_DEFINITION) + if (cpuinfo_initialize() && cpuinfo_has_arm_sve()) { + int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW. + if (sve_vl <= 0) { + // SVE is not supported on this system. + // Return the default CPU capability. + return CPUCapability::DEFAULT; + } + #ifdef HAVE_SVE256_CPU_DEFINITION + if (sve_vl == 256) { // Check for SVE256 + return CPUCapability::SVE256; + } + #endif + // Return the default CPU capability. + return CPUCapability::DEFAULT; + } +#endif #ifdef HAVE_VSX_CPU_DEFINITION return CPUCapability::VSX; #else @@ -106,6 +134,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif ) { constexpr auto supported_devices = c10::array_of( c10::DeviceType::CPU, @@ -139,6 +170,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , SVE256 #endif ); if (!std::holds_alternative(result)) { @@ -191,6 +225,9 @@ void* DispatchStubImpl::get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif ) { auto result = try_get_call_ptr( @@ -211,6 +248,10 @@ void* DispatchStubImpl::get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , + SVE256 #endif ); if (std::holds_alternative(result)) { @@ -242,6 +283,9 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl( #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ){ @@ -274,6 +318,16 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl( if (capability >= static_cast(CPUCapability::ZVECTOR)) { return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel; } +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::SVE256)) { + if (C10_UNLIKELY(!SVE256)) { + // dispatch to DEFAULT, since the SVE kernel is missing + return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel; + } else { + return DispatchResult(SVE256); + } + } #endif return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel; } @@ -292,6 +346,9 @@ void* DispatchStubImpl::choose_cpu_impl( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif ) { auto capability = static_cast(get_cpu_capability()); (void)capability; @@ -326,6 +383,17 @@ void* DispatchStubImpl::choose_cpu_impl( TORCH_INTERNAL_ASSERT(ZVECTOR, "DispatchStub: missing ZVECTOR kernel"); return ZVECTOR; } +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::SVE256)) { + if (C10_UNLIKELY(!SVE256)) { + // dispatch to DEFAULT, since the SVE kernel is missing + TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel"); + return DEFAULT; + } else { + return SVE256; + } + } #endif TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel"); return DEFAULT; diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 641b779701d1e2..22a97ca9882b8e 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -64,6 +64,8 @@ enum class CPUCapability { VSX = 1, #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) ZVECTOR = 1, +#elif defined(HAVE_SVE_CPU_DEFINITION) + SVE256 = 1, #else AVX2 = 1, AVX512 = 2, @@ -112,6 +114,9 @@ struct TORCH_API DispatchStubImpl { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ); @@ -130,6 +135,9 @@ struct TORCH_API DispatchStubImpl { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ); @@ -148,6 +156,9 @@ struct TORCH_API DispatchStubImpl { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ); @@ -169,6 +180,9 @@ struct TORCH_API DispatchStubImpl { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ); @@ -221,6 +235,9 @@ struct DispatchStub { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , reinterpret_cast(ZVECTOR) +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , reinterpret_cast(SVE256) #endif ) ); @@ -275,6 +292,9 @@ struct DispatchStub { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , reinterpret_cast(ZVECTOR) +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , reinterpret_cast(SVE256) #endif ); if (std::holds_alternative(result)){ @@ -296,6 +316,9 @@ struct DispatchStub { #ifdef HAVE_ZVECTOR_CPU_DEFINITION static TORCH_API FnPtr ZVECTOR; #endif +#ifdef HAVE_SVE256_CPU_DEFINITION + static TORCH_API FnPtr SVE256; +#endif private: DispatchStubImpl impl; }; @@ -387,6 +410,12 @@ struct RegisterPRIVATEUSE1Dispatch { #define REGISTER_ZVECTOR_DISPATCH(name, fn) #endif +#ifdef HAVE_SVE256_CPU_DEFINITION +#define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn) +#else +#define REGISTER_SVE256_DISPATCH(name, fn) +#endif + // Macro to register the same kernel for all CPU arch types. This is useful // if a kernel does not benefit from being recompiled across different arch types. #define REGISTER_ALL_CPU_DISPATCH(name, fn) \ @@ -394,7 +423,8 @@ struct RegisterPRIVATEUSE1Dispatch { REGISTER_AVX512_DISPATCH(name, fn) \ REGISTER_AVX2_DISPATCH(name, fn) \ REGISTER_VSX_DISPATCH(name, fn) \ - REGISTER_ZVECTOR_DISPATCH(name, fn) + REGISTER_ZVECTOR_DISPATCH(name, fn) \ + REGISTER_SVE256_DISPATCH(name, fn) #define REGISTER_NO_CPU_DISPATCH(name) \ REGISTER_ALL_CPU_DISPATCH(name, nullptr) @@ -432,12 +462,14 @@ struct RegisterPRIVATEUSE1Dispatch { #elif defined(CPU_CAPABILITY) // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches. // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others. +// ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others. #ifdef CPU_CAPABILITY_AVX512 #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr)) #else #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif } // namespace at::native diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index 0ab01bbe8c0bd2..08220869701407 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -466,6 +466,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cp REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_SVE256_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); // offsets dispatches REGISTER_ARCH_DISPATCH( @@ -476,6 +477,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cp REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); // Currently some computation is being duplicated across forward and backward. // TODO: Cache indices in forward pass to re-use in backward @@ -546,6 +548,9 @@ REGISTER_VSX_DISPATCH( REGISTER_ZVECTOR_DISPATCH( _segment_reduce_lengths_backward_stub, &_segment_reduce_cpu_lengths_backward_kernel); +REGISTER_SVE256_DISPATCH( + _segment_reduce_lengths_backward_stub, + &_segment_reduce_cpu_lengths_backward_kernel); REGISTER_ARCH_DISPATCH( _segment_reduce_offsets_backward_stub, @@ -563,5 +568,8 @@ REGISTER_VSX_DISPATCH( REGISTER_ZVECTOR_DISPATCH( _segment_reduce_offsets_backward_stub, &_segment_reduce_cpu_offsets_backward_kernel); +REGISTER_SVE256_DISPATCH( + _segment_reduce_offsets_backward_stub, + &_segment_reduce_cpu_offsets_backward_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp index 7eaac38c21c8ad..d8b4259775d968 100644 --- a/aten/src/ATen/native/cpu/LerpKernel.cpp +++ b/aten/src/ATen/native/cpu/LerpKernel.cpp @@ -19,7 +19,7 @@ Vectorized is_lerp_weight_small(Vectorized weight) { // is_lerp_weight_small doesn't work for complex because z.abs() returns a // complex vector which can't be compared. Either implement it with z.abs_2_(), // or fallback to the scalar function. -#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER)) +#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) || defined(CPU_CAPABILITY_SVE)) template Vectorized> is_lerp_weight_small(Vectorized> weight) { using vec_reg_t = decltype(weight.abs_2_()); diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index e26cfbf6d8ebaa..8ae620ed0028c1 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -165,6 +165,7 @@ REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_co REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) +REGISTER_SVE256_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) // _out variants can be shared between PocketFFT and MKL Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, diff --git a/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp b/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp index 947332635203ba..90d3e9cce6734f 100644 --- a/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp +++ b/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp @@ -27,5 +27,6 @@ REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); +REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index 6f20b4e245c7a3..e86d5c46a795fa 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -161,16 +161,19 @@ REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_ REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); +REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel); REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); +REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel); REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); } diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index c1054ccce07e48..d91955412fc588 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -449,6 +449,7 @@ REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); +REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); int64_t _fused_sdp_choice_meta( const Tensor& query_, diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index de1f70e97ca5a8..a2c8da12c446bf 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -992,6 +992,9 @@ namespace { blend_init(a, b); test_blendv(expected_val, a, b, mask); } +// NOTE: In this test, blend is not required to implement SVE Vectorized::set. +// so, this test is disabled for SVE. +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(BitwiseFloatsAdditional2, Blend) { using vec = TypeParam; using VT = ValueType; @@ -1005,6 +1008,7 @@ namespace { constexpr int64_t power_sets = 1LL << (vec::size()); test_blend(expected_val, a, b); } +#endif template // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) void test_set(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], int64_t count){ @@ -1606,6 +1610,7 @@ namespace { ASSERT_TRUE(vec_pinf.has_inf_nan()) << "Test failed for positive Infinity\n"; ASSERT_TRUE(vec_ninf.has_inf_nan()) << "Test failed for negative Infinity\n"; } +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecConvertTests, Convert) { using vec = TypeParam; using src_t = ValueType; @@ -1658,6 +1663,7 @@ namespace { TEST_CONVERT_TO(double); #undef TEST_CONVERT_TO } +#endif TYPED_TEST(VecMaskTests, MaskedLoad) { using vec = TypeParam; using src_t = ValueType; @@ -1716,6 +1722,7 @@ namespace { #undef TEST_MASK_LOAD #undef TEST_MASK_LOAD_N } +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, MaskedCheck) { using VT = ValueType; using vec = TypeParam; @@ -1739,6 +1746,8 @@ namespace { #undef TEST_MASK_CHECK_N } +#endif +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, ToFrom) { using vec = TypeParam; using VT = ValueType; @@ -1764,6 +1773,8 @@ namespace { << "Failure Details:\nTest Seed to reproduce: " << seed; } } +#endif +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, Cast) { using vec = TypeParam; using src_t = ValueType; @@ -1808,6 +1819,7 @@ namespace { #undef TEST_MASK_CAST #undef TEST_MASK_CAST_N } +#endif #else #error GTEST does not have TYPED_TEST #endif diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 98fda78885aecf..9215e9ff393f38 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -53,6 +53,9 @@ CACHE_ALIGN #define defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__)) #undef CHECK_DEQUANT_WITH_LOW_PRECISION #define CHECK_WITH_FMA 1 +#elif defined(CPU_CAPABILITY_SVE) +#define CHECK_DEQUANT_WITH_LOW_PRECISION 1 +#define CHECK_WITH_FMA 1 #elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2) #undef CHECK_DEQUANT_WITH_LOW_PRECISION #undef CHECK_WITH_FMA diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index d221f5844f30d4..5e383d9715298d 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -322,6 +322,18 @@ if(INTERN_BUILD_ATEN_OPS) LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}") endif(CXX_ZVECTOR_FOUND) + if(CXX_SVE_FOUND) + if(CXX_SVE256_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION") + list(APPEND CPU_CAPABILITY_NAMES "SVE256") + if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8.2-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256") + else() + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8.2-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256") + endif() + endif(CXX_SVE256_FOUND) + endif(CXX_SVE_FOUND) + list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES) math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1") diff --git a/cmake/Modules/FindARM.cmake b/cmake/Modules/FindARM.cmake index c7985d0111f024..340c2a64e0356e 100644 --- a/cmake/Modules/FindARM.cmake +++ b/cmake/Modules/FindARM.cmake @@ -22,6 +22,15 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") set(ASIMD_FOUND false CACHE BOOL "ASIMD/NEON available on host") ENDIF (ASIMD_TRUE) + #sve instruction can be found on the majority part of modern ARM processor + STRING(REGEX REPLACE "^.*(sve).*$" "\\1" SVE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "sve" "${SVE_THERE}" SVE_TRUE) + IF (SVE_TRUE) + set(SVE_FOUND true CACHE BOOL "SVE available on host") + ELSE (SVE_TRUE) + set(SVE_FOUND false CACHE BOOL "SVE available on host") + ENDIF (SVE_TRUE) + #Find the processor type (for now OMAP3 or OMAP4) STRING(REGEX REPLACE "^.*(OMAP3).*$" "\\1" OMAP3_THERE "${CPUINFO}") STRING(COMPARE EQUAL "OMAP3" "${OMAP3_THERE}" OMAP3_TRUE) @@ -79,3 +88,72 @@ if(NOT CORTEXA9_FOUND) MESSAGE(STATUS "No OMAP4 processor on this machine.") endif(NOT CORTEXA9_FOUND) mark_as_advanced(NEON_FOUND) + +#SVE support is availale is only for Linux OS. +IF(CMAKE_SYSTEM_NAME MATCHES "Linux") + # Include necessary modules for checking C and C++ source compilations + INCLUDE(CheckCSourceCompiles) + INCLUDE(CheckCXXSourceCompiles) + + # Test code for SVE support + SET(SVE_CODE " + #include + int main() + { + svfloat64_t a; + a = svdup_n_f64(0); + return 0; + } + ") + + # Macro to check for SVE instruction support + MACRO(CHECK_SVE lang type flags) + # Save the current state of required flags + SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + + # Set the flags necessary for compiling the test code with SVE support + SET(CMAKE_REQUIRED_FLAGS "${CMAKE_${lang}_FLAGS_INIT} ${flags}") + + # Check if the source code compiles with the given flags for the specified language (C or C++) + IF(lang STREQUAL "CXX") + CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type}) + ELSE() + CHECK_C_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type}) + ENDIF() + + # If the compilation test is successful, set appropriate variables indicating support + IF(${lang}_HAS_${type}) + set(${lang}_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support") + SET(${lang}_${type}_FLAGS "${flags}" CACHE STRING "${lang} ${type} flags") + ENDIF() + + # Restore the original state of required flags + SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + # If the compilation test fails, indicate that the support is not found + IF(NOT ${lang}_${type}_FOUND) + SET(${lang}_${type}_FOUND FALSE CACHE BOOL "${lang} ${type} support") + SET(${lang}_${type}_FLAGS "" CACHE STRING "${lang} ${type} flags") + ENDIF() + + # Mark the variables as advanced to hide them in the default CMake GUI + MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS) + ENDMACRO() + + # Check for SVE256 vector length + CHECK_SVE(CXX "SVE256" "-march=armv8-a+sve -msve-vector-bits=256") + + # If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user + if(NOT CXX_SVE256_FOUND) + set(CXX_SVE_FOUND FALSE CACHE BOOL "SVE not available on host") + message(STATUS "No SVE processor on this machine.") + else() + # If SVE256 support is found, set CXX_SVE_FOUND to TRUE and notify the user + set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + message(STATUS "SVE support detected.") + endif() + + # Mark the SVE support variable as advanced + mark_as_advanced(CXX_SVE_FOUND) +ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux") diff --git a/setup.py b/setup.py index b2cb4cd5926172..e9f5d2a579432c 100644 --- a/setup.py +++ b/setup.py @@ -1244,6 +1244,7 @@ def main(): "include/ATen/cpu/vec/vec256/zarch/*.h", "include/ATen/cpu/vec/vec512/*.h", "include/ATen/cpu/vec/*.h", + "include/ATen/cpu/vec/sve/*.h", "include/ATen/core/*.h", "include/ATen/cuda/*.cuh", "include/ATen/cuda/*.h", diff --git a/torch/backends/cpu/__init__.py b/torch/backends/cpu/__init__.py index cc8a174626bf17..82dc52cd4904c1 100644 --- a/torch/backends/cpu/__init__.py +++ b/torch/backends/cpu/__init__.py @@ -16,5 +16,6 @@ def get_cpu_capability() -> str: - "NO AVX" - "AVX2" - "AVX512" + - "SVE256" """ return torch._C._get_cpu_capability() From df249edbe5867309fedac3b695087a23fc4f51c1 Mon Sep 17 00:00:00 2001 From: eqy Date: Wed, 18 Sep 2024 19:42:10 +0000 Subject: [PATCH 0993/1018] [cuDNN][conv][A100] Bump tolerances for `vmap_autograd_grad` `conv2d` on A100 (#136178) Likely due to a cuDNN heuristics update Pull Request resolved: https://github.com/pytorch/pytorch/pull/136178 Approved by: https://github.com/Skylion007 --- test/functorch/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 5cd18a072d8daf..93e8f23d1ea40e 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -2404,7 +2404,7 @@ def fn(input, weight, bias): tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}), tol1( "nn.functional.conv2d", - {torch.float32: tol(atol=3e-05, rtol=5e-06)}, + {torch.float32: tol(atol=5e-05, rtol=5e-05)}, device_type="cuda", ), tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), From 4814341438778928e29bdb4c60b832fb3361789d Mon Sep 17 00:00:00 2001 From: Atul Jangra Date: Wed, 18 Sep 2024 20:14:10 +0000 Subject: [PATCH 0994/1018] Add wait counter for nccl abort (#136067) Summary: Quite a few times, we see the NCCL PG abort taking too long. There's no easy way to measure this, so let's add a counter to measure this across the stack. This will help us measure how much time we take the NCCL abort. Test Plan: Unit tests Reviewed By: c-p-i-o Differential Revision: D62675010 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136067 Approved by: https://github.com/fduwjj --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index a75c989d4d56a3..8a7aefdc238c4e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1202,6 +1202,8 @@ void ProcessGroupNCCL::abortCommsFromMap( // Abort all communicators on this rank bool ProcessGroupNCCL::abort(std::optional abortReason) { + // This will log counter for how long the abort actually takes. + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); // Remove record from global ncclCommDevIdxMapMutex before aboarting, // so that a new cache segment would not register to already aborded // communicators. Note that ncclCommDevIdxMap is a global container which may From 01db0e830562e031deaf89e86f9a962390a32506 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 16 Sep 2024 17:48:53 +0000 Subject: [PATCH 0995/1018] Fix calling Add._from_args and Mul._from_args (#136143) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136143 Approved by: https://github.com/ezyang --- torch/fx/experimental/symbolic_shapes.py | 48 +++++++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 51abd7f70a6869..2243ab02f7f669 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -386,6 +386,40 @@ def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean: expr = sympy.logic.boolalg.to_cnf(expr) return _canonicalize_bool_expr_impl(expr) + +def _sympy_from_args( + cls: type, + args: List[sympy.Expr], + sort: bool = True, + is_commutative: Optional[bool] = None, +) -> sympy.Expr: + if not args: + return cls.identity + # These args are already in canonical form, so we avoid calling + # Add(*args) to avoid expensive Add.flatten operation + if sort: + if cls is sympy.Add: + sort_fn = sympy.core.add._addsort + elif cls is sympy.Mul: + sort_fn = sympy.core.mul._mulsort + else: + raise ValueError(f"Unknown cls: {cls}") + + # we don't support non commutative with sort + assert is_commutative is True + if args[0].is_Number: + rest = args[1:] + sort_fn(rest) + return cls._from_args([args[0]] + rest, is_commutative=is_commutative) + else: + args = args.copy() + sort_fn(args) + return cls._from_args(args, is_commutative=is_commutative) + else: + # if the args are already sorted, we create directly + return cls._from_args(args, is_commutative=is_commutative) + + def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: """ After canonicalization, we are guaranteed to have eliminated Ge/Gt relations @@ -416,8 +450,10 @@ def is_neg(t): neg.append(-term) else: pos.append(term) - lhs = sympy.Add._from_args(neg) - rhs = sympy.Add._from_args(pos) + # these are already sorted + rhs = _sympy_from_args(sympy.Add, pos, sort=False, is_commutative=True) + # the terms were changed, so needs a sorting + lhs = _sympy_from_args(sympy.Add, neg, sort=True, is_commutative=True) elif is_neg(rhs): # lhs == 0 lhs, rhs = -rhs, 0 @@ -448,12 +484,12 @@ def div_by_factor(x, factor): return x / factor elif x.is_Mul: if x.args[0] != factor: - args = (x.args[0] / factor, *x.args[1:]) + args = [x.args[0] / factor, *x.args[1:]] else: # Mul._from_args require a canonical list of args # so we remove the first arg (x.args[0] / factor) if it was 1 - args = x.args[1:] - return sympy.Mul._from_args(args, is_commutative=x.is_commutative) + args = list(x.args[1:]) + return _sympy_from_args(sympy.Mul, args, is_commutative=x.is_commutative) if expr.is_Add: atoms = expr.args @@ -461,7 +497,7 @@ def div_by_factor(x, factor): if factor == 1: return expr atoms = [div_by_factor(x, factor) for x in atoms] - return sympy.Add._from_args(atoms) + return _sympy_from_args(sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative) elif expr.is_Integer: return sympy.One elif expr.is_Mul: From 3cd03e22f79994199ceee3c6ebc2c3500bae0362 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 17 Sep 2024 23:13:36 -0700 Subject: [PATCH 0996/1018] [Traceable FSDP2] Minor refactor to traceable FSDP2 unit tests (#136219) Changes in this PR: - Monkey-patching `F.scaled_dot_product_attention` with a lambda seems to not work in some cases. This PR avoids using a lambda. - Running `fullgraph=True` and `fullgraph=False` in the same unit test seems to cause the two cases to interfere with each other and causes error. This PR splits them into two separate unit tests. - The checks in the unit tests might not work with compile cache. This PR turns off the cache in order to have a more predictable compile behavior to do unit test on. Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor_fullgraph_True` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor_fullgraph_False` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor_fullgraph_True` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor_fullgraph_False` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136219 Approved by: https://github.com/yifuwang --- .../fsdp/test_fully_shard_compile.py | 120 ++++++++++++------ 1 file changed, 82 insertions(+), 38 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 4a4aac1f63fde7..ecd1e7a6d9aa27 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -44,6 +44,9 @@ def _is_fallback_op_in_snodes(snodes, op): return any(is_fallback_op(snode.node, op) for snode in snodes) +orig_F_scaled_dot_product_attention = F.scaled_dot_product_attention + + class TestFullyShardCompileCompute(FSDPTest): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @@ -465,13 +468,21 @@ def test_eager(): res = run_iters(model, optim) return res + torch._dynamo.reset() + torch._dynamo.compiled_autograd.reset() with torch._dynamo.config.patch( + # NOTE: Setting fullgraph=False for forward (to allow graph-breaks) is a common scenario + # and in that case we need a standalone Compiled Autograd ctx that has fullgraph=True for backward. + # Hence here we explicitly set compiled_autograd=False and use the standalone Compiled Autograd ctx + # `maybe_compiled_autograd_ctx` created in `run_iters()`. + compiled_autograd=False, inline_inbuilt_nn_modules=True, skip_fsdp_hooks=False, ), torch._functorch.config.patch( + enable_autograd_cache=False, recompute_views=True, - cse=False, ), torch._inductor.config.patch( + force_disable_caches=True, reorder_for_compute_comm_overlap=True, reorder_for_compute_comm_overlap_passes=[ "sink_waits", @@ -623,8 +634,8 @@ def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): @skipIfRocm @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - def test_nested_fully_shard_backend_inductor(self): - for fullgraph in [True, False]: + def test_nested_fully_shard_backend_inductor_fullgraph_True(self): + for fullgraph in [True]: with self._reinplace_all_gather_with_optional_checks( fullgraph ), self._maybe_run_decide_global_ordering_of_comms_with_checks( @@ -650,8 +661,9 @@ def test_nested_fully_shard_backend_inductor(self): ) ) if fullgraph: - self.assertTrue( - len(triton_codes) == 2, + self.assertEqual( + len(triton_codes), + 2, "Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph", ) fwd_code = triton_codes[0] @@ -715,13 +727,24 @@ def test_nested_fully_shard_backend_inductor(self): file_check, **bwd_rs_block_info ) file_check.run(bwd_code) - else: - # TODO: when fullgraph=False and there is graph break in FWD graph, - # there are several recompiles, need to figure out why. - self.assertTrue( - len(triton_codes) > 2, - "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", - ) + + @skipIfRocm + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_nested_fully_shard_backend_inductor_fullgraph_False(self): + _, triton_codes = run_and_get_code( + lambda: self._test_traceable_fsdp( + *self._create_nested_fully_shard_factory_fns(fullgraph=False), + "inductor", + fullgraph=False, + ) + ) + # TODO: when fullgraph=False and there is graph break in FWD graph, + # there are several recompiles, need to figure out why. + self.assertGreater( + len(triton_codes), + 2, + "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", + ) def _create_transformer_factory_fns( self, all_requires_grad, *, activation_checkpoint=False @@ -769,20 +792,18 @@ def input_creation_fn(): return model_init_fn, input_creation_fn def _maybe_add_graph_break_to_sdpa(self, fullgraph): - def _sdpa_with_graph_break(orig_fn, fullgraph, *args, **kwargs): - if not fullgraph: - torch._dynamo.graph_break() - return orig_fn(*args, **kwargs) - - return mock.patch.object( - F, - "scaled_dot_product_attention", - functools.partial( + def _sdpa_with_graph_break(*args, **kwargs): + torch._dynamo.graph_break() + return orig_F_scaled_dot_product_attention(*args, **kwargs) + + if not fullgraph: + return mock.patch.object( + F, + "scaled_dot_product_attention", _sdpa_with_graph_break, - F.scaled_dot_product_attention, - fullgraph, - ), - ) + ) + else: + return contextlib.nullcontext() @skipIfRocm @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @@ -822,17 +843,14 @@ def test_transformer_backend_aot_eager_decomp_partition(self): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why @torch._inductor.config.patch(fallback_random=True) - def test_transformer_backend_inductor(self): - # TODO: enable fullgraph=False case + def test_transformer_backend_inductor_fullgraph_True(self): for fullgraph, all_requires_grad, activation_checkpoint in itertools.product( [True], [True, False], [True, False] ): log.warning( f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001 ) - with self._maybe_add_graph_break_to_sdpa( - fullgraph - ), self._reinplace_all_gather_with_optional_checks( + with self._reinplace_all_gather_with_optional_checks( fullgraph ), self._maybe_run_decide_global_ordering_of_comms_with_checks( fullgraph @@ -861,8 +879,9 @@ def test_transformer_backend_inductor(self): ) ) if fullgraph: - self.assertTrue( - len(triton_codes) == 2, + self.assertEqual( + len(triton_codes), + 2, "Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph", ) fwd_code = triton_codes[0] @@ -922,13 +941,38 @@ def test_transformer_backend_inductor(self): file_check, **bwd_rs_block_info ) file_check.run(bwd_code) - else: - # TODO: when fullgraph=False and there is graph break in FWD graph, - # there are several recompiles, need to figure out why. - self.assertTrue( - len(triton_codes) > 2, - "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", + + @skipIfRocm + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + # TODO: native_dropout causes CUDA IMA error, need to figure out why + @torch._inductor.config.patch(fallback_random=True) + def test_transformer_backend_inductor_fullgraph_False(self): + fullgraph = False + # TODO: fix numerical issue in activation_checkpoint=True case + for all_requires_grad, activation_checkpoint in itertools.product( + [True, False], [False] + ): + log.warning( + f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001 + ) + with self._maybe_add_graph_break_to_sdpa(fullgraph): + _, triton_codes = run_and_get_code( + lambda: self._test_traceable_fsdp( + *self._create_transformer_factory_fns( + all_requires_grad=all_requires_grad, + activation_checkpoint=activation_checkpoint, + ), + "inductor", + fullgraph=fullgraph, + ) ) + # TODO: when fullgraph=False and there is graph break in FWD graph, + # there are several recompiles, need to figure out why. + self.assertGreater( + len(triton_codes), + 2, + "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", + ) if __name__ == "__main__": From 09cb80c4b140ee8c240e3138d83903407ff4128d Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Wed, 18 Sep 2024 22:32:34 +0000 Subject: [PATCH 0997/1018] [dynamo]Remove stream hardcoding in dynamo VariableBuilder (#131763) Fixes #ISSUE_NUMBER Recent change from PR#123487 used torch.cuda.Stream directly and this causes failure for other backends. This PR will generalize the stream handling for all backends like cuda/hpu/xpu Pull Request resolved: https://github.com/pytorch/pytorch/pull/131763 Approved by: https://github.com/yanboliang, https://github.com/yf225 --- torch/_dynamo/variables/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 193e227267ec3b..46e970cd85335f 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -826,7 +826,7 @@ def build_key_value(i, k, v): self.install_guards(GuardBuilder.ID_MATCH) stream_proxy = self.tx.output.create_proxy( "call_function", - torch.cuda.Stream, + type(value), (), { "stream_id": value.stream_id, From 0073d1043502589237a8186d9b1859a0617802cb Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 18 Sep 2024 22:44:43 +0000 Subject: [PATCH 0998/1018] [BE][MPS] Delete duplicated code in `View.mm` (#136295) After https://github.com/pytorch/pytorch/pull/135706 `getGatherScatterScalarType` returns exactly the same results as `scalarToMetalTypeString` , so delete the function and call `scalarToMetalTypeString` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136295 Approved by: https://github.com/kit1980 --- aten/src/ATen/native/mps/operations/View.mm | 29 +++------------------ 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index f1bc8ebf97026c..ba23956b5d32a4 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -729,27 +729,6 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self) { return kernelName + "_kernel_" + std::to_string(dim == 0 ? 1 : dim); } -static const std::string& getGatherScatterScalarType(const Tensor& t) { - auto scalar_type = t.scalar_type(); - static std::unordered_map scalarToMetalType = { - {c10::ScalarType::Float, "float"}, - {c10::ScalarType::Half, "half"}, - {c10::ScalarType::BFloat16, "bfloat"}, - {c10::ScalarType::Long, "long"}, - {c10::ScalarType::Int, "int"}, - {c10::ScalarType::Short, "short"}, - {c10::ScalarType::Char, "char"}, - {c10::ScalarType::Byte, "uchar"}, - {c10::ScalarType::Bool, "bool"}, - {c10::ScalarType::ComplexFloat, "float2"}, - {c10::ScalarType::ComplexHalf, "half2"}, - }; - - auto it = scalarToMetalType.find(scalar_type); - TORCH_CHECK(it != scalarToMetalType.end(), "Unsupported type byte size: ", scalar_type); - return it->second; -} - static std::string genScatterGatherCvtFunc(const std::string& dtypeSrc, const std::string& dtypeDst, bool needsConj) { const bool srcComplex = dtypeSrc[dtypeSrc.size() - 1] == '2'; const bool dstComplex = dtypeDst[dtypeDst.size() - 1] == '2'; @@ -805,8 +784,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { id computeEncoder = mpsStream->commandEncoder(); std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/false); id gatherPSO = getPipelineState(functionName, - getGatherScatterScalarType(src), - getGatherScatterScalarType(output), + scalarToMetalTypeString(src), + scalarToMetalTypeString(output), /*needsScatter=*/false, src.is_conj() != dst.is_conj()); @@ -862,8 +841,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true); id scatterPSO = getPipelineState(functionName, - getGatherScatterScalarType(src), - getGatherScatterScalarType(output), + scalarToMetalTypeString(src), + scalarToMetalTypeString(output), /*needsScatter=*/true, src.is_conj() != output.is_conj()); From f39faaa0688ca69ef959f748436844070806c0fd Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 18 Sep 2024 23:38:31 +0000 Subject: [PATCH 0999/1018] [BE][MPS] Fix deprecation warnings on MacOS 15.0 (#136292) [reverseSquareRootWithTensor:](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/reversesquareroot(with:name:)?changes=__8&language=objc) were deprecated in favor of [reciprocalSquareRootWithTensor:](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/reciprocalsquareroot(_:name:)?changes=__8&language=objc) Without it, following warnings are generated if compiled on recently released MacOS Sequoia: ``` /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:720:35: warning: 'reverseSquareRootWithTensor:name:' is deprecated: first deprecated in macOS 15.0 [-Wdeprecated-declarations] 720 | rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; | ^~~~~~~~~~~~~~~~~~~~~~~~~~~ | reciprocalSquareRootWithTensor /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__type_traits/invoke.h:341:10: note: in instantiation of function template specialization 'at::native::batch_norm_backward_mps(const Tensor &, const Tensor &, const std::optional &, const std::optional &, const std::optional &, const std::optional &, const std::optional &, bool, double, std::array)::(anonymous class)::operator()' requested here 341 | decltype(std::declval<_Fp>()(std::declval<_Args>()...)) | ^ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__type_traits/invoke.h:351:19: note: while substituting deduced template arguments into function template '__invoke' [with _Fp = (lambda at /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:623:68) &, _Args = ] 351 | static decltype(std::__invoke(std::declval<_XFp>(), std::declval<_XArgs>()...)) __try_call(int); | ^ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__type_traits/invoke.h:357:28: note: while substituting deduced template arguments into function template '__try_call' [with _XFp = (lambda at /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:623:68) &, _XArgs = (no value)] 357 | using _Result = decltype(__try_call<_Fp, _Args...>(0)); | ^ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__type_traits/conjunction.h:27:32: note: in instantiation of template class 'std::__invokable_r' requested here 27 | __expand_to_true<__enable_if_t<_Pred::value>...> __and_helper(int); | ^ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__type_traits/conjunction.h:38:39: note: while substituting explicitly-specified template arguments into function template '__and_helper' 38 | using _And _LIBCPP_NODEBUG = decltype(std::__and_helper<_Pred...>(0)); | ^ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__functional/function.h:828:20: note: (skipping 1 context in backtrace; use -ftemplate-backtrace-limit=0 to see all) 828 | bool = _And< _IsNotSame<__remove_cvref_t<_Fp>, function>, __invokable<_Fp, _ArgTypes...> >::value> | ^ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__functional/function.h:841:49: note: in instantiation of default argument for '__callable<(lambda at /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:623:68) &>' required here 841 | using _EnableIfLValueCallable = __enable_if_t<__callable<_Fp&>::value>; | ^~~~~~~~~~~~~~~~ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__functional/function.h:851:32: note: in instantiation of template type alias '_EnableIfLValueCallable' requested here 851 | template > | ^ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/usr/include/c++/v1/__functional/function.h:852:25: note: in instantiation of default argument for 'function<(lambda at /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:623:68)>' required here 852 | _LIBCPP_HIDE_FROM_ABI function(_Fp); | ^~~~~~~~~~~~~ /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:623:68: note: while substituting deduced template arguments into function template 'function' [with _Fp = (lambda at /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:623:68), $1 = (no value)] 623 | auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { | ^ /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:623:24: note: while substituting deduced template arguments into function template 'LookUpOrCreateCachedGraph' [with T = CachedGraph] 623 | auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { | ^ /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/System/Library/Frameworks/MetalPerformanceShadersGraph.framework/Headers/MPSGraphArithmeticOps.h:123:1: note: 'reverseSquareRootWithTensor:name:' has been explicitly marked deprecated here 123 | -(MPSGraphTensor *) reverseSquareRootWithTensor:(MPSGraphTensor *) tensor | ^ /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Normalization.mm:745:37: warning: 'reverseSquareRootWithTensor:name:' is deprecated: first deprecated in macOS 15.0 [-Wdeprecated-declarations] 745 | rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; | ^~~~~~~~~~~~~~~~~~~~~~~~~~~ | reciprocalSquareRootWithTensor /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.0.sdk/System/Library/Frameworks/MetalPerformanceShadersGraph.framework/Headers/MPSGraphArithmeticOps.h:123:1: note: 'reverseSquareRootWithTensor:name:' has been explicitly marked deprecated here 123 | -(MPSGraphTensor *) reverseSquareRootWithTensor:(MPSGraphTensor *) tensor | ^ 2 warnings generated. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136292 Approved by: https://github.com/kit1980 --- aten/src/ATen/native/mps/operations/Normalization.mm | 8 ++++++++ aten/src/ATen/native/mps/operations/UnaryOps.mm | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index ee32bc3d2274a1..391422e77b5350 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -717,7 +717,11 @@ static string get_mem_string(c10::MemoryFormat memory_format) { MPSGraphTensor* varianceEpsTensor = [mpsGraph additionWithPrimaryTensor:runningVarTensor secondaryTensor:epsilonTensor name:nil]; +#ifdef __MAC_15_0 + rsqrtTensor = [mpsGraph reciprocalSquareRootWithTensor:varianceEpsTensor name:nil]; +#else rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; +#endif MPSGraphTensor* bnForwardTensor = [mpsGraph multiplicationWithPrimaryTensor:xMinusMean secondaryTensor:rsqrtTensor name:nil]; @@ -742,7 +746,11 @@ static string get_mem_string(c10::MemoryFormat memory_format) { MPSGraphTensor* varianceEpsTensor = [mpsGraph additionWithPrimaryTensor:runningVarTensor secondaryTensor:epsilonTensor name:nil]; +#ifdef __MAC_15_0 + rsqrtTensor = [mpsGraph reciprocalSquareRootWithTensor:varianceEpsTensor name:nil]; +#else rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; +#endif } gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:unitTensor secondaryTensor:rsqrtTensor name:nil]; diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index ea94faef98bdfe..334a056cddfb53 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -225,7 +225,11 @@ static void unary_op(const Tensor& self, CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(sqrt_out_mps, squareRoot) +#ifdef __MAC_15_0 +CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(rsqrt_out_mps, reciprocalSquareRoot) +#else CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(rsqrt_out_mps, reverseSquareRoot) +#endif CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(neg_out_mps, negative) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(log_out_mps, logarithm) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(log10_out_mps, logarithmBase10) From 5ea1746618c81ffefe201c6aa8fdcdc71047ef27 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 18 Sep 2024 14:12:45 -0700 Subject: [PATCH 1000/1018] Add uint16 support for observer (#136238) Summary: att Test Plan: python test/test_quantization.py -k TestObserver Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D62909821](https://our.internmc.facebook.com/intern/diff/D62909821) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136238 Approved by: https://github.com/tarun292 --- .../quantization/core/test_workflow_module.py | 34 ++++++++++++++++--- torch/ao/quantization/observer.py | 3 ++ torch/ao/quantization/utils.py | 4 +++ 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 168c00920f623a..964dc051c3a434 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -67,14 +67,30 @@ NP_RANDOM_SEED = 19 tolerance = 1e-6 +# copy and modified from torch/ao/quantization/observer.py +_INT_DTYPES = ( + torch.qint8, + torch.quint8, + torch.quint4x2, + torch.qint32, + torch.int8, + torch.uint8, + torch.int16, + torch.int32, + torch.uint16, +) + class TestObserver(QuantizationTestCase): - @given(qdtype=st.sampled_from((torch.qint8, torch.quint8, torch.qint32)), + @given(qdtype=st.sampled_from(_INT_DTYPES), qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), reduce_range=st.booleans()) def test_per_tensor_observers(self, qdtype, qscheme, reduce_range): # reduce_range cannot be true for symmetric quantization with uint8 if (qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric) or qdtype == torch.qint32: reduce_range = False + if qdtype == torch.quint4x2: + return + ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range), MovingAverageMinMaxObserver(averaging_constant=0.5, dtype=qdtype, @@ -82,18 +98,23 @@ def test_per_tensor_observers(self, qdtype, qscheme, reduce_range): reduce_range=reduce_range)] def _get_ref_params(reduce_range, qscheme, dtype, input_scale, min_val, max_val): + assert dtype in _INT_DTYPES, "Not supported dtype: {dtype}, supported dtypes are {_INT_DTYPES}" eps = torch.tensor([tolerance]) - if dtype == torch.qint8: + if dtype in [torch.qint8, torch.int8]: if reduce_range: quant_min, quant_max = -64, 63 else: quant_min, quant_max = -128, 127 - elif dtype == torch.quint8: + elif dtype in [torch.quint8, torch.uint8]: if reduce_range: quant_min, quant_max = 0, 127 else: quant_min, quant_max = 0, 255 - elif dtype == torch.qint32: + elif dtype == torch.int16: + quant_min, quant_max = -1 * (2 ** 15), (2 ** 15) - 1 + elif dtype == torch.uint16: + quant_min, quant_max = 0, (2 ** 16) - 1 + elif dtype in [torch.qint32, torch.int32]: quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1 min_val_neg = torch.tensor([0.]) @@ -103,12 +124,15 @@ def _get_ref_params(reduce_range, qscheme, dtype, input_scale, min_val, max_val) if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: scale = torch.max(-min_val_neg, max_val_pos) / (float(quant_max - quant_min) / 2) scale = torch.max(scale, eps) - if dtype == torch.quint8: + if dtype in [torch.quint8, torch.uint8]: zero_point = 128 + if dtype in [torch.uint16]: + zero_point = 2 ** 15 else: scale = torch.max((max_val_pos - min_val_neg) / float(quant_max - quant_min), eps) zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale, zero_point for myobs in ObserverList: diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index e26f03027116e2..9b4e26743ab677 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -253,6 +253,7 @@ def __init__( torch.int32, torch.float8_e5m2, torch.float8_e4m3fn, + torch.uint16, ) assert ( @@ -368,6 +369,8 @@ def _calculate_qparams( ) else: zero_point = zero_point.new_full(zero_point.size(), 128) + elif self.dtype in [torch.uint16]: + zero_point = zero_point.new_full(zero_point.size(), 2**15) elif self.qscheme == torch.per_channel_affine_float_qparams: scale = (max_val - min_val) / float(quant_max - quant_min) scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index d243932a1ef82d..293e37ed456fb4 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -473,6 +473,10 @@ def calculate_qmin_qmax( quant_min, quant_max = 0, 255 elif dtype in [torch.qint32, torch.int32]: quant_min, quant_max = -1 * (2**31), (2**31) - 1 + elif dtype in [torch.uint16]: + quant_min, quant_max = 0, 2**16 - 1 + elif dtype in [torch.int16]: + quant_min, quant_max = -(2**15), 2**15 - 1 else: quant_min, quant_max = 0, 15 return quant_min, quant_max From 80970ab1c8483bdb7037d37d915e320c344f77a6 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 17 Sep 2024 00:48:24 +0000 Subject: [PATCH 1001/1018] [dynamo] fix crash in InspectSignatureVariable (#136010) Fix crash that was happening in https://github.com/pytorch/pytorch/issues/128095, because we were trying to extract a constant incorrectly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136010 Approved by: https://github.com/yanboliang, https://github.com/anijain2305, https://github.com/jansel --- torch/_dynamo/variables/misc.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2b62c0c1acc116..663ff5b20a6d0b 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -41,6 +41,7 @@ UserMethodVariable, wrap_bound_arg, ) +from .nn_module import UnspecializedNNModuleVariable from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable @@ -393,18 +394,20 @@ def __init__(self, inspected: VariableTracker, **kwargs) -> None: super().__init__(**kwargs) self.inspected = inspected + try: + if hasattr(self.inspected, "get_function"): + self.fn = self.inspected.get_function() + elif isinstance(self.inspected, UnspecializedNNModuleVariable): + self.fn = self.inspected.value + else: + self.fn = self.inspected.as_python_constant() + except NotImplementedError: + unimplemented("inspect.signature with non-constant function") + + self.signature = inspect.signature(self.fn) + self.parameters = list(self.signature.parameters.items()) if isinstance(self.inspected, UserMethodVariable): - self.fn = self.inspected.get_function() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items())[1:] - elif isinstance(self.inspected, UserFunctionVariable): - self.fn = self.inspected.get_function() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items()) - else: - self.fn = self.inspected.as_python_constant() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items()) + self.parameters = self.parameters[1:] def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if name == "parameters": From 82d9db2dfa1b370a56453fc79169cab5f2500b60 Mon Sep 17 00:00:00 2001 From: Duygu Altinok Date: Thu, 19 Sep 2024 03:09:36 +0000 Subject: [PATCH 1002/1018] Add type checks for Tensor.add_ (#135864) Fixes #127049 There's already a meta func in `meta_registrations.py` for `add_` and `sub_` methods. I added a second meta function for error checking, i.e `int.add/sub_(float)` and `bool.add/sub_(other types)` . Also the corresponding test with Dynamo passes, removed `@xfailIfTorchDynamo`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135864 Approved by: https://github.com/williamwen42 --- test/test_type_promotion.py | 4 +--- torch/_meta_registrations.py | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 6a1ae247c14320..a4bbb8394da215 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -8,8 +8,7 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, make_tensor, TEST_NUMPY, set_default_dtype, torch_to_numpy_dtype_dict, - numpy_to_torch_dtype_dict, skipIfTorchDynamo, - xfailIfTorchDynamo) + numpy_to_torch_dtype_dict, skipIfTorchDynamo) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, dtypes, onlyCPU, expectedFailureMeta, skipMeta) from torch.testing._internal.common_dtype import ( @@ -45,7 +44,6 @@ class TestTypePromotion(TestCase): # Promoting inplace would require re-allocating and copying the memory of the # tensor data, since element size could change. # https://github.com/pytorch/pytorch/issues/127049 - @xfailIfTorchDynamo @float_double_default_dtype def test_inplace(self, device): int_tensor = torch.ones([4, 4, 4], dtype=torch.int32, device=device) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 820c21218cb1c3..b67f28b75cc413 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -16,10 +16,12 @@ from torch._ops import OpOverload from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND from torch._prims_common import ( + BoolLike, corresponding_complex_dtype, corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, IntLike, make_contiguous_strides_for, Number, @@ -3563,6 +3565,45 @@ def meta_binop_inplace(self, other): ], ) def meta_binop_inplace_alpha(self, other, alpha=1): + """ + Some checks for inplace ops. + Checks for promotion rules for some dtypes. + int.add/sub_(float) and bool.add/sub_(others) are rejected. + Promoting in these in-place operations would require reallocating + and copying over elements, hence not allowed. + Checks for alpha param. + """ + + def is_integeric(arg): + if isinstance(arg, TensorLike): + return utils.is_integer_dtype(arg.dtype) + else: + return isinstance(arg, IntLike) + + def is_floatic(arg): + if isinstance(arg, TensorLike): + return utils.is_float_dtype(arg.dtype) + else: + return isinstance(arg, FloatLike) + + def is_booleanic(arg): + if isinstance(arg, TensorLike): + return utils.is_boolean_dtype(arg.dtype) + else: + return isinstance(arg, BoolLike) + + # Do not allow int+float->int in-place + if is_integeric(self) and is_floatic(other): + raise RuntimeError( + "Promotion of int.add/sub_(float) in in-place ops are not possible due to element size change." + ) + + # Do not allow bool+other->bool in-place + if is_booleanic(self) and not is_booleanic(other): + raise RuntimeError( + "Promotion of book.add/sub_(others) in in-place ops are not possible due to element size change." + ) + if isinstance(other, torch.Tensor): check_inplace_broadcast(self.shape, other.shape) return self From a8946732ed458013249ff68564174260b24059a4 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 19 Sep 2024 04:17:06 +0000 Subject: [PATCH 1003/1018] XFAIL test_segfault (#136252) Fixes https://github.com/pytorch/pytorch/issues/128551 As this has been failing in trunk for a while and there is no owner yet to fix it properly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136252 Approved by: https://github.com/andrewkho --- test/test_dataloader.py | 3 +++ torch/testing/_internal/common_utils.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 0cdba26a303cf0..6ece8afac855b9 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -40,6 +40,7 @@ TEST_WITH_ROCM, TEST_WITH_TSAN, TestCase, + xfailIfLinux, ) from torch.utils.data import ( _utils, @@ -1382,6 +1383,8 @@ def test_multiple_dataloaders(self): del loader1_it del loader2_it + # https://github.com/pytorch/pytorch/issues/128551 + @xfailIfLinux def test_segfault(self): p = ErrorTrackingProcess(target=_test_segfault) p.start() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 4fb13abb5b96d3..ea3485249bb020 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1507,6 +1507,10 @@ def xfailIfTorchDynamo(func): return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func +def xfailIfLinux(func): + return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM else func + + def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"): """ Usage: From c4105553e660cc353f7d778c612f97b16ff5f2d0 Mon Sep 17 00:00:00 2001 From: Jan Wieczorek Date: Thu, 19 Sep 2024 11:52:16 +0000 Subject: [PATCH 1004/1018] Return unsafe_view instead of view from matmul when folding occurs (#134568) When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors. It can be especially problematic when after such function inplace allreduce is performed. Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned. Test included in this PR reproduces the issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134568 Approved by: https://github.com/zou3519 --- test/dynamo/test_misc.py | 53 +++++++++++++++++++++++++++++++++ test/test_proxy_tensor.py | 4 +-- torch/_decomp/decompositions.py | 4 +-- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 96047d3436ebe1..e546463059f0c8 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -11363,6 +11363,59 @@ def forward(self, x): opt_fn() +class TestCustomFunction(torch.testing._internal.common_utils.TestCase): + def test_autograd_function_with_matmul_folding_at_output(self): + """ + When tensor folding occurs during matmul operation returned tensor is a view. + This can cause issues when matmul is used inside a custom function + and such view is then returned as output. Then it cannot be modified inplace + and causes errors. + It can be especially problematic when after such function inplace allreduce + is performed. This test recreates this behaviour. + Issue is resolved when unsafe_view is returned from matmul instead. + """ + + class CustomFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp1, inp2): + ctx.save_for_backward(inp2) + ctx.output_shape = inp1.size() + return torch.matmul(inp1, inp2) + + @staticmethod + def backward(ctx, grad_output): + output_shape = ctx.output_shape + (inp2,) = ctx.saved_tensors + return ( + torch.mm(grad_output.squeeze(), inp2.t()).view(output_shape), + None, + ) + + def outer_function(inp1, inp2): + res = CustomFunction.apply(inp1, inp2) + res.add_(1.0) + return res.sum() + + def usual_function(inp1, inp2) -> torch.Tensor: + res = torch.matmul(inp1, inp2) + res.add_(1.0) + return res.sum() + + inp1_custom = torch.randn(4, 1, 2, requires_grad=True) + inp1_usual = inp1_custom.detach().clone().requires_grad_(True) + + inp2 = torch.randn(2, 4) + c_custom_func = torch.compile(outer_function) + c_usual_func = torch.compile(usual_function) + + result_custom = c_custom_func(inp1_custom, inp2) + result_custom.backward() + result_usual = c_usual_func(inp1_usual, inp2) + result_usual.backward() + + torch.allclose(inp1_custom.grad, inp1_usual.grad) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 1a35e6823a189e..255b177a10eda7 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1358,8 +1358,8 @@ def forward(self, crop_camera_1, mask_1): mul_4 = sym_size_int * 3 view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None - view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None - index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None + _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None + index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None return None""") # noqa: B950 def test_unbacked_slice(self): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index ec092a5234096d..f3c09d762f33ae 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4364,10 +4364,10 @@ def matmul(tensor1, tensor2, *, is_out=False): if t2_is_matrix: # This copies if we perform a 2D @ 3D and the first tensor requires_grad # See should_fold native/LinearAlgebra.cpp for why. - output = t1_folded.mm(t2).view(output_shape) + output = torch.ops.aten._unsafe_view(t1_folded.mm(t2), output_shape) return output.mT.contiguous() if transpose else output else: - return t1_folded.mv(t2).view(output_shape) + return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape) elif dim_tensor1 >= 1 and dim_tensor2 >= 1: # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); From 1c813705651e00b4ebec81291b685f5d33b7093a Mon Sep 17 00:00:00 2001 From: Igor Sugak Date: Thu, 19 Sep 2024 12:40:36 +0000 Subject: [PATCH 1005/1018] [CODEMOD][caffe2] use npt.NDArray instead of np.ndarray in type annotations (#136288) Summary: To facilitate PSS-2 upgrade, this uses `ndt.NDArray` instead of `nd.ndarray` in type annotations. In Numpy-1.19 (PSS-1) it's an alias to `nd.ndarray` -- a noop. In Numpy-1.24, `ndt.NDArray` a proper generic type, and without this change uses of `nd.ndarray` generate this Pyre type error: ```counterexample Invalid type parameters [24]: Generic type `np.ndarray` expects 2 type parameters. ``` Test Plan: Sandcastle plus visual inspection Differential Revision: D62977370 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136288 Approved by: https://github.com/kit1980 --- benchmarks/dynamo/common.py | 5 +++-- torch/ao/quantization/experimental/linear.py | 3 ++- torch/onnx/_internal/exporter/_core.py | 6 +++--- torch/onnx/verification.py | 3 ++- .../testing/_internal/common_methods_invocations.py | 13 +++++++------ torch/testing/_internal/opinfo/utils.py | 3 ++- torch/utils/tensorboard/_utils.py | 3 ++- 7 files changed, 21 insertions(+), 15 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 81ace9e9c1d197..104e59bc193a41 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -39,6 +39,7 @@ from unittest.mock import MagicMock import numpy as np +import numpy.typing as npt import pandas as pd import psutil import yaml @@ -1565,14 +1566,14 @@ def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: ... - def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, np.ndarray]: + def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, npt.NDArray]: pt_inputs = self.format_pt_inputs(pt_inputs) return { ort_input.name: pt_input.cpu().numpy() for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs) } - def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[np.ndarray]) -> Any: + def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[npt.NDArray]) -> Any: pt_outputs = [ torch.from_numpy(onnx_output).to(current_device) for onnx_output in onnx_outputs diff --git a/torch/ao/quantization/experimental/linear.py b/torch/ao/quantization/experimental/linear.py index 34b0ca8e3921db..0093550472e0ca 100644 --- a/torch/ao/quantization/experimental/linear.py +++ b/torch/ao/quantization/experimental/linear.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import numpy as np +import numpy.typing as npt import torch from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule @@ -148,7 +149,7 @@ def forward(self, activation: torch.Tensor) -> torch.FloatTensor: weight_rows = self.weight_transposed.size()[0] weight_cols = self.weight_transposed.size()[1] - decomposed_weight: np.ndarray = np.empty( + decomposed_weight: npt.NDArray = np.empty( shape=(weight_rows, weight_cols), dtype=object ) for row in range(weight_rows): diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 7d49a654a9c00c..09fae0ad2b88ee 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -42,7 +42,7 @@ if typing.TYPE_CHECKING: import os - import numpy as np + import numpy.typing as npt # Define utilities to convert PyTorch data types so users do not need to specify manually @@ -100,7 +100,7 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None): tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype), name=name ) - def numpy(self) -> np.ndarray: + def numpy(self) -> npt.NDArray: self.raw: torch.Tensor if self.dtype == ir.DataType.BFLOAT16: return self.raw.view(torch.uint16).numpy(force=True) @@ -114,7 +114,7 @@ def numpy(self) -> np.ndarray: return self.raw.view(torch.uint8).numpy(force=True) return self.raw.numpy(force=True) - def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray: + def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: del copy # Unused, but needed for the signature if dtype is None: return self.numpy() diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index a21f1ffbba7783..f489252f5a7b27 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -21,6 +21,7 @@ from typing import Any, Callable, Collection, Mapping, Sequence, Tuple, Union import numpy as np +import numpy.typing as npt import torch import torch._C._onnx as _C_onnx @@ -98,7 +99,7 @@ def _flatten_tuples(elem): # TODO(justinchuby): Add type checking by narrowing down the return type when input is None -def _to_numpy(elem) -> list | np.ndarray: +def _to_numpy(elem) -> list | npt.NDArray: if isinstance(elem, torch.Tensor): if elem.requires_grad: return elem.detach().cpu().numpy() diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 34ab66cd8077a5..5f23ec4475544f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -13,6 +13,7 @@ import torch import numpy as np +import numpy.typing as npt from torch import inf, nan from typing import Any, Dict, List, Tuple, Union, Sequence @@ -11342,7 +11343,7 @@ def _tanh_gelu_ref(X): return _gelu_ref(X) -def reference_one_hot(a: np.ndarray, num_classes: int = -1) -> np.ndarray: +def reference_one_hot(a: npt.NDArray, num_classes: int = -1) -> npt.NDArray: if num_classes == -1: num_classes = int(np.amax(a) + 1) @@ -11362,11 +11363,11 @@ def reference_mse_loss(input, target, reduction="mean"): return se -def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5): +def reference_layer_norm(inp: npt.NDArray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5): return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] -def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight, bias, eps): +def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: Tuple[int], weight, bias, eps): feature_size = np.prod(normalized_shape) inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] mean = inp_view.mean(axis=-1, keepdims=True) @@ -11383,7 +11384,7 @@ def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], w return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) -def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, eps=None): +def reference_rms_norm(inp: npt.NDArray, normalized_shape: Tuple[int], weight=None, eps=None): if eps is None: eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps feature_size = np.prod(normalized_shape) @@ -11395,7 +11396,7 @@ def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=Non return Y.reshape(*inp.shape) -def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=None, eps=1e-5): +def reference_group_norm(inp: npt.NDArray, num_groups: int, weight=None, bias=None, eps=1e-5): inp_view = inp if np.prod(inp.shape) != 0: inp_view = inp.reshape((inp.shape[0], num_groups, -1)) @@ -11481,7 +11482,7 @@ def reference_std_var(f): g = reference_reduction_numpy(f) @wraps(g) - def wrapper(x: np.ndarray, *args, **kwargs): + def wrapper(x: npt.NDArray, *args, **kwargs): assert not ('unbiased' in kwargs and 'correction' in kwargs) if 'unbiased' in kwargs: diff --git a/torch/testing/_internal/opinfo/utils.py b/torch/testing/_internal/opinfo/utils.py index 41973dc2c0518a..05468e10da2c90 100644 --- a/torch/testing/_internal/opinfo/utils.py +++ b/torch/testing/_internal/opinfo/utils.py @@ -6,6 +6,7 @@ from typing import Sequence import numpy as np +import numpy.typing as npt import torch from torch.testing._internal.common_cuda import TEST_CUDA @@ -206,7 +207,7 @@ def reference_reduction_numpy(f, supports_keepdims=True): """ @wraps(f) - def wrapper(x: np.ndarray, *args, **kwargs): + def wrapper(x: npt.NDArray, *args, **kwargs): # Copy keys into a set keys = set(kwargs.keys()) diff --git a/torch/utils/tensorboard/_utils.py b/torch/utils/tensorboard/_utils.py index 30984cfadf17fc..8acaf1696cb1f2 100644 --- a/torch/utils/tensorboard/_utils.py +++ b/torch/utils/tensorboard/_utils.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import numpy as np +import numpy.typing as npt # Functions for converting @@ -21,7 +22,7 @@ def figure_to_image(figures, close=True): def render_to_rgb(figure): canvas = plt_backend_agg.FigureCanvasAgg(figure) canvas.draw() - data: np.ndarray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) + data: npt.NDArray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) w, h = figure.canvas.get_width_height() image_hwc = data.reshape([h, w, 4])[:, :, 0:3] image_chw = np.moveaxis(image_hwc, source=2, destination=0) From f0c3fa32ab069c125ab102b6d027d56c5c35c23b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 19 Sep 2024 12:44:54 +0000 Subject: [PATCH 1006/1018] Revert "Reland D62220158 (#136213)" This reverts commit 083c9149b75cd918b6fb2795050d7173923a3629. Reverted https://github.com/pytorch/pytorch/pull/136213 on behalf of https://github.com/jeanschmidt due to Seems to have introduced regressions in rocm signals ([comment](https://github.com/pytorch/pytorch/pull/136213#issuecomment-2360885064)) --- test/inductor/test_pad_mm.py | 35 -------------------------- torch/_inductor/fx_passes/pad_mm.py | 23 ----------------- torch/_inductor/fx_passes/split_cat.py | 1 - 3 files changed, 59 deletions(-) diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index 0b34e174cc6688..dbe1170994fdbe 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -9,7 +9,6 @@ get_pad_cache, get_padded_length, should_pad_common, - should_pad_mm_bf16, ) from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code @@ -450,40 +449,6 @@ def mm(inps, b): repr(get_pad_cache().get_local_cache()) ) - @fresh_inductor_cache() - @inductor_config.patch( - post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}} - ) - def test_pad_mm_bf16(self): - m = 2 - n = 13 - k = 15691904 - mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16) - mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16) - expected_alignment = get_alignment_size(mat1) - - assert expected_alignment == 8, "Alignment for bfloat16 should be 8" - assert should_pad_common( - mat1, mat2 - ), "This should pass the common padding criteria" - if torch.cuda.get_device_capability() < (9, 0): - assert should_pad_mm_bf16( - mat1.dtype, m, n, k - ), "This should pass the should_pad_mm_bf16 padding criteria" - - @torch.compile() - def mm(mat1, mat2): - return torch.mm(mat1, mat2) - - res2, (code,) = run_and_get_code(mm, mat1, mat2) - mm_expected_result = torch.mm(mat1, mat2) - # in call code, expect to see a single pad per input, and then we should see padded allocation for output - FileCheck().check("del async_compile").check_count( - ".run(", 2, exactly=True - ).check("empty_strided_cuda((8, 16)").run(code) - - assert torch.allclose(res2, mm_expected_result), "MM results are not identical" - if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index a00bad39747918..87450b34e7ee27 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -364,23 +364,6 @@ def should_pad(key: str, ori_time, pad_time) -> bool: return should_pad -def should_pad_mm_bf16(dtype, M, N, K): - # always force pad for mm with bf16 when the following are satisfied to avoid perf regression - large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ - "pad_aten_mm_pass" - ].get("k_threshold_to_pad", 8388608) - if ( - dtype is torch.bfloat16 - and K > M - and K > N - and N % 2 == 1 - and K >= large_k_threshold_to_pad - and torch.cuda.get_device_capability() < (9, 0) - ): # doesnt repro on h100s: - return True - return False - - def should_pad_bench( match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None ) -> bool: @@ -427,12 +410,6 @@ def realize_symbols(ds): if torch._inductor.config.force_shape_pad: return True - if ( - "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options - and should_pad_mm_bf16(mat1.dtype, m, n, k) - ): - return True - if not has_triton(): return False diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 194d1d6dbaa798..f850ecf6008c94 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -65,7 +65,6 @@ "decompose_mm_pass", "unbind_stack_aten_pass", "shape_padding_multiplier", - "pad_aten_mm_pass", ] for pass_name in pre_grad_pass_names: From e6609f8d87dc1b4fa6dc27358d27f7b05d181587 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 17 Sep 2024 23:00:11 -0700 Subject: [PATCH 1007/1018] [FSDP2] Fixed 2D mismatched grad placements (#136237) ``` CUDA_VISIBLE_DEVICES=2,3,6,7 pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_train_parity_2d_transformer ``` Differential Revision: [D62964658](https://our.internmc.facebook.com/intern/diff/D62964658) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136237 Approved by: https://github.com/weifengpy --- .../test_2d_composability.py | 62 +++++++++++++++++++ .../_composable/fsdp/_fsdp_param.py | 9 +-- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index c8d2ac4d4dc257..83b0f8f2b5ac62 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -171,6 +171,68 @@ def _test_train_parity_2d_mlp( _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + @skipIfRocm + def test_train_parity_2d_transformer(self): + torch.manual_seed(42) + model_args = ModelArgs(n_layers=3, dropout_p=0.0) + model = Transformer(model_args) + ref_model = copy.deepcopy(model).cuda() + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + + dp_size, tp_size = self.world_size // 2, 2 + global_mesh = init_device_mesh( + "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True) + + for layer in model.layers: + fully_shard(layer, mesh=global_mesh["dp"]) + fully_shard(model, mesh=global_mesh["dp"]) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + + torch.manual_seed(42 + global_mesh.get_local_rank("dp")) + inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") + for iter_idx in range(5): + ref_loss = ref_model(inp).sum() + loss = model(inp).sum() + self.assertEqual(ref_loss, loss) + ref_loss.backward() + loss.backward() + for param in ref_model.parameters(): + if param.grad is not None: + dist.all_reduce( + param.grad, + group=global_mesh.get_group("dp"), + op=dist.ReduceOp.AVG, + ) + + # Specially check the TP placement for `pos_embeddings.weight` and + # its which since the grad naturally has replicate placement, + # requiring FSDP to redistribute it to shard placement before FSDP + # runs its reduce-scatter + self.assertIsInstance(model.pos_embeddings.weight.placements[1], Shard) + self.assertIsInstance(model.pos_embeddings.weight.grad.placements[1], Shard) + for ref_param, (param_name, param) in zip( + ref_model.parameters(), model.named_parameters() + ): + full_grad = param.grad.full_tensor() + ref_grad = ref_param.grad + self.assertEqual(ref_param.grad, full_grad) + + ref_optim.step() + optim.step() + ref_optim.zero_grad() + optim.zero_grad() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + @skip_if_lt_x_gpu(2) @skipIfRocm def test_tp_with_fsdp_offloading(self): diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index 59c2ff7589d3a3..fef1c865b6b08b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -695,10 +695,11 @@ def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: if isinstance(grad, AsyncCollectiveTensor): grad = grad.wait() assert isinstance(grad, DTensor), f"{type(grad)}" - if any(pl.is_partial() for pl in grad.placements): - placements = [ - Replicate() if pl.is_partial() else pl for pl in grad.placements - ] + placements = self._tp_spec.placements + if placements != grad.placements: + assert len(self._tp_spec.placements) == len( + grad.placements + ), f"{self._tp_spec=} {grad.placements=}" grad = grad.redistribute(placements=placements) grad = grad._local_tensor return grad From ee071e023a079f9463bc2d4742b4a78406ef0a0a Mon Sep 17 00:00:00 2001 From: James Wu Date: Thu, 19 Sep 2024 16:11:38 +0000 Subject: [PATCH 1008/1018] Log structured logging overhead to dynamo compile (kinda) (#136142) Summary: X-link: https://github.com/pytorch/benchmark/pull/2454 This adds structured logging overhead at a per compile basis to compilation metrics. To do so, we track the frame_id_frame_compile_id that trace_structured uses to categorize compiles, and use that as the key in our timing table. Implementation notes: - If there's times we call trace_structured without a compile id, the time won't be measured. Not really a good way around that today given the compile id framework of compilation metrics. Strobelight is still the best way to measure on a per job basis. - We don't actually measure the time it takes to log the compilation metrics itself. Fundamentally, it's not possible to log this properly if we're storing the logging number *in* compilation metrics, since there's no way to measure it before we do it(unless we want discrepancies between dynamo_compile and tlparse, which seems suboptimal). Hopefully for a large job, the cost of structured_logging compilation metrics itself is small. - I wanted to use frame_phase_timing here, but there's a bunch of ids to iron out, and I don't really want to deal with that headache. compilation_time_metrics is sort of what I want, but that isn't by frame/compile id, so it's also a bit off. Putting it into torch.logging as a separate thing so logging tracks its own overhead seems fine, though. Test Plan: Run benchmarks/nanogpt and staging logger. See that the new compilation metric is logged to the staged dynamo_compile table: https://fburl.com/scuba/logger_staging_jjwu_30582a48f1ff9cf5f4ac50a4c40af/xazjg5xq Note that the sum(structured_logging_overhead_s) / sum(entire_frame_compile_time) = 8.387 / 124.278 = 6%, which seems reasonable as the overhead for a small compilation like this. You can also look at samples for a more detailed log of this. Reviewed By: oulgen Differential Revision: D62643611 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136142 Approved by: https://github.com/bobrenjc93 --- torch/_dynamo/convert_frame.py | 5 ++++ torch/_dynamo/utils.py | 11 +++++++++ torch/_logging/__init__.py | 1 + torch/_logging/_internal.py | 45 ++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 171af02e564b36..1b71c42b9ac5a7 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1028,6 +1028,10 @@ def format_guard_failures() -> str: possibly_missed_reinplacing_opportunities = None remote_cache_time_saved = None + structured_logging_overhead_s = ( + torch._logging.get_structured_logging_overhead() + ) + metrics = CompilationMetrics( str(compile_id), frame_key, @@ -1057,6 +1061,7 @@ def format_guard_failures() -> str: guarded_code is not None, possibly_missed_reinplacing_opportunities, remote_cache_time_saved, + structured_logging_overhead_s, ) record_compilation_metrics(metrics) torch._dynamo.callback_handler.run_end_callbacks() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index e086f11e174cd7..7d34671b11a3c0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -353,6 +353,9 @@ def dynamo_timed( inductor_compile_time = None code_gen_time = None remote_cache_time_saved = None + structured_logging_overhead_s = ( + torch._logging.get_structured_logging_overhead() + ) metrics = BwdCompilationMetrics( compile_id, inductor_compile_time, @@ -360,6 +363,7 @@ def dynamo_timed( fail_type, fail_reason, remote_cache_time_saved, + structured_logging_overhead_s, ) record_compilation_metrics(metrics) @@ -799,6 +803,7 @@ class CompilationMetrics: has_guarded_code: bool possibly_missed_reinplacing_opportunities: Optional[int] remote_cache_time_saved_s: Optional[float] + structured_logging_overhead_s: Optional[float] @dataclasses.dataclass @@ -809,6 +814,7 @@ class BwdCompilationMetrics: fail_type: Optional[str] fail_reason: Optional[str] remote_cache_time_saved_s: Optional[float] + structured_logging_overhead_s: Optional[float] DEFAULT_COMPILATION_METRICS_LIMIT = 64 @@ -834,6 +840,11 @@ def record_compilation_metrics( k: list(v) if isinstance(v, set) else v for k, v in dataclasses.asdict(compilation_metrics).items() }, + # NB: Because compilation metrics *includes* the logging overhead time, + # we can't both *measure* the logging overhead of compilation metrics + # without making it inconsistent with compilation metrics itself, so + # we ignore the (hopefully small) time spent logging compilation metrics + record_logging_overhead=False, ) if config.log_compilation_metrics: log_compilation_event(compilation_metrics) diff --git a/torch/_logging/__init__.py b/torch/_logging/__init__.py index 0531869ae2fdf3..5acf175c275222 100644 --- a/torch/_logging/__init__.py +++ b/torch/_logging/__init__.py @@ -9,6 +9,7 @@ from ._internal import ( _init_logs, DEFAULT_LOGGING, + get_structured_logging_overhead, getArtifactLogger, LazyString, set_logs, diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index f11af08e3b768d..f78396545c1f75 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -10,6 +10,8 @@ import re import sys import tempfile +import time +from collections import defaultdict from dataclasses import dataclass, field from importlib import __import__ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -1091,6 +1093,42 @@ def __str__(self): return self.func(*self.args, **self.kwargs) +# Logs the time it takes to do structured logging by frame/compile id +# key is always {frame_id}_{frame_compile_id} +structured_logging_overhead: Dict[str, float] = defaultdict(float) + + +# Same principle as add_remote_cache_time_saved, but do it for structured logging +def add_structured_logging_overhead(time_spent: float) -> None: + global structured_logging_overhead + key = None + if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None: + frame_id = trace_id.compile_id.frame_id + frame_compile_id = trace_id.compile_id.frame_compile_id + # Why not trace_id.attempt, like structured logging? + # We aggregate across all attempts because + # a compilation metric is logged per successful attempt + key = f"{frame_id}_{frame_compile_id}" + # TODO: deal with structured logging that occurs outside of specific compile ids + # It's hard to figure out where we would log that if we want it in compilation metrics + # itself. + if key is not None: + key = str(key) + structured_logging_overhead[key] += time_spent + + +def get_structured_logging_overhead() -> Optional[float]: + key = None + if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None: + frame_id = trace_id.compile_id.frame_id + frame_compile_id = trace_id.compile_id.frame_compile_id + key = f"{frame_id}_{frame_compile_id}" + if key is not None: + return structured_logging_overhead.get(key) + else: + return None + + def trace_structured( name: str, # NB: metadata expected to be dict so adding more info is forward compatible @@ -1100,6 +1138,7 @@ def trace_structured( payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, suppress_context: bool = False, expect_trace_id: bool = True, # Whether or not we expect to have a current trace id + record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging ): """ metadata is an arbitrary JSON compatible struct, but it's expected to not be @@ -1118,6 +1157,7 @@ def trace_structured( # trace_log never propagates and is ALWAYS DEBUG, so also check that there # are handlers instead of checking the log level if trace_log.handlers: + start_time = time.time_ns() record: Dict[str, object] = {} record[name] = metadata_fn() if not suppress_context: @@ -1156,6 +1196,11 @@ def trace_structured( ) log_trace_structured_event(name, record) + if record_logging_overhead: + # Convert to seconds from nanoseconds, add it to the frame compile total + structured_logging_overhead_s = (time.time_ns() - start_time) / 1e9 + add_structured_logging_overhead(structured_logging_overhead_s) + import torch._guards import torch._utils_internal From 3d38f28edc5bbb52f9e6ded026ae757e639c1d37 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Thu, 19 Sep 2024 07:40:29 -0700 Subject: [PATCH 1009/1018] Type _sympy/functions.py [1/n] (#136205) Signed-off-by: Bob Ren I was chatting with @jamesjwu about strategies to learn the code and he suggested adding types to some files. This stack of PRs adds types to _sympy/functions.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/136205 Approved by: https://github.com/Skylion007, https://github.com/jamesjwu --- torch/utils/_sympy/functions.py | 99 +++++++++++++++++++++------------ 1 file changed, 63 insertions(+), 36 deletions(-) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index ed597402d8b637..493352798eab8c 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -3,6 +3,17 @@ import math import operator import sys +from typing import ( + Any, + Callable, + Iterable, + List, + Optional, + SupportsFloat, + Tuple, + TypeVar, + Union, +) import sympy from sympy import S @@ -19,6 +30,8 @@ from .numbers import int_oo +_T = TypeVar("_T", bound=SupportsFloat) + # Portions of this file are adapted from the Sympy codebase, which was # licensed as follows: # @@ -76,9 +89,9 @@ ] -def _keep_float(f): +def _keep_float(f: Callable[..., _T]) -> Callable[..., sympy.Float]: @functools.wraps(f) - def inner(*args): + def inner(*args: Any) -> Union[_T, sympy.Float]: r = f(*args) if any(isinstance(a, sympy.Float) for a in args) and not isinstance( r, sympy.Float @@ -89,13 +102,13 @@ def inner(*args): return inner -def fuzzy_eq(x, y): +def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]: if None in (x, y): return None return x == y -def simple_floordiv_gcd(p, q): +def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic: """ Fast path for sympy.gcd, using a simple factoring strategy. @@ -112,23 +125,27 @@ def simple_floordiv_gcd(p, q): might be necessary to fall back on sympy.gcd. """ - def integer_coefficient(x): - integer_coefficients = [ + def integer_coefficient(x: sympy.Basic) -> int: + integer_coefficients: List[int] = [ abs(int(arg)) for arg in sympy.Mul.make_args(x) if isinstance(arg, (int, sympy.Integer)) ] return math.prod(integer_coefficients) - def integer_factor(expr): - integer_factors = map(integer_coefficient, sympy.Add.make_args(expr)) + def integer_factor(expr: sympy.Basic) -> int: + integer_factors: Iterable[int] = map( + integer_coefficient, sympy.Add.make_args(expr) + ) return functools.reduce(math.gcd, integer_factors) - gcd = math.gcd(integer_factor(p), integer_factor(q)) + gcd: int = math.gcd(integer_factor(p), integer_factor(q)) p, q = p / gcd, q / gcd - base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p))) - divisor_split = sympy.Mul.make_args(q) + base_splits: List[Tuple[sympy.Basic, ...]] = list( + map(sympy.Mul.make_args, sympy.Add.make_args(p)) + ) + divisor_split: Tuple[sympy.Basic, ...] = sympy.Mul.make_args(q) for x in divisor_split: if all(x in base_split for base_split in base_splits): gcd = gcd * x @@ -162,20 +179,19 @@ class FloorDiv(sympy.Function): NB: This is Python-style floor division, round to -Inf """ - nargs = (2,) - precedence = 50 # precedence of mul # noqa: F811 - - is_integer = True + nargs: Tuple[int, ...] = (2,) + precedence: int = 50 # precedence of mul # noqa: F811 + is_integer: bool = True @property - def base(self): + def base(self) -> sympy.Basic: return self.args[0] @property - def divisor(self): + def divisor(self) -> sympy.Basic: return self.args[1] - def _sympystr(self, printer): + def _sympystr(self, printer: sympy.printing.printer.Printer) -> str: base = printer.parenthesize(self.base, self.precedence) divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" @@ -183,7 +199,7 @@ def _sympystr(self, printer): # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod - def eval(cls, base, divisor): + def eval(cls, base: sympy.Basic, divisor: sympy.Basic) -> Union[sympy.Basic, None]: # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full # Assert triggered by inequality solver # assert base.is_integer, base @@ -252,17 +268,21 @@ def eval(cls, base, divisor): except sympy.PolynomialError: pass # https://github.com/pytorch/pytorch/issues/108276 + return None + class ModularIndexing(sympy.Function): """ ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus """ - nargs = (3,) - is_integer = True + nargs: Tuple[int, ...] = (3,) + is_integer: bool = True @classmethod - def eval(cls, base, divisor, modulus): + def eval( + cls, base: sympy.Basic, divisor: sympy.Basic, modulus: sympy.Basic + ) -> Optional[sympy.Basic]: if base == 0 or modulus == 1: return sympy.Integer(0) @@ -286,8 +306,8 @@ def eval(cls, base, divisor, modulus): pass # https://github.com/pytorch/pytorch/issues/108276 if isinstance(base, sympy.Add): - new_terms = [] - all_positive = True + new_terms: List[sympy.Basic] = [] + all_positive: bool = True for term in base.args: if sympy.gcd(term, modulus * divisor) != modulus * divisor: if (isinstance(term, sympy.Integer) and term < 0) or ( @@ -310,11 +330,13 @@ def eval(cls, base, divisor, modulus): if isinstance(base, FloorDiv): return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) - def _eval_is_nonnegative(self): # type:ignore[override] + return None + + def _eval_is_nonnegative(self) -> Optional[bool]: p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] - def _eval_is_positive(self): # type:ignore[override] + def _eval_is_positive(self) -> Optional[bool]: p, q = self.args[:2] return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined] @@ -324,37 +346,40 @@ class Where(sympy.Function): Good ol' ternary operator """ - nargs = (3,) + nargs: Tuple[int, ...] = (3,) - def _eval_is_integer(self): + def _eval_is_integer(self) -> Optional[bool]: return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] - def _eval_is_nonnegative(self): # type:ignore[override] + def _eval_is_nonnegative(self) -> Optional[bool]: return ( True if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] else None ) - def _eval_is_positive(self): # type:ignore[override] + def _eval_is_positive(self) -> Optional[bool]: return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] @classmethod - def eval(cls, c, p, q): + def eval( + cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic + ) -> Optional[sympy.Basic]: if c == sympy.true: return p elif c == sympy.false: return q + return None # Python-style modulus: take sign from RHS class PythonMod(sympy.Function): - nargs = (2,) + nargs: Tuple[int, ...] = (2,) - is_integer = True + is_integer: bool = True @classmethod - def eval(cls, p, q): + def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]: # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint # Triggered by sympy.solvers.inequalities.reduce_inequalities # assert p.is_integer, p @@ -396,11 +421,13 @@ def eval(cls, p, q): if sympy.Mod(p, q) == 0: return S.Zero + return None + # NB: args[1] for PythonMod - def _eval_is_nonnegative(self): # type:ignore[override] + def _eval_is_nonnegative(self) -> Optional[bool]: return True if self.args[1].is_positive else None # type: ignore[attr-defined] - def _eval_is_nonpositive(self): + def _eval_is_nonpositive(self) -> Optional[bool]: return True if self.args[1].is_negative else None # type: ignore[attr-defined] From f37b89568df34d2beb0e84bf71308d852a104691 Mon Sep 17 00:00:00 2001 From: Jerry Mannil Date: Thu, 19 Sep 2024 18:02:39 +0000 Subject: [PATCH 1010/1018] [ROCm] Enable Flex attention tests on AMD gpus (#136245) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136245 Approved by: https://github.com/malfet --- test/inductor/test_flex_attention.py | 19 ++++++++----------- test/inductor/test_flex_decoding.py | 6 ++---- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 5556d3ee4edb60..1cb6354275969e 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -30,26 +30,24 @@ from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU -from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM +from torch.testing._internal.common_utils import TEST_WITH_ROCM from torch.utils._triton import has_triton # Skip tests if Triton is not available supported_platform = skipUnless( torch.cuda.is_available() - and torch.version.hip is None and has_triton() and torch.cuda.get_device_capability() >= (8, 0), "Requires CUDA and Triton", ) # Use this decorator only when hitting Triton bugs on H100 -running_on_a100_only = skipUnless( +running_on_a100_or_rocm_only = skipUnless( torch.cuda.is_available() - and torch.version.hip is None and has_triton() - and torch.cuda.get_device_capability() == (8, 0), - "Requires A100 and Triton", + and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip is not None), + "Requires (A100 or ROCm) and Triton", ) Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) @@ -314,8 +312,6 @@ def run_test( V_D: int = D, block_mask: Optional[BlockMask] = None, ): - if TEST_WITH_ROCM and Q_H != KV_H: - self.skipTest("enable_gqa=True is unsupported on ROCM, for now") q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True ) @@ -583,7 +579,7 @@ def run_automatic_dynamic_test( def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable): self.run_test(score_mod, dtype) - @running_on_a100_only + @running_on_a100_or_rocm_only @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( @@ -597,7 +593,7 @@ def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( ) self.run_test_with_call(attention, dtype, B, H, 64, D, B, H, 64, D) - @running_on_a100_only + @running_on_a100_or_rocm_only @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_custom_sparse_block_size( @@ -1364,7 +1360,6 @@ def mask_mod(b, h, q, kv): self.run_test_with_call(attention) - @skipIfRocm @supported_platform def test_GQA_causal_mask(self): def mask_mod(b, h, q, kv): @@ -1876,6 +1871,8 @@ def test_force_write_lse(self): @common_utils.parametrize("backend", ["flex_attention", "flex_decode", "eager"]) def test_lse_masked_output(self, backend): if backend == "flex_decode": + if TEST_WITH_ROCM: + self.skipTest("backend=flex_decode is unsupported on ROCM, for now") kernel_options = {"FORCE_USE_FLEX_ATTENTION": False} flex_call = torch.compile(flex_attention, fullgraph=True) elif backend == "flex_attention": diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 04ca1429bebf19..dde47f6a9a2673 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -21,7 +21,7 @@ from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 -from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM +from torch.testing._internal.common_utils import skipIfRocm from torch.utils._triton import has_triton @@ -278,8 +278,6 @@ def run_test( score_mod is not None or block_mask is not None ), "Must provide score_mod or block_mask" assert Q_H % KV_H == 0 - if TEST_WITH_ROCM and Q_H != KV_H: - self.skipTest("enable_gqa=True is unsupported on ROCM, for now") q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, @@ -570,6 +568,7 @@ def bias_mod(score, b, h, q, kv): self.run_test(bias_mod, dtype) + @skipIfRocm @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) def test_load_from_bias_head_seq_batch(self, dtype): @@ -825,7 +824,6 @@ def bias_mod(score, batch, head, token_q, token_kv): self.run_test(bias_mod) - @skipIfRocm @supported_platform def test_fully_masked_out_rows_0_check_gqa(self): # Ensure fully masked out rows won't cause NaNs. From 6252f21ad1303f9c1f58c80271976c0ae2c6a027 Mon Sep 17 00:00:00 2001 From: Shan19900305 Date: Thu, 19 Sep 2024 18:43:39 +0000 Subject: [PATCH 1011/1018] fix stride compare failed when size value equal to one in ForeachUtils.h (#134546) When size value equal to one, tensor strides value need be skipped to compare. @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/134546 Approved by: https://github.com/janeyx99 --- aten/src/ATen/native/ForeachUtils.h | 18 +++++++++++++++++- test/test_foreach.py | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index f5c0672402f361..a8fbe13b8da0b9 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -128,10 +128,26 @@ inline bool _check_tensors_share_device_and_dtype( // corresponding tensors in tensor lists have the same sizes and strides. inline bool _check_tensors_share_sizes_and_strides( ArrayRef tensorLists) { + auto is_diff_stride = [](const IntArrayRef& size, + const IntArrayRef& left_stride, + const IntArrayRef& right_stride) -> bool { + const size_t size_size = size.size(); + for (const auto dim : c10::irange(size_size)) { + if (size[dim] == 1) + continue; + if (left_stride[dim] != right_stride[dim]) { + return true; + } + } + return false; + }; for (const auto i : c10::irange(1, tensorLists.size())) { for (const auto j : c10::irange(tensorLists[0].size())) { if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() || - tensorLists[0][j].strides() != tensorLists[i][j].strides()) { + is_diff_stride( + tensorLists[0][j].sizes(), + tensorLists[0][j].strides(), + tensorLists[i][j].strides())) { return false; } } diff --git a/test/test_foreach.py b/test/test_foreach.py index d668cb30f24c8e..c91c88abcb4cbe 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -545,6 +545,24 @@ def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype): # Regression test for https://github.com/pytorch/pytorch/issues/113156 torch._foreach_mul_(tensors, 1) + @onlyCUDA + @dtypes(torch.float32) + def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype): + # default tensor stride is (9, 9, 3, 1). + tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype) + strided_tensor = torch.ones( + (2, 1, 3, 3), device=device, dtype=dtype + ).as_strided((2, 1, 3, 3), (9, 1, 3, 1)) + left_inputs = [tensor, strided_tensor] + right_inputs = [strided_tensor, tensor] + compare_result = tensor + strided_tensor + foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add) + out = foreach_add_check_( + (left_inputs, right_inputs), is_cuda=True, expect_fastpath=True + ) + for res in out: + self.assertEqual(res, compare_result) + @ops( filter(lambda op: op.supports_out, foreach_binary_op_db), dtypes=OpDTypes.supported, From 56873911d0fc432cf4a786ed6bcb214d0e9230ad Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Thu, 19 Sep 2024 18:51:57 +0000 Subject: [PATCH 1012/1018] [AOTI][Tooling][8/n] Add option to pinpoint kernel names in debug printer (#136182) Summary: Add a third mode where we only print kernel names without dumping any intermediate actual tensor value info. It can be helpful in quickly identifying the troublesome kernels in CUDA IMA issues. thanks ColinPeppler and henrylhtsang for this "feature request". Test Plan: The output can look like this if set the `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3`: {F1871629091} Differential Revision: D62791371 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136182 Approved by: https://github.com/henrylhtsang --- torch/_inductor/codegen/debug_utils.py | 31 ++++++++++++++++++++++++-- torch/_inductor/config.py | 1 + 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index 985678cadb285c..32d4d55e58545f 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -43,6 +43,9 @@ class IntermediateValueDebuggingLevel(Enum): SAVE_ONLY = "1" # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed. PRINT_ONLY = "2" + # LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed. + # This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc. + PRINT_KERNEL_NAMES_ONLY = "3" class DebugPrinterManager: @@ -104,6 +107,16 @@ def _perform_debug_print_or_save_helper( before_launch, arg_signatures=self.arg_signatures, ) + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + # Print all kernel names to the console only + self.codegen_intermediate_tensor_value_print( + [], + self.kernel_name, + before_launch, + ) @functools.lru_cache # noqa: B019 def _get_debug_filtered_kernel_names(self) -> List[str]: @@ -191,6 +204,22 @@ def codegen_intermediate_tensor_value_print( before_launch=True, arg_signatures: Optional[List[type]] = None, ) -> None: + launch_prefix = "before_launch" if before_launch else "after_launch" + + # if the debug printing level is PRINT_KERNEL_NAMES_ONLY + # we only print the kernel name to the console + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + if V.graph.cpp_wrapper: + if config.abi_compatible: + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix}: {kernel_name} ]");' + ) + V.graph.wrapper_code.writeline('printf("\\n");') + return + for i, arg in enumerate(args_to_print): if arg_signatures is not None and not isinstance( arg_signatures[i], torch_dtype @@ -206,8 +235,6 @@ def codegen_intermediate_tensor_value_print( ): continue - launch_prefix = "before_launch" if before_launch else "after_launch" - if V.graph.cpp_wrapper: if config.abi_compatible: V.graph.wrapper_code.writeline( f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 04595ba4f0c738..dd35755f966311 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -990,6 +990,7 @@ class aot_inductor: # 0: disable debug dumping # 1: enable saving intermediate tensor values # 2: enable printing intermediate tensor values + # 3: enable printing kernel names only (useful for pinpointing troublesome kernels) debug_intermediate_value_printer = os.environ.get( "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0" ) From dd3c89f721ff3f59bfde811dff877a610ebe9b7c Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 17 Sep 2024 11:54:08 -0700 Subject: [PATCH 1013/1018] add basic_modules_ListOfLinears_inductor_gpu_force_shape_pad (#136175) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136175 Approved by: https://github.com/ezyang --- .../benchmarks/basic_modules_benchmarks.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py index 65588cb3719dea..d3311b7458a97c 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py @@ -20,12 +20,15 @@ def forward(self, x): class Benchmark(BenchmarkBase): - def __init__(self, ModuleClass, backend, is_gpu=False, dynamic=False): + def __init__( + self, ModuleClass, backend, is_gpu=False, dynamic=False, force_shape_pad=False + ): self.ModuleClass = ModuleClass self.backend = backend self._name = ModuleClass.__name__ self._is_gpu = is_gpu self._dynamic = dynamic + self._force_shape_pad = force_shape_pad def name(self): prefix = f"basic_modules_{self._name}_{self.backend}" @@ -33,6 +36,8 @@ def name(self): prefix += "_dynamic" if self._is_gpu: prefix += "_gpu" + if self._force_shape_pad: + prefix += "_force_shape_pad" return prefix def _prepare_once(self): @@ -44,7 +49,9 @@ def _prepare(self): torch._dynamo.reset() def _work(self): - with fresh_inductor_cache(): + with fresh_inductor_cache(), torch._inductor.config.patch( + force_shape_pad=self._force_shape_pad + ): opt_m = torch.compile(backend=self.backend, dynamic=self._dynamic)( self.m.cuda() if self._is_gpu else self.m ) @@ -57,6 +64,8 @@ def main(): Benchmark(ListOfLinears, "eager"), Benchmark(ListOfLinears, "inductor"), Benchmark(ListOfLinears, "inductor", is_gpu=True), + Benchmark(ListOfLinears, "inductor", is_gpu=True), + Benchmark(ListOfLinears, "inductor", is_gpu=True, force_shape_pad=True), ] for b in benchmarks: b.enable_compile_time_instruction_count().collect_all().append_results( From 2d4a4782b745d8ecc8e92af6f944324c4581e82b Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 19 Sep 2024 19:24:40 +0000 Subject: [PATCH 1014/1018] [22/N] Fix clang-tidy warnings in jit (#134829) Follows #134537 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134829 Approved by: https://github.com/ezyang --- torch/csrc/jit/api/module.cpp | 1 - torch/csrc/jit/backends/backend_detail.cpp | 10 ++---- .../jit/backends/nnapi/nnapi_backend_lib.cpp | 8 ++--- torch/csrc/jit/codegen/fuser/compiler.cpp | 2 +- torch/csrc/jit/codegen/fuser/compiler.h | 2 +- .../jit/codegen/fuser/cpu/fused_kernel.cpp | 10 ++---- .../compatibility/model_compatibility.cpp | 5 --- .../jit/passes/create_functional_graphs.cpp | 1 - torch/csrc/jit/passes/lower_tuples.cpp | 1 - .../quantization/insert_quant_dequant.cpp | 2 -- .../quantization/quantization_patterns.h | 7 ---- .../quantization/register_packed_params.cpp | 1 - .../csrc/jit/passes/utils/subgraph_utils.cpp | 1 - torch/csrc/jit/runtime/static/impl.cpp | 2 -- .../jit/serialization/export_bytecode.cpp | 1 - .../serialization/flatbuffer_serializer.cpp | 1 - torch/csrc/jit/tensorexpr/codegen.h | 7 ++-- torch/csrc/jit/tensorexpr/eval.cpp | 2 +- torch/csrc/jit/tensorexpr/expr.cpp | 2 +- torch/csrc/jit/tensorexpr/expr.h | 2 +- torch/csrc/jit/tensorexpr/ir_simplifier.cpp | 2 -- torch/csrc/jit/tensorexpr/kernel.cpp | 2 -- .../jit/tensorexpr/loopnest_randomization.cpp | 4 +-- .../jit/tensorexpr/operators/reduction.cpp | 1 - torch/csrc/jit/tensorexpr/stmt.h | 36 ++++++++----------- .../jit/tensorexpr/unique_name_manager.cpp | 2 +- 26 files changed, 34 insertions(+), 81 deletions(-) diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 0dd9c76d5eb6ac..a44eccb601ba32 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -564,7 +564,6 @@ std::string Module::dump_to_str( std::stringstream parameters_ss; std::stringstream attributes_ss; std::stringstream methods_ss; - std::stringstream submodules_ss; for (const NameTensor& p : named_parameters(/*recurse=*/false)) { parameters_ss << p.name << " = "; diff --git a/torch/csrc/jit/backends/backend_detail.cpp b/torch/csrc/jit/backends/backend_detail.cpp index f6cb7b768edcec..de352f50ab503b 100644 --- a/torch/csrc/jit/backends/backend_detail.cpp +++ b/torch/csrc/jit/backends/backend_detail.cpp @@ -11,9 +11,7 @@ #include #include -namespace torch { -namespace jit { -namespace detail { +namespace torch::jit::detail { namespace { /* @@ -361,7 +359,7 @@ Module codegen_backend_module( wrapper_method_te.v("def_inputs", def_inputs); wrapper_method_te.v("fwd_inputs", fwd_inputs); - wrapper_methods.push_back(wrapper_method_ct.format(wrapper_method_te)); + wrapper_methods.emplace_back(wrapper_method_ct.format(wrapper_method_te)); // If the output type is a single element tuple then add an extra comma // to ensure the final output maintains this type. @@ -408,6 +406,4 @@ Module codegen_backend_module( return wrapper; } -} // namespace detail -} // namespace jit -} // namespace torch +} // namespace torch::jit::detail diff --git a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp index ba4a2b25c23a78..a5a331d15c21c7 100644 --- a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp +++ b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { // Implementation of Android NNAPI Backend delegate @@ -107,7 +106,7 @@ class NnapiBackend : public PyTorchBackendInterface { // Runs once per model initialization // Cannot be moved to compile(), because init() requires actual inputs - void init(c10::IValue handle, c10::List inputs) { + void init(const c10::IValue& handle, const c10::List& inputs) { TORCH_CHECK(comp_ == nullptr); auto dict = handle.toGenericDict(); @@ -134,5 +133,4 @@ constexpr auto backend_name = "nnapi"; static auto cls = torch::jit::backend(backend_name); } // namespace -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index c28a1b89f07699..3944e6f2df82f3 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -201,7 +201,7 @@ std::shared_ptr compileKernel( const KernelSpec& spec, const ArgSpec& arg_spec, const std::vector& map_size, - const at::Device device) { + const at::Device& device) { const std::vector& input_desc = arg_spec.descs(); auto graph = spec.graph()->copy(); diff --git a/torch/csrc/jit/codegen/fuser/compiler.h b/torch/csrc/jit/codegen/fuser/compiler.h index a998c6d459d370..d9c0005b9d2858 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.h +++ b/torch/csrc/jit/codegen/fuser/compiler.h @@ -25,7 +25,7 @@ TORCH_API std::shared_ptr compileKernel( const KernelSpec& spec, const ArgSpec& arg_spec, const std::vector& map_size, - const at::Device device); + const at::Device& device); TORCH_API size_t nCompiledKernels(); diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp index db9d57a679cb15..09624309d16cd1 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp @@ -11,10 +11,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { -namespace cpu { +namespace torch::jit::fuser::cpu { #ifdef _MSC_VER static const std::string getTempPath() { @@ -357,7 +354,4 @@ static std::shared_ptr createFusionKernel( } RegisterFusionBackend reg(DeviceType::CPU, createFusionKernel); -} // namespace cpu -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser::cpu diff --git a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp index b8b1ca6adc0dc4..60f059cd927847 100644 --- a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp @@ -284,10 +284,6 @@ std::unordered_set _get_mobile_model_contained_types( std::unordered_set _get_mobile_model_contained_types( const std::vector& bytecode_ivalues) { std::unordered_set contained_types; - // To avoid parsing same type twice, declare $parsed_type_names_records and - // use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as - // the hash to record which types are parsed. - std::unordered_set parsed_type_names_records; for (const auto i : c10::irange(1, bytecode_ivalues.size())) { const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements(); auto type_table_tuple = @@ -299,7 +295,6 @@ std::unordered_set _get_mobile_model_contained_types( // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" std::vector type_name_list; for (const auto& type_definition : type_table) { - std::unordered_set type_tokens; std::string type_name = type_definition.toStringRef(); type_name_list.emplace_back(type_name); } diff --git a/torch/csrc/jit/passes/create_functional_graphs.cpp b/torch/csrc/jit/passes/create_functional_graphs.cpp index 036f30f916693d..86e9fa13893f61 100644 --- a/torch/csrc/jit/passes/create_functional_graphs.cpp +++ b/torch/csrc/jit/passes/create_functional_graphs.cpp @@ -80,7 +80,6 @@ struct FunctionalGraphSlicer { graph_->createWithSubgraph(prim::FunctionalGraph) ->insertBefore(block->return_node()); auto reverse_iter = block->nodes().reverse(); - std::vector graph_outputs; for (auto it = reverse_iter.begin(); it != reverse_iter.end();) { Node* n = *it++; diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp index 0cca0d77f193e8..94610679e98e9d 100644 --- a/torch/csrc/jit/passes/lower_tuples.cpp +++ b/torch/csrc/jit/passes/lower_tuples.cpp @@ -40,7 +40,6 @@ static void flattenTupleInLoopParams(Node* n, size_t index) { Block* block = n->blocks().at(0); Node* block_node = n; - std::vector new_node_inputs = {}; auto new_construct_node = block->prependNode(block->owningGraph()->create(prim::TupleConstruct)); for (size_t j = 0; j < tt->elements().size(); ++j) { diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 3e0b265ae70c8f..05c19bdb38a1fd 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -513,7 +513,6 @@ void ReplicateChooseQParamsQuantDequant(std::shared_ptr& graph) { Node* pattern_choose_qparam = choose_qparam_val->node(); std::vector nodes_to_rewrite; - std::vector choose_qparam_nodes_to_rewrite; for (const Match& match : matches) { Node* matched_dequantize = match.nodes_map.at(pattern_dequant); Node* matched_quantize = match.nodes_map.at(pattern_quant); @@ -1557,7 +1556,6 @@ QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams( "getQSchemeAndParamMap expects the corresponding observer for ", v->debugName(), " exists."); - std::vector qparams_graph_values; QuantOpParams quant_op_params; TORCH_CHECK( diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h index 851548862dfc47..aeba208ea98e9c 100644 --- a/torch/csrc/jit/passes/quantization/quantization_patterns.h +++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h @@ -808,13 +808,6 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype) "%count_include_pad", "%divisor_override"}); - std::string common_general_value_op = R"( - %r_scale : float = aten::q_scale(%a_quant) - %r_zero_point : int = aten::q_zero_point(%a_quant) - %r_dtype : int = prim::dtype(%a_quant) - %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) - return (%r_quant) )"; - auto avg_pool3d = getInputTensorQParamOpFusionInfo( "aten::avg_pool3d", {"%kernel_size", diff --git a/torch/csrc/jit/passes/quantization/register_packed_params.cpp b/torch/csrc/jit/passes/quantization/register_packed_params.cpp index 1d7dcfe72eea56..c3696cdc5109cf 100644 --- a/torch/csrc/jit/passes/quantization/register_packed_params.cpp +++ b/torch/csrc/jit/passes/quantization/register_packed_params.cpp @@ -59,7 +59,6 @@ std::unordered_set RegisterPrePackParams( int64_t uid = 0; // int + method name gives unique identifier auto graph = m.get_method(method_name).graph(); std::stack blocks_to_visit; - std::unordered_set nodes_to_delete; blocks_to_visit.push(graph->block()); std::string attr_name_base = attr_prefix + "_" + method_name + "_ondevice_ptq_packed_weight_"; diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index f4dfc4ce99c940..0cc07a18c05eba 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -133,7 +133,6 @@ void mergeSubgraph(Node* mergeTo, Node* mergeFrom) { } ++it; - std::vector merged_nodes; while (it != end_it) { Node* node = *it; ++it; diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 19e2e47498f39a..15f22bee7dfc0c 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -399,8 +399,6 @@ ManagedTensorRanges::ManagedTensorRanges( const AliasDb& alias_db, const c10::FastSet& managed_tensor_values) { const std::vector nodes(block.nodes().begin(), block.nodes().end()); - const c10::FastSet graph_inputs( - block.inputs().begin(), block.inputs().end()); const auto num_nodes = static_cast(nodes.size()); for (const auto i : c10::irange(num_nodes)) { diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp index 07eca82df2df54..e5dbae392ccb47 100644 --- a/torch/csrc/jit/serialization/export_bytecode.cpp +++ b/torch/csrc/jit/serialization/export_bytecode.cpp @@ -149,7 +149,6 @@ mobile::Code compileGraphToMobileCode( // operator names std::vector method_names; - std::vector op_debug_handles; int next_new_op_index = 0; auto op_to_specified_args = code.op_to_num_specified_args(); diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index f201ce49bd7683..ee83ff78444f04 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -518,7 +518,6 @@ flatbuffers::Offset FlatbufferSerializer:: } else { size_t num_attr = class_ptr->numAttributes(); std::vector> names; - std::vector type_index; for (size_t i = 0; i < num_attr; ++i) { names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i))); } diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index b818051bf14a88..e1a42cb1d45935 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -244,13 +244,12 @@ class RegisterCodeGen { RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance(); codegen_list.AddStmtFactoryMethod( name, - [](StmtPtr stmt, + [](const StmtPtr& stmt, const std::vector& params, at::Device device, const std::string& kernel_func_name) { - std::unique_ptr method( - new CodeGenType(stmt, params, device, kernel_func_name)); - return method; + return std::make_unique( + stmt, params, device, kernel_func_name); }); } }; diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index 58876fb6a5e75e..a3d2274a1eccf0 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -1044,7 +1044,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { v->buffer_var()->name_hint()); } buffer_mapping_[b] = buffer->data(); - internal_buffers_.insert(std::make_pair(b, std::move(buffer))); + internal_buffers_.emplace(std::move(b), std::move(buffer)); } void visit(const PlacementAllocatePtr& v) override { diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp index a498408a37c3e0..35c6ac03ce8dda 100644 --- a/torch/csrc/jit/tensorexpr/expr.cpp +++ b/torch/csrc/jit/tensorexpr/expr.cpp @@ -552,7 +552,7 @@ bool Buf::is_stride_one(int cur_dim) const { return exprEquals(strides_[cur_dim], alloc(1)); } -ExprHandle expr_to_vec(ExprHandle v, int lanes) { +ExprHandle expr_to_vec(const ExprHandle& v, int lanes) { if (lanes == 1) { return v; } else { diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index fbef2b134057e3..c5c41cd0a045ce 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -488,6 +488,6 @@ TORCH_API ExprHandle Relu(const ExprHandle& v1); TORCH_API ExprHandle ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f); -TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes); +TORCH_API ExprHandle expr_to_vec(const ExprHandle& v, int lanes); } // namespace torch::jit::tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 76cb0b0daec0c4..cdd2c0e66bf7ab 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -2885,7 +2885,6 @@ ExprPtr SimplifierUnderContext::mutate(const DivPtr& v) { ExprPtr lhs = v->lhs(); ExprPtr rhs = v->rhs(); - std::ostringstream oss; if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) { GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret); return ret->accept_mutator(this); @@ -3005,7 +3004,6 @@ ExprPtr SimplifierUnderContext::mutate(const ModPtr& v) { ExprPtr lhs = v->lhs(); ExprPtr rhs = v->rhs(); - std::ostringstream oss; if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) { GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret); return ret->accept_mutator(this); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 4720da2dc59fc9..fa6aa3d70f766e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -984,7 +984,6 @@ TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice( // we use the debug names in printing cuda code, they need to be removed // of characters that can't be used in a variable identifier void TensorExprKernel::genInputDebugNames() { - std::unordered_map name_to_value; std::unordered_set name_set; std::unordered_map value_to_name; for (const torch::jit::Value* input : graph_->inputs()) { @@ -1747,7 +1746,6 @@ void TensorExprKernel::compile() { VarPtr v = t.buf()->base_handle(); scalars_[output] = VarHandle(v); block->append_stmt(t.stmt()); - std::vector dims; BufHandle buf( "scalar_" + sanitizeName(output->debugName()), {}, v->dtype()); StmtPtr store = Store::make(buf, {}, ExprHandle(v)); diff --git a/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp b/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp index cd16de19255803..343bb8e8f30df9 100644 --- a/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp @@ -480,7 +480,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) { } int index = rand() % (int)all_nested_loops.size(); - auto nested_loops = all_nested_loops.at(index); + auto const& nested_loops = all_nested_loops.at(index); if (nested_loops.size() < 2) { break; } @@ -554,7 +554,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) { // Randomly pick a set of consecutive loops to flatten int index = rand() % (int)all_nested_loops.size(); - auto nested_loops = all_nested_loops.at(index); + auto const& nested_loops = all_nested_loops.at(index); // Generate a good history message std::vector indices; diff --git a/torch/csrc/jit/tensorexpr/operators/reduction.cpp b/torch/csrc/jit/tensorexpr/operators/reduction.cpp index 93030a1a8e045c..3e189e7acab9ac 100644 --- a/torch/csrc/jit/tensorexpr/operators/reduction.cpp +++ b/torch/csrc/jit/tensorexpr/operators/reduction.cpp @@ -146,7 +146,6 @@ Tensor computeMax( } BufHandle ResultBuf("max", outputShape, dtype); BufHandle InputBuf = std::get(inputs[0]); - std::vector max_dims_expr; auto max_dim = std::get(inputs[1]); auto keep_dim = std::get(inputs[2]); return Tensor( diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index fc871b73bbb3c4..62bde5c4c33ea1 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -84,51 +84,47 @@ class TORCH_API Block : public StmtNode { return stmts_.empty(); } - void prepend_stmt(StmtPtr s) { + void prepend_stmt(const StmtPtr& s) { if (s->get_parent()) { - throw malformed_input( - "Block prepend Stmt with existing parent", std::move(s)); + throw malformed_input("Block prepend Stmt with existing parent", s); } stmts_.push_front(s); set_parent(s, this); } - void append_stmt(StmtPtr s) { + void append_stmt(const StmtPtr& s) { if (s->get_parent()) { - throw malformed_input( - "Block append Stmt with existing parent", std::move(s)); + throw malformed_input("Block append Stmt with existing parent", s); } stmts_.push_back(s); set_parent(s, this); } - void insert_stmt_before(StmtPtr s, const StmtPtr& before) { + void insert_stmt_before(const StmtPtr& s, const StmtPtr& before) { if (s->get_parent()) { - throw malformed_input( - "Block append Stmt with existing parent", std::move(s)); + throw malformed_input("Block append Stmt with existing parent", s); } auto pos = std::find(stmts_.begin(), stmts_.end(), before); if (pos == stmts_.end()) { throw malformed_input( - "Inserting after statement that is not in block", std::move(s)); + "Inserting after statement that is not in block", s); } stmts_.insert(pos, s); set_parent(s, this); } - void insert_stmt_after(StmtPtr s, const StmtPtr& after) { + void insert_stmt_after(const StmtPtr& s, const StmtPtr& after) { if (s->get_parent()) { - throw malformed_input( - "Block append Stmt with existing parent", std::move(s)); + throw malformed_input("Block append Stmt with existing parent", s); } auto pos = std::find(stmts_.begin(), stmts_.end(), after); if (pos == stmts_.end()) { throw malformed_input( - "Inserting after statement that is not in block", std::move(s)); + "Inserting after statement that is not in block", s); } ++pos; @@ -137,10 +133,10 @@ class TORCH_API Block : public StmtNode { set_parent(s, this); } - bool replace_stmt(const StmtPtr& old_stmt, StmtPtr new_stmt) { + bool replace_stmt(const StmtPtr& old_stmt, const StmtPtr& new_stmt) { if (new_stmt->get_parent()) { throw malformed_input( - "Block replace Stmt with existing parent", std::move(new_stmt)); + "Block replace Stmt with existing parent", new_stmt); } auto pos = std::find(stmts_.begin(), stmts_.end(), old_stmt); @@ -157,10 +153,10 @@ class TORCH_API Block : public StmtNode { // Creates a new block by cloning `this` block and replacing the given // statement with a new statement. Note that `old_stmt` refers to a statement // in `this` block. If the `old_stmt` is not found, it will return `nullptr`. - BlockPtr clone_and_replace(const StmtPtr& old_stmt, StmtPtr new_stmt) { + BlockPtr clone_and_replace(const StmtPtr& old_stmt, const StmtPtr& new_stmt) { if (new_stmt->get_parent()) { throw malformed_input( - "Block replace Stmt with existing parent", std::move(new_stmt)); + "Block replace Stmt with existing parent", new_stmt); } std::vector stmts(stmts_.begin(), stmts_.end()); @@ -260,9 +256,7 @@ class TORCH_API Block : public StmtNode { StmtPtr p1_p = std::move(p1); while (p1_p) { if (BlockPtr b = to(p1_p)) { - if (b) { - enclosing.insert(b); - } + enclosing.insert(b); } p1_p = p1_p->get_parent(); } diff --git a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp index f74c7eb37a9df6..8f344e001e31ff 100644 --- a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp +++ b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp @@ -31,7 +31,7 @@ const std::string& UniqueNameManager::get_unique_name(const VarPtr& v) { } if (all_unique_names_.count(unique_name) == 0) { all_unique_names_.insert(unique_name); - auto result = unique_name_mapping_.insert(std::make_pair(v, unique_name)); + auto result = unique_name_mapping_.emplace(v, unique_name); return result.first->second; } } From c574ab354726363ee1ccd8130fbee7f52dbaff0c Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 18 Sep 2024 11:57:01 -0700 Subject: [PATCH 1015/1018] DTensor: dont hash symint tensor input in propagate_tensor_meta (#136266) This fixes a subset of issues for dynamic shapes + DTensor. It's pretty easy to run into other issues - it's likely that we need https://github.com/pytorch/pytorch/pull/125941 to land for DTensor + dynamic shapes to work more generally. I ended up writing a test that had dynamic shape inputs but not dynamic shape outputs in order to properly test this fix Pull Request resolved: https://github.com/pytorch/pytorch/pull/136266 Approved by: https://github.com/ezyang, https://github.com/yf225 --- .../_tensor/test_dtensor_compile.py | 19 +++++++++++++++++++ torch/distributed/tensor/_redistribute.py | 19 ++++++++++++++++--- torch/distributed/tensor/_sharding_prop.py | 11 ++++++++--- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index eaef0716ba30dc..3f4ddfce7813f2 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -176,6 +176,25 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + def test_dtensor_dynamic(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # test passing in DTensor as inputs/outputs and run some tensor computation + def fn(x): + return ( + torch.mul(x, x) + .redistribute(device_mesh=x.device_mesh, placements=[Replicate()]) + .to_local()[0] + ) + + x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + torch._dynamo.mark_dynamic(x, 0) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(res, ref) + def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 6b4e37e54744f4..88414081a1785b 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -27,8 +27,7 @@ class _TransformInfo(NamedTuple): logical_shape: List[int] -@lru_cache(maxsize=None) -def _gen_transform_infos( +def _gen_transform_infos_non_cached( src_spec: DTensorSpec, dst_spec: DTensorSpec, ) -> List[_TransformInfo]: @@ -146,6 +145,14 @@ def _gen_transform_infos( return transform_infos +@lru_cache(maxsize=None) +def _gen_transform_infos( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, +) -> List[_TransformInfo]: + return _gen_transform_infos_non_cached(src_spec, dst_spec) + + def redistribute_local_tensor( local_tensor: torch.Tensor, current_spec: DTensorSpec, @@ -174,7 +181,13 @@ def redistribute_local_tensor( # which should be an empty tensor return local_tensor - transform_infos = _gen_transform_infos(current_spec, target_spec) + has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any( + isinstance(s, torch.SymInt) for s in target_spec.shape + ) + if has_symints: + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) for transform_info in transform_infos: i = transform_info.mesh_dim diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index ec538d35ed66d4..2b87d79a342b29 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -104,8 +104,7 @@ def register_op_strategy( if schema_info is not None: self.op_to_schema_info[op_overload] = schema_info - @lru_cache # noqa: B019 - def _propagate_tensor_meta( + def _propagate_tensor_meta_non_cached( self, op_schema: OpSchema ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: """ @@ -150,6 +149,12 @@ def _propagate_tensor_meta( # if fake is not a tensor or tuple of tensor, return as none return None + @lru_cache # noqa: B019 + def _propagate_tensor_meta( + self, op_schema: OpSchema + ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + return self._propagate_tensor_meta_non_cached(op_schema) + def _wrap_output_spec_tensor_meta( self, op: OpOverload, @@ -211,7 +216,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin if op_schema.op is aten._local_scalar_dense.default: return OutputSharding(None, op_schema) - out_tensor_meta = self._propagate_tensor_meta(op_schema) + out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) def spec_to_strategy(spec: object) -> object: if isinstance(spec, DTensorSpec): From 0e9e7d34e7bf5f3357028aeb8d3ff61bc336cd55 Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Thu, 19 Sep 2024 23:45:57 +0000 Subject: [PATCH 1016/1018] [CI][CUSPARSELT] Extend cusparselt installation script to support cuda 12.6 (#136321) To prepare for future cuda updates. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136321 Approved by: https://github.com/Skylion007, https://github.com/eqy --- .ci/docker/common/install_cusparselt.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/common/install_cusparselt.sh b/.ci/docker/common/install_cusparselt.sh index b77684bdf9a05c..c4b3f3e02a7838 100644 --- a/.ci/docker/common/install_cusparselt.sh +++ b/.ci/docker/common/install_cusparselt.sh @@ -5,7 +5,7 @@ set -ex # cuSPARSELt license: https://docs.nvidia.com/cuda/cusparselt/license.html mkdir tmp_cusparselt && cd tmp_cusparselt -if [[ ${CUDA_VERSION:0:4} =~ ^12\.[2-4]$ ]]; then +if [[ ${CUDA_VERSION:0:4} =~ ^12\.[2-6]$ ]]; then arch_path='sbsa' export TARGETARCH=${TARGETARCH:-$(uname -m)} if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then From fa5087b51d973523ff1f4685fc335949de8405cd Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 20 Sep 2024 00:53:21 +0000 Subject: [PATCH 1017/1018] TCPStoreLibUvBackend: trace operations (#136320) Summary: This logs all operations when tracing log level is enabled for the `TCPStoreLibUvBackend`. This is very useful for debugging collective operations when issues occur as it logs all hosts and the keys that they're modifying. To minimize total data we only log the keys and not the values This changes the C10D_* macros to be much more efficient -- previously we would always format the log string even if they would never be printed which is very wasteful for detailed tracing. This now gates them with an if statement to achieve the same behavior with no overhead Test Plan: ``` TORCH_DISTRIBUTED_DEBUG=DETAIL torchrun --nnodes 1 --nproc_per_node 1 --no-python /bin/bash -c "echo foo" ``` ``` I0919 09:26:52.352013 34271 TCPStore.cpp:285] [c10d - debug] The server has started on port = 29500. I0919 09:26:52.352246 34271 socket.cpp:783] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (127.0.0.1, 29500). I0919 09:26:52.352241 36903 TCPStoreLibUvBackend.cpp:1173] [c10d - debug] Uv main loop running I0919 09:26:52.352308 34271 socket.cpp:854] [c10d - trace] The client socket is attempting to connect to [localhost]:29500. I0919 09:26:52.353633 34271 socket.cpp:945] [c10d] The client socket has connected to [localhost]:29500 on SocketImpl(fd=41, addr=[localhost]:45646, remote=[localhost]:29500). I0919 09:26:52.354422 34271 TCPStore.cpp:321] [c10d - debug] TCP client connected to host 127.0.0.1:29500 I0919 09:26:52.354558 36903 TCPStoreLibUvBackend.cpp:774] [c10d - trace] validate magic:1015412686 address:[localhost]:45646 I0919 09:26:52.354638 36903 TCPStoreLibUvBackend.cpp:789] [c10d - trace] ping nonce:34271 address:[localhost]:45646 I0919 09:26:52.356122 36903 TCPStoreLibUvBackend.cpp:866] [c10d - trace] add key:init/ val:1 address:[localhost]:45646 I0919 09:26:52.356308 36903 TCPStoreLibUvBackend.cpp:930] [c10d - trace] wait key_count:1 address:[localhost]:45646 I0919 09:26:52.356410 36903 TCPStoreLibUvBackend.cpp:846] [c10d - trace] get key:init/ address:[localhost]:45646 I0919 09:26:52.358688 36903 TCPStoreLibUvBackend.cpp:808] [c10d - trace] set key:/none/torchelastic/role_info/0 address:[localhost]:45646 I0919 09:26:52.360177 36903 TCPStoreLibUvBackend.cpp:930] [c10d - trace] wait key_count:1 address:[localhost]:45646 I0919 09:26:52.360296 36903 TCPStoreLibUvBackend.cpp:1004] [c10d - trace] multi_get key_count:1 address:[localhost]:45646 I0919 09:26:52.362076 36903 TCPStoreLibUvBackend.cpp:1036] [c10d - trace] multi_set key_count:1 address:[localhost]:45646 I0919 09:26:52.364001 36903 TCPStoreLibUvBackend.cpp:930] [c10d - trace] wait key_count:1 address:[localhost]:45646 I0919 09:26:52.364091 36903 TCPStoreLibUvBackend.cpp:846] [c10d - trace] get key:/none/torchelastic/assigned_ranks/0 address:[localhost]:45646 ``` Differential Revision: D62924454 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136320 Approved by: https://github.com/c-p-i-o, https://github.com/XilunWu --- .../distributed/c10d/TCPStoreLibUvBackend.cpp | 53 +++++++++++- torch/csrc/distributed/c10d/logging.h | 37 ++++----- torch/csrc/distributed/c10d/socket.cpp | 81 ++++++++++--------- torch/csrc/distributed/c10d/socket.h | 2 - torch/csrc/distributed/c10d/socket_fmt.h | 32 ++++++++ 5 files changed, 141 insertions(+), 64 deletions(-) create mode 100644 torch/csrc/distributed/c10d/socket_fmt.h diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 588827a3a422de..c3fa09ab38bef1 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #ifdef TORCH_USE_LIBUV #include @@ -89,6 +90,7 @@ class UvHandle : public c10::intrusive_ptr_target { class UvTcpSocket : public UvHandle { uv_tcp_t client{}; + std::string address_{"unknown"}; c10::intrusive_ptr iptr() { return c10::intrusive_ptr::reclaim_copy(this); @@ -143,7 +145,24 @@ class UvTcpSocket : public UvHandle { } } + const std::string& address() const { + return address_; + } + void startRead() { + struct ::sockaddr_storage addr {}; + int addrLen{sizeof(struct ::sockaddr_storage)}; + + if (int err = uv_tcp_getpeername( + &client, reinterpret_cast(&addr), &addrLen)) { + C10D_WARNING( + "The remote address of the client socket cannot be retrieved. err={}", + uv_strerror(err)); + } else { + address_ = + formatSockAddr(reinterpret_cast(&addr), addrLen); + } + int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback); if (res) { C10D_WARNING( @@ -246,8 +265,8 @@ class UvTcpServer : public UvTcpSocket { ", message: ", uv_strerror(uv_res)); - uv_res = - uv_tcp_bind(res->unsafeGetSocket(), (const struct sockaddr*)&addr, 0); + uv_res = uv_tcp_bind( + res->unsafeGetSocket(), (const struct ::sockaddr*)&addr, 0); TORCH_CHECK( uv_res == 0, "The server socket has failed to bind. ", @@ -325,7 +344,7 @@ class UvTcpServer : public UvTcpSocket { if (uv_tcp_getsockname( (uv_tcp_t*)unsafeGetStream(), - reinterpret_cast(&addr_s), + reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) != 0) { throw std::runtime_error( "The port number of the socket cannot be retrieved."); @@ -753,6 +772,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_value(validateNumber)) return false; + C10D_TRACE("validate magic:{} address:{}", validateNumber, this->address()); + if (validateNumber != c10d::detail::validationMagicNumber) return false; return true; @@ -764,6 +785,8 @@ class UvClient : public UvTcpSocket { return false; } + C10D_TRACE("ping nonce:{} address:{}", nonce, this->address()); + StreamWriter sw(iptr()); sw.write_value(nonce); sw.send(); @@ -779,6 +802,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_payload(newData)) return false; + C10D_TRACE("set key:{} address:{}", key, this->address()); + store->set(key, newData); return true; } @@ -796,6 +821,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_payload(newValue)) return false; + C10D_TRACE("compareAndSet key:{} address:{}", key, this->address()); + auto res = store->compareAndSet(key, currentValue, newValue); StreamWriter sw(iptr()); sw.write_vector(res); @@ -809,6 +836,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_key(key)) return false; + C10D_TRACE("get key:{} address:{}", key, this->address()); + const auto& data = store->get(key); StreamWriter sw(iptr()); sw.write_vector(data); @@ -825,6 +854,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_value(addVal)) return false; + C10D_TRACE("add key:{} val:{} address:{}", key, addVal, this->address()); + addVal = store->add(key, addVal); StreamWriter sw(iptr()); sw.write_value(addVal); @@ -851,6 +882,8 @@ class UvClient : public UvTcpSocket { return false; } + C10D_TRACE("check key_count:{} address:{}", key_count, this->address()); + // Now we have received all the keys StreamWriter sw(iptr()); if (store->checkKeys(keys)) { @@ -881,6 +914,8 @@ class UvClient : public UvTcpSocket { return false; } + C10D_TRACE("wait key_count:{} address:{}", key_count, this->address()); + if (store->waitKeys(keys, iptr())) { StreamWriter sw(iptr()); sw.write1((uint8_t)WaitResponseType::STOP_WAITING); @@ -891,6 +926,8 @@ class UvClient : public UvTcpSocket { } bool parse_getnumkeys_command() { + C10D_TRACE("getnumkeys address:{}", this->address()); + StreamWriter sw(iptr()); sw.write_value(store->size()); sw.send(); @@ -903,6 +940,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_key(key)) return false; + C10D_TRACE("delete key:{} address:{}", key, this->address()); + auto numDeleted = store->deleteKey(key); StreamWriter sw(iptr()); sw.write_value(numDeleted); @@ -922,6 +961,8 @@ class UvClient : public UvTcpSocket { return false; } + C10D_TRACE("append key:{} address:{}", key, this->address()); + store->append(key, data); return true; } @@ -939,6 +980,8 @@ class UvClient : public UvTcpSocket { ", max: ", MAX_KEY_COUNT); + C10D_TRACE("multi_get key_count:{} address:{}", key_count, this->address()); + StreamWriter sw(iptr()); for (const auto _ : c10::irange(key_count)) { (void)_; // Suppress unused variable warning @@ -967,6 +1010,8 @@ class UvClient : public UvTcpSocket { ", max: ", MAX_KEY_COUNT); + C10D_TRACE("multi_set key_count:{} address:{}", key_count, this->address()); + for (const auto _ : c10::irange(key_count)) { (void)_; // Suppress unused variable warning @@ -987,6 +1032,8 @@ class UvClient : public UvTcpSocket { bool parse_cancel_wait_command() { store->clearClientWaitState(iptr()); + C10D_TRACE("cancel_wait key_count:{} address:{}", this->address()); + StreamWriter sw(iptr()); sw.write1((uint8_t)WaitResponseType::WAIT_CANCELED); sw.send(); diff --git a/torch/csrc/distributed/c10d/logging.h b/torch/csrc/distributed/c10d/logging.h index b4df02e773807e..a7cc82f702eea2 100644 --- a/torch/csrc/distributed/c10d/logging.h +++ b/torch/csrc/distributed/c10d/logging.h @@ -27,25 +27,22 @@ std::string formatLogMessage(fmt::string_view fmt, T&&... args) { } // namespace detail } // namespace c10d -#define C10D_ERROR(...) \ - LOG_IF( \ - ERROR, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Error)) \ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_ERROR(...) \ + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Error)) \ + LOG(ERROR) << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) #define C10D_WARNING(...) \ - LOG_IF( \ - WARNING, \ - c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Warning)) \ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) - -#define C10D_INFO(...) \ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Info)) \ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) - -#define C10D_DEBUG(...) \ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug)) \ - << "[c10d - debug] " << c10d::detail::formatLogMessage(__VA_ARGS__) - -#define C10D_TRACE(...) \ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Trace)) \ - << "[c10d - trace] " << c10d::detail::formatLogMessage(__VA_ARGS__) + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Warning)) \ + LOG(WARNING) << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) + +#define C10D_INFO(...) \ + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Info)) \ + LOG(INFO) << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) + +#define C10D_DEBUG(...) \ + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug)) \ + LOG(INFO) << "[c10d - debug] " << c10d::detail::formatLogMessage(__VA_ARGS__) + +#define C10D_TRACE(...) \ + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Trace)) \ + LOG(INFO) << "[c10d - trace] " << c10d::detail::formatLogMessage(__VA_ARGS__) diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index f155f825284236..db4519d7b2ad38 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -37,6 +37,7 @@ C10_DIAGNOSTIC_POP() #include #include #include +#include #include @@ -193,6 +194,43 @@ class SocketImpl { Handle hnd_; const std::optional remote_; }; + +std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { + char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT + + if (int err = ::getnameinfo( + addr, len, host, NI_MAXHOST, port, NI_MAXSERV, NI_NUMERICSERV)) { + C10D_WARNING( + "The hostname of the client socket cannot be retrieved. err={}", err); + + // if we can't resolve the hostname, display the IP address + if (addr->sa_family == AF_INET) { + struct sockaddr_in* psai = (struct sockaddr_in*)&addr; + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != + NULL) { + return fmt::format("{}:{}", ip, psai->sin_port); + } + } else if (addr->sa_family == AF_INET6) { + struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr; + char ip[INET6_ADDRSTRLEN]; + if (inet_ntop( + addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != + NULL) { + return fmt::format("[{}]:{}", ip, psai->sin6_port); + } + } + + C10_THROW_ERROR( + DistNetworkError, + fmt::format( + "failed to format addr, unknown family={}", addr->sa_family)); + } + if (addr->sa_family == AF_INET) { + return fmt::format("{}:{}", host, port); + } + return fmt::format("[{}]:{}", host, port); +} } // namespace c10d::detail // @@ -208,45 +246,10 @@ struct formatter<::addrinfo> { template decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) const { - char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT - - int r = ::getnameinfo( - addr.ai_addr, - addr.ai_addrlen, - host, - NI_MAXHOST, - port, - NI_MAXSERV, - NI_NUMERICSERV); - if (r != 0) { - // if we can't resolve the hostname, display the IP address - if (addr.ai_family == AF_INET) { - struct sockaddr_in* psai = (struct sockaddr_in*)addr.ai_addr; - char ip[INET_ADDRSTRLEN]; - if (inet_ntop(addr.ai_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != - NULL) { - return fmt::format_to(ctx.out(), "{}:{}", ip, psai->sin_port); - } - } else if (addr.ai_family == AF_INET6) { - struct sockaddr_in6* psai = (struct sockaddr_in6*)addr.ai_addr; - char ip[INET6_ADDRSTRLEN]; - if (inet_ntop( - addr.ai_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != - NULL) { - return fmt::format_to(ctx.out(), "[{}]:{}", ip, psai->sin6_port); - } - } - C10_THROW_ERROR( - DistNetworkError, - fmt::format( - "failed to format addr, unknown family={}", addr.ai_family)); - } - - if (addr.ai_addr->sa_family == AF_INET) { - return fmt::format_to(ctx.out(), "{}:{}", host, port); - } else { - return fmt::format_to(ctx.out(), "[{}]:{}", host, port); - } + return fmt::format_to( + ctx.out(), + "{}", + c10d::detail::formatSockAddr(addr.ai_addr, addr.ai_addrlen)); } }; diff --git a/torch/csrc/distributed/c10d/socket.h b/torch/csrc/distributed/c10d/socket.h index 1e42a53b0b530d..de9bd6989c290e 100644 --- a/torch/csrc/distributed/c10d/socket.h +++ b/torch/csrc/distributed/c10d/socket.h @@ -103,7 +103,5 @@ class Socket { std::unique_ptr impl_; }; - } // namespace detail - } // namespace c10d diff --git a/torch/csrc/distributed/c10d/socket_fmt.h b/torch/csrc/distributed/c10d/socket_fmt.h new file mode 100644 index 00000000000000..8c7832ebf933cd --- /dev/null +++ b/torch/csrc/distributed/c10d/socket_fmt.h @@ -0,0 +1,32 @@ +// (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +/* +This file should not be included from other .h files and only used in cpp files +as it exposes the underlying platform specific socket headers. +*/ + +#include + +#ifdef _WIN32 +#include + +#include +#include +#else +#include +#endif + +namespace c10d { +namespace detail { + +// Returns a human-readable representation of the given socket address. +std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len); + +} // namespace detail +} // namespace c10d From c9d6bb8554f08ff608d4f6f5a4392d30de110034 Mon Sep 17 00:00:00 2001 From: Felix Su Date: Fri, 20 Sep 2024 00:54:00 +0000 Subject: [PATCH 1018/1018] passing FileTimerRequests.to_json() to log_debug_info_for_expired_timers for a better debugging experience (#135913) Summary: The change involves passing the expired timers to the log_debug_info_for_expired_timers function after to_json() has been applied . This change is made to provide a better debugging experience for the user. Test Plan: unit tests Reviewed By: gag1jain Differential Revision: D62408767 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135913 Approved by: https://github.com/gag1jain --- torch/distributed/elastic/timer/file_based_local_timer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 1c6d0d9a33064a..f762614a8e6c59 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -269,7 +269,7 @@ def _run_watchdog(self, fd: io.TextIOWrapper) -> None: log_debug_info_for_expired_timers( self._run_id, { - pid: self._get_scopes(expired_timers) + pid: [expired_timer.to_json() for expired_timer in expired_timers] for pid, expired_timers in all_expired_timers.items() }, )